From b3c8849188cc6d00f7cf67f22458e443578f4733 Mon Sep 17 00:00:00 2001 From: flashmouse Date: Tue, 16 Jun 2026 03:53:02 +0800 Subject: [PATCH 01/23] [Fix] nn.attention support dynamic batch_size (#19779) This PR try to fix #19696 , ``nn.attention`` support dynamic batch_size Co-authored-by: flashmouse --- python/tvm/relax/transform/legalize_ops/nn.py | 13 +++---- .../relax/test_transform_legalize_ops_nn.py | 35 +++++++++++++++++++ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 51d23de0f761..35d81f968b37 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -714,10 +714,11 @@ def _te_attention( q = topi.transpose(q, [0, 2, 1, 3]) k = topi.transpose(k, [0, 2, 1, 3]) v = topi.transpose(v, [0, 2, 1, 3]) - q = topi.reshape(q, [batch_size * num_head, seq_len, head_dim]) - k = topi.reshape(k, [batch_size * num_head, seq_len_kv, head_dim]) - v = topi.reshape(v, [batch_size * num_head, seq_len_kv, head_dim_v]) - p = topi.nn.batch_matmul(q, k) + bs = batch_size * num_head + q = topi.reshape(q, [bs, seq_len, head_dim]) + k = topi.reshape(k, [bs, seq_len_kv, head_dim]) + v = topi.reshape(v, [bs, seq_len_kv, head_dim_v]) + p = topi.nn.batch_matmul(q, k, oshape=[bs, seq_len, seq_len_kv]) if scale is not None: p = topi.multiply(p, scale) else: @@ -725,7 +726,7 @@ def _te_attention( if bias is not None: p = topi.reshape(p, [batch_size, num_head, seq_len, seq_len_kv]) p = topi.add(p, bias) - p = topi.reshape(p, [batch_size * num_head, seq_len, seq_len_kv]) + p = topi.reshape(p, [bs, seq_len, seq_len_kv]) if causal_mask is None: s = topi.nn.softmax(p) else: @@ -741,7 +742,7 @@ def _te_attention( ) p_masked_sum = topi.sum(p_masked_exp, axis=-1, keepdims=True) s = topi.divide(p_masked_exp, p_masked_sum) - o = topi.nn.batch_matmul(s, v, transpose_b=False) + o = topi.nn.batch_matmul(s, v, transpose_b=False, oshape=[bs, seq_len, head_dim_v]) o = topi.reshape(o, [batch_size, num_head, seq_len, head_dim_v]) return topi.transpose(o, [0, 2, 1, 3]) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 4a708b5da1f4..8136997cf66c 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -3727,6 +3727,41 @@ def main( LegalizeOps()(Attention) +def test_dynamic_batch_attention(): + """The batch dimension may be dynamic (symbolic). + + fix https://github.com/apache/tvm/issues/19696 + """ + + @tvm.script.ir_module + class Attention: + @R.function + def main( + q: R.Tensor(("batch_size", 16, 32, 8), "float32"), + k: R.Tensor(("batch_size", 8, 32, 8), "float32"), + v: R.Tensor(("batch_size", 8, 32, 16), "float32"), + ): + gv = R.nn.attention(q, k, v) + return gv + + LegalizeOps()(Attention) + + @tvm.script.ir_module + class AttentionBias: + @R.function + def main( + q: R.Tensor(("batch_size", 16, 32, 8), "float32"), + k: R.Tensor(("batch_size", 8, 32, 8), "float32"), + v: R.Tensor(("batch_size", 8, 32, 16), "float32"), + bias: R.Tensor(("batch_size", 32, 16, 8), "float32"), + ): + scale = T.FloatImm("float32", 0.1) + gv = R.nn.attention(q, k, v, bias, scale=scale, causal_mask="BottomRight") + return gv + + LegalizeOps()(AttentionBias) + + def test_nll_loss(): # fmt: off @tvm.script.ir_module From 1c1afe3e7bd0f49f5953213bc0b8054fde921509 Mon Sep 17 00:00:00 2001 From: Neo Chien <6762509+cchung100m@users.noreply.github.com> Date: Tue, 16 Jun 2026 04:19:50 +0800 Subject: [PATCH 02/23] [Relax][ONNX] Make ReduceMax/ReduceMin NaN propagation order-independent(numpy semantics) (#19755) Hi Committers, This PR addresses the `ReduceMax`/ `ReduceMin` part of issue https://github.com/apache/tvm/issues/19572. Any suggestions would be appreciated if you are available. ### Root cause: The ONNX frontend ReduceMax / ReduceMin converters return relax.op.max / relax.op.min. After legalization these map to topi.max / topi.min, which fold with a commutative reducer whose combiner is Max(x, y) / Min(x, y). In codegen, Max(a, b) lowers to select(a > b, a, b) using an **ordered** float comparison (fcmp ogt), which is false for NaN. As a left-fold (acc = Max(acc, elem)), NaN propagation becomes **position-dependent** - a later non-NaN element silently overwrites an earlier NaN. ### Solution: Adopt the well-defined, **order-independent numpy/IEEE convention** (matching numpy.max/min and torch.amax/amin): the reduction yields NaN whenever **any** reduced element is NaN. Minimal, ONNX-frontend-only change: - Add a shared helper _reduce_min_max_preserve_nan(reduce_op, data, axes, keepdims). - For floating-pint inputs, detect NaN along the reduced axes via `sum(astype(isnan(data), dtype), axes, keepdims) > 0` and force those outputs to `NaN` with `where(has_nan, nan, reduce(data))`. The mask reduces over the **same axes/keepdims**, so it aligns in shape with the reduced result. - Keep non-floating(integer) inputs unchanged. - Route all reduce paths(`_impl_v11`and both reduce branches of `_impl_v18`) through the helper; the `noop_with_empty_axes` passthrough is left untouched since it performs no reduction. ### Note on scope (re: #19589 ): The underlying NaN behavior of Max/Min is the same family of ops discussed in #19589. Per review comments there, enforcing NaN semantics at the IR / LLVM-IR level is undesirable(backward-compat with older LLVM, and portability to CUDA/OpenCL/Vulkan), and a dedicated portable nanmin/nanmax TIRx intrinsic(like `nearbyint`) would be the preferred long-term mechanism. This PR deliberately: - does not touch the IR-level Max/Min lowering, and - does not rely on the bool reduction of the NaN mask - it uses `sum(isnan) > 0`, fully sidestepping Max/Min NaN behavior. --------- Co-authored-by: cchung100m --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 34 +++++++++++++--- tests/python/relax/test_frontend_onnx.py | 40 +++++++++++++++++++ 2 files changed, 68 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index d64020bfc772..2d1cc47377c4 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3887,6 +3887,28 @@ def _impl_v23(cls, bb, inputs, attr, params): return output +def _reduce_min_max_preserve_nan(reduce_op, data, axes, keepdims): + """Apply a min/max reduction with well-defined, order-independent NaN propagation. + + relax.op.max/min legalize to a max/min fold implemented as select(x > y, x, y) with an + ordered float comparison, so NaN propagation depends on the fold position (a later non-NaN + element silently overwrites an earlier NaN). ONNX Runtime is also order-independent (it only + yields NaN when the first reduced element is NaN), which is an implementation artifact rather + than a defined semantics and is impractical to replicate portably. We instead adopt the + numpy/IEEE convention used by numpy.max/min and torch.amax/amin: for floating pint inputs, + detect NaN along the reduced axes and force the output to NaN whenever any reduced element is + NaN. + """ + y = reduce_op(data, axes, keepdims) + dtype = data.struct_info.dtype if isinstance(data.struct_info, relax.TensorStructInfo) else None + if dtype is None or not _relax_dtype_is_floating_point(dtype): + return y + nan_count = relax.op.sum(relax.op.astype(relax.op.isnan(data), dtype), axes, keepdims) + has_nan = relax.op.greater(nan_count, relax.const(0, dtype)) + nan_filled = relax.op.full_like(y, relax.const(float("nan"), dtype)) + return relax.op.where(has_nan, nan_filled, y) + + class ReduceMax(OnnxOpConverter): """Converts an onnx ReduceMax node into an equivalent Relax expression.""" @@ -3895,7 +3917,7 @@ def _impl_v11(cls, bb, inputs, attr, params): data = inputs[0] axes = attr.get("axes", None) keepdims = attr.get("keepdims", 1) - return relax.op.max(data, axes, keepdims) + return _reduce_min_max_preserve_nan(relax.op.max, data, axes, keepdims) @classmethod def _impl_v18(cls, bb, inputs, attr, params): @@ -3912,13 +3934,13 @@ def _impl_v18(cls, bb, inputs, attr, params): # If axes is empty and noop_with_empty_axes is False, reduce all dims if not axes and not noop_with_empty_axes: - return relax.op.max(data, None, keepdims) + return _reduce_min_max_preserve_nan(relax.op.max, data, None, keepdims) # If axes is empty and noop_with_empty_axes is True, return input unchanged elif not axes and noop_with_empty_axes: return data # Otherwise reduce over specified axes else: - return relax.op.max(data, axes, keepdims) + return _reduce_min_max_preserve_nan(relax.op.max, data, axes, keepdims) class ReduceMin(OnnxOpConverter): @@ -3929,7 +3951,7 @@ def _impl_v11(cls, bb, inputs, attr, params): data = inputs[0] axes = attr.get("axes", None) keepdims = attr.get("keepdims", 1) - return relax.op.min(data, axes, keepdims) + return _reduce_min_max_preserve_nan(relax.op.min, data, axes, keepdims) @classmethod def _impl_v18(cls, bb, inputs, attr, params): @@ -3946,13 +3968,13 @@ def _impl_v18(cls, bb, inputs, attr, params): # If axes is empty and noop_with_empty_axes is False, reduce all dims if not axes and not noop_with_empty_axes: - return relax.op.min(data, None, keepdims) + return _reduce_min_max_preserve_nan(relax.op.min, data, None, keepdims) # If axes is empty and noop_with_empty_axes is True, return input unchanged elif not axes and noop_with_empty_axes: return data # Otherwise reduce over specified axes else: - return relax.op.min(data, axes, keepdims) + return _reduce_min_max_preserve_nan(relax.op.min, data, axes, keepdims) class ReduceSum(OnnxOpConverter): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index a83333e7d7ba..db8b977efcbb 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -804,6 +804,46 @@ def test_sign_nan_preserve(): ) +@pytest.mark.parametrize("op_name", ["ReduceMax", "ReduceMin"]) +@pytest.mark.parametrize( + "x", + [ + # NaN in different positions. TVM's max/min fold previously dropped NaN depending on + # position, ONNX Runtime only propagates NaN when it is the first reduced element, which + # is an order-dependent implementation artifact. We instead adopt the well-defined, + # order-independent numpy/IEEE semantics: any NaN in the reduced range yields NaN. + np.array([np.nan, 1.0, 2.0], dtype=np.float32), + np.array([2.0, 1.0, np.nan], dtype=np.float32), + np.array([1.0, np.nan, 2.0], dtype=np.float32), + np.array([1.0, 2.0, 3.0], dtype=np.float32), + ], +) +def test_reduce_min_max_nan_preserve(op_name, x): + reduce_node = helper.make_node(op_name, ["x"], ["y"], keepdims=0) + graph = helper.make_graph( + [reduce_node], + "reduce_nan_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x.shape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [])], + ) + model = helper.make_model(graph, producer_name="reduce_nan_test") + model.ir_version = 8 + for opset_import in model.opset_import: + if opset_import.domain in ["", "ai.onnx"]: + opset_import.version = 18 + break + + # Reference is numpy (NaN propagates if any element is NaN), not ONNX Runtime. + ref_out = (np.max if op_name == "ReduceMax" else np.min)(x) + + tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18) + out_np = (tvm_out[0] if isinstance(tvm_out, (list, tuple)) else tvm_out).numpy() + + np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ref_out)) + if not np.isnan(ref_out): + np.testing.assert_allclose(out_np, ref_out, rtol=1e-7, atol=1e-5) + + @pytest.mark.parametrize("op_name", ["Softmax", "LogSoftmax", "Hardmax"]) def test_softmax_family_opset11_default_axis_semantics(op_name: str): verify_unary(op_name, [2, 3, 4], opset=11) From 668d119894fc9bd1418351ad06133ca082ef3b12 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Mon, 15 Jun 2026 16:52:00 -0400 Subject: [PATCH 03/23] [Docs][CI] Bump tlcpack-sphinx-addon to restore search result summaries (#19782) --- docker/install/ubuntu_install_sphinx.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/install/ubuntu_install_sphinx.sh b/docker/install/ubuntu_install_sphinx.sh index f6f2d4093182..e40aff3e37d5 100755 --- a/docker/install/ubuntu_install_sphinx.sh +++ b/docker/install/ubuntu_install_sphinx.sh @@ -30,4 +30,4 @@ uv pip install \ sphinx_autodoc_annotation~=1.0 \ sphinx-gallery==0.20.0 \ sphinx_rtd_theme==3.1.0 \ - git+https://github.com/tlc-pack/tlcpack-sphinx-addon.git@99c6947b05b1ae26ff5003277fb7cff57ce78353 + git+https://github.com/tlc-pack/tlcpack-sphinx-addon.git@dded1a3fbaf549485d8f7bb3f79ecb0484a11629 From 1086fc9394aa41d73db99795c569681df5091d3b Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 15 Jun 2026 16:57:57 -0400 Subject: [PATCH 04/23] [REFACTOR][IR] Cleanup IR naming utilities (#19781) IR module cleanup benefits from using a single unique-name primitive directly at module call sites. This PR renames NameSupply to UniqueNameSupply and removes redundant wrappers around global variable naming. Main changes: - Rename the public name supply API and header to UniqueNameSupply - Replace GlobalVarSupply with direct iterator-seeded UniqueNameSupply usage - Remove obsolete access-path repr registration now covered by tvm-ffi --- include/tvm/ir/global_var_supply.h | 128 ------------- include/tvm/ir/name_supply.h | 169 ------------------ include/tvm/ir/unique_name_supply.h | 143 +++++++++++++++ include/tvm/relax/binding_rewrite.h | 4 +- include/tvm/relax/block_builder.h | 8 +- python/tvm/ir/supply.py | 97 ++-------- .../tvm/relax/frontend/onnx/onnx_frontend.py | 4 +- python/tvm/runtime/__init__.py | 2 +- python/tvm/runtime/_ffi_node_api.py | 5 +- src/ir/access_path_repr.cc | 49 ----- src/ir/global_var_supply.cc | 111 ------------ src/ir/module.cc | 12 +- src/ir/name_supply.cc | 108 ----------- src/ir/unique_name_supply.cc | 114 ++++++++++++ src/relax/backend/contrib/cutlass/codegen.cc | 6 +- src/relax/ir/binding_rewrite.cc | 4 +- src/relax/ir/block_builder.cc | 6 +- src/relax/ir/dataflow_expr_rewriter.cc | 2 +- src/relax/transform/allocate_workspace.cc | 4 +- src/relax/transform/normalize.cc | 2 +- src/target/source/codegen_c.cc | 2 +- src/target/source/codegen_c.h | 4 +- src/target/source/codegen_source_base.cc | 2 +- src/target/source/codegen_source_base.h | 6 +- src/te/operation/create_primfunc.cc | 6 +- src/tirx/ir/index_map.cc | 4 +- src/tirx/transform/bind_target.cc | 7 +- src/tirx/transform/split_host_device.cc | 9 +- ...e_supply.py => test_unique_name_supply.py} | 27 ++- 29 files changed, 347 insertions(+), 698 deletions(-) delete mode 100644 include/tvm/ir/global_var_supply.h delete mode 100644 include/tvm/ir/name_supply.h create mode 100644 include/tvm/ir/unique_name_supply.h delete mode 100644 src/ir/access_path_repr.cc delete mode 100644 src/ir/global_var_supply.cc delete mode 100644 src/ir/name_supply.cc create mode 100644 src/ir/unique_name_supply.cc rename tests/python/ir/{test_name_supply.py => test_unique_name_supply.py} (62%) diff --git a/include/tvm/ir/global_var_supply.h b/include/tvm/ir/global_var_supply.h deleted file mode 100644 index 2241385167e2..000000000000 --- a/include/tvm/ir/global_var_supply.h +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ir/global_var_supply.h - * \brief GlobalVarSupply that can be used to generate unique \class GlobalVar. - */ -#ifndef TVM_IR_GLOBAL_VAR_SUPPLY_H_ -#define TVM_IR_GLOBAL_VAR_SUPPLY_H_ - -#include -#include -#include -#include - -#include -#include - -namespace tvm { - -/*! - * \brief GlobalVarSupply can be used to generate unique GlobalVars. - */ -class GlobalVarSupplyNode : public ffi::Object { - public: - /*! - * \brief Empty constructor. Will use an empty NameSupply. - */ - GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply()) {} - - /*! - * \brief Constructor. - * \param name_supply The NameSupply to use for generating the names of fresh GlobalVars. - * \param name_to_var_map An optional map. - */ - explicit GlobalVarSupplyNode(NameSupply name_supply, - std::unordered_map name_to_var_map = {}); - - /*! - * \brief Generates a unique GlobalVar from this supply. - * \param name The name from which the name of the GlobalVar is derived. - * \param add_prefix If set to true, then the prefix of the contained NameSupply will be prepended - * to the name. \return A unique GlobalVar. - */ - GlobalVar FreshGlobal(ffi::String name, bool add_prefix = true); - - /*! - * \brief Looks up for a GlobalVar with the given name in this supply. - * If no entry is found, creates one, places it in the cache and returns it. - * \param name The name of the GlobalVar to search for. - * \param add_prefix If set to true, the prefix of the contained NameSupply will be prepended to - * the name before performing the search. \return A cached GlobalVar. - */ - GlobalVar UniqueGlobalFor(const ffi::String& name, bool add_prefix = true); - - /*! - * \brief Reserves an existing GlobalVar with this supply. - * \param var The GlobalVar to be registered. - * \param allow_conflict Allow conflict with other GlobalVars that have the same name. - */ - void ReserveGlobalVar(const GlobalVar& var, bool allow_conflict = false); - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); - } - - /*! \brief The NameSupply used to generate unique name hints to GlobalVars. */ - NameSupply name_supply_; - - static constexpr const bool _type_mutable = true; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.GlobalVarSupply", GlobalVarSupplyNode, ffi::Object); - - private: - std::unordered_map name_to_var_map_; -}; - -/*! - * \brief Managed reference class to GlobalVarSupplyNode. - * \sa GlobalVarSupplyNode - */ -class GlobalVarSupply : public ffi::ObjectRef { - public: - /*! - * \brief Constructor. - * \param name_supply The NameSupply to be used when generating new GlobalVars. - * \param name_to_var_map An optional map. - */ - TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply = NameSupply(), - std::unordered_map name_to_var_map = {}); - - /*! - * \brief Constructs a supply from an array of IRModules. GlobalVars generated by this supply are - * guaranteed not to conflict with any GlobalVars that belong to the modules. \param modules Array - * of IRModules. - */ - TVM_DLL explicit GlobalVarSupply(const ffi::Array& modules); - - /*! - * \brief Constructs a GlobalVarSupply from an IRModule. GlobalVars generated by this supply are - * guaranteed not to conflict with GlobalVars that belong to the modules. \param module The - * IRModule. - */ - TVM_DLL explicit GlobalVarSupply(const IRModule module); - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(GlobalVarSupply, ffi::ObjectRef, - GlobalVarSupplyNode); -}; - -} // namespace tvm - -#endif // TVM_IR_GLOBAL_VAR_SUPPLY_H_ diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h deleted file mode 100644 index 54bac2afc3b5..000000000000 --- a/include/tvm/ir/name_supply.h +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/ir/name_supply.h - * \brief NameSupply that can be used to generate unique variable names. - */ -#ifndef TVM_IR_NAME_SUPPLY_H_ -#define TVM_IR_NAME_SUPPLY_H_ - -#include -#include - -#include -#include -#include -#include -#include - -namespace tvm { - -/*! - * \brief NameSupply can be used to generate unique names. - */ -class NameSupplyNode : public ffi::Object { - public: - /*! - * \brief Empty constructor. Needed by the TVM_REGISTER_NODE_TYPE macro. - */ - NameSupplyNode() = default; - - /*! - * \brief Constructor. - * \param prefix The prefix to be used with this NameSupply. - * \param name_map The map used to guarantee uniqueness. - */ - NameSupplyNode(const ffi::String& prefix, std::unordered_map name_map) - : prefix_(prefix), name_map(std::move(name_map)) {} - - /*! - * \brief Generates a unique name from this NameSupply. - * \param name The name from which the generated name is derived. - * \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the - * name. - * \param add_underscore If set to true, add '_' between prefix and a digit. - * \return A unique name. - */ - ffi::String FreshName(const ffi::String& name, bool add_prefix = true, - bool add_underscore = true); - - /*! - * \brief Reserves an existing name with this NameSupply. - * \param name The name to be reserved. - * \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the - * name before reserving it. \return The name that was reserved with the NameSupply. It can be - * different if a prefix is added. - */ - ffi::String ReserveName(const ffi::String& name, bool add_prefix = true); - - /*! - * \brief Checks if this NameSupply already generated a name. - * \param name The name to check. - * \param add_prefix If set to true, then the prefix of this NameSupply will be prepended to the - * name before checking for it. \return True if the name has already been generated. False - * otherwise. - */ - bool ContainsName(const ffi::String& name, bool add_prefix = true); - - // Prefix for all GlobalVar names. It can be empty. - std::string prefix_; - - static constexpr const bool _type_mutable = true; - - static void RegisterReflection() { - namespace refl = tvm::ffi::reflection; - refl::ObjectDef(); - } - - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.NameSupply", NameSupplyNode, ffi::Object); - - private: - /*! \brief Helper function to add the NameSupply prefix to the name. */ - ffi::String add_prefix_to_name(const ffi::String& name); - - /*! - * \brief Function that will generate a unique name. - * \param name The name to be used as a base. - * \param add_underscore If set to true, add '_' between prefix and a digit. - * \return A unique name. - */ - std::string GetUniqueName(std::string name, bool add_underscore = true); - - /*! \brief A map that is used to generate unique names. */ - std::unordered_map name_map; -}; - -/*! - * \brief Managed reference class to NameSupplyNode. - * \sa NameSupplyNode - */ -class NameSupply : public ffi::ObjectRef { - public: - /*! - * \brief Constructor. - * \param prefix The prefix to be used with this NameSupply. - * \param name_map An optional map. - */ - TVM_DLL explicit NameSupply(const ffi::String& prefix = "", - std::unordered_map name_map = {}); - - /*! - * \brief Construct NameSupply with a name map created from the given iterator range and - * the functor. - * - * The functor should return the name of the dereferenced object. - */ - template - TVM_DLL explicit NameSupply(Iter begin, Iter end, Lambda f) - : NameSupply("", GetNameMap(begin, end, f)) {} - - TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(NameSupply, ffi::ObjectRef, NameSupplyNode); - - private: - template - static std::unordered_map GetNameMap(Iter begin, Iter end, Lambda f) { - // static_assert is more reader-friendly than SFINAE when template specialization is not needed. - static_assert(std::is_convertible::value, - "Lambda f must has a signature of [?](*it) -> string {}"); - std::unordered_map name_map; - for (auto it = begin; it != end; ++it) { - const std::string& name = f(*it); - const size_t idx_last_first_num = std::distance( - std::find_if(name.rbegin(), name.rend(), [](char c) { return !std::isdigit(c); }), - name.rend()); - // name = {O = others}{D = consecutive digits} - // let O -> prefix; - std::string prefix = name.substr(0, idx_last_first_num); - TVM_FFI_ICHECK(prefix.size() > 0 && std::isalpha(prefix[0])) - << "Invalid variable name: " << name; - if (0 == name_map.count(prefix)) name_map[prefix] = 0; - if (idx_last_first_num < name.size()) { // has some digits. - // let D's nearest natural number -> idx; - // note: stoul("000123") = 123; - name_map[prefix] = std::max(name_map[prefix], std::stoi(name.substr(idx_last_first_num))); - } - } - return name_map; - } -}; - -} // namespace tvm - -#endif // TVM_IR_NAME_SUPPLY_H_ diff --git a/include/tvm/ir/unique_name_supply.h b/include/tvm/ir/unique_name_supply.h new file mode 100644 index 000000000000..0f79318d1dc8 --- /dev/null +++ b/include/tvm/ir/unique_name_supply.h @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/unique_name_supply.h + * \brief UniqueNameSupply that can be used to generate unique variable names. + */ +#ifndef TVM_IR_UNIQUE_NAME_SUPPLY_H_ +#define TVM_IR_UNIQUE_NAME_SUPPLY_H_ + +#include +#include + +#include +#include +#include + +namespace tvm { + +/*! + * \brief UniqueNameSupply can be used to generate unique names. + */ +class UniqueNameSupplyNode : public ffi::Object { + public: + /*! + * \brief Empty constructor. Needed by the TVM_REGISTER_NODE_TYPE macro. + */ + UniqueNameSupplyNode() = default; + + /*! + * \brief Constructor. + * \param prefix The prefix to be used with this UniqueNameSupply. + * \param name_map The map used to guarantee uniqueness. + */ + UniqueNameSupplyNode(const ffi::String& prefix, ffi::Map name_map) + : prefix_(prefix), name_map(std::move(name_map)) {} + + /*! + * \brief Generates a unique name from this UniqueNameSupply. + * \param name The name from which the generated name is derived. + * \param add_prefix If set to true, then the prefix of this UniqueNameSupply will be prepended to + * the name. + * \param add_underscore If set to true, add '_' between prefix and a digit. + * \return A unique name. + */ + ffi::String FreshName(const ffi::String& name, bool add_prefix = true, + bool add_underscore = true); + + /*! + * \brief Reserves an existing name with this UniqueNameSupply. + * \param name The name to be reserved. + * \param add_prefix If set to true, then the prefix of this UniqueNameSupply will be prepended to + * the name before reserving it. \return The name that was reserved with the UniqueNameSupply. It + * can be different if a prefix is added. + */ + ffi::String ReserveName(const ffi::String& name, bool add_prefix = true); + + /*! + * \brief Checks if this UniqueNameSupply already generated a name. + * \param name The name to check. + * \param add_prefix If set to true, then the prefix of this UniqueNameSupply will be prepended to + * the name before checking for it. \return True if the name has already been generated. False + * otherwise. + */ + bool ContainsName(const ffi::String& name, bool add_prefix = true); + + // Prefix for all GlobalVar names. It can be empty. + std::string prefix_; + + static constexpr const bool _type_mutable = true; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("ir.UniqueNameSupply", UniqueNameSupplyNode, ffi::Object); + + private: + /*! \brief Helper function to add the UniqueNameSupply prefix to the name. */ + ffi::String AddPrefixToName(const ffi::String& name); + + /*! + * \brief Function that will generate a unique name. + * \param name The name to be used as a base. + * \param add_underscore If set to true, add '_' between prefix and a digit. + * \return A unique name. + */ + std::string GetUniqueName(std::string name, bool add_underscore = true); + + /*! \brief A map that is used to generate unique names. */ + ffi::Map name_map; +}; + +/*! + * \brief Managed reference class to UniqueNameSupplyNode. + * \sa UniqueNameSupplyNode + */ +class UniqueNameSupply : public ffi::ObjectRef { + public: + /*! + * \brief Constructor. + * \param prefix The prefix to be used with this UniqueNameSupply. + * \param name_map An optional map. + */ + TVM_DLL explicit UniqueNameSupply(const ffi::String& prefix = "", + ffi::Map name_map = {}); + + /*! + * \brief Construct UniqueNameSupply by reserving names from the given iterator range. + * + * The functor should return the name of the dereferenced object. + */ + template + TVM_DLL UniqueNameSupply(Iter begin, Iter end, Lambda f) : UniqueNameSupply("") { + for (auto it = begin; it != end; ++it) { + this->operator->()->ReserveName(f(*it), false); + } + } + + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(UniqueNameSupply, ffi::ObjectRef, + UniqueNameSupplyNode); +}; + +} // namespace tvm + +#endif // TVM_IR_UNIQUE_NAME_SUPPLY_H_ diff --git a/include/tvm/relax/binding_rewrite.h b/include/tvm/relax/binding_rewrite.h index 69092726b474..740e8ed01fda 100644 --- a/include/tvm/relax/binding_rewrite.h +++ b/include/tvm/relax/binding_rewrite.h @@ -25,7 +25,7 @@ #ifndef TVM_RELAX_BINDING_REWRITE_H_ #include -#include +#include #include #include @@ -87,7 +87,7 @@ class DataflowBlockRewriteNode : public ffi::Object { ffi::Array fn_outputs_; //!< Variables required by function outputs. private: - NameSupply name_supply_; //!< Name supply for tracking and generating unique names. + UniqueNameSupply name_supply_; //!< Unique name supply for tracking and generating unique names. }; /*! diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index 68d6fc7bfa2c..8413686dc9df 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -25,7 +25,7 @@ #define TVM_RELAX_BLOCK_BUILDER_H_ #include -#include +#include #include #include #include @@ -68,11 +68,11 @@ class BlockBuilderNode : public ffi::Object { // Global Context management //------------------------------- /*! - * \brief Get the name supply for generating unique names. + * \brief Get the unique name supply for generating unique names. * - * \return The name supply. + * \return The unique name supply. */ - virtual NameSupply name_supply() = 0; + virtual UniqueNameSupply name_supply() = 0; /*! * \brief Get the context IRModule in this builder. diff --git a/python/tvm/ir/supply.py b/python/tvm/ir/supply.py index 183e20f25789..07b91e1a86e9 100644 --- a/python/tvm/ir/supply.py +++ b/python/tvm/ir/supply.py @@ -14,18 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Suppliers that are used to guarantee uniqueness of names and GlobalVars.""" +"""Suppliers that are used to guarantee uniqueness of names.""" import tvm_ffi -from tvm import IRModule, Object +from tvm import Object from . import _ffi_api -@tvm_ffi.register_object("ir.NameSupply") -class NameSupply(Object): - """NameSupply that can be used to generate unique names. +@tvm_ffi.register_object("ir.UniqueNameSupply") +class UniqueNameSupply(Object): + """UniqueNameSupply that can be used to generate unique names. Parameters ---------- @@ -33,10 +33,10 @@ class NameSupply(Object): """ def __init__(self, prefix=""): - self.__init_handle_by_constructor__(_ffi_api.NameSupply, prefix) + self.__init_handle_by_constructor__(_ffi_api.UniqueNameSupply, prefix) def fresh_name(self, name, add_prefix=True, add_underscore=True): - """Generates a unique name from this NameSupply. + """Generates a unique name from this UniqueNameSupply. Parameters ---------- @@ -44,15 +44,15 @@ def fresh_name(self, name, add_prefix=True, add_underscore=True): The name from which the generated name is derived. add_prefix: bool - If set to true, then the prefix of this NameSupply will be prepended to the name. + If set to true, then the prefix of this UniqueNameSupply will be prepended to the name. add_underscore: bool If set to True, adds '_' between prefix and digit. """ - return _ffi_api.NameSupply_FreshName(self, name, add_prefix, add_underscore) + return _ffi_api.UniqueNameSupply_FreshName(self, name, add_prefix, add_underscore) def reserve_name(self, name, add_prefix=True): - """Reserves an existing name with this NameSupply. + """Reserves an existing name with this UniqueNameSupply. Parameters ---------- @@ -60,13 +60,13 @@ def reserve_name(self, name, add_prefix=True): The name to be reserved. add_prefix: bool - If set to true, then the prefix of this NameSupply will be prepended to the name + If set to true, then the prefix of this UniqueNameSupply will be prepended to the name before reserving it. """ - return _ffi_api.NameSupply_ReserveName(self, name, add_prefix) + return _ffi_api.UniqueNameSupply_ReserveName(self, name, add_prefix) def contains_name(self, name, add_prefix=True): - """Checks if this NameSupply already generated a name. + """Checks if this UniqueNameSupply already generated a name. Parameters ---------- @@ -74,74 +74,7 @@ def contains_name(self, name, add_prefix=True): The name to check. add_prefix: bool - If set to true, then the prefix of this NameSupply will be prepended to the name + If set to true, then the prefix of this UniqueNameSupply will be prepended to the name before checking for it. """ - return _ffi_api.NameSupply_ContainsName(self, name, add_prefix) - - -@tvm_ffi.register_object("ir.GlobalVarSupply") -class GlobalVarSupply(Object): - """GlobalVarSupply that holds a mapping between names and GlobalVars. - - GlobalVarSupply can be used to generate new GlobalVars with a unique name. - It also can be used to retrieve previously generated GlobalVars based on a name. - - Parameters - ---------- - value: Union[List[IRModule], IRModule, NameSupply] - The IRModules used to build this GlobalVarSupply or a NameSupply. - """ - - def __init__(self, value=None): - if value is None: - name_supply = NameSupply("") - self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_NameSupply, name_supply) - elif isinstance(value, NameSupply): - self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_NameSupply, value) - elif isinstance(value, list | tvm_ffi.Array): - self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModules, value) - elif isinstance(value, IRModule): - self.__init_handle_by_constructor__(_ffi_api.GlobalVarSupply_IRModule, value) - - def fresh_global(self, name, add_prefix=True): - """Generates a unique GlobalVar from this supply. - - Parameters - ---------- - name: String - The name from which the name of the GlobalVar is derived. - - add_prefix: bool - If set to true, then the prefix of the contained NameSupply will be prepended - to the name. - """ - return _ffi_api.GlobalVarSupply_FreshGlobal(self, name, add_prefix) - - def unique_global_for(self, name, add_prefix=True): - """Looks up for a GlobalVar with the given name in this supply. If no entry is found - , creates one, places it in the cache and returns it. - - Parameters - ---------- - name: String - The name of the GlobalVar to search for. - - add_prefix: bool - If set to true, the prefix of the contained NameSupply will be prepended to the - name before performing the search. - """ - return _ffi_api.GlobalVarSupply_UniqueGlobalFor(self, name, add_prefix) - - def reserve_global(self, global_var, allow_conflict=False): - """Reserves an existing GlobalVar with this supply. - - Parameters - ---------- - global_var: GlobalVar - The GlobalVar to be registered. - - allow_conflict: bool - Allow conflict with other GlobalVars that have the same name - """ - return _ffi_api.GlobalVarSupply_ReserveGlobalVar(self, global_var, allow_conflict) + return _ffi_api.UniqueNameSupply_ContainsName(self, name, add_prefix) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 2d1cc47377c4..11485659fb87 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -59,7 +59,7 @@ import tvm from tvm import relax, tirx, topi from tvm.ir import IRModule -from tvm.ir.supply import NameSupply +from tvm.ir.supply import UniqueNameSupply from tvm.runtime import DataType, DataTypeCode from tvm.tirx.generic import cast from tvm.topi.utils import get_const_tuple @@ -5359,7 +5359,7 @@ def __init__( self._input_names: list[str] = [] self._dtype = dtype_dict self.opset: int = None - self._name_supply = NameSupply() + self._name_supply = UniqueNameSupply() self._keep_params_in_input = keep_params_in_input self._sanitize: bool = sanitize self.bb: relax.BlockBuilder = relax.BlockBuilder() # pylint: disable=invalid-name diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index ee5f3e1dd43c..c51cb05dc4e8 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -21,7 +21,7 @@ from tvm_ffi._dtype import dtype as DataType, DataTypeCode # Import _ffi_node_api for its side effect of installing AsRepr as -# tvm_ffi.core.__object_repr__ so TVM IR objects use the rich C++ ReprPrinter. +# tvm_ffi.core.__object_repr__. from . import _ffi_node_api # class exposures diff --git a/python/tvm/runtime/_ffi_node_api.py b/python/tvm/runtime/_ffi_node_api.py index 18af61ec7563..1c87b989b69f 100644 --- a/python/tvm/runtime/_ffi_node_api.py +++ b/python/tvm/runtime/_ffi_node_api.py @@ -40,8 +40,5 @@ def LoadJSON(json_str): # Exports functions registered in node namespace. tvm_ffi.init_ffi_api("node", __name__) -# Override the default repr function for tvm_ffi.core.Object so TVM IR -# objects use the rich C++ ReprPrinter (registered above via init_ffi_api), -# falling back to the runtime-only AsRepr defined in this file when libtvm.so -# is not available. +# Override the default repr function for tvm_ffi.core.Object. tvm_ffi.core.__object_repr__ = AsRepr diff --git a/src/ir/access_path_repr.cc b/src/ir/access_path_repr.cc deleted file mode 100644 index 8225452c5446..000000000000 --- a/src/ir/access_path_repr.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file ir/access_path_repr.cc - * \brief FFI registration for ffi-repr-based printing. - * - * This file: - * - Registers node.AsRepr (for backward Python compatibility) via ffi::ReprPrint. - * - * Note: __ffi_repr__ hooks for ffi::reflection::AccessPath and AccessStep are - * registered by tvm-ffi itself (src/ffi/extra/reflection_extra.cc, landed in - * apache/tvm-ffi#598). The duplicate registrations that previously lived here - * were removed when bumping tvm-ffi to 59da4c0 to avoid a double-registration - * abort at library load time. - * - * Note: tvm::Dump() has been removed (zero in-tree callers). Use - * tvm::ffi::ReprPrint(any) directly from gdb instead. - */ -#include -#include -#include - -namespace tvm { - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - // node.AsRepr: backward-compatible Python entry point. - // Python's tvm.runtime._ffi_node_api sets __object_repr__ = AsRepr via init_ffi_api. - refl::GlobalDef().def("node.AsRepr", - [](ffi::Any obj) -> ffi::String { return ffi::ReprPrint(obj); }); -} -} // namespace tvm diff --git a/src/ir/global_var_supply.cc b/src/ir/global_var_supply.cc deleted file mode 100644 index 700c3ef84038..000000000000 --- a/src/ir/global_var_supply.cc +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file global_var_supply.cc - * \brief GlobalVarSupply that can be used to generate unique GlobalVars. - */ -#include "tvm/ir/global_var_supply.h" - -#include -#include - -#include - -#include "tvm/ir/expr.h" - -namespace tvm { - -TVM_FFI_STATIC_INIT_BLOCK() { GlobalVarSupplyNode::RegisterReflection(); } - -GlobalVarSupply::GlobalVarSupply(const NameSupply& name_supply, - std::unordered_map name_to_var_map) { - auto n = ffi::make_object(name_supply, name_to_var_map); - data_ = std::move(n); -} - -std::string GetModuleName(const IRModule& module) { - return module->GetAttr(tvm::attr::kModuleName).value_or("tvmgen_default"); -} - -GlobalVarSupply::GlobalVarSupply(const ffi::Array& modules) : GlobalVarSupply() { - if (!modules.empty()) { - IRModule first_mod = modules.front(); - this->operator->()->name_supply_->prefix_ = GetModuleName(first_mod); - } - for (auto& mod : modules) { - for (auto kv : mod->functions) { - this->operator->()->ReserveGlobalVar(kv.first); - } - } -} - -GlobalVarSupply::GlobalVarSupply(const IRModule module) - : GlobalVarSupply(ffi::Array{module}) {} - -void GlobalVarSupplyNode::ReserveGlobalVar(const GlobalVar& var, bool allow_conflict) { - name_supply_->ReserveName(var->name_hint, false); - if (!allow_conflict) { - TVM_FFI_ICHECK(name_to_var_map_.count(var->name_hint) == 0) - << "GlobalVar " << var << " conflicts by name in this supply."; - } - name_to_var_map_[var->name_hint] = var; -} - -GlobalVarSupplyNode::GlobalVarSupplyNode(NameSupply name_supply, - std::unordered_map name_to_var_map) - : name_supply_(std::move(name_supply)), name_to_var_map_(std::move(name_to_var_map)) {} - -GlobalVar GlobalVarSupplyNode::UniqueGlobalFor(const ffi::String& name, bool add_prefix) { - ffi::String final_name = name_supply_->ReserveName(name, add_prefix); - - auto it = name_to_var_map_.find(final_name); - if (it != name_to_var_map_.end()) { - return it->second; - } else { - GlobalVar var = GlobalVar(final_name); - name_to_var_map_.emplace(final_name, var); - return var; - } -} - -GlobalVar GlobalVarSupplyNode::FreshGlobal(ffi::String name, bool add_prefix) { - ffi::String final_name = name_supply_->FreshName(name, add_prefix); - TVM_FFI_ICHECK(name_to_var_map_.find(final_name) == name_to_var_map_.end()) - << "GlobalVar already exists for name " << final_name; - GlobalVar var = GlobalVar(final_name); - name_to_var_map_.emplace(final_name, var); - return var; -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("ir.GlobalVarSupply_NameSupply", - [](const NameSupply& name_supply) { return GlobalVarSupply(name_supply); }) - .def("ir.GlobalVarSupply_IRModule", - [](IRModule mod) { return GlobalVarSupply(std::move(mod)); }) - .def("ir.GlobalVarSupply_IRModules", - [](const ffi::Array& mods) { return GlobalVarSupply(mods); }) - .def_method("ir.GlobalVarSupply_FreshGlobal", &GlobalVarSupplyNode::FreshGlobal) - .def_method("ir.GlobalVarSupply_UniqueGlobalFor", &GlobalVarSupplyNode::UniqueGlobalFor) - .def_method("ir.GlobalVarSupply_ReserveGlobalVar", &GlobalVarSupplyNode::ReserveGlobalVar); -} - -} // namespace tvm diff --git a/src/ir/module.cc b/src/ir/module.cc index a09780d94dc5..156ca17c1255 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -28,9 +28,9 @@ #include #include #include -#include #include #include +#include #include #include @@ -219,13 +219,17 @@ IRModule IRModule::FromExpr(const RelaxExpr& expr, } } + UniqueNameSupply global_names(mod->functions.begin(), mod->functions.end(), + [](const auto& kv) { return kv.first->name_hint; }); GlobalVar main_gv; - auto global_var_supply = GlobalVarSupply(mod); if (gv_name.empty()) { // Bind function to 'main' (though rename if would clash with existing 'main'). - main_gv = global_var_supply->FreshGlobal("main", false); + main_gv = GlobalVar(global_names->FreshName("main", false)); + } else if (mod->ContainGlobalVar(gv_name)) { + main_gv = mod->GetGlobalVar(gv_name); } else { - main_gv = global_var_supply->UniqueGlobalFor(gv_name, false); + global_names->ReserveName(gv_name, false); + main_gv = GlobalVar(gv_name); } mod->Add(main_gv, func); return mod; diff --git a/src/ir/name_supply.cc b/src/ir/name_supply.cc deleted file mode 100644 index 2f7bf501e55a..000000000000 --- a/src/ir/name_supply.cc +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file name_supply.cc - * \brief NameSupply that can be used to generate unique variable names. - */ -#include "tvm/ir/name_supply.h" - -#include -#include - -#include - -namespace tvm { - -NameSupply::NameSupply(const ffi::String& prefix, std::unordered_map name_map) { - auto n = ffi::make_object(prefix, std::move(name_map)); - data_ = std::move(n); -} - -ffi::String NameSupplyNode::ReserveName(const ffi::String& name, bool add_prefix) { - ffi::String final_name = name; - if (add_prefix) { - final_name = add_prefix_to_name(name); - } - name_map[final_name] = 0; - return final_name; -} - -ffi::String NameSupplyNode::FreshName(const ffi::String& name, bool add_prefix, - bool add_underscore) { - ffi::String unique_name = name; - if (unique_name.empty()) { - // Special case for empty name, set to "v". - unique_name = "v"; - } - if (add_prefix) { - unique_name = add_prefix_to_name(unique_name); - } - unique_name = GetUniqueName(unique_name, add_underscore); - return unique_name; -} - -bool NameSupplyNode::ContainsName(const ffi::String& name, bool add_prefix) { - ffi::String unique_name = name; - if (add_prefix) { - unique_name = add_prefix_to_name(name); - } - - return name_map.count(unique_name); -} - -ffi::String NameSupplyNode::add_prefix_to_name(const ffi::String& name) { - if (prefix_.empty()) { - return name; - } - - std::ostringstream ss; - ss << prefix_ << "_" << name; - return ss.str(); -} - -std::string NameSupplyNode::GetUniqueName(std::string name, bool add_underscore) { - for (size_t i = 0; i < name.size(); ++i) { - if (name[i] == '.') name[i] = '_'; - } - auto it = name_map.find(name); - if (it != name_map.end()) { - auto new_name = name; - while (!name_map.insert({new_name, 0}).second) { - std::ostringstream os; - os << name << (add_underscore ? "_" : "") << (++it->second); - new_name = os.str(); - } - return new_name; - } - name_map[name] = 0; - return name; -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - NameSupplyNode::RegisterReflection(); - refl::GlobalDef() - .def("ir.NameSupply", [](ffi::String prefix) { return NameSupply(prefix); }) - .def_method("ir.NameSupply_FreshName", &NameSupplyNode::FreshName) - .def_method("ir.NameSupply_ReserveName", &NameSupplyNode::ReserveName) - .def_method("ir.NameSupply_ContainsName", &NameSupplyNode::ContainsName); -} - -} // namespace tvm diff --git a/src/ir/unique_name_supply.cc b/src/ir/unique_name_supply.cc new file mode 100644 index 000000000000..481edcac89ce --- /dev/null +++ b/src/ir/unique_name_supply.cc @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file unique_name_supply.cc + * \brief UniqueNameSupply that can be used to generate unique variable names. + */ +#include "tvm/ir/unique_name_supply.h" + +#include +#include + +#include +#include + +namespace tvm { + +UniqueNameSupply::UniqueNameSupply(const ffi::String& prefix, + ffi::Map name_map) { + if (!name_map.defined()) { + name_map = ffi::Map(); + } + auto n = ffi::make_object(prefix, std::move(name_map)); + data_ = std::move(n); +} + +ffi::String UniqueNameSupplyNode::ReserveName(const ffi::String& name, bool add_prefix) { + ffi::String final_name = name; + if (add_prefix) { + final_name = AddPrefixToName(name); + } + name_map.Set(final_name, 0); + return final_name; +} + +ffi::String UniqueNameSupplyNode::FreshName(const ffi::String& name, bool add_prefix, + bool add_underscore) { + ffi::String unique_name = name; + if (unique_name.empty()) { + unique_name = "v"; + } + if (add_prefix) { + unique_name = AddPrefixToName(unique_name); + } + return GetUniqueName(unique_name, add_underscore); +} + +bool UniqueNameSupplyNode::ContainsName(const ffi::String& name, bool add_prefix) { + ffi::String unique_name = name; + if (add_prefix) { + unique_name = AddPrefixToName(name); + } + return name_map.count(unique_name); +} + +ffi::String UniqueNameSupplyNode::AddPrefixToName(const ffi::String& name) { + if (prefix_.empty()) { + return name; + } + + std::ostringstream ss; + ss << prefix_ << "_" << name; + return ss.str(); +} + +std::string UniqueNameSupplyNode::GetUniqueName(std::string name, bool add_underscore) { + for (size_t i = 0; i < name.size(); ++i) { + if (name[i] == '.') name[i] = '_'; + } + ffi::String name_key = name; + auto it = name_map.find(name_key); + if (it != name_map.end()) { + auto new_name = name; + int64_t suffix = (*it).second; + while (name_map.count(ffi::String(new_name))) { + std::ostringstream os; + os << name << (add_underscore ? "_" : "") << (++suffix); + new_name = os.str(); + } + name_map.Set(name_key, suffix); + name_map.Set(ffi::String(new_name), 0); + return new_name; + } + name_map.Set(name_key, 0); + return name; +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + UniqueNameSupplyNode::RegisterReflection(); + refl::GlobalDef() + .def("ir.UniqueNameSupply", [](ffi::String prefix) { return UniqueNameSupply(prefix); }) + .def_method("ir.UniqueNameSupply_FreshName", &UniqueNameSupplyNode::FreshName) + .def_method("ir.UniqueNameSupply_ReserveName", &UniqueNameSupplyNode::ReserveName) + .def_method("ir.UniqueNameSupply_ContainsName", &UniqueNameSupplyNode::ContainsName); +} + +} // namespace tvm diff --git a/src/relax/backend/contrib/cutlass/codegen.cc b/src/relax/backend/contrib/cutlass/codegen.cc index 91840f6936e5..6de72397dc52 100644 --- a/src/relax/backend/contrib/cutlass/codegen.cc +++ b/src/relax/backend/contrib/cutlass/codegen.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include #include #include @@ -333,8 +333,8 @@ class CodegenCutlass : public relax::MemoizedExprTranslator, * name_hint. */ std::unordered_map var_name_map_; - /*! \brief A name supply to generate a unique name for each parameter. */ - NameSupply name_sup_; + /*! \brief A unique name supply to generate a unique name for each parameter. */ + UniqueNameSupply name_sup_; }; class CutlassModuleCodegen { diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index 9fad59f4e374..85fcfef1ea56 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -48,8 +48,8 @@ DataflowBlockRewrite::DataflowBlockRewrite(DataflowBlock dfb, Function root_fn) auto p = FunctionUseDef(root_fn); n->to_users_ = std::move(p.first); n->fn_outputs_ = std::move(p.second); - n->name_supply_ = NameSupply(n->to_users_.begin(), n->to_users_.end(), - [](const auto& p) { return p.first->name_hint(); }); + n->name_supply_ = UniqueNameSupply(n->to_users_.begin(), n->to_users_.end(), + [](const auto& p) { return p.first->name_hint(); }); data_ = std::move(n); } diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 344b09024e59..f9360c6c4246 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -70,7 +70,7 @@ class BlockBuilderImpl : public BlockBuilderNode { //------------------------------- // Global Context management //------------------------------- - NameSupply name_supply() final { return name_supply_; } + UniqueNameSupply name_supply() final { return name_supply_; } IRModule GetContextIRModule() const final { return context_mod_; } @@ -346,8 +346,8 @@ class BlockBuilderImpl : public BlockBuilderNode { /*! \brief A binding table that maps var to value. */ std::unordered_map binding_table_; - /*! \brief A name supply to get unique names for IR construction. */ - NameSupply name_supply_; + /*! \brief A unique name supply to get unique names for IR construction. */ + UniqueNameSupply name_supply_; /*! \brief The IRModule being built by the BlockBuilder. */ IRModule context_mod_; diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index 9efa92bd8490..625ae1e76416 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -154,7 +154,7 @@ void RewriteSpec::Append(RewriteSpec other) { return; } - NameSupply gvar_name_supply(""); + UniqueNameSupply gvar_name_supply(""); for (const auto& [gvar, func] : new_subroutines) { gvar_name_supply->ReserveName(gvar->name_hint); } diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index 718214d49157..6bbe86d148f9 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -25,7 +25,7 @@ #include #include -#include +#include #include #include #include @@ -96,7 +96,7 @@ class ExternFunctionRewriter : ExprMutator { } private: - NameSupply name_sup_; + UniqueNameSupply name_sup_; /*! \brief A variable that represents the workspace parameter passed from main. */ Var workspace_var_param_; size_t max_workspace_size_ = 0; diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc index 7c8f6b3854a3..ac3f0611db48 100644 --- a/src/relax/transform/normalize.cc +++ b/src/relax/transform/normalize.cc @@ -268,7 +268,7 @@ class GlobalVarNormalizer : private ExprMutator { } IRModule module_; - NameSupply name_supply_; + UniqueNameSupply name_supply_; ffi::Map gvar_map_; }; diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 8762c83ee4f7..ddedd9ee355a 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -24,7 +24,7 @@ #include #include -#include +#include #include #include diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 934d1af83a36..a023277ed19c 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -357,8 +357,8 @@ class CodeGenC : public ExprFunctor, */ std::unordered_map internal_functions_; - /* \brief Name supply to generate unique function names */ - NameSupply func_name_supply_; + /* \brief Unique unique name supply to generate unique function names */ + UniqueNameSupply func_name_supply_; }; } // namespace codegen diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index 5a07e3c7aa07..2646a6597ef4 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -28,7 +28,7 @@ namespace tvm { namespace codegen { void CodeGenSourceBase::ClearFuncState() { - name_supply_ = NameSupply(); + name_supply_ = UniqueNameSupply(); ssa_assign_map_.clear(); var_idmap_.clear(); scope_mark_.clear(); diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 9283944c1b0d..f6e58cc9efba 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -25,7 +25,7 @@ #ifndef TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ #define TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ -#include +#include #include #include #include @@ -123,8 +123,8 @@ class CodeGenSourceBase { std::ostringstream fwd_decl_stream; /*! \brief name of each variable */ std::unordered_map var_idmap_; - /*! \brief NameSupply for allocation */ - NameSupply name_supply_; + /*! \brief Unique name supply for allocation */ + UniqueNameSupply name_supply_; /*! \brief The current indentation value */ int indent_{0}; diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index a4ce62812a08..5a7223430ed5 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include #include #include @@ -112,8 +112,8 @@ struct CreateFuncInfo { ProducerToBufferTransformer transformer; /*! \brief The buffers should be allocated at function root. */ ffi::Array root_alloc; - /*! \brief The NameSupply to make block name unique. */ - NameSupply name_supply; + /*! \brief The unique name supply to make block name unique. */ + UniqueNameSupply name_supply; ffi::String FreshName(ffi::String base_name) { return name_supply->FreshName(base_name); } diff --git a/src/tirx/ir/index_map.cc b/src/tirx/ir/index_map.cc index b26ccca248d6..cde0370f7f9d 100644 --- a/src/tirx/ir/index_map.cc +++ b/src/tirx/ir/index_map.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include #include @@ -347,7 +347,7 @@ IndexMap IndexMap::RenameVariables( const std::function(const Var& var)>& f_name_map) const { std::unordered_set used_names; ffi::Map var_remap; - NameSupply name_supply; + UniqueNameSupply name_supply; const IndexMapNode* n = this->get(); if (f_name_map != nullptr) { // Collect variables with pre-defined names provided by f_name_map. diff --git a/src/tirx/transform/bind_target.cc b/src/tirx/transform/bind_target.cc index 7a5627c80bcb..16bf74015200 100644 --- a/src/tirx/transform/bind_target.cc +++ b/src/tirx/transform/bind_target.cc @@ -36,7 +36,7 @@ #include #include -#include +#include #include #include #include @@ -261,7 +261,8 @@ IRModule BindTarget(IRModule mod, const Target& target) { // Track duplicated functions for call replacement ffi::Map host_function_replacements; - GlobalVarSupply gvar_supply(new_mod); + UniqueNameSupply global_names(new_mod->functions.begin(), new_mod->functions.end(), + [](const auto& kv) { return kv.first->name_hint; }); for (auto [gvar, func] : mod->functions) { const auto* prim_func_node = func.as(); @@ -313,7 +314,7 @@ IRModule BindTarget(IRModule mod, const Target& target) { // Create duplicate with host target for host callers host_func = WithAttr(std::move(host_func), tvm::attr::kTarget, target_host); ffi::String host_func_name = gvar->name_hint + "_host"; - GlobalVar host_gvar = gvar_supply->FreshGlobal(host_func_name, false); + GlobalVar host_gvar = GlobalVar(global_names->FreshName(host_func_name, false)); new_mod->Add(host_gvar, host_func); host_function_replacements.Set(gvar, host_gvar); diff --git a/src/tirx/transform/split_host_device.cc b/src/tirx/transform/split_host_device.cc index 7ec104765f3a..acc5e473afb8 100644 --- a/src/tirx/transform/split_host_device.cc +++ b/src/tirx/transform/split_host_device.cc @@ -24,8 +24,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -678,7 +678,8 @@ namespace transform { Pass SplitHostDevice() { auto pass_func = [](IRModule mod, PassContext ctx) { - GlobalVarSupply global_var_supply(mod); + UniqueNameSupply global_names(mod->functions.begin(), mod->functions.end(), + [](const auto& kv) { return kv.first->name_hint; }); IRModule device_mod = IRModule(ffi::Map({})); IRModule updates = IRModule(ffi::Map({})); @@ -691,8 +692,8 @@ Pass SplitHostDevice() { auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); auto name_prefix = global_symbol.value_or(gvar->name_hint); auto kernel_name = name_prefix + "_kernel"; - auto var_supply = [&global_var_supply, &kernel_name]() -> GlobalVar { - return global_var_supply->FreshGlobal(kernel_name, false); + auto var_supply = [&global_names, &kernel_name]() -> GlobalVar { + return GlobalVar(global_names->FreshName(kernel_name, false)); }; func = SplitHostDevice(std::move(func), &device_mod, var_supply); diff --git a/tests/python/ir/test_name_supply.py b/tests/python/ir/test_unique_name_supply.py similarity index 62% rename from tests/python/ir/test_name_supply.py rename to tests/python/ir/test_unique_name_supply.py index bc3283968d3f..f440301e1feb 100644 --- a/tests/python/ir/test_name_supply.py +++ b/tests/python/ir/test_unique_name_supply.py @@ -16,12 +16,17 @@ # under the License. import tvm import tvm.testing -from tvm.ir.supply import NameSupply +from tvm import relax as rx +from tvm.ir.supply import UniqueNameSupply + + +def _empty_relax_func(): + return rx.Function([], rx.Tuple([])) def test_fresh_name_empty_string(): """Empty name should produce a valid variable name, not an empty string.""" - ns = NameSupply("") + ns = UniqueNameSupply("") name = ns.fresh_name("", add_prefix=False) assert name == "v" name2 = ns.fresh_name("", add_prefix=False) @@ -30,12 +35,28 @@ def test_fresh_name_empty_string(): def test_fresh_name_empty_string_with_prefix(): """Empty name with prefix should produce a valid variable name.""" - ns = NameSupply("prefix") + ns = UniqueNameSupply("prefix") name = ns.fresh_name("", add_prefix=True) assert name == "prefix_v" name2 = ns.fresh_name("", add_prefix=True) assert name2 == "prefix_v_1" +def test_ir_module_from_expr_freshens_main_collision(): + main_gv = tvm.ir.GlobalVar("main") + mod = tvm.IRModule.from_expr(_empty_relax_func(), {main_gv: _empty_relax_func()}) + + assert sorted(gvar.name_hint for gvar in mod.get_global_vars()) == ["main", "main_1"] + + +def test_ir_module_from_expr_reuses_existing_global_symbol(): + foo_gv = tvm.ir.GlobalVar("foo") + func = _empty_relax_func().with_attr("global_symbol", "foo") + mod = tvm.IRModule.from_expr(func, {foo_gv: _empty_relax_func()}) + + assert mod.get_global_var("foo").same_as(foo_gv) + assert [gvar.name_hint for gvar in mod.get_global_vars()] == ["foo"] + + if __name__ == "__main__": tvm.testing.main() From 4066127509e22e9c634f734c25bc773686702854 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Mon, 15 Jun 2026 18:47:33 -0400 Subject: [PATCH 05/23] [CUDA] Narrow the cuda extra from cuda-python to cuda-bindings (#19784) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TVM's shipped code only uses cuda.bindings — cuda.bindings.nvrtc for the NVRTC JIT path and cuda.bindings.driver for the NVSHMEM link path, both in python/tvm/support/nvcc.py; it never uses cuda.core. cuda-python is now a metapackage that pulls in cuda-bindings + cuda-core (and cuda-pathfinder), so depending on it drags in cuda-core that TVM does not need. Depend directly on cuda-bindings, which provides exactly the nvrtc and driver submodules TVM imports, and update the user-facing 'pip install cuda-python' hints to match. A plain cuda-bindings install pulls no nvidia-* toolkit wheels (those live behind the [all] extra); libnvrtc is loaded from the system / TVM's CUDA install as before. --- pyproject.toml | 2 +- python/tvm/support/nvcc.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3c61fa389fc3..221b3c3383b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ dependencies = [ [project.optional-dependencies] torch = ["torch"] -cuda = ["cuda-python"] +cuda = ["cuda-bindings"] meta-schedule = ["xgboost"] popen-pool = ["psutil", "cloudpickle"] rpc = ["tornado", "psutil", "cloudpickle"] diff --git a/python/tvm/support/nvcc.py b/python/tvm/support/nvcc.py index 859dbd3077c7..ea5939fceffc 100644 --- a/python/tvm/support/nvcc.py +++ b/python/tvm/support/nvcc.py @@ -65,7 +65,7 @@ def compile_cuda( Notes ----- - NVRTC is a "runtime" compilation library and can be faster for JIT compilation. - - NVRTC requires cuda-python: pip install cuda-python + - NVRTC requires cuda-bindings: pip install cuda-bindings """ use_nvshmem = "#include " in code or "#include " in code @@ -289,9 +289,9 @@ def _compile_cuda_nvrtc( from cuda.bindings import nvrtc # pylint: disable=import-outside-toplevel except ImportError as e: raise RuntimeError( - "Failed to compile CUDA with NVRTC because the `cuda-python` package " + "Failed to compile CUDA with NVRTC because the `cuda-bindings` package " "is not available.\n" - "Please install it with: pip install cuda-python\n" + "Please install it with: pip install cuda-bindings\n" "See: https://nvidia.github.io/cuda-python/" ) from e @@ -301,9 +301,9 @@ def _compile_cuda_nvrtc( if importlib.util.find_spec("cuda.bindings.driver") is None: raise RuntimeError( - "Failed to compile CUDA with NVRTC+NVSHMEM because the `cuda-python` package " + "Failed to compile CUDA with NVRTC+NVSHMEM because the `cuda-bindings` package " "is not available.\n" - "Please install it with: pip install cuda-python\n" + "Please install it with: pip install cuda-bindings\n" "See: https://nvidia.github.io/cuda-python/" ) @@ -812,7 +812,7 @@ def tvm_callback_cuda_compile(code): TVM_CUDA_COMPILE_MODE : str Compiler backend: "nvcc" (default) or "nvrtc" - "nvcc": Use nvcc subprocess, generates fatbin - - "nvrtc": Use NVRTC via cuda-python for faster JIT, generates cubin + - "nvrtc": Use NVRTC via cuda-bindings for faster JIT, generates cubin TVM_KERNEL_DUMP : str If set, dump generated CUDA/intermediate files and append "-lineinfo" so profilers can correlate SASS back to the dumped source. From 3ea565b482c8394ff790993ee4dc25b9a708548b Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 15 Jun 2026 18:50:14 -0400 Subject: [PATCH 06/23] [AGENT] Migrate agent instructions to vendor-neutral layout (#19783) This PR migrates repository agent instructions away from Claude-specific paths and into a vendor-neutral layout. Changes: - Add root `AGENTS.md` - Move existing command guidance from `.claude/commands` to `.agents/skills/*/SKILL.md` using git renames. - Move the GPU monitor helper from `.claude/scripts` to `.agents/scripts`. - Update the TIR test skill to reference `.agents/scripts/monitor_gpu.sh`. - Replace the ASF header skip entry for `.claude/*` with `.agents/*`. Validation: - `bash -n .agents/scripts/monitor_gpu.sh` - `.agents/scripts/monitor_gpu.sh --help` - `pre-commit run --files AGENTS.md .agents/skills/tir-build/SKILL.md .agents/skills/tir-test/SKILL.md .agents/skills/tir-bench/SKILL.md .agents/scripts/monitor_gpu.sh tests/lint/check_asf_header.py` --- {.claude => .agents}/scripts/monitor_gpu.sh | 13 ++- .../skills/tir-bench/SKILL.md | 0 .../skills/tir-build/SKILL.md | 6 +- .../skills/tir-test/SKILL.md | 4 +- AGENTS.md | 99 +++++++++++++++++++ tests/lint/check_asf_header.py | 2 +- 6 files changed, 118 insertions(+), 6 deletions(-) rename {.claude => .agents}/scripts/monitor_gpu.sh (86%) rename .claude/commands/tir-bench.md => .agents/skills/tir-bench/SKILL.md (100%) rename .claude/commands/tir-build.md => .agents/skills/tir-build/SKILL.md (69%) rename .claude/commands/tir-test.md => .agents/skills/tir-test/SKILL.md (95%) create mode 100644 AGENTS.md diff --git a/.claude/scripts/monitor_gpu.sh b/.agents/scripts/monitor_gpu.sh similarity index 86% rename from .claude/scripts/monitor_gpu.sh rename to .agents/scripts/monitor_gpu.sh index 85963da93089..e1d91ae3b49c 100755 --- a/.claude/scripts/monitor_gpu.sh +++ b/.agents/scripts/monitor_gpu.sh @@ -23,7 +23,18 @@ while [[ $# -gt 0 ]]; do --interval) INTERVAL="$2"; shift 2 ;; --log) LOG="$2"; shift 2 ;; -h|--help) - sed -n '2,12p' "$0" | sed 's/^# \{0,1\}//' + cat <<'EOF' +Watch a single GPU for foreign processes (anyone other than the current +user) appearing during a long-running test. Intended companion to +`/tir-test`: leave this running in a side terminal while pytest runs, and +it will alert if someone else lands on the same GPU. + +Usage: + monitor_gpu.sh # uses $CUDA_VISIBLE_DEVICES, defaults to 0 + monitor_gpu.sh --gpu 3 # watch GPU 3 + monitor_gpu.sh --gpu 3 --interval 2 # poll every 2 seconds + monitor_gpu.sh --log /tmp/gpu.log # also tee to a log file +EOF exit 0 ;; *) echo "unknown arg: $1" >&2; exit 2 ;; esac diff --git a/.claude/commands/tir-bench.md b/.agents/skills/tir-bench/SKILL.md similarity index 100% rename from .claude/commands/tir-bench.md rename to .agents/skills/tir-bench/SKILL.md diff --git a/.claude/commands/tir-build.md b/.agents/skills/tir-build/SKILL.md similarity index 69% rename from .claude/commands/tir-build.md rename to .agents/skills/tir-build/SKILL.md index 21aadbe68563..ee43dfa84bb7 100644 --- a/.claude/commands/tir-build.md +++ b/.agents/skills/tir-build/SKILL.md @@ -4,12 +4,14 @@ Build TVM from the current worktree. 1. Check that `build/` directory exists. If not, run initial setup: ```bash - mkdir -p build && cd build && cmake .. && make -j$(nproc) + mkdir -p build + cmake -S . -B build + cmake --build build --parallel ``` 2. If `build/` already exists, run incremental build: ```bash - cmake --build build -j$(nproc) + cmake --build build --parallel ``` 3. Report success/failure and build time. diff --git a/.claude/commands/tir-test.md b/.agents/skills/tir-test/SKILL.md similarity index 95% rename from .claude/commands/tir-test.md rename to .agents/skills/tir-test/SKILL.md index f6cd25236b38..b4c45069163b 100644 --- a/.claude/commands/tir-test.md +++ b/.agents/skills/tir-test/SKILL.md @@ -10,14 +10,14 @@ Run the full TIRX test suite. 2. Start the GPU monitor in the background so we can detect if anyone else lands on the same GPU mid-run: ```bash GPU_LOG="/tmp/tir_test_gpu_${CUDA_VISIBLE_DEVICES}.log" - bash .claude/scripts/monitor_gpu.sh --gpu "$CUDA_VISIBLE_DEVICES" --interval 5 --log "$GPU_LOG" & + bash .agents/scripts/monitor_gpu.sh --gpu "$CUDA_VISIBLE_DEVICES" --interval 5 --log "$GPU_LOG" & MON_PID=$! trap 'kill $MON_PID 2>/dev/null' EXIT ``` 3. Run the full test suite with xdist parallelism: ```bash - pytest tests/python/tirx/ -n 16 + pytest tests/python/tirx/ -n auto ``` 4. Stop the monitor and check for foreign GPU usage during the run: diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 000000000000..4d70e7c79c43 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,99 @@ + + +# AGENTS.md + +This file provides vendor-neutral guidance for agentic coding tools working +with Apache TVM. + +## Repository Overview + +Apache TVM is an open-source machine learning compiler stack. The repository +contains the C++ compiler/runtime, Python bindings, TIR/Relax IRs, scheduling +and lowering passes, target code generators, runtime integrations, tests, +documentation, and application examples. + +## Repository Structure + +- `include/tvm/` - public C++ headers +- `src/` - C++ implementation +- `python/tvm/` - Python package +- `tests/` - C++, Python, integration, and lint tests +- `cmake/` - CMake modules and default configuration +- `3rdparty/` - vendored dependencies and submodules +- `docs/` - documentation source +- `apps/` - application examples +- `.agents/skills/` - reusable agent workflows for this repository + +## Build + +Use an existing `build/` directory when present: + +```bash +cmake --build build --parallel +``` + +For a fresh checkout, initialize submodules and configure CMake first: + +```bash +git submodule update --init --recursive +mkdir -p build +cp cmake/config.cmake build/config.cmake +cmake -S . -B build -G Ninja -DCMAKE_BUILD_TYPE=RelWithDebInfo +cmake --build build --parallel +``` + +Development should use `PYTHONPATH`, not editable installs: + +```bash +export PYTHONPATH="$(pwd)/python:$(pwd)/.local/python" +``` + +Do not use `pip install -e` for TVM or `tvm-ffi`; editable installs can make +one worktree silently import another worktree's code. + +## Test And Lint + +Run the smallest relevant test first, then broaden as needed. Common examples: + +```bash +python -m pytest tests/python/all-platform-minimal-test/ -xvs +python -m pytest tests/python/tir-base/test_tir_base.py -xvs +./build/cpptest +``` + +For lint validation on a pull request, run pre-commit on the files changed by +the branch instead of the whole repository: + +```bash +pre-commit run --files ... +``` + +Use `.agents/skills/tir-build`, `.agents/skills/tir-test`, and +`.agents/skills/tir-bench` when their workflows apply. + +## Coding Conventions + +- Follow the surrounding style before introducing new abstractions. +- Keep changes scoped to the task and avoid unrelated cleanups. +- Prefer explicit tests that show the IR or behavior being changed. +- Use Apache TVM commit tags such as `[REFACTOR][IR]`, `[FIX][TIR]`, or + `[DOCS]` as appropriate. +- Preserve Apache license headers in new source, script, and documentation + files when the surrounding tree uses them. diff --git a/tests/lint/check_asf_header.py b/tests/lint/check_asf_header.py index 8ba73524f79a..d6237c600a4a 100644 --- a/tests/lint/check_asf_header.py +++ b/tests/lint/check_asf_header.py @@ -186,7 +186,7 @@ "ffi/3rdparty/*", ".github/*", ".txdev/*", - ".claude/*", + ".agents/*", "*.json", "*.txt", "*.svg", From b684868bd8c2d5cf0a9b1bb4d492caa709f424bb Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Mon, 15 Jun 2026 18:50:57 -0400 Subject: [PATCH 07/23] [Tests] Modernize test gating (#19777) This pr modernizes test gating. It replaces the heavy `tvm.testing.Feature` machinery with a thin `tvm.testing.env` module of `has_*()` capability probes, used via standard pytest.mark + skipif. And markers move to `pyproject.toml` --- pyproject.toml | 8 +- python/tvm/contrib/hexagon/_ci_env_check.py | 6 +- python/tvm/contrib/hexagon/pytest_plugin.py | 6 +- python/tvm/testing/__init__.py | 1 + python/tvm/testing/env.py | 518 +++++++++++++ python/tvm/testing/plugin.py | 75 +- python/tvm/testing/utils.py | 721 +----------------- .../test_minimal_target_codegen_llvm.py | 4 +- .../codegen/test_codegen_error_handling.py | 4 +- .../codegen/test_gpu_codegen_allreduce.py | 4 +- tests/python/codegen/test_inject_ptx_ldg32.py | 5 +- .../codegen/test_target_codegen_blob.py | 2 +- .../codegen/test_target_codegen_bool.py | 3 +- .../codegen/test_target_codegen_cross_llvm.py | 4 +- .../codegen/test_target_codegen_cuda.py | 96 +-- .../test_target_codegen_cuda_fastmath.py | 5 +- .../codegen/test_target_codegen_cuda_fp4.py | 10 +- .../codegen/test_target_codegen_cuda_fp8.py | 34 +- .../codegen/test_target_codegen_device.py | 8 +- .../codegen/test_target_codegen_extern.py | 3 +- .../codegen/test_target_codegen_gpu_common.py | 4 +- .../codegen/test_target_codegen_hexagon.py | 7 +- .../codegen/test_target_codegen_llvm.py | 69 +- .../codegen/test_target_codegen_metal.py | 32 +- .../codegen/test_target_codegen_opencl.py | 27 +- .../codegen/test_target_codegen_riscv.py | 7 +- .../codegen/test_target_codegen_rocm.py | 20 +- .../codegen/test_target_codegen_vulkan.py | 14 +- .../python/codegen/test_target_codegen_x86.py | 3 +- tests/python/contrib/test_cutlass_gemm.py | 22 +- .../test_hexagon/test_async_dma_pipeline.py | 5 +- .../test_benchmark_elemwise_add.py | 3 +- .../test_hexagon/test_benchmark_maxpool2d.py | 3 +- .../contrib/test_hexagon/test_dma_builtin.py | 4 +- .../test_hexagon/test_meta_schedule.py | 7 +- .../contrib/test_hexagon/test_parallel_hvx.py | 4 +- .../test_parallel_hvx_load_vtcm.py | 4 +- .../test_hexagon/test_parallel_scalar.py | 4 +- .../test_hexagon/test_relax_integration.py | 5 +- .../test_hexagon/test_run_unit_tests.py | 5 +- .../contrib/test_hexagon/test_sigmoid.py | 4 +- .../test_software_pipeline_async.py | 4 +- .../contrib/test_hexagon/test_thread_pool.py | 6 +- .../python/contrib/test_hexagon/test_vtcm.py | 5 +- .../test_hexagon/test_vtcm_bandwidth.py | 3 +- tests/python/contrib/test_hipblas.py | 8 +- tests/python/contrib/test_random.py | 3 +- .../contrib/test_tir_triton_integration.py | 4 +- tests/python/disco/test_callback.py | 5 +- tests/python/disco/test_loader.py | 5 +- tests/python/disco/test_nvshmem.py | 6 +- .../python/nightly/test_nnapi/test_network.py | 3 +- tests/python/relax/backend/adreno/utils.py | 62 +- tests/python/relax/test_codegen_cublas.py | 15 +- tests/python/relax/test_codegen_cudnn.py | 6 +- tests/python/relax/test_codegen_cutlass.py | 5 +- tests/python/relax/test_codegen_hipblas.py | 6 +- tests/python/relax/test_codegen_tensorrt.py | 6 +- tests/python/relax/test_contrib_vllm.py | 8 +- tests/python/relax/test_frontend_dynamo.py | 19 +- .../test_frontend_from_exported_program.py | 3 +- tests/python/relax/test_frontend_from_fx.py | 7 +- ...frontend_nn_llm_sequence_prefill_masked.py | 29 +- tests/python/relax/test_frontend_nn_op.py | 11 +- tests/python/relax/test_frontend_stablehlo.py | 31 +- tests/python/relax/test_op_vision.py | 31 +- ...uiltin_paged_attention_kv_cache_mla_tir.py | 17 +- ...me_builtin_paged_attention_kv_cache_tir.py | 33 +- .../relax/test_runtime_builtin_rnn_state.py | 13 +- .../relax/test_tir_call_source_kernel.py | 5 +- .../relax/test_transform_codegen_pass.py | 15 +- tests/python/relax/test_vm_build.py | 10 +- tests/python/relax/test_vm_cuda_graph.py | 7 +- tests/python/relax/test_vm_multi_device.py | 6 +- tests/python/relax/texture/test_texture_nd.py | 6 +- .../runtime/test_runtime_module_export.py | 5 +- .../runtime/test_runtime_module_load.py | 8 +- tests/python/runtime/test_runtime_rpc.py | 37 +- tests/python/s_tir/dlight/test_primitives.py | 6 +- .../test_meta_schedule_mma_tensorize.py | 13 +- .../test_meta_schedule_space_post_opt.py | 6 +- .../test_meta_schedule_tune_tir.py | 6 +- ...schedule_tensorize_ldmatrix_mma_numeric.py | 17 +- ...est_tir_schedule_tensorize_mfma_numeric.py | 11 +- ...t_s_tir_transform_inject_ptx_async_copy.py | 13 +- ..._tir_transform_inject_software_pipeline.py | 7 +- .../test_s_tir_transform_thread_sync.py | 9 +- tests/python/target/test_arm_target.py | 9 +- tests/python/target/test_target_target.py | 13 +- tests/python/testing/test_env.py | 205 +++++ tests/python/tirx-base/test_tir_imm_values.py | 13 +- .../python/tirx-base/test_tir_ptx_cp_async.py | 5 +- .../tirx-base/test_tir_ptx_griddepcontrol.py | 5 +- .../python/tirx-base/test_tir_ptx_ldmatrix.py | 5 +- tests/python/tirx-base/test_tir_ptx_mma.py | 61 +- tests/python/tirx-base/test_tir_ptx_mma_sp.py | 8 +- .../tirx-base/test_tir_ptx_scalar_f32_math.py | 5 +- .../test_tir_transform_lower_intrin.py | 8 +- .../test_tir_transform_lower_tvm_builtin.py | 3 +- .../tirx/codegen/test_codegen_ampere.py | 7 +- .../tirx/codegen/test_codegen_blackwell.py | 19 +- .../tirx/codegen/test_codegen_hopper.py | 55 +- .../cuda/copy/test_gmem_smem.py | 4 +- .../cuda/copy/test_ld_stmatrix.py | 13 +- .../cuda/copy_async/test_dsmem.py | 4 +- .../cuda/copy_async/test_smem_tmem.py | 13 +- .../cuda/copy_async/test_tma.py | 16 +- .../cuda/gemm/test_gemm_mma_m16n8k_.py | 16 +- tests/python/tirx/test_bench_utils.py | 17 +- tests/python/tvmscript/test_tvmscript_ops.py | 3 +- .../task_python_integration_gpuonly.sh | 1 + tests/scripts/task_python_unittest_gpuonly.sh | 1 + 112 files changed, 1592 insertions(+), 1228 deletions(-) create mode 100644 python/tvm/testing/env.py create mode 100644 tests/python/testing/test_env.py diff --git a/pyproject.toml b/pyproject.toml index 221b3c3383b0..e3f1038f22b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -160,13 +160,7 @@ addopts = "-v --tb=short" python_files = ["test_*.py", "*_test.py"] python_classes = ["Test*"] python_functions = ["test_*"] -markers = [ - "adreno_clml: Mark a test as using adreno_clml", - "adreno_opencl_vulkan: Mark a test as using adreno_opencl_vulkan", - "adreno_vulkan: Mark a test as using adreno_vulkan", - "adreno_opencl: Mark a test as using adreno_opencl", - "adreno_opencl_real: Mark a test as using adreno_opencl_real", -] +markers = ["gpu: Mark a test as requiring a GPU"] [tool.ruff] include = [ diff --git a/python/tvm/contrib/hexagon/_ci_env_check.py b/python/tvm/contrib/hexagon/_ci_env_check.py index e36dde5d2162..f7f14a23955e 100644 --- a/python/tvm/contrib/hexagon/_ci_env_check.py +++ b/python/tvm/contrib/hexagon/_ci_env_check.py @@ -34,8 +34,7 @@ def _compile_time_check(): """Return True if compile-time support for Hexagon is present, otherwise error string. - Designed for use as a the ``compile_time_check`` argument to - `tvm.testing.Feature`. + Backs :func:`tvm.testing.env.has_hexagon_toolchain`. """ if tvm.runtime.enabled("llvm") and tvm.target.codegen.llvm_version_major() < 7: return "Hexagon requires LLVM 7 or later" @@ -50,8 +49,7 @@ def _run_time_check(): """Return True if run-time support for Hexagon is present, otherwise error string. - Designed for use as a the ``run_time_check`` argument to - `tvm.testing.Feature`. + Backs :func:`tvm.testing.env.has_hexagon`. """ if ANDROID_SERIAL_NUMBER not in os.environ: return f"Missing environment variable {ANDROID_SERIAL_NUMBER}." diff --git a/python/tvm/contrib/hexagon/pytest_plugin.py b/python/tvm/contrib/hexagon/pytest_plugin.py index 809300ea2fb3..97a644400b51 100644 --- a/python/tvm/contrib/hexagon/pytest_plugin.py +++ b/python/tvm/contrib/hexagon/pytest_plugin.py @@ -69,7 +69,11 @@ def _compose(args, decs): return decs -requires_hexagon_toolchain = tvm.testing.requires_hexagon(support_required="compile-only") +def requires_hexagon_toolchain(func): + """Skip a test unless the Hexagon toolchain is available (compile-only).""" + return pytest.mark.skipif( + not tvm.testing.env.has_hexagon_toolchain(), reason="need hexagon toolchain" + )(func) def android_serial_number() -> str | None: diff --git a/python/tvm/testing/__init__.py b/python/tvm/testing/__init__.py index cca984370020..976c893c8021 100644 --- a/python/tvm/testing/__init__.py +++ b/python/tvm/testing/__init__.py @@ -28,3 +28,4 @@ ) from .runner import local_run, rpc_run from .utils import * +from . import env diff --git a/python/tvm/testing/env.py b/python/tvm/testing/env.py new file mode 100644 index 000000000000..6b39b5f2f674 --- /dev/null +++ b/python/tvm/testing/env.py @@ -0,0 +1,518 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Thin capability probes for test gating. + +This module exposes small ``has_*`` predicates that report whether the +current environment can run a given feature. They are meant to be used +with plain pytest markers and ``skipif``:: + + import pytest + import tvm.testing + + @pytest.mark.gpu + @pytest.mark.skipif(not tvm.testing.env.has_cuda(), reason="need cuda") + def test_my_cuda_kernel(): + ... + +Every probe is memoized with :func:`functools.cache`, so the +underlying device query / ``nvcc`` subprocess runs at most once per +process even though ``skipif`` evaluates the predicate at import time for +every decorated test. Probes never raise: when support is absent they +return ``False`` (or a zero version tuple) rather than propagating an +error out of collection. + +Three kinds of probe live here: + +* **runtime device** probes (``has_cuda``, ``has_gpu`` …) ask whether a + usable device of a given kind is present; +* **build-support** probes (``has_cutlass``, ``has_cudnn`` …) ask whether + an optional library was compiled into the runtime; +* **version / capability** probes (``has_cuda_compute``, + ``has_tensorcore`` …) ask about a finer capability of a present device + or toolchain. +""" + +import functools +import os +import platform + +import tvm + +__all__ = [ + "has_aarch64_sme", + "has_aarch64_sve", + "has_adreno_opencl", + "has_aprofile_aem_fvp", + "has_arm_dot", + "has_arm_fp16", + # cpu features + "has_cpu_feature", + "has_cublas", + # runtime device + "has_cuda", + # version / capability + "has_cuda_compute", + "has_cudagraph", + # build support + "has_cudnn", + "has_cutlass", + "has_gpu", + # toolchain / environment + "has_hexagon", + "has_hexagon_toolchain", + "has_hipblas", + "has_llvm", + "has_llvm_min_version", + "has_matrixcore", + "has_metal", + "has_mrvl", + "has_multi_gpu", + "has_nccl", + "has_nnapi", + "has_nvcc_version", + "has_nvptx", + "has_nvshmem", + "has_opencl", + "has_openclml", + "has_rocm", + "has_rpc", + "has_tensorcore", + "has_vulkan", + "has_x86_amx", + "has_x86_avx512", + "has_x86_vnni", + "is_aarch64", + # host architecture + "is_x86", +] + + +@functools.cache +def _device_exists(kind: str, index: int = 0) -> bool: + """Return whether ``tvm.device(kind, index)`` is present and usable.""" + try: + return bool(tvm.device(kind, index).exist) + except Exception: # pylint: disable=broad-except + # A missing backend / driver must skip the test, not crash collection. + return False + + +@functools.cache +def _build_flag_enabled(flag: str) -> bool: + """Return whether an optional build flag (e.g. ``USE_CUTLASS``) is on. + + Mirrors the historical ``Feature`` check: a flag counts as enabled + unless it is explicitly disabled, so library flags carrying a path + still register as present. + """ + try: + value = tvm.support.libinfo().get(flag, "OFF") + return str(value).lower() not in ("off", "false", "0") + except Exception: # pylint: disable=broad-except + return False + + +@functools.cache +def _target_enabled(kind: str) -> bool: + """True if ``kind`` is selected by ``TVM_TEST_TARGETS`` (or the default set). + + Restores the historical ``target_kind_enabled`` opt-out, so CI can exclude a + flaky backend (e.g. opencl) via ``TVM_TEST_TARGETS`` and have its tests skip + even when a device is physically present. + """ + try: + from tvm.testing.utils import _tvm_test_targets # pylint: disable=import-outside-toplevel + + for target in _tvm_test_targets(): + k = target["kind"] if isinstance(target, dict) else str(target).split()[0] + if k == kind: + return True + return False + except Exception: # pylint: disable=broad-except + return True # fail open: the device check still gates + + +@functools.cache +def _runtime_enabled(kind: str) -> bool: + """True if the runtime was built with support for target ``kind``. + + Used for kinds whose device existence does not imply the backend was + compiled in -- notably ``llvm``, which maps to the always-present CPU + device, so ``tvm.device("llvm").exist`` is True even on ``USE_LLVM=OFF``. + """ + try: + return bool(tvm.runtime.enabled(kind)) + except Exception: # pylint: disable=broad-except + return False + + +def _device_usable(kind: str) -> bool: + """True if ``kind`` is enabled for this run and a ``kind`` device exists. + + The TVM_TEST_TARGETS opt-out is checked first so that an excluded backend + never probes a (possibly crashy) device. + """ + return _target_enabled(kind) and _device_exists(kind) + + +# --- runtime device probes ------------------------------------------------- + + +def has_cuda() -> bool: + """True if a CUDA device is present and enabled in TVM_TEST_TARGETS.""" + return _device_usable("cuda") + + +def has_rocm() -> bool: + """True if a ROCm device is present and enabled in TVM_TEST_TARGETS.""" + return _device_usable("rocm") + + +def has_vulkan() -> bool: + """True if a Vulkan device is present and enabled in TVM_TEST_TARGETS.""" + return _device_usable("vulkan") + + +def has_metal() -> bool: + """True if a Metal device is present and enabled in TVM_TEST_TARGETS.""" + return _device_usable("metal") + + +def has_opencl() -> bool: + """True if an OpenCL device is present and enabled in TVM_TEST_TARGETS.""" + return _device_usable("opencl") + + +def has_nvptx() -> bool: + """True if NVPTX is usable: a (CUDA) device, plus the LLVM backend it needs.""" + return _device_usable("nvptx") and has_llvm() + + +def has_llvm() -> bool: + """True if the LLVM backend was built in and enabled in TVM_TEST_TARGETS. + + Uses ``tvm.runtime.enabled`` rather than device existence: ``llvm`` maps to + the CPU device, which exists even on a ``USE_LLVM=OFF`` build. + """ + return _target_enabled("llvm") and _runtime_enabled("llvm") + + +def has_gpu() -> bool: + """True if any GPU backend (cuda/rocm/opencl/metal/vulkan) is present.""" + return ( + _device_exists("cuda") + or _device_exists("rocm") + or _device_exists("opencl") + or _device_exists("metal") + or _device_exists("vulkan") + ) + + +@functools.cache +def has_multi_gpu(count: int = 2) -> bool: + """True if at least ``count`` devices of a single GPU backend exist.""" + for kind in ("cuda", "rocm", "opencl", "metal", "vulkan"): + if all(_device_exists(kind, index) for index in range(count)): + return True + return False + + +# --- build-support probes -------------------------------------------------- +# +# These wrap the optional-library build flags. Features that extend CUDA / +# ROCm additionally require the parent device to be present. + + +def has_cudnn() -> bool: + """True if cuDNN was built in and a CUDA device is present.""" + return has_cuda() and _build_flag_enabled("USE_CUDNN") + + +def has_cublas() -> bool: + """True if cuBLAS was built in and a CUDA device is present.""" + return has_cuda() and _build_flag_enabled("USE_CUBLAS") + + +def has_nccl() -> bool: + """True if NCCL was built in and a CUDA device is present.""" + return has_cuda() and _build_flag_enabled("USE_NCCL") + + +def has_hipblas() -> bool: + """True if hipBLAS was built in and a ROCm device is present.""" + return has_rocm() and _build_flag_enabled("USE_HIPBLAS") + + +def has_cutlass() -> bool: + """True if CUTLASS support was built into the runtime.""" + return _build_flag_enabled("USE_CUTLASS") + + +def has_rpc() -> bool: + """True if RPC support was built into the runtime.""" + return _build_flag_enabled("USE_RPC") + + +def has_nnapi() -> bool: + """True if NNAPI codegen support was built into the runtime.""" + return _build_flag_enabled("USE_NNAPI_CODEGEN") + + +def has_openclml() -> bool: + """True if OpenCLML (CLML) support was built into the runtime.""" + return _build_flag_enabled("USE_CLML") + + +def has_mrvl() -> bool: + """True if the Marvell (MRVL) backend was built into the runtime.""" + return _build_flag_enabled("USE_MRVL") + + +@functools.cache +def has_nvshmem() -> bool: + """True if the disco NVSHMEM runtime is available (requires CUDA). + + Probes the runtime global function rather than the ``USE_NVSHMEM`` build + flag, since the flag can be set in builds that do not ship the runtime. + """ + try: + return has_cuda() and ( + tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", allow_missing=True) + is not None + ) + except Exception: # pylint: disable=broad-except + return False + + +# --- version / capability probes ------------------------------------------- + + +@functools.cache +def _cuda_compute_version() -> tuple: + """Return the (major, minor) CUDA compute version, or (0, 0) if unknown.""" + try: + from tvm.support import nvcc # pylint: disable=import-outside-toplevel + + arch = nvcc.get_target_compute_version() + return nvcc.parse_compute_version(arch) + except Exception: # pylint: disable=broad-except + return (0, 0) + + +def has_cuda_compute(major: int, minor: int = 0, exact: bool = False) -> bool: + """True if the CUDA compute capability satisfies ``(major, minor)``. + + When ``exact`` is False (default) the check is ``compute >= (major, + minor)``; when True it requires an exact match. Returns False when no + CUDA device is present, so it implies :func:`has_cuda`. + """ + if not has_cuda(): + return False + compute = _cuda_compute_version() + want = (major, minor) + if exact: + return compute == want + return compute >= want + + +@functools.cache +def _nvcc_version() -> tuple: + """Return the (major, minor, release) nvcc version, or (0, 0, 0).""" + try: + from tvm.support import nvcc # pylint: disable=import-outside-toplevel + + return nvcc.get_cuda_version() + except Exception: # pylint: disable=broad-except + return (0, 0, 0) + + +def has_nvcc_version(major: int, minor: int = 0, release: int = 0) -> bool: + """True if a CUDA device is present and nvcc is at least ``(major, minor, release)``. + + Implies :func:`has_cuda`, matching the historical ``requires_nvcc_version`` + decorator which also required the CUDA runtime. + """ + return has_cuda() and _nvcc_version() >= (major, minor, release) + + +@functools.cache +def _llvm_version_major() -> int: + """Return the major LLVM version, or 0 if LLVM is unavailable.""" + try: + return int(tvm.target.codegen.llvm_version_major()) + except Exception: # pylint: disable=broad-except + return 0 + + +def has_llvm_min_version(major: int) -> bool: + """True if LLVM is available and its major version is at least ``major``.""" + return has_llvm() and _llvm_version_major() >= major + + +@functools.cache +def has_tensorcore() -> bool: + """True if a CUDA device with Tensor Core support (compute >= 7) exists.""" + try: + from tvm.support import nvcc # pylint: disable=import-outside-toplevel + + return has_cuda() and bool(nvcc.have_tensorcore(tvm.cuda().compute_version)) + except Exception: # pylint: disable=broad-except + return False + + +@functools.cache +def has_matrixcore() -> bool: + """True if a ROCm device with Matrix Core support (compute >= 8) exists.""" + try: + from tvm.support import rocm # pylint: disable=import-outside-toplevel + + return has_rocm() and bool(rocm.have_matrixcore(tvm.rocm().compute_version)) + except Exception: # pylint: disable=broad-except + return False + + +@functools.cache +def has_cudagraph() -> bool: + """True if a CUDA device is present and the toolkit supports CUDA Graphs. + + Implies :func:`has_cuda`, matching the historical ``requires_cudagraph`` + decorator (``parent_features="cuda"``): ``nvcc.have_cudagraph()`` only + checks the toolkit version, so the device guard must be explicit. + """ + try: + from tvm.support import nvcc # pylint: disable=import-outside-toplevel + + return has_cuda() and bool(nvcc.have_cudagraph()) + except Exception: # pylint: disable=broad-except + return False + + +# --- toolchain / environment probes ---------------------------------------- + + +@functools.cache +def has_hexagon_toolchain() -> bool: + """True if the Hexagon toolchain is available for compilation.""" + try: + from tvm.contrib.hexagon import ( # pylint: disable=import-outside-toplevel + _ci_env_check, + ) + + return _build_flag_enabled("USE_HEXAGON") and _ci_env_check._compile_time_check() is True + except Exception: # pylint: disable=broad-except + return False + + +@functools.cache +def has_hexagon() -> bool: + """True if Hexagon can both compile and run (toolchain + attached device).""" + try: + from tvm.contrib.hexagon import ( # pylint: disable=import-outside-toplevel + _ci_env_check, + ) + + return has_hexagon_toolchain() and _ci_env_check._run_time_check() is True + except Exception: # pylint: disable=broad-except + return False + + +@functools.cache +def has_adreno_opencl() -> bool: + """True if remote Adreno OpenCL testing is configured (RPC_TARGET set).""" + return _build_flag_enabled("USE_OPENCL") and os.environ.get("RPC_TARGET") is not None + + +@functools.cache +def has_aprofile_aem_fvp() -> bool: + """True if the AProfile AEM FVP simulator is on PATH.""" + try: + import shutil # pylint: disable=import-outside-toplevel + + return shutil.which("FVP_Base_RevC-2xAEMvA") is not None + except Exception: # pylint: disable=broad-except + return False + + +# --- cpu feature probes ---------------------------------------------------- + + +@functools.cache +def _has_cpu_feature(features) -> bool: + """True if the host CPU advertises the given LLVM target ``features``.""" + try: + codegen = tvm.target.codegen + cpu = codegen.llvm_get_system_cpu() + triple = codegen.llvm_get_system_triple() + target = tvm.target.Target({"kind": "llvm", "mtriple": triple, "mcpu": cpu}) + return bool(codegen.target_has_features(features, target)) + except Exception: # pylint: disable=broad-except + return False + + +def has_cpu_feature(features) -> bool: + """True if the host CPU supports ``features`` (a name or list of names).""" + if isinstance(features, list): + features = tuple(features) + return _has_cpu_feature(features) + + +def has_arm_dot() -> bool: + """True if the host CPU supports the ARM dot-product instructions.""" + return has_cpu_feature("dotprod") + + +def has_arm_fp16() -> bool: + """True if the host CPU supports ARM Neon FP16 instructions.""" + return has_cpu_feature("fullfp16") + + +def has_aarch64_sve() -> bool: + """True if the host CPU supports AArch64 SVE.""" + return has_cpu_feature("sve") + + +def has_aarch64_sme() -> bool: + """True if the host CPU supports AArch64 SME.""" + return has_cpu_feature("sme") + + +def has_x86_vnni() -> bool: + """True if the host CPU supports x86 VNNI (AVX512-VNNI or AVX-VNNI).""" + return has_cpu_feature("avx512vnni") or has_cpu_feature("avxvnni") + + +def has_x86_avx512() -> bool: + """True if the host CPU supports the x86 AVX512 extensions.""" + return has_cpu_feature(["avx512bw", "avx512cd", "avx512dq", "avx512vl", "avx512f"]) + + +def has_x86_amx() -> bool: + """True if the host CPU supports the x86 AMX (int8) extensions.""" + return has_cpu_feature("amx-int8") + + +# --- host architecture probes ---------------------------------------------- + + +def is_x86() -> bool: + """True if running on an x86_64 host.""" + return platform.machine() == "x86_64" + + +def is_aarch64() -> bool: + """True if running on an aarch64 host.""" + return platform.machine() == "aarch64" diff --git a/python/tvm/testing/plugin.py b/python/tvm/testing/plugin.py index d7e096472f4d..bba2da6aee0d 100644 --- a/python/tvm/testing/plugin.py +++ b/python/tvm/testing/plugin.py @@ -36,7 +36,7 @@ import pytest import tvm -from tvm.testing import utils +from tvm.testing import env, utils try: from xdist.scheduler.loadscope import LoadScopeScheduling @@ -46,24 +46,12 @@ HAVE_XDIST = False -MARKERS = { - "gpu": "mark a test as requiring a gpu", - "tensorcore": "mark a test as requiring a tensorcore", - "cuda": "mark a test as requiring CUDA", - "opencl": "mark a test as requiring opencl", - "rocm": "mark a test as requiring rocm", - "vulkan": "mark a test as requiring vulkan", - "metal": "mark a test as requiring metal", - "llvm": "mark a test as requiring llvm", - "hexagon": "mark a test as requiring hexagon", -} - - def pytest_configure(config): - """Runs at pytest configure time, defines marks to be used later.""" + """Runs at pytest configure time. - for feature in utils.Feature._all_features.values(): - feature._register_marker(config) + Hardware/feature markers are declared statically in pyproject.toml; this + hook only reports the active target configuration. + """ print( "enabled targets:", @@ -290,31 +278,42 @@ def sort_key(item): items.sort(key=sort_key) +def _gpu_mark_and_skip(has_fn, reason): + """A GPU-family target: the ``gpu`` selection marker plus an env skip.""" + return [pytest.mark.gpu, pytest.mark.skipif(not has_fn(), reason=reason)] + + +def _skip_only(has_fn, reason): + """A non-GPU target: an env skip with no selection marker.""" + return [pytest.mark.skipif(not has_fn(), reason=reason)] + + def _target_to_requirement(target): if isinstance(target, str | dict): target = tvm.target.Target(target) - # mapping from target to decorator - if target.kind.name == "cuda" and "cudnn" in target.attrs.get("libs", []): - return utils.requires_cudnn.marks() - if target.kind.name == "cuda" and "cublas" in target.attrs.get("libs", []): - return utils.requires_cublas.marks() - if target.kind.name == "cuda": - return utils.requires_cuda.marks() - if target.kind.name == "rocm": - return utils.requires_rocm.marks() - if target.kind.name == "vulkan": - return utils.requires_vulkan.marks() - if target.kind.name == "nvptx": - return utils.requires_nvptx.marks() - if target.kind.name == "metal": - return utils.requires_metal.marks() - if target.kind.name == "opencl": - return utils.requires_opencl.marks() - if target.kind.name == "llvm": - return utils.requires_llvm.marks() - if target.kind.name == "hexagon": - return utils.requires_hexagon.marks() + # GPU-family kinds get the `gpu` selection marker; CPU-family kinds only skip. + kind = target.kind.name + if kind == "cuda" and "cudnn" in target.attrs.get("libs", []): + return _gpu_mark_and_skip(env.has_cudnn, "need cudnn") + if kind == "cuda" and "cublas" in target.attrs.get("libs", []): + return _gpu_mark_and_skip(env.has_cublas, "need cublas") + if kind == "cuda": + return _gpu_mark_and_skip(env.has_cuda, "need cuda") + if kind == "rocm": + return _gpu_mark_and_skip(env.has_rocm, "need rocm") + if kind == "vulkan": + return _gpu_mark_and_skip(env.has_vulkan, "need vulkan") + if kind == "nvptx": + return _gpu_mark_and_skip(env.has_nvptx, "need nvptx") + if kind == "metal": + return _gpu_mark_and_skip(env.has_metal, "need metal") + if kind == "opencl": + return _gpu_mark_and_skip(env.has_opencl, "need opencl") + if kind == "llvm": + return _skip_only(env.has_llvm, "need llvm") + if kind == "hexagon": + return _skip_only(env.has_hexagon, "need hexagon") return [] diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index cfc5a357a36b..c90e610af4d6 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -69,18 +69,14 @@ def test_something(): import ctypes import functools import inspect -import itertools import logging import os import pickle import platform -import shutil import sys import textwrap import time -from collections.abc import Callable from pathlib import Path -from typing import ClassVar import ml_dtypes import numpy as np @@ -88,13 +84,11 @@ def test_something(): import tvm import tvm.arith -import tvm.contrib.hexagon._ci_env_check as hexagon import tvm.support.utils import tvm.te import tvm.tirx from tvm.contrib import cudnn -from tvm.support import nvcc, rocm -from tvm.target import codegen +from tvm.support import nvcc SKIP_SLOW_TESTS = os.getenv("SKIP_SLOW_TESTS", "").lower() in {"true", "1", "yes"} IS_IN_CI = os.getenv("CI", "") == "true" @@ -485,7 +479,7 @@ def device_enabled(target): Example ------- - >>> @tvm.testing.uses_gpu + >>> @pytest.mark.gpu >>> def test_mytest(): >>> for target in ["cuda", "llvm"]: >>> if device_enabled(target): @@ -517,8 +511,8 @@ def enabled_targets(): target exists. If TVM_TEST_TARGETS is not set, it defaults to variable DEFAULT_TEST_TARGETS in this module. - If you use this function in a test, you **must** decorate the test with - :py:func:`tvm.testing.uses_gpu` (otherwise it will never be run on the gpu). + If you use this function in a test, you **must** mark the test with + ``@pytest.mark.gpu`` (otherwise it will never be run on the gpu). Returns ------- @@ -529,589 +523,6 @@ def enabled_targets(): return [(t["target"], tvm.device(t["target_kind"])) for t in _get_targets() if t["is_runnable"]] -class Feature: - """A feature that may be required to run a test. - - Parameters - ---------- - name: str - - The short name of the feature. Should match the name in the - requires_* decorator. This is applied as a mark to all tests - using this feature, and can be used in pytests ``-m`` - argument. - - long_name: Optional[str] - - The long name of the feature, to be used in error messages. - - If None, defaults to the short name. - - cmake_flag: Optional[str] - - The flag that must be enabled in the config.cmake in order to - use this feature. - - If None, no flag is required to use this feature. - - target_kind_enabled: Optional[str] - - The target kind that must be enabled to run tests using this - feature. If present, the target_kind must appear in the - TVM_TEST_TARGETS environment variable, or in - tvm.testing.DEFAULT_TEST_TARGETS if TVM_TEST_TARGETS is - undefined. - - If None, this feature does not require a specific target to be - enabled. - - compile_time_check: Optional[Callable[[], Union[bool,str]]] - - A check that returns True if the feature can be used at - compile-time. (e.g. Validating the version number of the nvcc - compiler.) If the feature does not have support to perform - compile-time tests, the check should returns False to display - a generic error message, or a string to display a more - specific error message. - - If None, no additional check is performed. - - target_kind_hardware: Optional[str] - - The target kind that must have available hardware in order to - run tests using this feature. This is checked using - tvm.device(target_kind_hardware).exist. If a feature requires - a different check, this should be implemented using - run_time_check. - - If None, this feature does not require a specific - tvm.device to exist. - - run_time_check: Optional[Callable[[], Union[bool,str]]] - - A check that returns True if the feature can be used at - run-time. (e.g. Validating the compute version supported by a - GPU.) If the feature does not have support to perform - run-time tests, the check should returns False to display a - generic error message, or a string to display a more specific - error message. - - If None, no additional check is performed. - - parent_features: Optional[Union[str,List[str]]] - - The short name of a feature or features that are required in - order to use this feature. (e.g. Using cuDNN requires using - CUDA) This feature should inherit all checks of the parent - feature, with the exception of the `target_kind_enabled` - checks. - - If None, this feature does not require any other parent - features. - - """ - - _all_features: ClassVar[dict[str, "Feature"]] = {} - - def __init__( - self, - name: str, - long_name: str | None = None, - cmake_flag: str | None = None, - target_kind_enabled: str | None = None, - compile_time_check: Callable[[], bool | str] | None = None, - target_kind_hardware: str | None = None, - run_time_check: Callable[[], bool | str] | None = None, - parent_features: str | list[str] | None = None, - ): - self.name = name - self.long_name = long_name or name - self.cmake_flag = cmake_flag - self.target_kind_enabled = target_kind_enabled - self.compile_time_check = compile_time_check - self.target_kind_hardware = target_kind_hardware - self.run_time_check = run_time_check - - if parent_features is None: - self.parent_features = [] - elif isinstance(parent_features, str): - self.parent_features = [parent_features] - else: - self.parent_features = parent_features - - self._all_features[self.name] = self - - def _register_marker(self, config): - config.addinivalue_line("markers", f"{self.name}: Mark a test as using {self.long_name}") - - def _uses_marks(self): - for parent in self.parent_features: - yield from self._all_features[parent]._uses_marks() - - yield getattr(pytest.mark, self.name) - - def _compile_only_marks(self): - for parent in self.parent_features: - yield from self._all_features[parent]._compile_only_marks() - - if self.compile_time_check is not None: - res = self.compile_time_check() - if isinstance(res, str): - yield pytest.mark.skipif(True, reason=res) - else: - yield pytest.mark.skipif( - not res, reason=f"Compile-time support for {self.long_name} not present" - ) - - if self.target_kind_enabled is not None: - target_kind = self.target_kind_enabled.split()[0] - - def _kind_of(enabled): - return enabled["kind"] if isinstance(enabled, dict) else enabled.split()[0] - - yield pytest.mark.skipif( - all(_kind_of(enabled) != target_kind for enabled in _tvm_test_targets()), - reason=( - f"{self.target_kind_enabled} tests disabled " - f"by TVM_TEST_TARGETS environment variable" - ), - ) - - if self.cmake_flag is not None: - yield pytest.mark.skipif( - not _cmake_flag_enabled(self.cmake_flag), - reason=( - f"{self.long_name} support not enabled. " - f"Set {self.cmake_flag} in config.cmake to enable." - ), - ) - - def _run_only_marks(self): - for parent in self.parent_features: - yield from self._all_features[parent]._run_only_marks() - - if self.run_time_check is not None: - res = self.run_time_check() - if isinstance(res, str): - yield pytest.mark.skipif(True, reason=res) - else: - yield pytest.mark.skipif( - not res, reason=f"Run-time support for {self.long_name} not present" - ) - - if self.target_kind_hardware is not None: - yield pytest.mark.skipif( - not tvm.device(self.target_kind_hardware).exist, - reason=f"No device exists for target {self.target_kind_hardware}", - ) - - def marks(self, support_required="compile-and-run"): - """Return a list of marks to be used - - Parameters - ---------- - - support_required: str - - Allowed values: "compile-and-run" (default), - "compile-only", or "optional". - - See Feature.__call__ for details. - """ - if support_required not in ["compile-and-run", "compile-only", "optional"]: - raise ValueError(f"Unknown feature support type: {support_required}") - - if support_required == "compile-and-run": - marks = itertools.chain( - self._run_only_marks(), self._compile_only_marks(), self._uses_marks() - ) - elif support_required == "compile-only": - marks = itertools.chain(self._compile_only_marks(), self._uses_marks()) - elif support_required == "optional": - marks = self._uses_marks() - else: - raise ValueError(f"Unknown feature support type: {support_required}") - - return list(marks) - - def __call__(self, func=None, *, support_required="compile-and-run"): - """Mark a pytest function as requiring this feature - - Can be used either as a bare decorator, or as a decorator with - arguments. - - Parameters - ---------- - - func: Callable - - The pytest test function to be marked - - support_required: str - - Allowed values: "compile-and-run" (default), - "compile-only", or "optional". - - If "compile-and-run", the test case is marked as using the - feature, and is skipped if the environment lacks either - compile-time or run-time support for the feature. - - If "compile-only", the test case is marked as using the - feature, and is skipped if the environment lacks - compile-time support. - - If "optional", the test case is marked as using the - feature, but isn't skipped. This is kept for backwards - compatibility for tests that use `enabled_targets()`, and - should be avoided in new test code. Instead, prefer - parametrizing over the target using the `target` fixture. - - Examples - -------- - - .. code-block:: python - - @feature - def test_compile_and_run(): - ... - - @feature(compile_only=True) - def test_compile_only(): - ... - - """ - - if support_required not in ["compile-and-run", "compile-only", "optional"]: - raise ValueError(f"Unknown feature support type: {support_required}") - - def wrapper(func): - for mark in self.marks(support_required=support_required): - func = mark(func) - return func - - if func is None: - return wrapper - - return wrapper(func) - - @classmethod - def require(cls, name, support_required="compile-and-run"): - """Returns a decorator that marks a test as requiring a feature - - Parameters - ---------- - - name: str - - The name of the feature that is used by the test - - support_required: str - - Allowed values: "compile-and-run" (default), - "compile-only", or "optional". - - See Feature.__call__ for details. - - Examples - -------- - - .. code-block:: python - - @Feature.require("cuda") - def test_compile_and_run(): - ... - - @Feature.require("cuda", compile_only=True) - def test_compile_only(): - ... - """ - return cls._all_features[name](support_required=support_required) - - -def _any_gpu_exists(): - return ( - tvm.cuda().exist - or tvm.rocm().exist - or tvm.opencl().exist - or tvm.metal().exist - or tvm.vulkan().exist - ) - - -def _multi_gpu_exists(): - return ( - (tvm.cuda(0).exist and tvm.cuda(1).exist) - or (tvm.rocm(0).exist and tvm.rocm(1).exist) - or (tvm.opencl(0).exist and tvm.opencl(1).exist) - or (tvm.metal(0).exist and tvm.metal(1).exist) - or (tvm.vulkan(0).exist and tvm.vulkan(1).exist) - ) - - -# Mark a test as requiring llvm to run -requires_llvm = Feature( - "llvm", "LLVM", cmake_flag="USE_LLVM", target_kind_enabled="llvm", target_kind_hardware="llvm" -) - -# Mark a test as requiring a GPU to run. -requires_gpu = Feature("gpu", run_time_check=_any_gpu_exists) - -# Mark to differentiate tests that use the GPU in some capacity. -# -# These tests will be run on CPU-only test nodes and on test nodes with GPUs. -# To mark a test that must have a GPU present to run, use -# :py:func:`tvm.testing.requires_gpu`. -uses_gpu = requires_gpu(support_required="optional") - -# Mark a test as requiring multiple GPUs to run. -requires_multi_gpu = Feature("multi_gpu", run_time_check=_multi_gpu_exists) - -# Mark to differentiate tests that use multiple GPUs in some capacity. -# -# These tests will be run on test nodes with multiple GPUs. -# To mark a test that must have multiple GPUs present to run, use -# :py:func:`tvm.testing.requires_multi_gpu`. -uses_multi_gpu = requires_multi_gpu(support_required="optional") - -# Mark a test as requiring the x86 Architecture to run. -requires_x86 = Feature( - "x86", "x86 Architecture", run_time_check=lambda: platform.machine() == "x86_64" -) - -# Mark a test as requiring the aarch64 Architecture to run. -requires_aarch64 = Feature( - "AArch64", "AArch64 Architecture", run_time_check=lambda: platform.machine() == "aarch64" -) - -# Mark a test as requiring the CUDA runtime. -requires_cuda = Feature( - "cuda", - "CUDA", - cmake_flag="USE_CUDA", - target_kind_enabled="cuda", - target_kind_hardware="cuda", - parent_features="gpu", -) - -# Mark a test as requiring a tensorcore to run -requires_tensorcore = Feature( - "tensorcore", - "NVIDIA Tensor Core", - run_time_check=lambda: tvm.cuda().exist and nvcc.have_tensorcore(tvm.cuda().compute_version), - parent_features="cuda", -) - -# Mark a test as requiring the cuDNN library. -requires_cudnn = Feature("cudnn", "cuDNN", cmake_flag="USE_CUDNN", parent_features="cuda") - -# Mark a test as requiring the cuBLAS library. -requires_cublas = Feature("cublas", "cuBLAS", cmake_flag="USE_CUBLAS", parent_features="cuda") - -# Mark a test as requiring NCCL support -requires_nccl = Feature("nccl", "NCCL", cmake_flag="USE_NCCL", parent_features="cuda") - - -def _nvshmem_exists(): - # Probe the runtime function rather than the USE_NVSHMEM cmake flag: the - # flag can be ON in builds that do not ship the disco NVSHMEM runtime. - return ( - tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", allow_missing=True) - is not None - ) - - -# Mark a test as requiring NVSHMEM support -requires_nvshmem = Feature( - "nvshmem", "NVSHMEM", run_time_check=_nvshmem_exists, parent_features="cuda" -) - -# Mark a test as requiring the NVPTX compilation on the CUDA runtime -requires_nvptx = Feature( - "nvptx", - "NVPTX", - target_kind_enabled="nvptx", - target_kind_hardware="nvptx", - parent_features=["llvm", "cuda"], -) - -# Mark a test as requiring the CUDA Graph Feature -requires_cudagraph = Feature( - "cudagraph", - "CUDA Graph", - target_kind_enabled="cuda", - compile_time_check=nvcc.have_cudagraph, - parent_features="cuda", -) - -# Mark a test as requiring the OpenCL runtime on remote RPC -requires_adreno_opencl = Feature( - "opencl", - long_name="Remote Adreno OpenCL", - cmake_flag="USE_OPENCL", - target_kind_enabled="opencl", - target_kind_hardware=None, - parent_features="gpu", - run_time_check=lambda: os.getenv("RPC_TARGET") is not None, -) - -# Mark a test as requiring the OpenCL runtime -requires_opencl = Feature( - "opencl", - "OpenCL", - cmake_flag="USE_OPENCL", - target_kind_enabled="opencl", - target_kind_hardware="opencl" if "RPC_TARGET" not in os.environ else None, - parent_features="gpu" if "RPC_TARGET" not in os.environ else None, -) - -# Mark a test as requiring the rocm runtime -requires_rocm = Feature( - "rocm", - "ROCm", - cmake_flag="USE_ROCM", - target_kind_enabled="rocm", - target_kind_hardware="rocm", - parent_features="gpu", -) - -# Mark a test as requiring a matrixcore to run -requires_matrixcore = Feature( - "matrixcore", - "AMD Matrix Core", - run_time_check=lambda: tvm.rocm().exist and rocm.have_matrixcore(tvm.rocm().compute_version), - parent_features="rocm", -) - -# Mark a test as requiring the hipBLAS library. -requires_hipblas = Feature("hipblas", "hipBLAS", cmake_flag="USE_HIPBLAS", parent_features="rocm") - -# Mark a test as requiring the metal runtime -requires_metal = Feature( - "metal", - "Metal", - cmake_flag="USE_METAL", - target_kind_enabled="metal", - target_kind_hardware="metal", - parent_features="gpu", -) - -# Mark a test as requiring the vulkan runtime -requires_vulkan = Feature( - "vulkan", - "Vulkan", - cmake_flag="USE_VULKAN", - target_kind_enabled="vulkan", - target_kind_hardware="vulkan", - parent_features="gpu", -) - -# Mark a test as requiring OpenCLML support in build. -requires_openclml = Feature("OpenCLML", "CLML", cmake_flag="USE_CLML", target_kind_enabled="opencl") - -# Mark a test as requiring NNAPI support in build. -requires_nnapi = Feature("NNAPI", "NNAPI", cmake_flag="USE_NNAPI_CODEGEN") - -# Mark a test as requiring CUTLASS to run -requires_cutlass = Feature("cutlass", "CUTLASS", cmake_flag="USE_CUTLASS") - -# Mark a test as requiring rpc to run -requires_rpc = Feature("rpc", "RPC", cmake_flag="USE_RPC") - -# Mark a test as requiring the MRVL Library -requires_mrvl = Feature("mrvl", "Marvell", cmake_flag="USE_MRVL") - -# Mark a test as requiring Hexagon to run -requires_hexagon = Feature( - "hexagon", - "Hexagon", - cmake_flag="USE_HEXAGON", - target_kind_enabled="hexagon", - compile_time_check=hexagon._compile_time_check, - run_time_check=hexagon._run_time_check, - parent_features="llvm", -) - - -def _aprofile_aem_fvp_compile_time_check(): - if shutil.which("FVP_Base_RevC-2xAEMvA") is None: - return "AProfile AEM is not available" - return True - - -requires_aprofile_aem_fvp = Feature( - "aprofile-aem-fvp", - "AProfile AEM FVP", - compile_time_check=_aprofile_aem_fvp_compile_time_check, -) - - -# check cpu features -def _has_cpu_feat(features): - cpu = codegen.llvm_get_system_cpu() - triple = codegen.llvm_get_system_triple() - target = {"kind": "llvm", "mtriple": triple, "mcpu": cpu} - has_feat = codegen.target_has_features(features, tvm.target.Target(target)) - - return has_feat - - -requires_arm_dot = Feature( - "arm_dot", - "ARM dot product", - run_time_check=lambda: _has_cpu_feat("dotprod"), -) - - -requires_arm_fp16 = Feature( - "arm_fp16", - "Arm(R) Neon(TM) instructions for FP16", - run_time_check=lambda: _has_cpu_feat("fullfp16"), -) - - -requires_aarch64_sve = Feature( - "arm_sve", - "AArch64 SVE", - run_time_check=lambda: _has_cpu_feat("sve"), -) - - -requires_aarch64_sme = Feature( - "arm_sme", - "AArch64 SME", - run_time_check=lambda: _has_cpu_feat("sme"), -) - - -requires_x86_vnni = Feature( - "x86_vnni", - "x86 VNNI Extensions", - run_time_check=lambda: _has_cpu_feat("avx512vnni") or _has_cpu_feat("avxvnni"), -) - - -requires_x86_avx512 = Feature( - "x86_avx512", - "x86 AVX512 Extensions", - run_time_check=lambda: _has_cpu_feat( - ["avx512bw", "avx512cd", "avx512dq", "avx512vl", "avx512f"] - ), -) - - -requires_x86_amx = Feature( - "x86_amx", "x86 AMX Extensions", run_time_check=lambda: _has_cpu_feat("amx-int8") -) - - -def _cmake_flag_enabled(flag): - flag = tvm.support.libinfo().get(flag, "OFF") - - # Because many of the flags can be library flags, we check if the - # flag is not disabled, rather than checking if it is enabled. - return flag.lower() not in ["off", "false", "0"] - - def _parse_target_entry(entry): """Parse a target entry from TVM_TEST_TARGETS env var. @@ -1164,126 +575,6 @@ def _compose(args, decs): ) -def requires_llvm_minimum_version(major_version): - """Mark a test as requiring at least a specific version of LLVM. - - Unit test marked with this decorator will run only if the - installed version of LLVM is at least `major_version`. - - This also marks the test as requiring LLVM backend support. - - Parameters - ---------- - major_version: int - - - """ - - try: - llvm_version = tvm.target.codegen.llvm_version_major() - except RuntimeError: - llvm_version = 0 - - requires = [ - pytest.mark.skipif( - llvm_version < major_version, reason=f"Requires LLVM >= {major_version}" - ), - *requires_llvm.marks(), - ] - - def inner(func): - return _compose([func], requires) - - return inner - - -def requires_nvcc_version(major_version, minor_version=0, release_version=0): - """Mark a test as requiring at least a specific version of nvcc. - - Unit test marked with this decorator will run only if the - installed version of NVCC is at least `(major_version, - minor_version, release_version)`. - - This also marks the test as requiring a cuda support. - - Parameters - ---------- - major_version: int - - The major version of the (major,minor,release) version tuple. - - minor_version: int - - The minor version of the (major,minor,release) version tuple. - - release_version: int - - The release version of the (major,minor,release) version tuple. - - """ - - try: - nvcc_version = nvcc.get_cuda_version() - except RuntimeError: - nvcc_version = (0, 0, 0) - - min_version = (major_version, minor_version, release_version) - version_str = ".".join(str(v) for v in min_version) - requires = [ - pytest.mark.skipif(nvcc_version < min_version, reason=f"Requires NVCC >= {version_str}"), - *requires_cuda.marks(), - ] - - def inner(func): - return _compose([func], requires) - - return inner - - -def requires_cuda_compute_version(major_version, minor_version=0, exact=False): - """Mark a test as requiring at least a compute architecture - - Unit test marked with this decorator will run only if the CUDA - compute architecture of the GPU is at least `(major_version, - minor_version)`. - - This also marks the test as requiring a cuda support. - - Parameters - ---------- - major_version: int - - The major version of the (major,minor) version tuple. - - minor_version: int - - The minor version of the (major,minor) version tuple. - """ - min_version = (major_version, minor_version) - try: - arch = tvm.support.nvcc.get_target_compute_version() - compute_version = tvm.support.nvcc.parse_compute_version(arch) - except ValueError: - # No GPU present. This test will be skipped from the - # requires_cuda() marks as well. - compute_version = (0, 0) - - min_version_str = ".".join(str(v) for v in min_version) - compute_version_str = ".".join(str(v) for v in compute_version) - requires = [ - pytest.mark.skipif( - compute_version < min_version or (exact and compute_version != min_version), - reason=f"Requires CUDA compute >= {min_version_str}, but have {compute_version_str}", - ), - *requires_cuda.marks(), - ] - - def inner(func): - return _compose([func], requires) - - return inner - - def skip_if_32bit(reason): def decorator(*args): if "32bit" in platform.architecture()[0]: @@ -1872,8 +1163,8 @@ def terminate_self(): def is_ampere_or_newer(): """Check if the target environment has an NVIDIA Ampere GPU or newer.""" - arch = tvm.support.nvcc.get_target_compute_version() - major, minor = tvm.support.nvcc.parse_compute_version(arch) + arch = nvcc.get_target_compute_version() + major, minor = nvcc.parse_compute_version(arch) return major >= 8 and minor != 9 diff --git a/tests/python/all-platform-minimal-test/test_minimal_target_codegen_llvm.py b/tests/python/all-platform-minimal-test/test_minimal_target_codegen_llvm.py index 117be6e78d61..56ef78c4aad9 100644 --- a/tests/python/all-platform-minimal-test/test_minimal_target_codegen_llvm.py +++ b/tests/python/all-platform-minimal-test/test_minimal_target_codegen_llvm.py @@ -22,14 +22,16 @@ import re import numpy as np +import pytest import tvm import tvm.testing from tvm import te, topi from tvm.support import utils +from tvm.testing import env -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_add_pipeline(): """all-platform-minimal-test: Check LLVM enablement.""" nn = 128 diff --git a/tests/python/codegen/test_codegen_error_handling.py b/tests/python/codegen/test_codegen_error_handling.py index 2329b06f3948..3e8cddfca87e 100644 --- a/tests/python/codegen/test_codegen_error_handling.py +++ b/tests/python/codegen/test_codegen_error_handling.py @@ -29,6 +29,7 @@ import tvm import tvm.testing from tvm.script import tirx as T +from tvm.testing import env # Parameterize over both LLVM and C backends codegen_target = tvm.testing.parameter("llvm", "c") @@ -276,7 +277,8 @@ def func(a: T.Buffer((128, 128), "float32"), b: T.Buffer((128, 128), "float32")) # ── Device mismatch errors ───────────────────────────────── -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_device_mismatch_error(): """Passing GPU tensor to CPU function raises ValueError.""" diff --git a/tests/python/codegen/test_gpu_codegen_allreduce.py b/tests/python/codegen/test_gpu_codegen_allreduce.py index 31fb71706df2..278b8ff88d82 100644 --- a/tests/python/codegen/test_gpu_codegen_allreduce.py +++ b/tests/python/codegen/test_gpu_codegen_allreduce.py @@ -23,6 +23,7 @@ import tvm.testing from tvm.script import ir as I from tvm.script import tirx as T +from tvm.testing import env def _reduce_sum_module(d1, d2, d3): @@ -118,7 +119,8 @@ def compile_metal(src, target): tvm.register_global_func(name, cached, override=True) -@tvm.testing.requires_metal(support_required="compile-only") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_metal(), reason="need metal") def test_allreduce_sum_compile(optional_metal_compile_callback): # Disable the parametrization over dims, at least for now dims = (1, 1, 2) diff --git a/tests/python/codegen/test_inject_ptx_ldg32.py b/tests/python/codegen/test_inject_ptx_ldg32.py index 821f987e635b..41f41bd802ed 100644 --- a/tests/python/codegen/test_inject_ptx_ldg32.py +++ b/tests/python/codegen/test_inject_ptx_ldg32.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm import tvm.testing from tvm.script import tirx as T +from tvm.testing import env @T.prim_func(s_tir=True) @@ -38,7 +40,8 @@ def vector_add(A: T.Buffer((16), "float32"), B: T.Buffer((32), "float32")) -> No B[tx] = A_local[tx] + 1.0 -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_inject_ptx_intrin(): f = vector_add arch = tvm.support.nvcc.get_target_compute_version() diff --git a/tests/python/codegen/test_target_codegen_blob.py b/tests/python/codegen/test_target_codegen_blob.py index f481142e19d5..86dcb6dcfb55 100644 --- a/tests/python/codegen/test_target_codegen_blob.py +++ b/tests/python/codegen/test_target_codegen_blob.py @@ -28,7 +28,7 @@ from tvm.support import cc, popen_pool, tar, utils -@tvm.testing.uses_gpu +@pytest.mark.gpu def test_cuda_multi_lib(): pytest.importorskip("cloudpickle") diff --git a/tests/python/codegen/test_target_codegen_bool.py b/tests/python/codegen/test_target_codegen_bool.py index a1ff6f339d0e..3a5a69b02fd6 100644 --- a/tests/python/codegen/test_target_codegen_bool.py +++ b/tests/python/codegen/test_target_codegen_bool.py @@ -17,6 +17,7 @@ """codegen related to bool types""" import numpy as np +import pytest import tvm import tvm.testing @@ -24,7 +25,7 @@ from tvm.script import tirx as T -@tvm.testing.uses_gpu +@pytest.mark.gpu @tvm.testing.exclude_targets("nvptx") def test_cmp_load_store(target, dev): @I.ir_module(s_tir=True) diff --git a/tests/python/codegen/test_target_codegen_cross_llvm.py b/tests/python/codegen/test_target_codegen_cross_llvm.py index 204784f031f1..0dc5cfe2345b 100644 --- a/tests/python/codegen/test_target_codegen_cross_llvm.py +++ b/tests/python/codegen/test_target_codegen_cross_llvm.py @@ -21,6 +21,7 @@ import struct import numpy as np +import pytest import tvm import tvm.testing @@ -28,6 +29,7 @@ from tvm.script import ir as I from tvm.script import tirx as T from tvm.support import cc, utils +from tvm.testing import env @I.ir_module(s_tir=True) @@ -48,7 +50,7 @@ def main( C[v_i0] = A[v_i0] + B[v_i0] -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_add_pipeline(): nn = 1024 diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index 31a2d4dd1aab..2c0f164baf68 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -26,6 +26,7 @@ from tvm.script import ir as I from tvm.script import tirx as T from tvm.support.nvcc import have_bf16, have_fp16, have_int8 +from tvm.testing import env @pytest.fixture(autouse=True, params=["nvcc", "nvrtc"]) @@ -53,8 +54,8 @@ def compile_mode_wrapper(code): tvm.register_global_func("tvm_callback_cuda_compile", orig_func, override=True) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_vectorize_add(): num_thread = 8 @@ -106,8 +107,8 @@ def main(A: T.Buffer((n,), vec_dtype), B: T.Buffer((n,), vec_dtype)): check_cuda("float16", 64, 8) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_bf16_vectorize_add(): if not have_bf16(tvm.cuda(0).compute_version): print("skip because gpu does not support bf16") @@ -164,8 +165,8 @@ def main(A: T.Buffer((n,), vec_dtype), B: T.Buffer((n,), vec_dtype)): check_cuda(64, 8) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_multiply_add(): num_thread = 8 @@ -211,8 +212,8 @@ def main( check_cuda("int8", 64, 4) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_vectorize_load(): num_thread = 8 @@ -249,8 +250,8 @@ def main(A: T.Buffer((n,), vec_dtype), B: T.Buffer((n,), vec_dtype)): check_cuda("int8", 64, 16) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_make_int8(): def check_cuda(n, value, lanes): dtype = "int8" @@ -288,8 +289,8 @@ def main(A: T.Buffer((n, lanes), dtype)): check_cuda(64, -3, 2) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_inf_nan(): target = "cuda" @@ -427,8 +428,8 @@ def verify(nthdx, nthdy): verify(32, 16) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_reduction_binding(): @I.ir_module(s_tir=True) class Module: @@ -450,8 +451,8 @@ def main(A: T.Buffer((96, 32), "float32"), B: T.Buffer((96,), "float32")): func = tvm.compile(Module, target="cuda") -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_const_float_to_half(): # This import is required to use nvcc to perform code gen; # otherwise it is found that the code gen is done by nvrtc. @@ -486,8 +487,8 @@ def main(a: T.Buffer((2, 3, 4), "float16"), C: T.Buffer((2, 3, 4), "bool")): np.testing.assert_equal(c.numpy(), a_np > 0.5) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_floordiv_with_vectorization(): with tvm.target.Target("cuda"): # B[i] = A[floordiv(i, k)] @@ -519,8 +520,8 @@ def main(A: T.Buffer((256,), "float32"), B: T.Buffer((256,), "float32")): tvm.testing.assert_allclose(b_nd.numpy(), b_np, rtol=1e-3) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_floormod_with_vectorization(): with tvm.target.Target("cuda"): # B[i] = A[floormod(i, k)] @@ -552,8 +553,8 @@ def main(A: T.Buffer((256,), "float32"), B: T.Buffer((256,), "float32")): tvm.testing.assert_allclose(b_nd.numpy(), b_np, rtol=1e-3) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_vectorized_casts(): def check(t0, t1, factor): if (t0 == "float16" or t1 == "float16") and not have_fp16(tvm.cuda(0).compute_version): @@ -647,8 +648,8 @@ def main(A: T.Buffer((n,), dtype), B: T.Buffer((n,), dtype)): return tvm.compile(Module, target="cuda") -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_vectorized_intrin1(): test_funcs = [ (tvm.tirx.floor, lambda x: np.floor(x)), @@ -703,8 +704,8 @@ def run_test(tvm_intrin, np_func, dtype): run_test(*func, "float16") -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_vectorized_intrin2(dtype="float32"): c2 = tvm.tirx.const(2, dtype=dtype) test_funcs = [ @@ -725,8 +726,8 @@ def run_test(tvm_intrin, np_func): run_test(*func) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_vectorized_popcount(): def ref_popcount(x): cnt = 0 @@ -749,8 +750,8 @@ def run_test(dtype): run_test("uint64") -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_vectorize_load_permute_pad(): def check_cuda(dtype, n, l, padding, lanes): if dtype == "float16" and not have_fp16(tvm.cuda(0).compute_version): @@ -801,8 +802,8 @@ def main(A: T.Buffer((n, l), dtype), B: T.Buffer((dim0, dim1, lanes), dtype)): check_cuda("float32", 64, 16, 3, 4) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_try_unaligned_vector_load(): def build(N, C_N, offset): @I.ir_module(s_tir=True) @@ -846,8 +847,8 @@ def main(A: T.Buffer((N,), "float16"), C: T.Buffer((C_N,), "float16")): assert np.allclose(c, expected), f"expected={expected}\nactual={c}" -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_thread_sync_inside_condition(): @T.prim_func(s_tir=True) def func1(A: T.Buffer((4, 4), "float32")) -> None: @@ -893,7 +894,8 @@ def func3(A: T.Buffer((4, 4), "float32")) -> None: tvm.compile(mod, target="cuda") -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_invalid_reinterpret(): @T.prim_func(s_tir=True) def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,), "uint8")) -> None: @@ -904,8 +906,8 @@ def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,), "uint8")) -> None: tvm.compile(func, target="cuda") -@tvm.testing.requires_cuda -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_cuda_tensormap(): # fmt: off @T.prim_func(s_tir=True) @@ -935,7 +937,8 @@ def main(A_ptr: T.handle): ) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_device_func_call(): @I.ir_module(s_tir=True) class Module: @@ -958,7 +961,8 @@ def main( assert 'extern "C" __device__ float add(float a, float b) {\n return (a + b);\n}' in cuda_code -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_float_const_hex_format(): """Test that float constants are emitted in hexadecimal format for precision""" @@ -977,7 +981,8 @@ def main( assert "0x1.2f684bda12f68p-5f" in cuda_code -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_device_host_call_same_func(): @I.ir_module(s_tir=True) class Module: @@ -1017,7 +1022,8 @@ def main( tvm.testing.assert_allclose(c_tvm.numpy(), a_np + b_np) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_thread_return(): @I.ir_module(s_tir=True) class Module: @@ -1034,8 +1040,8 @@ def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")): assert "return;" in cuda_code -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_loop_step(): @T.prim_func(s_tir=True) def cuda_loop_step( @@ -1066,8 +1072,8 @@ def cuda_loop_step( tvm.testing.assert_allclose(c_nd.numpy(), a_np + b_np) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_export_load_with_fallback(monkeypatch, tmp_path): """Force the codegen wrapper into the fallback branch, then export+load+run.""" n = 1024 diff --git a/tests/python/codegen/test_target_codegen_cuda_fastmath.py b/tests/python/codegen/test_target_codegen_cuda_fastmath.py index 7686dc0dad80..809266bdc8a5 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fastmath.py +++ b/tests/python/codegen/test_target_codegen_cuda_fastmath.py @@ -30,6 +30,7 @@ from tvm.runtime.executable import Executable from tvm.script import tirx as T from tvm.support.nvcc import have_fp16 +from tvm.testing import env VECTOR_N_INPUTS = 8 @@ -286,8 +287,8 @@ def test_cuda_math_intrinsic_lowering_pass_context(enable_fast_math): check_lowered_ir("float32", MATH_CASES[0], enable_fast_math) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize( "dtype", ["float16", "bfloat16", "float32", "float64"], diff --git a/tests/python/codegen/test_target_codegen_cuda_fp4.py b/tests/python/codegen/test_target_codegen_cuda_fp4.py index 5c7f9a1b6611..6a24fdf03c17 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp4.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp4.py @@ -24,6 +24,7 @@ import tvm.testing from tvm.script import ir as I from tvm.script import tirx as T +from tvm.testing import env try: from ml_dtypes import float4_e2m1fn @@ -34,7 +35,8 @@ @pytest.mark.parametrize("promoted_dtype", ["float32x2", "float16x2"]) -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_e2m1_vector_conversions(promoted_dtype): native_dtype = "float4_e2m1fnx2" vector_length = 64 @@ -180,7 +182,8 @@ def main( return Module -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_e2m1_dequantize(): n = 128 @@ -204,7 +207,8 @@ def test_e2m1_dequantize(): tvm.compile(mod, target=target) -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_e2m1_scalar_buffer_offset(): """Regression test: float4_e2m1fn scalar buffer access uses correct byte offset. diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index 23acbd56fc8a..331c96b1d386 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -28,6 +28,7 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tirx as T +from tvm.testing import env try: import ml_dtypes @@ -42,7 +43,8 @@ ("float8_e5m2", "__nv_fp8_e5m2"), ], ) -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_fp8_conversions(input): dtype, nv_dtype = input @@ -91,7 +93,8 @@ def main( "dtype", ["float8_e4m3fn", "float8_e5m2", "float8_e8m0fnu"], ) -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_fp8_packing(dtype): length = 64 vector_length = 4 @@ -156,7 +159,8 @@ def main( ) -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_fp8_vector_conversions(native_dtype, promoted_dtype, numpytype): vector_length = 64 @@ -217,7 +221,8 @@ def main( bcast_length = tvm.testing.parameter(2, 4, 6, 8) -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_half_broadcast(bcast_length): dtype = "float16" @@ -252,7 +257,8 @@ def main(a: T.Buffer((), dtype), vec: T.Buffer((bcast_length,), dtype)): vector_length = tvm.testing.parameter(2, 4) -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_half_misaligned_vector_load(vector_length): dtype = "float16" vec_dtype = dtype + "x" + str(vector_length) @@ -287,7 +293,8 @@ def vector_load( tvm.testing.assert_allclose(b.numpy(), b_np) -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_half4_vector_add(): dtype = "float16" length = 64 @@ -790,7 +797,8 @@ def compiled_functions( dev, ) - @tvm.testing.requires_cuda_compute_version(8, 9) + @pytest.mark.gpu + @pytest.mark.skipif(not env.has_cuda_compute(8, 9), reason="need cuda compute >= 8.9") def test_main(self, weight_shape, model_dtype, target_str, compiled_functions): quant, dequant = compiled_functions dev = tvm.device(target_str, 0) @@ -805,7 +813,8 @@ def test_main(self, weight_shape, model_dtype, target_str, compiled_functions): tvm.testing.assert_allclose(weight_np, dequant_weight_np, atol=10, rtol=5e-2) -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("dtype", ["float8_e5m2", "float8_e4m3fn", "float8_e8m0fnu"]) def test_const(dtype): @T.prim_func(s_tir=True) @@ -820,7 +829,8 @@ def func(A: T.Buffer((4,), dtype)) -> None: tvm.compile(mod, target="cuda") -@tvm.testing.requires_cuda_compute_version(8, 9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8, 9), reason="need cuda compute >= 8.9") @pytest.mark.parametrize("dtype", ["float8_e5m2", "float8_e4m3fn"]) @pytest.mark.parametrize("vec_len", [2, 4, 8, 16]) def test_copy(dtype, vec_len): @@ -854,7 +864,8 @@ def func( spatial_size = 4096 -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes to be installed") def test_moe_gemv_shfl_down_illegal_instr(): global num_experts @@ -965,7 +976,8 @@ def _pipeline(mod: tvm.ir.IRModule) -> tvm.ir.IRModule: @pytest.mark.parametrize("vec_length", [2, 4]) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) -@tvm.testing.requires_cuda_compute_version(8, 9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8, 9), reason="need cuda compute >= 8.9") def test_fp8_fp16_bf16_vectorize_arith(vec_length, dtype): def _create_mod(vec_length, dtype): num_threads = 128 // vec_length diff --git a/tests/python/codegen/test_target_codegen_device.py b/tests/python/codegen/test_target_codegen_device.py index aaa29f58091e..1f931649b295 100644 --- a/tests/python/codegen/test_target_codegen_device.py +++ b/tests/python/codegen/test_target_codegen_device.py @@ -15,14 +15,17 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm import tvm.testing from tvm.script import ir as I from tvm.script import tirx as T +from tvm.testing import env -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_large_uint_imm(): value = (1 << 63) + 123 value_const = tvm.tirx.const(value, "uint64") @@ -55,7 +58,8 @@ def check_target(target): check_target({"kind": "vulkan", "from_device": 0}) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_add_pipeline(): @I.ir_module(s_tir=True) class Module: diff --git a/tests/python/codegen/test_target_codegen_extern.py b/tests/python/codegen/test_target_codegen_extern.py index 0c3f9e8bf33b..2057c94da919 100644 --- a/tests/python/codegen/test_target_codegen_extern.py +++ b/tests/python/codegen/test_target_codegen_extern.py @@ -16,6 +16,7 @@ # under the License. # ruff: noqa: F841 import numpy as np +import pytest import tvm import tvm.testing @@ -23,7 +24,7 @@ from tvm.script import tirx as T -@tvm.testing.uses_gpu +@pytest.mark.gpu def test_add_pipeline(): """Test extern-style add pipeline with vectorized operations.""" nn = 64 diff --git a/tests/python/codegen/test_target_codegen_gpu_common.py b/tests/python/codegen/test_target_codegen_gpu_common.py index 59b5e099cabc..067ef98b186d 100644 --- a/tests/python/codegen/test_target_codegen_gpu_common.py +++ b/tests/python/codegen/test_target_codegen_gpu_common.py @@ -23,9 +23,11 @@ import tvm.testing from tvm.script import ir as I from tvm.script import tirx as T +from tvm.testing import env -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") @tvm.testing.parametrize_targets( "cuda", "metal", {"kind": "vulkan", "supports_int64": True}, "opencl" ) diff --git a/tests/python/codegen/test_target_codegen_hexagon.py b/tests/python/codegen/test_target_codegen_hexagon.py index 087cecbc3e5f..98afef1ff3ca 100644 --- a/tests/python/codegen/test_target_codegen_hexagon.py +++ b/tests/python/codegen/test_target_codegen_hexagon.py @@ -24,6 +24,7 @@ import tvm.testing from tvm.script import ir as I from tvm.script import tirx as T +from tvm.testing import env @pytest.fixture(autouse=True) @@ -36,7 +37,7 @@ def register_linker(): hexagon.register_linker(original_linker) -@tvm.testing.requires_hexagon +@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_basic(): target = tvm.target.Target("qcom/hexagon-v66") @@ -62,7 +63,7 @@ def main( assert vadds # Check that it's non-empty -@tvm.testing.requires_hexagon +@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_llvm_target_features(): target = tvm.target.Target("qcom/hexagon-v66") @@ -85,7 +86,7 @@ def add_one(C: T.Buffer((128,), "int32"), A: T.Buffer((128,), "uint8")): assert fs # Check that it's non-empty -@tvm.testing.requires_hexagon +@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_llvm_options(): target = tvm.target.Target( { diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index 8c2c8d6e07bb..7c093f9be27b 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -27,9 +27,10 @@ from tvm.script import tirx as T from tvm.support import clang, utils from tvm.target.codegen import llvm_get_intrinsic_name, llvm_lookup_intrinsic_id +from tvm.testing import env -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_intrin(): @I.ir_module(s_tir=True) class Module: @@ -41,7 +42,7 @@ def main(A: T.handle("float32")): fcode = tvm.compile(Module) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_void_intrin(): @I.ir_module(s_tir=True) class Module: @@ -53,7 +54,7 @@ def main(A: T.handle("uint8")): fcode = tvm.compile(Module) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_intrinsic_id(): orig_name = "llvm.x86.sse2.pmadd.wd" intrin_id = llvm_lookup_intrinsic_id(orig_name) @@ -61,7 +62,7 @@ def test_llvm_intrinsic_id(): assert orig_name == name -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_overloaded_intrin(): # Name lookup for overloaded intrinsics in LLVM 4- requires a name # that includes the overloaded types. @@ -83,7 +84,7 @@ def main(A: T.Buffer((1, 1), "int32"), C: T.Buffer((1, 1), "int32")): f = tvm.compile(Module, target="llvm") -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_lookup_intrin(): @I.ir_module(s_tir=True) class Module: @@ -95,7 +96,7 @@ def main(A: T.handle("uint8x8")): fcode = tvm.compile(Module, None) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_large_uintimm(): value = (1 << 63) + 123 large_val = tvm.tirx.const(value, "uint64") @@ -118,7 +119,7 @@ def main(A: T.Buffer((), "uint64")): assert a.numpy() == value + 3 -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_multi_parallel(): @I.ir_module(s_tir=True) class Module: @@ -150,7 +151,7 @@ def main(A: T.Buffer((128,), "float32"), C: T.Buffer((128,), "float32")): tvm.testing.assert_allclose(c.numpy(), np.sqrt(a.numpy() + 1) * 2 + 2, rtol=1e-5) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_flip_pipeline(): def check_llvm(nn, base): @I.ir_module(s_tir=True) @@ -180,7 +181,7 @@ def main(A: T.Buffer((nn + base,), "float32"), C: T.Buffer((nn,), "float32")): check_llvm(128, 1) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_vadd_pipeline(): @I.ir_module(s_tir=True) class Module: @@ -210,7 +211,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_madd_pipeline(): def check_llvm(nn, base, stride): @I.ir_module(s_tir=True) @@ -246,7 +247,7 @@ def main( check_llvm(4, 0, 3) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_temp_space(): @I.ir_module(s_tir=True) class Module: @@ -276,7 +277,7 @@ def main(A: T.Buffer((1024,), "float32"), C: T.Buffer((1024,), "float32")): tvm.testing.assert_allclose(c.numpy(), a.numpy() + 1 + 1) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_multiple_func(): @I.ir_module(s_tir=True) class Module: @@ -321,7 +322,7 @@ def fadd2(var_A: T.handle, var_B: T.handle, var_C: T.handle): tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_condition(): @I.ir_module(s_tir=True) class Module: @@ -347,7 +348,7 @@ def main(A: T.Buffer((64,), "float32"), C: T.Buffer((64,), "float32")): tvm.testing.assert_allclose(c.numpy(), c_np) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_bool(): @I.ir_module(s_tir=True) class Module: @@ -371,7 +372,7 @@ def main(A: T.Buffer((64,), "int32"), C: T.Buffer((64,), "float32")): tvm.testing.assert_allclose(c.numpy(), c_np) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_cast_float_to_bool(): @I.ir_module(s_tir=True) class Module: @@ -395,7 +396,7 @@ def main(A: T.Buffer((4,), "float32"), C: T.Buffer((4,), "bool")): tvm.testing.assert_allclose(c.numpy(), c_np) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_rank_zero(): @I.ir_module(s_tir=True) class Module: @@ -432,7 +433,7 @@ def main( tvm.testing.assert_allclose(d.numpy(), d_np) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_rank_zero_bound_checkers(): @I.ir_module(s_tir=True) class Module: @@ -470,7 +471,7 @@ def main( tvm.testing.assert_allclose(d.numpy(), d_np) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_alignment(): @I.ir_module(s_tir=True) class Module: @@ -518,7 +519,7 @@ def has_call_to_assume(): assert has_call_to_assume() -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_div(): """Check that the semantics of div and mod is correct""" @@ -658,7 +659,7 @@ def _show_info(): check(0, 255, dstart, dend, "uint8", floor_div=True) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_fp_math(): @I.ir_module(s_tir=True) class RecipModule: @@ -709,7 +710,7 @@ def main(var_A: T.handle, var_B: T.handle): tvm.testing.assert_allclose(b.numpy(), np.zeros((n,), "float32")) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_dwarf_debug_information(): @I.ir_module(s_tir=True) class Module: @@ -797,7 +798,7 @@ def check_llvm_ir(): check_llvm_ir() -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_bf16(): def dotest(do_vectorize): loop_kind = T.vectorized if do_vectorize else T.serial @@ -835,7 +836,7 @@ def main( dotest(False) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_crt_static_lib(): @I.ir_module(s_tir=True) class Module: @@ -862,7 +863,7 @@ def main( module.write_to_file(temp.relpath("test.o")) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_order_functions(): """Check that functions in the LLVM module are ordered alphabetically.""" @@ -888,7 +889,7 @@ def Kirby(v: T.float32) -> T.float32: assert matches == sorted(matches) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") @tvm.testing.skip_if_32bit def test_llvm_import(): """all-platform-minimal-test: check shell dependent clang behavior.""" @@ -931,7 +932,7 @@ def main(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32")): check_llvm(use_file=False) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_scalar_concat(): @I.ir_module(s_tir=True) class Module: @@ -945,7 +946,7 @@ def main(x: T.int32, y: T.int32, buffer: T.Buffer((1,), "int32x2")): m = tvm.compile(Module, target="llvm") -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_raise_exception_during_codegen(): @I.ir_module(s_tir=True) class Module: @@ -962,7 +963,7 @@ def main(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")) -> None assert msg.find("Nested parallel loop is not supported") != -1 -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_target_attributes(): """Check that when LLVM codegen creates new functions, they get the same target attributes as the original function. @@ -1028,7 +1029,7 @@ def test_func(var_A: T.handle, var_B: T.handle, var_C: T.handle, tindex: T.int32 assert re.match('.*"target-features"=".*[+]avx512f.*".*', attribute_definitions[attr_num]) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_assume(): """ Check that LLVM does not error out when generating code with tirx.assume. @@ -1051,7 +1052,7 @@ def main(A: T.Buffer((4, 4), "int32"), B: T.Buffer((14,), "int32")): m = tvm.compile(Module, target="llvm") -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_debug_symbol_for_float64(): """Check that LLVM can define DWARF debug type for float64 @@ -1073,7 +1074,7 @@ def main(a: T.handle("float64"), b: T.handle("float64"), n: T.int64): tvm.compile(Module, target="llvm") -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_subroutine_call(): @I.ir_module(s_tir=True) class Module: @@ -1099,7 +1100,7 @@ def subroutine(A_data: T.handle("float32")): assert arr.numpy()[0] == 42.0 -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_call_packed_returning_void(): """Allow codegen of PackedFunc calls returning void @@ -1130,7 +1131,7 @@ def main(): built = tvm.compile(Module, target="llvm") -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_call_packed_without_string_arg(): """The first argument to tvm_call_packed must be a string @@ -1150,7 +1151,7 @@ def main(A: T.Buffer(1, "float32")): built = tvm.compile(Module, target="llvm") -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_call_extern_returning_void(): """Like test_call_packed_returning_void, but for call_extern""" diff --git a/tests/python/codegen/test_target_codegen_metal.py b/tests/python/codegen/test_target_codegen_metal.py index c1a8054b6087..0237367a57dd 100644 --- a/tests/python/codegen/test_target_codegen_metal.py +++ b/tests/python/codegen/test_target_codegen_metal.py @@ -15,15 +15,17 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm import tvm.testing from tvm.script import ir as I from tvm.script import tirx as T +from tvm.testing import env -@tvm.testing.requires_gpu -@tvm.testing.requires_metal +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_metal(), reason="need metal") def test_metal_inf_nan(): target = "metal" @@ -59,8 +61,8 @@ def main( check_inf_nan(dev, 1, float("nan"), "float16") -@tvm.testing.requires_gpu -@tvm.testing.requires_metal +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_metal(), reason="need metal") def test_unaligned_vectorize(): @tvm.script.ir_module class IRModule: @@ -84,8 +86,8 @@ def main(A: T.Buffer((2, 3), "float32"), B: T.Buffer((6,), "float32")): tvm.testing.assert_allclose(b_nd.numpy(), a.reshape(6), atol=1e-5, rtol=1e-5) -@tvm.testing.requires_gpu -@tvm.testing.requires_metal +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_metal(), reason="need metal") def test_metal_erf(): target = "metal" @@ -117,8 +119,8 @@ def main( check_erf(dev, 1, "float16") -@tvm.testing.requires_gpu -@tvm.testing.requires_metal +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_metal(), reason="need metal") def test_ramp(): target = "metal" @@ -140,8 +142,8 @@ def main(A: T.Buffer((1, 2), "int32")): assert tuple(a_nd.numpy()[0, :]) == (0, 3) -@tvm.testing.requires_gpu -@tvm.testing.requires_metal +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_metal(), reason="need metal") def test_select_vectorize(): @tvm.script.ir_module class IRModule: @@ -165,8 +167,8 @@ def main(A: T.Buffer((6), "float32"), B: T.Buffer((6,), "float32")): tvm.testing.assert_allclose(b_nd.numpy(), a, atol=1e-5, rtol=1e-5) -@tvm.testing.requires_gpu -@tvm.testing.requires_metal +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_metal(), reason="need metal") def test_vectorized_uint8(): @T.prim_func(s_tir=True) def func(A: T.Buffer((16), "uint8"), B: T.Buffer((16), "float32")): @@ -185,7 +187,8 @@ def func(A: T.Buffer((16), "uint8"), B: T.Buffer((16), "float32")): tvm.testing.assert_allclose(b_nd.numpy(), a.astype("float32"), atol=1e-5, rtol=1e-5) -@tvm.testing.requires_metal(support_required="compile-only") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_metal(), reason="need metal") def test_func_with_trailing_pod_params(): from tvm.support import xcode # pylint: disable=import-outside-toplevel @@ -208,7 +211,8 @@ def compile_metal(src, target): assert occurrences == 1, occurrences -@tvm.testing.requires_metal(support_required="compile-only") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_metal(), reason="need metal") def test_export_load_with_fallback(monkeypatch, tmp_path): """Force the codegen wrapper into the fallback branch, then export.""" n = 1024 diff --git a/tests/python/codegen/test_target_codegen_opencl.py b/tests/python/codegen/test_target_codegen_opencl.py index 227dfa626f05..c1d5143756a4 100644 --- a/tests/python/codegen/test_target_codegen_opencl.py +++ b/tests/python/codegen/test_target_codegen_opencl.py @@ -17,16 +17,19 @@ # ruff: noqa: E501 import re +import pytest + import tvm import tvm.testing from tvm.script import ir as I from tvm.script import tirx as T +from tvm.testing import env target = "opencl" -@tvm.testing.requires_gpu -@tvm.testing.requires_opencl +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_opencl(), reason="need opencl") def test_opencl_ternary_expression(): def check_if_then_else(dev, n, dtype): @I.ir_module(s_tir=True) @@ -92,8 +95,8 @@ def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): check_select(dev, 1, "uint16") -@tvm.testing.requires_gpu -@tvm.testing.requires_opencl +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_opencl(), reason="need opencl") def test_opencl_inf_nan(): def check_inf_nan(dev, n, value, dtype): @I.ir_module(s_tir=True) @@ -124,8 +127,8 @@ def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): check_inf_nan(dev, 1, float("nan"), "float64") -@tvm.testing.requires_gpu -@tvm.testing.requires_opencl +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_opencl(), reason="need opencl") def test_opencl_max(): def check_max(dev, n, dtype): @I.ir_module(s_tir=True) @@ -183,8 +186,8 @@ def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): check_erf(dev, 1, "float64") -@tvm.testing.requires_gpu -@tvm.testing.requires_opencl +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_opencl(), reason="need opencl") def test_opencl_type_casting(): @I.ir_module(s_tir=True) class Module: @@ -218,8 +221,8 @@ def check_type_casting(ctx, n, dtype): # check_type_casting(dev, 16, "float16") -@tvm.testing.requires_gpu -@tvm.testing.requires_opencl +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_opencl(), reason="need opencl") @tvm.testing.parametrize_targets("opencl", {"kind": "opencl", "device": "adreno"}) def test_opencl_ceil_log2(target): def _check(target, n, dtype): @@ -265,8 +268,8 @@ def get_kernel_args(source): return max_args -@tvm.testing.requires_gpu -@tvm.testing.requires_opencl +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_opencl(), reason="need opencl") def test_export_load_with_fallback(monkeypatch, tmp_path): """Force the codegen wrapper into the fallback branch, then export+load+run.""" import numpy as np diff --git a/tests/python/codegen/test_target_codegen_riscv.py b/tests/python/codegen/test_target_codegen_riscv.py index c13e4e91be7d..5b9b1ecd7707 100644 --- a/tests/python/codegen/test_target_codegen_riscv.py +++ b/tests/python/codegen/test_target_codegen_riscv.py @@ -16,13 +16,16 @@ # under the License. # ruff: noqa: E501, F841 +import pytest + import tvm import tvm.testing from tvm.script import tirx as T from tvm.target.codegen import target_has_features +from tvm.testing import env -@tvm.testing.requires_llvm_minimum_version(14) +@pytest.mark.skipif(not env.has_llvm_min_version(14), reason="need llvm >= 14") @tvm.testing.parametrize_targets( { "kind": "llvm", @@ -72,7 +75,7 @@ def load_vec(A: T.Buffer((N,), "int8")): check_rvv_presence(16, 32) -@tvm.testing.requires_llvm_minimum_version(14) +@pytest.mark.skipif(not env.has_llvm_min_version(14), reason="need llvm >= 14") @tvm.testing.parametrize_targets( { "kind": "llvm", diff --git a/tests/python/codegen/test_target_codegen_rocm.py b/tests/python/codegen/test_target_codegen_rocm.py index 8254f821810d..d8865eb1efb1 100644 --- a/tests/python/codegen/test_target_codegen_rocm.py +++ b/tests/python/codegen/test_target_codegen_rocm.py @@ -16,14 +16,17 @@ # under the License. # ruff: noqa: F841 import numpy as np +import pytest import tvm import tvm.testing from tvm.script import ir as I from tvm.script import tirx as T +from tvm.testing import env -@tvm.testing.requires_rocm +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_rocm(), reason="need rocm") def test_rocm_inf_nan(): def check_inf_nan(dev, n, value, dtype): @I.ir_module(s_tir=True) @@ -56,7 +59,8 @@ def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): check_inf_nan(dev, 1, float("nan"), "float64") -@tvm.testing.requires_rocm +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_rocm(), reason="need rocm") def test_rocm_copy(): def check_rocm(dtype, n): dev = tvm.rocm(0) @@ -73,7 +77,8 @@ def check_rocm(dtype, n): check_rocm(dtype, int(peturb * (2**logN))) -@tvm.testing.requires_rocm +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_rocm(), reason="need rocm") def test_rocm_vectorize_add(): def check_rocm(dtype, n, lanes): vec_dtype = f"{dtype}x{lanes}" @@ -104,7 +109,8 @@ def main(A: T.Buffer((n,), vec_dtype), B: T.Buffer((n,), vec_dtype)): check_rocm("float16", 64, 2) -@tvm.testing.requires_rocm +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_rocm(), reason="need rocm") def test_rocm_warp_shuffle(): @T.prim_func(s_tir=True) def func( @@ -130,7 +136,8 @@ def func( tvm.testing.assert_allclose(a.numpy(), np.ones((32,)) * a.numpy()[0]) -@tvm.testing.requires_rocm +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_rocm(), reason="need rocm") def test_rocm_vectorized_exp(): @T.prim_func(s_tir=True) def func( @@ -154,7 +161,8 @@ def func( tvm.testing.assert_allclose(b.numpy(), np.exp2(a.numpy())) -@tvm.testing.requires_rocm +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_rocm(), reason="need rocm") def test_export_load_with_fallback(monkeypatch, tmp_path): """Force the codegen wrapper into the fallback branch, then export+load+run.""" n = 1024 diff --git a/tests/python/codegen/test_target_codegen_vulkan.py b/tests/python/codegen/test_target_codegen_vulkan.py index 1440bc5bca6e..c240cd1d1bc7 100644 --- a/tests/python/codegen/test_target_codegen_vulkan.py +++ b/tests/python/codegen/test_target_codegen_vulkan.py @@ -28,6 +28,7 @@ from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import ir as I_builder from tvm.script.ir_builder import tirx as T_builder +from tvm.testing import env dtype = tvm.testing.parameter("float32", "int32", "float16", "int8") fuzz_seed = tvm.testing.parameter(range(25)) @@ -468,7 +469,8 @@ def main(X: T.Buffer((16, 32), "float16"), W: T.Buffer((32, 16), "float16"), com tvm.testing.assert_allclose(C.numpy(), ref, rtol=1e-2, atol=1e-2) -@tvm.testing.requires_vulkan(support_required="compile-only") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_vulkan(), reason="need vulkan") def test_codegen_decl_buffer(): """The codegen should accept DeclBuffer nodes in its input""" @@ -485,7 +487,8 @@ def kernel(): vulkan_codegen(Module, target) -@tvm.testing.requires_vulkan(support_required="compile-only") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_vulkan(), reason="need vulkan") def test_codegen_static_shared_memory(): """The codegen should accept static shared/workgroup allocations.""" @@ -503,8 +506,8 @@ def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): tvm.compile(Module, target="vulkan") -@tvm.testing.requires_gpu -@tvm.testing.requires_vulkan +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_vulkan(), reason="need vulkan") def test_unary(): test_funcs = [ (tvm.tirx.sin, lambda x: np.sin(x)), @@ -562,7 +565,8 @@ def main(var_A: T.handle, var_B: T.handle): run_test(*func) -@tvm.testing.requires_vulkan(support_required="compile-only") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_vulkan(), reason="need vulkan") def test_export_load_with_fallback(monkeypatch, tmp_path): """Force the codegen wrapper into the fallback branch, then export.""" n = 1024 diff --git a/tests/python/codegen/test_target_codegen_x86.py b/tests/python/codegen/test_target_codegen_x86.py index bed010cdea61..6ca3b546adbe 100644 --- a/tests/python/codegen/test_target_codegen_x86.py +++ b/tests/python/codegen/test_target_codegen_x86.py @@ -23,6 +23,7 @@ import tvm from tvm.script import ir as I from tvm.script import tirx as T +from tvm.testing import env llvm_version = tvm.target.codegen.llvm_version_major() machine = platform.machine() @@ -31,7 +32,7 @@ pytest.skip(f"Requires x86_64, but machine is {machine}", allow_module_level=True) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") @pytest.mark.skipif(llvm_version < 6, reason=f"Requires LLVM 6+, got {llvm_version}") def test_fp16_to_fp32(): def fp16_to_fp32(target, width, match=None, not_match=None): diff --git a/tests/python/contrib/test_cutlass_gemm.py b/tests/python/contrib/test_cutlass_gemm.py index 327353cca288..19e53c48b0e7 100644 --- a/tests/python/contrib/test_cutlass_gemm.py +++ b/tests/python/contrib/test_cutlass_gemm.py @@ -17,10 +17,12 @@ import ml_dtypes import numpy as np +import pytest import tvm import tvm.testing from tvm.contrib.pickle_memoize import memoize +from tvm.testing import env def get_random_tensor(shape, dtype): @@ -71,8 +73,9 @@ def to_numpy_dtype(dtype): tvm.testing.assert_allclose(c_nd.numpy(), c_np, rtol=rtol, atol=atol) -@tvm.testing.requires_cutlass -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.skipif(not env.has_cutlass(), reason="need cutlass") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_group_gemm_sm90(): verify_group_gemm( "cutlass.group_gemm", @@ -115,8 +118,9 @@ def test_group_gemm_sm90(): ) -@tvm.testing.requires_cutlass -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.skipif(not env.has_cutlass(), reason="need cutlass") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_group_gemm_sm100(): verify_group_gemm( "cutlass.group_gemm", @@ -298,8 +302,9 @@ def blockwise_bmm( return o_np -@tvm.testing.requires_cutlass -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.skipif(not env.has_cutlass(), reason="need cutlass") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_fp8_e4m3_groupwise_scaled_gemm(): M = 16 N = 4608 @@ -331,8 +336,9 @@ def test_fp8_e4m3_groupwise_scaled_gemm(): tvm.testing.assert_allclose(o_tvm, o_np, rtol=1e-4, atol=0.5) -@tvm.testing.requires_cutlass -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.skipif(not env.has_cutlass(), reason="need cutlass") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_fp8_e4m3_groupwise_scaled_bmm(): B = 16 M = 40 diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index 0abdd6c9d236..fba8ccb47dc0 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -23,6 +23,7 @@ import tvm from tvm.script import tirx as T +from tvm.testing import env VRMPY_SIZE_B = 128 VRMPY_SIZE_INT32 = 32 @@ -390,7 +391,7 @@ def expected_output(self, size_a, size_w, input_a, input_w): ) * np.uint32(input_w[x, index_0 * 4 + r_index]) return expected_result - @tvm.testing.requires_hexagon + @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_loading_vtcm_for_vrmpy( self, hexagon_session, @@ -839,7 +840,7 @@ def main( ] -@tvm.testing.requires_hexagon +@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_meta(hexagon_session): """Test meta.""" if tvm.testing.utils.IS_IN_CI: diff --git a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py index 6b0bf4824240..39b4eef55568 100644 --- a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py +++ b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py @@ -28,6 +28,7 @@ import tvm.testing from tvm.contrib.hexagon.session import Session from tvm.script import tirx as T +from tvm.testing import env from . import benchmark_util as bu from .infrastructure import get_hexagon_target @@ -369,7 +370,7 @@ def _get_elemwise_add_reference_value_tensors(shape: list, dtype: str): @pytest.mark.skipif(_SHOULD_SKIP_BENCHMARKS, reason=_SKIP_BENCHMARKS_REASON) -@tvm.testing.requires_hexagon +@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_elemwise_add(hexagon_session: Session): """Main elementwise add test function""" for dtype in [ diff --git a/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py b/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py index 5431c0ae6964..91830e5d1a9d 100644 --- a/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py +++ b/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py @@ -55,6 +55,7 @@ from tvm import te, tirx, topi from tvm.contrib.hexagon import allocate_hexagon_array from tvm.contrib.hexagon.session import Session +from tvm.testing import env from tvm.topi import testing from . import benchmark_util as bu @@ -194,7 +195,7 @@ class TestMaxPool2D: io_tensor_mem_scope = tvm.testing.parameter("global.vtcm") @pytest.mark.skipif(_SHOULD_SKIP_BENCHMARKS, reason=_SKIP_BENCHMARKS_REASON) - @tvm.testing.requires_hexagon + @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_maxpool2d_nhwc( self, n_batch, diff --git a/tests/python/contrib/test_hexagon/test_dma_builtin.py b/tests/python/contrib/test_hexagon/test_dma_builtin.py index bae14da5ed46..961d3bb5d602 100644 --- a/tests/python/contrib/test_hexagon/test_dma_builtin.py +++ b/tests/python/contrib/test_hexagon/test_dma_builtin.py @@ -20,6 +20,7 @@ """ import numpy as np +import pytest import tvm import tvm.contrib.hexagon @@ -29,6 +30,7 @@ from tvm.script.parser import ir as I from tvm.script.parser import relax as R from tvm.script.parser import tirx as T +from tvm.testing import env # pylint: disable=invalid-name, missing-class-docstring, missing-function-docstring, no-self-argument @@ -165,7 +167,7 @@ class TestDMACopyWait: mode = tvm.testing.parameter("bytecode", "compiled") module = tvm.testing.parameter(Module_1D) - @tvm.testing.requires_hexagon + @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_vtcm_alloc_compute(self, hexagon_launcher, mode, module): target_hexagon = tvm.target.Target("qcom/hexagon-v69") target = tvm.target.Target(target_hexagon, host=target_hexagon) diff --git a/tests/python/contrib/test_hexagon/test_meta_schedule.py b/tests/python/contrib/test_hexagon/test_meta_schedule.py index 0d0c01bcb675..4c9f7d2ee358 100644 --- a/tests/python/contrib/test_hexagon/test_meta_schedule.py +++ b/tests/python/contrib/test_hexagon/test_meta_schedule.py @@ -38,6 +38,7 @@ from tvm.s_tir.meta_schedule.runner import RunnerInput from tvm.s_tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN from tvm.script import tirx as T +from tvm.testing import env from tvm.tirx import FloatImm from .infrastructure import get_hexagon_target @@ -69,7 +70,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # type: ignore ) -@tvm.testing.requires_hexagon +@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_builder_runner(hexagon_launcher): """Test builder and runner.""" if hexagon_launcher.is_simulator(): @@ -191,7 +192,7 @@ def verify_dense(sch, target, m_size, n_size, k_size, hexagon_session): print(f"{time_ms:f} ms, {gflops / (time_ms / 1e3):f} GOPS") -@tvm.testing.requires_hexagon +@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_vrmpy_dense(hexagon_launcher): """Test vector reduce muliply dense.""" if hexagon_launcher.is_simulator(): @@ -300,7 +301,7 @@ def main( # type: ignore ) -@tvm.testing.requires_hexagon +@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_vrmpy_dense_auto_tensorize(hexagon_launcher): """Test VRMPY dense operator.""" if hexagon_launcher.is_simulator(): diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx.py b/tests/python/contrib/test_hexagon/test_parallel_hvx.py index fe385c16c3a1..4f8747c4e034 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx.py @@ -20,9 +20,11 @@ """ import numpy as np +import pytest import tvm from tvm.script import tirx as T +from tvm.testing import env from .infrastructure import get_hexagon_target @@ -197,7 +199,7 @@ class TestMatMulVec: # 16384, ) - @tvm.testing.requires_hexagon + @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test( self, hexagon_session, diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py index 0698c0db1b47..39755e28111f 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py @@ -18,9 +18,11 @@ """Test different strategies for loading data into vtcm before running HVX workloads.""" import numpy as np +import pytest import tvm from tvm.script import tirx as T +from tvm.testing import env from .infrastructure import get_hexagon_target @@ -408,7 +410,7 @@ def expected_output(self, operations, input_a, input_b, input_c): ) * np.uint32(input_b[n, i * 4 + r_ind]) return expected_output - @tvm.testing.requires_hexagon + @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_loading_vtcm_for_vrmpy( self, hexagon_session, diff --git a/tests/python/contrib/test_hexagon/test_parallel_scalar.py b/tests/python/contrib/test_hexagon/test_parallel_scalar.py index 43314cd6a832..e14e6911f05d 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_scalar.py +++ b/tests/python/contrib/test_hexagon/test_parallel_scalar.py @@ -18,9 +18,11 @@ """Test parallelism for multiple different scalar workloads.""" import numpy as np +import pytest import tvm from tvm.script import tirx as T +from tvm.testing import env from .infrastructure import get_hexagon_target @@ -141,7 +143,7 @@ class TestMatMulVec: split_factor = tvm.testing.parameter(4) - @tvm.testing.requires_hexagon + @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_add( self, hexagon_session, diff --git a/tests/python/contrib/test_hexagon/test_relax_integration.py b/tests/python/contrib/test_hexagon/test_relax_integration.py index 40421abd3b7e..b8947b114c74 100644 --- a/tests/python/contrib/test_hexagon/test_relax_integration.py +++ b/tests/python/contrib/test_hexagon/test_relax_integration.py @@ -27,6 +27,7 @@ from tvm.contrib.hexagon.session import Session from tvm.relax.frontend import onnx from tvm.relax.testing import relay_translator +from tvm.testing import env def get_onnx_mobilenet(): @@ -42,7 +43,7 @@ def get_onnx_mobilenet(): @pytest.mark.skip("takes too long (~20min)") -@tvm.testing.requires_hexagon +@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_mobilenet_onnx(hexagon_session: Session): """Test MobileNetV2 ONNX model""" onnx_model = get_onnx_mobilenet() @@ -77,7 +78,7 @@ def test_mobilenet_onnx(hexagon_session: Session): @pytest.mark.skip("takes too long (~20min)") -@tvm.testing.requires_hexagon +@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_mobilenet(hexagon_session: Session): """Test MobileNet workload""" relay_mod, params = testing.mobilenet.get_workload(batch_size=1, dtype="float32") diff --git a/tests/python/contrib/test_hexagon/test_run_unit_tests.py b/tests/python/contrib/test_hexagon/test_run_unit_tests.py index a7795813cd71..f1cec118e4c3 100644 --- a/tests/python/contrib/test_hexagon/test_run_unit_tests.py +++ b/tests/python/contrib/test_hexagon/test_run_unit_tests.py @@ -19,9 +19,12 @@ """capture gtest output and return over FFI""" +import pytest + import tvm import tvm.testing from tvm.contrib.hexagon.session import Session +from tvm.testing import env unit_test_name = tvm.testing.parameter( "HexagonUserDMATest.wait", @@ -146,7 +149,7 @@ # use --gtest_args to pass arguments to gtest # for example to run all "foo" tests twice and observe gtest output run # pytest -sv --gtests_args="--gtest_filter=*foo* --gtest_repeat=2" -@tvm.testing.requires_hexagon +@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_run_unit_tests(hexagon_session: Session, gtest_args, unit_test_name): """Try running gtest unit tests and capture output and error code""" try: diff --git a/tests/python/contrib/test_hexagon/test_sigmoid.py b/tests/python/contrib/test_hexagon/test_sigmoid.py index 805f0b3477da..f9ffe5522097 100644 --- a/tests/python/contrib/test_hexagon/test_sigmoid.py +++ b/tests/python/contrib/test_hexagon/test_sigmoid.py @@ -18,11 +18,13 @@ """Sigmoid operator tests.""" import numpy as np +import pytest import tvm import tvm.testing from tvm import te, tirx, topi from tvm.contrib.hexagon import allocate_hexagon_array +from tvm.testing import env from .infrastructure import get_hexagon_target @@ -68,7 +70,7 @@ def ref_output_np(self, input_np): output_np = 1 / (1 + np.exp(-input_np)) return output_np - @tvm.testing.requires_hexagon + @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_sigmoid( self, in_shape, diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py index d66b145d39ba..e06a07fef1ac 100644 --- a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py +++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py @@ -18,10 +18,12 @@ """Async software pipeline tests.""" import numpy as np +import pytest import tvm from tvm import tirx from tvm.script import tirx as T +from tvm.testing import env from .infrastructure import get_hexagon_target @@ -165,7 +167,7 @@ def schedule(self, comp_type, sched_type, outer, inner, dtype, scope): return sch - @tvm.testing.requires_hexagon + @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_async_software_pipeline( self, hexagon_launcher, comp_type, data, reference, schedule, verify ): diff --git a/tests/python/contrib/test_hexagon/test_thread_pool.py b/tests/python/contrib/test_hexagon/test_thread_pool.py index fc06275b4004..ae37427bc7e6 100644 --- a/tests/python/contrib/test_hexagon/test_thread_pool.py +++ b/tests/python/contrib/test_hexagon/test_thread_pool.py @@ -18,6 +18,7 @@ """Add hexagon thread pool test""" import numpy as np +import pytest import tvm import tvm.contrib.hexagon @@ -25,6 +26,7 @@ import tvm.testing from tvm.contrib.hexagon.session import Session from tvm.script import tirx as T +from tvm.testing import env from .infrastructure import get_hexagon_target @@ -72,7 +74,7 @@ def benchmark_func(mod, name, args, hexagon_session): return evaluator(a, b, c, n).mean -@tvm.testing.requires_hexagon +@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_speedup(hexagon_session: Session, capsys): """Test speedup""" func = tvm.compile( @@ -88,7 +90,7 @@ def test_speedup(hexagon_session: Session, capsys): print(f"... speedup of {serial_mean / parallel_mean:.2f}", end=" ") -@tvm.testing.requires_hexagon +@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_elemwise_sum_parallel(hexagon_session: Session): """Test parallel elementwise sum""" func = tvm.compile( diff --git a/tests/python/contrib/test_hexagon/test_vtcm.py b/tests/python/contrib/test_hexagon/test_vtcm.py index 7ac4327cc42a..f52fef0b9b73 100644 --- a/tests/python/contrib/test_hexagon/test_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_vtcm.py @@ -22,6 +22,7 @@ import tvm.testing from tvm import tirx from tvm.script import tirx as T +from tvm.testing import env from .infrastructure import get_hexagon_target @@ -47,7 +48,7 @@ def get_scale_by_two_schedule(): return sch -@tvm.testing.requires_hexagon +@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_vtcm_building(): """Test building with vtcm mem scope""" sch = get_scale_by_two_schedule() @@ -56,7 +57,7 @@ def test_vtcm_building(): assert "global.vtcm" in built.inspect_source("asm") -@tvm.testing.requires_hexagon +@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") @pytest.mark.parametrize("vtcm_capacity,limited", [(8192, False), (1024, False), (128, True)]) def test_vtcm_limit(vtcm_capacity, limited): """Test building with vtcm mem scope limit""" diff --git a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py index 3afe27a236bc..9f42a9bbdb9f 100644 --- a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py +++ b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py @@ -23,6 +23,7 @@ import tvm from tvm.s_tir.tensor_intrin.hexagon import DMA_READ_128_i8 from tvm.script import tirx as T +from tvm.testing import env from .infrastructure import get_hexagon_target @@ -136,7 +137,7 @@ class TestMatMulVec: unroll_split = tvm.testing.parameter(2) vector_split = tvm.testing.parameter(128) - @tvm.testing.requires_hexagon + @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") def test_bandwidth(self, hexagon_session, size, outer_split, unroll_split, vector_split): """Test bandwidth.""" diff --git a/tests/python/contrib/test_hipblas.py b/tests/python/contrib/test_hipblas.py index baecb21384d4..d64d2610b785 100644 --- a/tests/python/contrib/test_hipblas.py +++ b/tests/python/contrib/test_hipblas.py @@ -16,11 +16,13 @@ # under the License. # ruff: noqa: E741 import numpy as np +import pytest import tvm import tvm.testing from tvm import te from tvm.contrib import hipblas +from tvm.testing import env def verify_matmul_add(in_dtype, out_dtype, rtol=1e-5): @@ -76,7 +78,8 @@ def verify_batch_matmul(Ashape, Bshape, Cshape, in_dtype, out_dtype, rtol=1e-5): ) -@tvm.testing.requires_rocm +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_rocm(), reason="need rocm") def test_matmul_add(): verify_matmul_add("float", "float", rtol=1e-3) verify_matmul_add("float16", "float") @@ -84,7 +87,8 @@ def test_matmul_add(): verify_matmul_add("int8", "int32") -@tvm.testing.requires_rocm +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_rocm(), reason="need rocm") def test_batch_matmul(): if not tvm.get_global_func("tvm.contrib.hipblas.batch_matmul", True): print("skip because extern function is not available") diff --git a/tests/python/contrib/test_random.py b/tests/python/contrib/test_random.py index 5d946f87a89a..7a482284129e 100644 --- a/tests/python/contrib/test_random.py +++ b/tests/python/contrib/test_random.py @@ -21,6 +21,7 @@ import threading import numpy as np +import pytest import tvm import tvm.testing @@ -102,7 +103,7 @@ def verify(target="llvm"): verify() -@tvm.testing.uses_gpu +@pytest.mark.gpu def test_random_fill(): """Tests random_fill function""" diff --git a/tests/python/contrib/test_tir_triton_integration.py b/tests/python/contrib/test_tir_triton_integration.py index 33d7962e8ff3..c424fe8b591a 100644 --- a/tests/python/contrib/test_tir_triton_integration.py +++ b/tests/python/contrib/test_tir_triton_integration.py @@ -27,6 +27,7 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tirx as T +from tvm.testing import env try: import triton @@ -39,7 +40,8 @@ pytestmark = pytest.skip("Triton >= 3.3.0 is required", allow_module_level=True) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_tir_triton_integration(): @triton.jit def add_kernel( diff --git a/tests/python/disco/test_callback.py b/tests/python/disco/test_callback.py index 9fdf9aaae62e..7df8b560bfa2 100644 --- a/tests/python/disco/test_callback.py +++ b/tests/python/disco/test_callback.py @@ -21,14 +21,17 @@ import tempfile import numpy as np +import pytest import tvm import tvm.testing from tvm.script import relax as R from tvm.script import tirx as T +from tvm.testing import env -@tvm.testing.requires_nccl +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_nccl(), reason="need nccl") def test_callback(): """Simulate lazy loading of parameters in a callback diff --git a/tests/python/disco/test_loader.py b/tests/python/disco/test_loader.py index 290b9f401f31..2b379dc71137 100644 --- a/tests/python/disco/test_loader.py +++ b/tests/python/disco/test_loader.py @@ -34,6 +34,7 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.target import Target +from tvm.testing import env # `runtime.disco.compiled_ccl` is registered together with the CCL runtime # functions, so its absence means the disco CCL runtime is not in this build. @@ -42,7 +43,9 @@ pytest.skip("Disco NCCL support is not available", allow_module_level=True) # All tests in this file shard across two GPUs. -pytestmark = tvm.testing.requires_multi_gpu.marks() +pytestmark = [ + pytest.mark.skipif(not env.has_multi_gpu(), reason="need multiple gpus"), +] @register_global_func("tests.disco.shard_dim_0", override=True) diff --git a/tests/python/disco/test_nvshmem.py b/tests/python/disco/test_nvshmem.py index 5d70ccf6bd19..cfa7755915b0 100644 --- a/tests/python/disco/test_nvshmem.py +++ b/tests/python/disco/test_nvshmem.py @@ -36,11 +36,15 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tirx as T +from tvm.testing import env if di is None: pytest.skip("disco runtime is not available", allow_module_level=True) -pytestmark = tvm.testing.requires_nvshmem.marks() +pytestmark = [ + pytest.mark.gpu, + pytest.mark.skipif(not env.has_nvshmem(), reason="need nvshmem"), +] _SOCKET_SESSION_TESTER = None diff --git a/tests/python/nightly/test_nnapi/test_network.py b/tests/python/nightly/test_nnapi/test_network.py index bbfc01067f7e..85b8b31456af 100644 --- a/tests/python/nightly/test_nnapi/test_network.py +++ b/tests/python/nightly/test_nnapi/test_network.py @@ -28,6 +28,7 @@ from test_nnapi.infrastructure import build_and_run # , build_and_run_vm from tvm.contrib.download import download_testdata from tvm.relax.frontend.onnx import from_onnx +from tvm.testing import env def _build_and_run_network(remote_obj, tracker, mod, input_data): @@ -115,7 +116,7 @@ def create_model(name): "float32", ], ) -@tvm.testing.requires_nnapi +@pytest.mark.skipif(not env.has_nnapi(), reason="need nnapi") def test_network(name, dtype): remote_obj, tracker = remote() print(f"Network evaluating {name} with dtype {dtype}") diff --git a/tests/python/relax/backend/adreno/utils.py b/tests/python/relax/backend/adreno/utils.py index d1153ff41709..608530c32590 100644 --- a/tests/python/relax/backend/adreno/utils.py +++ b/tests/python/relax/backend/adreno/utils.py @@ -19,6 +19,7 @@ import tempfile import numpy as np +import pytest import tvm import tvm.testing @@ -55,50 +56,49 @@ def __call__(self): return self.check +def _adreno_requires(predicate, reason): + """Tag a GPU test with the ``gpu`` marker plus an eager runtime skip. + + The predicate is evaluated when the decorator is applied (at collection + time), so the skip condition is resolved eagerly. + """ + + def decorator(func): + func = pytest.mark.skipif(not predicate(), reason=reason)(func) + return pytest.mark.gpu(func) + + return decorator + + # OpenCL or Vulkan -requires_adreno_opencl_vulkan = tvm.testing.Feature( - "adreno_opencl_vulkan", - "Adreno Vulkan Or OpenCL", - run_time_check=run_time_check("any")(), - parent_features="gpu" if "ADRENO_TARGET" not in os.environ else "rpc", +requires_adreno_opencl_vulkan = _adreno_requires( + run_time_check("any").check, + "need adreno opencl or vulkan", ) # Any Vulkan -requires_adreno_vulkan = tvm.testing.Feature( - "adreno_vulkan", - "Adreno Vulkan", - target_kind_enabled="vulkan", - run_time_check=lambda: tvm.runtime.enabled("vulkan") and run_time_check("vulkan").check(), - parent_features="gpu" if "ADRENO_TARGET" not in os.environ else "rpc", +requires_adreno_vulkan = _adreno_requires( + lambda: tvm.runtime.enabled("vulkan") and run_time_check("vulkan").check(), + "need adreno vulkan", ) # Any OpenCL -requires_adreno_opencl = tvm.testing.Feature( - "adreno_opencl", - "Adreno OpenCL", - target_kind_enabled="opencl", - run_time_check=lambda: tvm.runtime.enabled("opencl") and run_time_check("opencl").check(), - parent_features="gpu" if "ADRENO_TARGET" not in os.environ else "rpc", +requires_adreno_opencl = _adreno_requires( + lambda: tvm.runtime.enabled("opencl") and run_time_check("opencl").check(), + "need adreno opencl", ) # Real Adreno GPU OpenCL Target -requires_adreno_opencl_real = tvm.testing.Feature( - "adreno_opencl_real", - "Adreno OpenCL Real", - target_kind_enabled="opencl", - run_time_check=lambda: tvm.runtime.enabled("opencl") and run_time_check("real").check(), - parent_features="rpc", +requires_adreno_opencl_real = _adreno_requires( + lambda: tvm.runtime.enabled("opencl") and run_time_check("real").check(), + "need real adreno opencl", ) # CLML Codegen -requires_adreno_clml = tvm.testing.Feature( - "adreno_clml", - "Adreno OpenCLML", - run_time_check=lambda: ( - tvm.get_global_func("relax.is_openclml_runtime_enabled", allow_missing=True) is not None - ), - target_kind_enabled="opencl", - parent_features="opencl" if "ADRENO_TARGET" not in os.environ else "rpc", +requires_adreno_clml = _adreno_requires( + lambda: tvm.get_global_func("relax.is_openclml_runtime_enabled", allow_missing=True) + is not None, + "need adreno openclml", ) diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index 76e702482947..6e0700d98ee7 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -20,6 +20,7 @@ import tvm import tvm.testing +from tvm.testing import env pytest.importorskip("scipy") # tvm.topi.testing imports scipy @@ -42,7 +43,10 @@ def reset_seed(): np.random.seed(0) -pytestmark = tvm.testing.requires_cublas.marks() +pytestmark = [ + pytest.mark.gpu, + pytest.mark.skipif(not env.has_cublas(), reason="need cublas"), +] def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False): @@ -303,7 +307,8 @@ def test_matmul_igemm_offload( tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed") @pytest.mark.parametrize( "x_shape, y_shape, transpose_y, out_dtype", @@ -341,7 +346,8 @@ def test_matmul_fp8_offload( tvm.testing.assert_allclose(out, ref_out, rtol=1e-3, atol=1e-3) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed") def test_matmul_fp8_dequantize_offload(): x_shape = (10, 32) @@ -367,7 +373,8 @@ def test_matmul_fp8_dequantize_offload(): tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed") def test_matmul_fp8_multiply_offload(): x_shape = (10, 32) diff --git a/tests/python/relax/test_codegen_cudnn.py b/tests/python/relax/test_codegen_cudnn.py index bd8749dc6b96..36b4c54e1f44 100644 --- a/tests/python/relax/test_codegen_cudnn.py +++ b/tests/python/relax/test_codegen_cudnn.py @@ -20,6 +20,7 @@ import tvm import tvm.testing +from tvm.testing import env pytest.importorskip("scipy") # tvm.topi.testing imports scipy @@ -38,7 +39,10 @@ def reset_seed(): np.random.seed(0) -pytestmark = tvm.testing.requires_cudnn.marks() +pytestmark = [ + pytest.mark.gpu, + pytest.mark.skipif(not env.has_cudnn(), reason="need cudnn"), +] _activation_table = { diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index af4f0805e397..ff662f4faaba 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -20,6 +20,7 @@ import tvm import tvm.testing +from tvm.testing import env pytest.importorskip("scipy") # tvm.topi.testing imports scipy @@ -83,7 +84,9 @@ def main( return conv2 -pytestmark = tvm.testing.requires_cutlass.marks() +pytestmark = [ + pytest.mark.skipif(not env.has_cutlass(), reason="need cutlass"), +] def build_and_run(mod, inputs_np, target, legalize=True, cuda_graph=False): diff --git a/tests/python/relax/test_codegen_hipblas.py b/tests/python/relax/test_codegen_hipblas.py index 156700994678..62c0ffcd9f12 100644 --- a/tests/python/relax/test_codegen_hipblas.py +++ b/tests/python/relax/test_codegen_hipblas.py @@ -19,6 +19,7 @@ import tvm import tvm.testing +from tvm.testing import env pytest.importorskip("scipy") # tvm.topi.testing imports scipy @@ -39,7 +40,10 @@ def reset_seed(): np.random.seed(0) -pytestmark = tvm.testing.requires_hipblas.marks() +pytestmark = [ + pytest.mark.gpu, + pytest.mark.skipif(not env.has_hipblas(), reason="need hipblas"), +] def build_and_run(mod, inputs_np, target, legalize=False): diff --git a/tests/python/relax/test_codegen_tensorrt.py b/tests/python/relax/test_codegen_tensorrt.py index b8bae635b39d..57390515d72d 100644 --- a/tests/python/relax/test_codegen_tensorrt.py +++ b/tests/python/relax/test_codegen_tensorrt.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# ruff: noqa: RUF005 import numpy as np import pytest @@ -24,6 +23,7 @@ from tvm.contrib.pickle_memoize import memoize from tvm.relax.dpl import is_op, make_fused_bias_activation_pattern, wildcard from tvm.script import relax as R +from tvm.testing import env @tvm.script.ir_module @@ -59,7 +59,9 @@ def main( pytestmark = [ requires_tensorrt_codegen, requires_tensorrt_runtime, -] + tvm.testing.requires_cuda.marks() + pytest.mark.gpu, + pytest.mark.skipif(not env.has_cuda(), reason="need cuda"), +] def build_and_run(mod, inputs_np, target, legalize=False): diff --git a/tests/python/relax/test_contrib_vllm.py b/tests/python/relax/test_contrib_vllm.py index 478fbe95c6f5..fc97859ee908 100644 --- a/tests/python/relax/test_contrib_vllm.py +++ b/tests/python/relax/test_contrib_vllm.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# ruff: noqa: RUF005 import numpy as np import pytest import tvm_ffi @@ -25,6 +24,7 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tirx as T +from tvm.testing import env has_vllm = tvm.get_global_func("tvm.contrib.vllm.single_query_cached_kv_attention", True) @@ -33,7 +33,11 @@ reason="VLLM not enabled.", ) -pytestmark = [vllm_enabled] + tvm.testing.requires_cuda.marks() +pytestmark = [ + vllm_enabled, + pytest.mark.gpu, + pytest.mark.skipif(not env.has_cuda(), reason="need cuda"), +] def build_and_run(mod, inputs_np, target, legalize=True): diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py index e9cb65d6047e..fa5878a006ec 100644 --- a/tests/python/relax/test_frontend_dynamo.py +++ b/tests/python/relax/test_frontend_dynamo.py @@ -32,6 +32,7 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tirx as T +from tvm.testing import env torch_version = torch.__version__ @@ -343,7 +344,8 @@ def _convert_data_type(input_type): raise NotImplementedError(f"input_type {input_type} is not handled yet") -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_ones(): import torch from torch.nn import Module @@ -374,7 +376,8 @@ def main( ) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_full(): import torch from torch.nn import Module @@ -405,7 +408,8 @@ def main( ) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_gelu(): import torch from torch.nn import Module @@ -457,7 +461,8 @@ def main( ) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_masked_fill(): import torch from torch.nn import Module @@ -494,7 +499,8 @@ def main( ) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_getitem(): import torch from torch.nn import Module @@ -568,7 +574,8 @@ def forward(self, input1): version.parse(torch_version) >= version.parse("2.6.0"), reason="Need to support dynamic arange in Relax", ) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_arange(): import torch from torch.nn import Module diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index b4f1c475d901..ee2f4a8f8df6 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -31,6 +31,7 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tirx as T +from tvm.testing import env def verify_model( @@ -8319,7 +8320,7 @@ def main( verify_model(SparseMatrixMultiply(), example_args, {}, Expected) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_lstm(): class LSTM(nn.Module): def __init__(self, input_size, hidden_size, batch_first, bidirectional): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index cdb343e73afa..34da69d5f061 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -33,6 +33,7 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tirx as T +from tvm.testing import env def verify_model(torch_model, input_info, binding, expected): @@ -900,7 +901,8 @@ def main( verify_model(Outer(), input_infos, {}, expected) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_softplus(): import torch from torch.nn import Module @@ -937,7 +939,8 @@ def main(inp_0: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype verify_model(Softplus1(), input_info, {}, expected) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_leakyrelu(): import torch from torch.nn import Module diff --git a/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py b/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py index d252eeb9d740..d86fa6ca5e2c 100644 --- a/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py +++ b/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py @@ -43,10 +43,12 @@ import math import numpy as np +import pytest import tvm import tvm.testing from tvm.relax.frontend.nn.llm.kv_cache import _attention_sequence_prefill_with_mask +from tvm.testing import env def _reference_masked_attention(q, k, v, valid_lens, sm_scale): @@ -182,7 +184,8 @@ def _run_case( np.testing.assert_allclose(got[b, pad_q:], ref[b, pad_q:], rtol=rtol, atol=atol) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") @tvm.testing.parametrize_targets("cuda", "metal") def test_valid_len_zero(target, dev): """All samples are fully padded: kernel must not crash and must stay bounded.""" @@ -198,7 +201,8 @@ def test_valid_len_zero(target, dev): ) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") @tvm.testing.parametrize_targets("cuda", "metal") def test_valid_len_full(target, dev): """All samples are fully valid: must match a plain unmasked attention.""" @@ -214,7 +218,8 @@ def test_valid_len_full(target, dev): ) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") @tvm.testing.parametrize_targets("cuda", "metal") def test_valid_len_mixed(target, dev): """Typical encoder batch with different valid lengths per sample.""" @@ -230,7 +235,8 @@ def test_valid_len_mixed(target, dev): ) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") @tvm.testing.parametrize_targets("cuda", "metal") def test_valid_len_mixed_gqa(target, dev): """Grouped-query attention: ``group_size = h_q / h_kv > 1``.""" @@ -246,7 +252,8 @@ def test_valid_len_mixed_gqa(target, dev): ) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") @tvm.testing.parametrize_targets("cuda", "metal") def test_causal_padded_left_valid_len_zero(target, dev): """Causal left-pad: all samples are fully padded.""" @@ -263,7 +270,8 @@ def test_causal_padded_left_valid_len_zero(target, dev): ) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") @tvm.testing.parametrize_targets("cuda", "metal") def test_causal_padded_left_valid_len_full(target, dev): """Causal left-pad: all samples are fully valid — degenerates to plain causal attention.""" @@ -280,7 +288,8 @@ def test_causal_padded_left_valid_len_full(target, dev): ) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") @tvm.testing.parametrize_targets("cuda", "metal") def test_causal_padded_left_valid_len_mixed(target, dev): """Causal left-pad: typical decoder-embedding batch with mixed lengths.""" @@ -297,7 +306,8 @@ def test_causal_padded_left_valid_len_mixed(target, dev): ) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") @tvm.testing.parametrize_targets("cuda", "metal") def test_causal_padded_left_valid_len_mixed_gqa(target, dev): """Causal left-pad: grouped-query attention with mixed lengths.""" @@ -314,7 +324,8 @@ def test_causal_padded_left_valid_len_mixed_gqa(target, dev): ) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") @tvm.testing.parametrize_targets("cuda", "metal") def test_causal_padded_left_qo_len_differs_from_kv_len(target, dev): """Causal left-pad: Q and K/V may have different padded lengths.""" diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index a7db885c4abe..51a0d1e9f0f0 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -17,6 +17,7 @@ # pylint: disable=missing-docstring, invalid-name # ruff: noqa: E501, F841 import numpy as np +import pytest import tvm import tvm.testing @@ -25,6 +26,7 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tirx as T +from tvm.testing import env # mypy: disable-error-code="attr-defined,valid-type,name-defined" @@ -927,7 +929,8 @@ def test(self): vm["test"](*effects) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_multinomial_from_uniform(): prob_shape = (3, 5) sample_shape = (6, 1) @@ -999,7 +1002,8 @@ def foo(prob: R.Tensor((3, 5), dtype="float32"), uniform_sample: R.Tensor((6, 1) ) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_sample_top_p_top_k_from_sorted_prob(): prob_shape = (2, 3) sample_shape = (3, 1) @@ -1131,7 +1135,8 @@ def foo(prob: R.Tensor((2, 3), dtype="float32"), index: R.Tensor((2, 3), dtype=" tvm.testing.assert_allclose(res[0].numpy(), np.array([[2], [0], [0]]).astype(np.int64)) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_renormalize_top_p_top_k_prob(): prob_shape = (2, 3) sample_shape = (2, 1) diff --git a/tests/python/relax/test_frontend_stablehlo.py b/tests/python/relax/test_frontend_stablehlo.py index 88bdbf301087..f1b426d935c2 100644 --- a/tests/python/relax/test_frontend_stablehlo.py +++ b/tests/python/relax/test_frontend_stablehlo.py @@ -35,6 +35,7 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tirx as T +from tvm.testing import env def generate_np_inputs( @@ -170,7 +171,8 @@ def get_vm_res( return tvm_output -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_add_dynamic(): add_dyn = """ func.func @test(%arg0: tensor, %arg1: tensor) -> tensor { @@ -201,7 +203,8 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_unary(): import jax @@ -234,7 +237,8 @@ def _round(x): check_correctness(jax.jit(fn), input_shapes) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_binary(): import jax @@ -255,7 +259,8 @@ def fn(x, y): check_correctness(jit_fn, input_shapes) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_const(): import jax @@ -265,7 +270,8 @@ def fn(x): check_correctness(jax.jit(fn), (2,)) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_maximum(): import jax import jax.numpy as jnp @@ -276,7 +282,8 @@ def fn(x, y): check_correctness(jax.jit(fn), ((2, 3), (2, 3))) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_minimum(): import jax import jax.numpy as jnp @@ -287,7 +294,8 @@ def fn(x, y): check_correctness(jax.jit(fn), ((2, 3), (2, 3))) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") @pytest.mark.skip( reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed." ) @@ -302,7 +310,8 @@ def fn(x): check_correctness(jax.jit(fn), (2, 3, 4, 5)) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") @pytest.mark.skip( reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed." ) @@ -317,7 +326,8 @@ def fn(x): check_correctness(jax.jit(fn), (2, 3, 4)) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_dot_general(): import jax @@ -328,7 +338,8 @@ def fn(x, y): check_correctness(jax.jit(fn), input_shapes) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") @pytest.mark.skip( reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed." ) diff --git a/tests/python/relax/test_op_vision.py b/tests/python/relax/test_op_vision.py index 2e8715456072..075f49e9ca3e 100644 --- a/tests/python/relax/test_op_vision.py +++ b/tests/python/relax/test_op_vision.py @@ -20,6 +20,7 @@ import tvm import tvm.testing +from tvm.testing import env pytest.importorskip("scipy") # tvm.topi.testing imports scipy @@ -570,7 +571,7 @@ def main( ) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_get_valid_counts_e2e(): """Run get_valid_counts through legalization and compare with the numpy reference.""" @@ -685,7 +686,7 @@ def _run_nms_e2e( ) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_nms_e2e_return_indices(): """Run classic NMS through legalization and compare with the numpy reference.""" @@ -728,7 +729,7 @@ def test_nms_e2e_return_indices(): tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_nms_e2e_soft_nms_reorders_by_decayed_score(): """Soft-NMS should re-rank by decayed scores instead of keeping the initial order.""" @@ -779,7 +780,7 @@ def test_nms_e2e_soft_nms_reorders_by_decayed_score(): tvm.testing.assert_allclose(result[2].numpy(), ref_valid_box_count) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_nms_e2e_return_indices_with_invalid_to_bottom(): """Validate that invalid_to_bottom is a no-op when returning indices.""" @@ -822,7 +823,7 @@ def test_nms_e2e_return_indices_with_invalid_to_bottom(): tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_nms_e2e_top_k(): """Validate that classic NMS honors top_k before suppression.""" @@ -869,7 +870,7 @@ def test_nms_e2e_top_k(): np.testing.assert_array_equal(ref_valid_box_count, np.array([[2]], dtype="int32")) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_nms_e2e_force_suppress(): """Validate that force_suppress ignores class ids when suppressing overlaps.""" @@ -914,7 +915,7 @@ def test_nms_e2e_force_suppress(): np.testing.assert_array_equal(ref_valid_box_count, np.array([[2]], dtype="int32")) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_nms_e2e_max_output_size(): """Validate that max_output_size truncates the kept boxes after score sorting.""" @@ -960,7 +961,7 @@ def test_nms_e2e_max_output_size(): np.testing.assert_array_equal(ref_valid_box_count, np.array([[2]], dtype="int32")) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_nms_e2e_multi_batch(): """Validate that classic NMS processes each batch independently.""" @@ -1013,7 +1014,7 @@ def test_nms_e2e_multi_batch(): np.testing.assert_array_equal(ref_valid_box_count, np.array([[2], [3]], dtype="int32")) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_nms_e2e_invalid_to_bottom(): """Validate that invalid_to_bottom compacts only boxes that remain valid after NMS.""" @@ -1068,7 +1069,7 @@ def test_nms_e2e_invalid_to_bottom(): tvm.testing.assert_allclose(result.numpy(), expected_out_data) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_nms_e2e_return_data_without_compaction(): """Validate the return_indices=False path when invalid boxes stay in-place.""" @@ -1123,7 +1124,7 @@ def test_nms_e2e_return_data_without_compaction(): tvm.testing.assert_allclose(result.numpy(), expected_out_data) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_nms_e2e_index_remap(): """Validate that returned indices remap from filtered order back to original order.""" @@ -1349,7 +1350,7 @@ def main( ) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_all_class_non_max_suppression_legalize_e2e(): @tvm.script.ir_module class NMSModule: @@ -1533,7 +1534,7 @@ def _softmax(x, axis): return boxes, scores -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_multibox_transform_loc_legalize_e2e(): @tvm.script.ir_module class Mod: @@ -1581,7 +1582,7 @@ def main( tvm.testing.assert_allclose(out[1].numpy(), ref_s, rtol=1e-4, atol=1e-5) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_multibox_transform_loc_legalize_e2e_nonunity_variances(): @tvm.script.ir_module class Mod: @@ -1629,7 +1630,7 @@ def main( tvm.testing.assert_allclose(out[1].numpy(), ref_s, rtol=1e-4, atol=1e-5) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_multibox_transform_loc_legalize_attr_branches(): @tvm.script.ir_module class Mod: diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py index ea63c2d21e65..548abfbe5a32 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py @@ -35,6 +35,7 @@ _merge_state_inplace, ) from tvm.s_tir import dlight as dl +from tvm.testing import env reserved_nseq = 32 maximum_total_seq_length = 2048 @@ -412,8 +413,8 @@ def apply_attention( verify_cached_kv(kv_cache, seq_ids, cached_kv) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config): (kv_cache,) = kv_cache_and_config fclear(kv_cache) @@ -433,8 +434,8 @@ def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config): apply_attention(kv_cache, batch, cached_kv) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_config): (kv_cache,) = kv_cache_and_config fclear(kv_cache) @@ -454,8 +455,8 @@ def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_config): ) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config): (kv_cache,) = kv_cache_and_config fclear(kv_cache) @@ -524,8 +525,8 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config): apply_attention(kv_cache, [(10, 1), (12, 1)], cached_kv) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_paged_attention_kv_cache_popn(kv_cache_and_config): (kv_cache,) = kv_cache_and_config fclear(kv_cache) diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index aa679c649b1f..b33721e5280e 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -40,6 +40,7 @@ tree_attn_with_paged_kv_cache, ) from tvm.s_tir import dlight as dl +from tvm.testing import env reserved_nseq = 32 maximum_total_seq_length = 2048 @@ -587,8 +588,8 @@ def apply_attention( verify_cached_kv(kv_cache, seq_ids, cached_k, cached_v) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config): kv_cache, rope_mode, support_sliding_window = kv_cache_and_config if support_sliding_window and rope_mode == RopeMode.NORMAL: @@ -612,8 +613,8 @@ def test_paged_attention_kv_cache_prefill_and_decode(kv_cache_and_config): apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_config): kv_cache, rope_mode, support_sliding_window = kv_cache_and_config if support_sliding_window and rope_mode == RopeMode.NORMAL: @@ -639,8 +640,8 @@ def test_paged_attention_kv_cache_remove_sequence(kv_cache_and_config): ) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config): kv_cache, rope_mode, support_sliding_window = kv_cache_and_config if support_sliding_window and rope_mode == RopeMode.NORMAL: @@ -717,8 +718,8 @@ def test_paged_attention_kv_cache_fork_sequence(kv_cache_and_config): apply_attention(kv_cache, rope_mode, [(10, 1), (12, 1)], cached_k, cached_v) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_paged_attention_kv_cache_unlimited_depth(kv_cache_and_config): kv_cache, rope_mode, support_sliding_window = kv_cache_and_config if support_sliding_window and rope_mode == RopeMode.NORMAL: @@ -768,8 +769,8 @@ def test_paged_attention_kv_cache_unlimited_depth(kv_cache_and_config): assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences" -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_paged_attention_kv_cache_popn(kv_cache_and_config): kv_cache, rope_mode, support_sliding_window = kv_cache_and_config if support_sliding_window and rope_mode == RopeMode.NORMAL: @@ -803,8 +804,8 @@ def test_paged_attention_kv_cache_popn(kv_cache_and_config): assert fis_empty(kv_cache), "The KV cache is not empty after removing all sequences" -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_paged_attention_kv_cache_sliding_window(kv_cache_and_config): kv_cache, rope_mode, support_sliding_window = kv_cache_and_config if not support_sliding_window or rope_mode == RopeMode.NORMAL: @@ -855,8 +856,8 @@ def test_paged_attention_kv_cache_sliding_window(kv_cache_and_config): ) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_paged_attention_kv_cache_sliding_window_fork(kv_cache_and_config): kv_cache, rope_mode, support_sliding_window = kv_cache_and_config if not support_sliding_window or rope_mode == RopeMode.NORMAL: @@ -928,8 +929,8 @@ def test_paged_attention_kv_cache_sliding_window_fork(kv_cache_and_config): # seq_len: [15+6, 20+13, 25+7, 38, 41, 43, 24+6] -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config): kv_cache, rope_mode, support_sliding_window = kv_cache_and_config if support_sliding_window: diff --git a/tests/python/relax/test_runtime_builtin_rnn_state.py b/tests/python/relax/test_runtime_builtin_rnn_state.py index bc54ce1fd163..89276bb8240f 100644 --- a/tests/python/relax/test_runtime_builtin_rnn_state.py +++ b/tests/python/relax/test_runtime_builtin_rnn_state.py @@ -26,6 +26,7 @@ from tvm import tirx from tvm.s_tir import dlight as dl from tvm.script import tirx as T +from tvm.testing import env # pylint: disable=invalid-name @@ -116,7 +117,8 @@ def verify_state(state, seq_ids, expected_values): tvm.testing.assert_allclose(state_value.numpy(), expected_value) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_rnn_state_get(rnn_state): # pylint: disable=redefined-outer-name state = rnn_state f_clear(state) @@ -131,7 +133,8 @@ def test_rnn_state_get(rnn_state): # pylint: disable=redefined-outer-name tvm.testing.assert_allclose(tvm_nd_1.numpy(), np.ones((1, 32, 32), "float32")) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_rnn_state_set(rnn_state): # pylint: disable=redefined-outer-name state = rnn_state f_clear(state) @@ -147,7 +150,8 @@ def test_rnn_state_set(rnn_state): # pylint: disable=redefined-outer-name verify_state(state, [0, 1, 2], expected_values) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_rnn_state_popn(rnn_state): # pylint: disable=redefined-outer-name state = rnn_state f_clear(state) @@ -165,7 +169,8 @@ def test_rnn_state_popn(rnn_state): # pylint: disable=redefined-outer-name f_popn(state, 0, 1) # no available history to pop -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_rnn_state_fork_sequence(rnn_state): # pylint: disable=redefined-outer-name state = rnn_state f_clear(state) diff --git a/tests/python/relax/test_tir_call_source_kernel.py b/tests/python/relax/test_tir_call_source_kernel.py index 450b03bb879c..9e406dec7258 100644 --- a/tests/python/relax/test_tir_call_source_kernel.py +++ b/tests/python/relax/test_tir_call_source_kernel.py @@ -16,6 +16,7 @@ # under the License. import numpy as np +import pytest import tvm import tvm.testing @@ -23,6 +24,7 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tirx as T +from tvm.testing import env add_cuda_source = """ extern "C" __global__ void add_kernel(float* x, float* y, float* output, int n_elements) { @@ -34,7 +36,8 @@ """ -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_tir_call_source_kernel(): @I.ir_module(s_tir=True) class Module: diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py index 65f2ec144415..c690eac5603b 100644 --- a/tests/python/relax/test_transform_codegen_pass.py +++ b/tests/python/relax/test_transform_codegen_pass.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# ruff: noqa: E501, F401, RUF005 +# ruff: noqa: E501, F401 import os import tempfile @@ -31,6 +31,7 @@ from tvm.script import relax as R from tvm.script import tirx as T from tvm.support import utils +from tvm.testing import env env_checker_codegen = tvm.get_global_func("relax.ext.tensorrt", True) env_checker_runtime = tvm.get_global_func("relax.is_tensorrt_runtime_enabled", True) @@ -45,7 +46,11 @@ ) # Global variable in pytest that applies markers to all tests. -pytestmark = [requires_tensorrt_codegen] + tvm.testing.requires_cuda.marks() +pytestmark = [ + requires_tensorrt_codegen, + pytest.mark.gpu, + pytest.mark.skipif(not env.has_cuda(), reason="need cuda"), +] # Target gpu target_str = "nvidia/nvidia-t4" @@ -121,7 +126,8 @@ def setup_test(): entry_func_name = tvm.testing.parameter("main", "func") -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") @requires_tensorrt_runtime def test_tensorrt_only(entry_func_name): mod, inputs, expected = setup_test() @@ -151,7 +157,8 @@ def test_tensorrt_only(entry_func_name): check_roundtrip(ex0, dev, inputs, expected, entry_func_name) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") @requires_tensorrt_runtime def test_mix_use_tensorrt_and_tvm(): mod, inputs, expected = setup_test() diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 2ef1f780ef06..c85d4d8d744e 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -34,6 +34,7 @@ from tvm.script import relax as R from tvm.script import tirx as T from tvm.support import cc, popen_pool, utils +from tvm.testing import env EXEC_MODE = ["bytecode", "compiled"] @@ -471,7 +472,8 @@ def test_vm_emit_te_constant_param_cpu(exec_mode): tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7, atol=1e-7) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_vm_emit_te_constant_param_gpu(exec_mode): x_np = np.random.rand(2, 2).astype("float32") c_np = np.random.rand(2, 2).astype("float32") @@ -852,7 +854,8 @@ def recursion(n: R.Tensor((1,), "float32")) -> R.Tensor: tvm.testing.assert_allclose(res.numpy(), np.power(2.0, recursion_runs), rtol=1e-7, atol=1e-7) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_vm_to_device(exec_mode): @tvm.script.ir_module class TestToVDevice: @@ -1260,7 +1263,8 @@ def test_set_input_get_failure_rpc(exec_mode): run_on_rpc(TestVMSetInput, set_input_attempt_get, exec_mode) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_relax_module_with_multiple_targets(exec_mode): """Relax functions may contain kernels for multiple targets diff --git a/tests/python/relax/test_vm_cuda_graph.py b/tests/python/relax/test_vm_cuda_graph.py index d558e6f51bec..53450a6fdf67 100644 --- a/tests/python/relax/test_vm_cuda_graph.py +++ b/tests/python/relax/test_vm_cuda_graph.py @@ -25,6 +25,7 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tirx as T +from tvm.testing import env # fmt: off @@ -94,7 +95,8 @@ def codegen(mod, target, exec_mode="bytecode"): return relax.vm_build._vmlink(builder, target, tir_mod) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_vm_run(): mod = Module target = tvm.target.Target("cuda", host="llvm") @@ -108,7 +110,8 @@ def test_vm_run(): tvm.testing.assert_allclose(y.numpy(), y_np, rtol=1e-5, atol=1e-5) -@tvm.testing.requires_cudagraph +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cudagraph(), reason="need cudagraph") def test_capture_error_is_recoverable(): """Function calls while capturing cudagraph may throw exceptions diff --git a/tests/python/relax/test_vm_multi_device.py b/tests/python/relax/test_vm_multi_device.py index dbbf8a701440..7ac37a818dbd 100644 --- a/tests/python/relax/test_vm_multi_device.py +++ b/tests/python/relax/test_vm_multi_device.py @@ -26,6 +26,7 @@ from tvm.runtime import Device from tvm.script.parser import ir as I from tvm.script.parser import relax as R +from tvm.testing import env def compile( @@ -86,7 +87,7 @@ def foo( tvm.testing.assert_allclose(res.numpy(), np_res) -@tvm.testing.requires_multi_gpu +@pytest.mark.skipif(not env.has_multi_gpu(), reason="need multiple gpus") def test_multi_gpu(): if not tvm.cuda(2).exist: pytest.skip("requires at least 3 visible CUDA devices") @@ -145,7 +146,8 @@ def foo( tvm.testing.assert_allclose(res.numpy(), np_res) -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_multi_device(): @I.ir_module class Example: diff --git a/tests/python/relax/texture/test_texture_nd.py b/tests/python/relax/texture/test_texture_nd.py index 201faec112f8..3c3447749d8f 100644 --- a/tests/python/relax/texture/test_texture_nd.py +++ b/tests/python/relax/texture/test_texture_nd.py @@ -36,6 +36,7 @@ from tvm.script import tirx as T from tvm.support import ndk from tvm.target import Target +from tvm.testing import env def get_rpc(): @@ -105,8 +106,9 @@ def postprocess_pipeline(mod: IRModule) -> IRModule: return mod -@tvm.testing.requires_rpc -@tvm.testing.requires_adreno_opencl +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_adreno_opencl(), reason="need adreno opencl") @pytest.mark.parametrize("backend", ["opencl"]) @pytest.mark.parametrize("dtype", ["int8", "float16", "int16", "float32", "int32"]) @pytest.mark.parametrize("channel_size", [64, 128]) diff --git a/tests/python/runtime/test_runtime_module_export.py b/tests/python/runtime/test_runtime_module_export.py index bb6727c0f7f4..623e6fe29f17 100644 --- a/tests/python/runtime/test_runtime_module_export.py +++ b/tests/python/runtime/test_runtime_module_export.py @@ -15,12 +15,15 @@ # specific language governing permissions and limitations # under the License. +import pytest + import tvm import tvm.testing from tvm.support import utils +from tvm.testing import env -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_import_static_library(): from tvm import te diff --git a/tests/python/runtime/test_runtime_module_load.py b/tests/python/runtime/test_runtime_module_load.py index 6a1eaa8e7e5a..fd6e2d1198d5 100644 --- a/tests/python/runtime/test_runtime_module_load.py +++ b/tests/python/runtime/test_runtime_module_load.py @@ -24,6 +24,7 @@ import tvm.testing from tvm import te from tvm.support import cc, popen_pool, utils +from tvm.testing import env runtime_py = """ import os @@ -43,7 +44,7 @@ """ -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") @pytest.mark.parametrize("target", ["llvm", {"kind": "llvm", "jit": "mcjit"}]) def test_dso_module_load(target): dtype = "int64" @@ -96,7 +97,8 @@ def save_object(names): assert proc.returncode == 0, f"{proc.args} exited with {proc.returncode}: {proc.stdout}" -@tvm.testing.requires_gpu +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_gpu(), reason="need gpu") def test_device_module_dump(): pytest.importorskip("cloudpickle") # needed by popen_pool.PopenWorker @@ -154,7 +156,7 @@ def check_c(device): check_c(device) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_combine_module_llvm(): """Test combine multiple module into one shared lib.""" pytest.importorskip("cloudpickle") # needed by popen_pool.PopenWorker diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index 079116281f3a..b48d9631dc83 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -38,6 +38,7 @@ from tvm.script import ir as I from tvm.script import tirx as T from tvm.support import cc, utils +from tvm.testing import env if __name__ == "__main__": # NOTE: must live here to avoid registering PackedFunc with libtvm_compiler.so twice. @@ -65,7 +66,7 @@ # to ensure all the remote resources destructs before the server terminates -@tvm.testing.requires_rpc +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") def test_bigendian_rpc(): """Test big endian rpc when there is a PowerPC RPC server available""" host = os.environ.get("TVM_POWERPC_TEST_HOST", None) @@ -96,7 +97,7 @@ def verify_rpc(remote, target, shape, dtype): verify_rpc(remote, target, (10,), dtype) -@tvm.testing.requires_rpc +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") def test_rpc_simple(): server = rpc.Server(key="x1") client = rpc.connect("127.0.0.1", server.port, key="x1") @@ -115,7 +116,7 @@ def check_remote(): check_remote() -@tvm.testing.requires_rpc +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") def test_rpc_runtime_string(): server = rpc.Server(key="x1") client = rpc.connect("127.0.0.1", server.port, key="x1") @@ -129,7 +130,7 @@ def check_remote(): check_remote() -@tvm.testing.requires_rpc +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") def test_rpc_array(): server = rpc.Server() remote = rpc.connect("127.0.0.1", server.port) @@ -145,7 +146,7 @@ def check_remote(): check_remote() -@tvm.testing.requires_rpc +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") def test_rpc_large_array(): # testcase of large array creation server = rpc.Server() @@ -164,7 +165,7 @@ def check_remote(): @tvm.testing.skip_if_32bit(reason="skipping test for i386.") -@tvm.testing.requires_rpc +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") def test_rpc_echo(): def check(remote, local_session): fecho = remote.get_function("testing.echo") @@ -213,7 +214,7 @@ def check_minrpc(): # check_minrpc() -@tvm.testing.requires_rpc +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") def test_rpc_file_exchange(): server = rpc.Server() remote = rpc.connect("127.0.0.1", server.port) @@ -227,8 +228,8 @@ def check_remote(): check_remote() -@tvm.testing.requires_rpc -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_rpc_remote_module(): # graph n = tvm.runtime.convert(102) @@ -338,7 +339,7 @@ def check_remote_link_cl(remote): check_minrpc() -@tvm.testing.requires_rpc +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") def test_rpc_return_func(): server = rpc.Server(key="x1") client = rpc.connect("127.0.0.1", server.port, key="x1") @@ -351,7 +352,7 @@ def check_remote(): check_remote() -@tvm.testing.requires_rpc +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") def test_rpc_session_constructor_args(): # start server server0 = rpc.Server(key="x0") @@ -388,7 +389,7 @@ def check_error_handling(): check_error_handling() -@tvm.testing.requires_rpc +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") def test_rpc_return_tensor(): def run_arr_test(): server = rpc.Server(key="x1") @@ -409,7 +410,7 @@ def run_arr_test(): run_arr_test() -@tvm.testing.requires_rpc +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") def test_rpc_return_remote_object(): def check(client, is_local): make_shape = client.get_function("ffi.Shape") @@ -455,7 +456,7 @@ def check_minrpc(): check_minrpc() -@tvm.testing.requires_rpc +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") def test_local_func(): client = rpc.LocalSession() @@ -472,7 +473,7 @@ def check_remote(): check_remote() -@tvm.testing.requires_rpc +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") @pytest.mark.parametrize("device_key", ["test_device", "127.0.0.1:5555"]) def test_rpc_tracker_register(device_key): # test registration @@ -545,7 +546,7 @@ def _target(host, port, device_key, timeout): remote.cpu() -@tvm.testing.requires_rpc +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") @pytest.mark.parametrize("device_key", ["test_device", "127.0.0.1:5555"]) def test_rpc_tracker_request(device_key): # test concurrent request @@ -586,7 +587,7 @@ def test_rpc_tracker_request(device_key): tracker.terminate() -@tvm.testing.requires_rpc +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") @pytest.mark.parametrize("device_key", ["test_device", "127.0.0.1:5555"]) def test_rpc_tracker_via_proxy(device_key): """ @@ -628,7 +629,7 @@ def test_rpc_tracker_via_proxy(device_key): tracker_server.terminate() -@tvm.testing.requires_rpc +@pytest.mark.skipif(not env.has_rpc(), reason="need rpc") @pytest.mark.parametrize("with_proxy", (True, False)) def test_rpc_session_timeout_error(with_proxy): port = 9000 diff --git a/tests/python/s_tir/dlight/test_primitives.py b/tests/python/s_tir/dlight/test_primitives.py index b21e007396f5..ec7f7dc2bfa9 100644 --- a/tests/python/s_tir/dlight/test_primitives.py +++ b/tests/python/s_tir/dlight/test_primitives.py @@ -17,9 +17,12 @@ # pylint: disable=missing-docstring # ruff: noqa: F841 +import pytest + import tvm import tvm.testing from tvm.script import tirx as T +from tvm.testing import env @T.prim_func(s_tir=True) @@ -50,7 +53,8 @@ def main(p0: T.Buffer((), "int32"), T_stack: T.Buffer((T.int64(3),), "int32")): ) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_normalize_primfunc_with_scalar(): sch = tvm.s_tir.Schedule(main) f_normalize_prim_func = tvm.get_global_func("s_tir.schedule.NormalizePrimFunc") diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_mma_tensorize.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_mma_tensorize.py index a32997e4c53a..2f18af3d602d 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_mma_tensorize.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_mma_tensorize.py @@ -24,6 +24,7 @@ import tvm.testing from tvm.s_tir.schedule import Schedule from tvm.script import tirx as T +from tvm.testing import env torch = pytest.importorskip("torch") @@ -65,8 +66,8 @@ def main( C[vi, vj] = C[vi, vj] + T.cast(A[vi, vk], "float32") * T.cast(B[vk, vj], "float32") -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_run_target(mod=None, tgt_str=None, in_dtype="float16", out_dtype="float16"): if mod is None: return @@ -93,8 +94,8 @@ def test_run_target(mod=None, tgt_str=None, in_dtype="float16", out_dtype="float torch.allclose(c_th, c_f, rtol=0.05, atol=0.05) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_f16f16f16_mma_gemm(): # fmt: off mod = Gemm_F16F16F16 @@ -212,8 +213,8 @@ def test_f16f16f16_mma_gemm(): test_run_target(mod) -@tvm.testing.requires_gpu -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_f16f16f32_mma_gemm(): mod = Gemm_F16F16F32 sch = Schedule(mod) diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_post_opt.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_post_opt.py index d8e45d52d08f..a058959bdbad 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_post_opt.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_post_opt.py @@ -28,6 +28,7 @@ from tvm.s_tir.meta_schedule.runner.config import EvaluatorConfig from tvm.script import tirx as T from tvm.target import Target +from tvm.testing import env logging.basicConfig() logging.getLogger("tvm.s_tir.meta_schedule").setLevel(logging.DEBUG) @@ -47,7 +48,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: @pytest.mark.skip("Integration test") -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_tune_matmul_cpu(): with tempfile.TemporaryDirectory() as work_dir: target = Target({"kind": "llvm", "num-cores": 16}) @@ -85,7 +86,8 @@ def test_tune_matmul_cpu(): @pytest.mark.skip("Integration test") -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_tune_matmul_cuda(): with tempfile.TemporaryDirectory() as work_dir: target = Target("nvidia/geforce-rtx-3070") diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_tir.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_tir.py index a3bce951e207..155767db553b 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_tir.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_tir.py @@ -32,6 +32,7 @@ from tvm.s_tir.schedule import SBlockRV, Schedule from tvm.script import tirx as T from tvm.target import Target +from tvm.testing import env logging.basicConfig() logging.getLogger("tvm.s_tir.meta_schedule").setLevel(logging.DEBUG) @@ -66,7 +67,7 @@ def two_step(a: T.handle, c: T.handle) -> None: @pytest.mark.skip("Integration test") -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_tune_matmul_cpu(): with tempfile.TemporaryDirectory() as work_dir: target = Target({"kind": "llvm", "num-cores": 16}) @@ -86,7 +87,8 @@ def test_tune_matmul_cpu(): @pytest.mark.skip("Integration test") -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_tune_matmul_cuda(): with tempfile.TemporaryDirectory() as work_dir: target = Target("nvidia/geforce-rtx-3070") diff --git a/tests/python/s_tir/schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py b/tests/python/s_tir/schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py index b9db349ca414..3d7a3fdf595b 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py @@ -17,6 +17,7 @@ # pylint: disable=missing-docstring # ruff: noqa: E501 import numpy as np +import pytest import tvm import tvm.testing @@ -50,6 +51,7 @@ shared_16x32_to_ldmatrix_32x16_layout, shared_32x16_to_ldmatrix_32x16_layout, ) +from tvm.testing import env from tvm.testing.tir import mma_schedule M = 4096 @@ -183,7 +185,8 @@ def run_test( return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c) -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_f16f16f32_m16n16k16(): def index_map(i, j): return ( @@ -240,7 +243,8 @@ def index_map(i, j): print("f16f16f32_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean))) -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_f16f16f16_m16n16k16(): def index_map(i, j): return ( @@ -297,7 +301,8 @@ def index_map(i, j): print("f16f16f16_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean))) -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_i8i8i32_m16n16k32(): def index_map_A(i, j): return ( @@ -368,7 +373,8 @@ def index_map_C(i, j): print("i8i8i32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean))) -@tvm.testing.requires_cuda_compute_version(8, 9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8, 9), reason="need cuda compute >= 8.9") def test_e4m3e4m3f32_m16n16k32(): def index_map_A(i, j): return ( @@ -411,7 +417,8 @@ def index_map_C(i, j): print("e4m3e4m3f32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean))) -@tvm.testing.requires_cuda_compute_version(8, 9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8, 9), reason="need cuda compute >= 8.9") def test_e5m2e5m2f32_m16n16k32(): def index_map_A(i, j): return ( diff --git a/tests/python/s_tir/schedule/test_tir_schedule_tensorize_mfma_numeric.py b/tests/python/s_tir/schedule/test_tir_schedule_tensorize_mfma_numeric.py index ec330adab2e9..5113218119c6 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_tensorize_mfma_numeric.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_tensorize_mfma_numeric.py @@ -17,6 +17,7 @@ # pylint: disable=missing-docstring # ruff: noqa: E501 import numpy as np +import pytest import tvm import tvm.testing @@ -41,6 +42,7 @@ shared_16x16_to_local_64x4_layout_B, shared_16x16_to_local_64x4_layout_C, ) +from tvm.testing import env from tvm.testing.tir import mfma_schedule M = 1024 @@ -160,7 +162,8 @@ def run_test( return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c) -@tvm.testing.requires_matrixcore +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_matrixcore(), reason="need matrixcore") def test_i8i8i32_m16n16k16(): def index_map_A(i, j): return ( @@ -210,7 +213,8 @@ def index_map_C(i, j): print("test_i8i8i32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean))) -@tvm.testing.requires_matrixcore +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_matrixcore(), reason="need matrixcore") def test_f16f16f32_m16n16k16(): def index_map_A(i, j): return ( @@ -260,7 +264,8 @@ def index_map_C(i, j): print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean))) -@tvm.testing.requires_matrixcore +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_matrixcore(), reason="need matrixcore") def test_f32f32f32_m16n16k4(): def index_map_A(i, j): return ( diff --git a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py index 428fb9b89565..1ed69262a464 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py @@ -25,6 +25,7 @@ from tvm import s_tir from tvm.script import ir as I from tvm.script import tirx as T +from tvm.testing import env def count_cp_async(stmt): @@ -125,7 +126,8 @@ def ptx_global_to_shared_dyn_copy_fp16x8( C[tx, i] = A_shared[tx, i] + B_shared[tx, i] -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_inject_async_copy(): for dtype, vec_size in [("float16", 8), ("float16", 4), ("float32", 4), ("float32", 1)]: if vec_size == 1: @@ -157,7 +159,8 @@ def test_inject_async_copy(): tvm.testing.assert_allclose(B_nd.numpy(), A_np) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_inject_async_copy_shared_dyn(): f = ptx_global_to_shared_dyn_copy_fp16x8 @@ -350,7 +353,8 @@ def tvm_callback_cuda_postproc(code, _): tvm.register_global_func(func_name, prev_postproc, override=True) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cp_async_in_if_then_else(postproc_if_missing_async_support): @T.prim_func(s_tir=True) def simple_compute( @@ -411,7 +415,8 @@ def simple_compute( "This bug should be addressed. See discussion in https://github.com/apache/tvm/pull/16769 " "and https://github.com/apache/tvm/pull/16569#issuecomment-1992720448" ) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_vectorize_cp_async_in_if_then_else(postproc_if_missing_async_support): @T.prim_func(s_tir=True) def complex_compute( diff --git a/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py index 1ba1b4839d0a..183505d010a4 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py @@ -34,6 +34,7 @@ shared_16x16_to_ldmatrix_32x8_layout, ) from tvm.script import tirx as T +from tvm.testing import env from tvm.testing.tir import mma_schedule @@ -1547,7 +1548,8 @@ def build_and_run(sch): tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_async_pipelined_mma_gemm_simple(): sch = get_mma_schedule() @@ -1588,7 +1590,8 @@ def test_async_pipelined_mma_gemm_simple(): build_and_run(sch) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_async_nested_pipeline_mma_gemm_ideal_annotation(): sch = get_mma_schedule() diff --git a/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py index 1afe7028b9d3..c17ce80cb7eb 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py @@ -15,10 +15,13 @@ # specific language governing permissions and limitations # under the License. # ruff: noqa: F401, F821, F841 +import pytest + import tvm import tvm.testing from tvm import s_tir from tvm.script import tirx as T +from tvm.testing import env def run_passes(func: tvm.tirx.PrimFunc): @@ -34,7 +37,8 @@ def run_passes(func: tvm.tirx.PrimFunc): return tvm.s_tir.transform.ThreadSync("shared")(mod) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_sync_read_thread_id_independent_location(): @T.prim_func(check_well_formed=False, s_tir=True) def func(p0_arg: T.Buffer((1, 2, 1, 1), "float32"), p1: T.Buffer(2, "float32")) -> None: @@ -98,7 +102,8 @@ def expected(A: T.Buffer((4, 4), "float32"), E: T.Buffer((4, 4), "float32")): tvm.ir.assert_structural_equal(mod["main"], expected) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_sync_bind(): @T.prim_func(private=True, s_tir=True) def func(A: T.Buffer((16 * 512), "float32")): diff --git a/tests/python/target/test_arm_target.py b/tests/python/target/test_arm_target.py index 6b8e7c3a229f..6b964f13fe30 100644 --- a/tests/python/target/test_arm_target.py +++ b/tests/python/target/test_arm_target.py @@ -26,6 +26,7 @@ import tvm from tvm.script import tirx as T from tvm.target import codegen +from tvm.testing import env llvm_version, arm_target, input_dtype, kernel_dtype, is_supported = tvm.testing.parameters( # Testing mcpu type @@ -107,7 +108,7 @@ def sve_device_vector_length(): return int(out) -@tvm.testing.requires_aarch64_sve +@pytest.mark.skipif(not env.has_aarch64_sve(), reason="need aarch64 sve") def test_scalable_div(sve_device_vector_length): np.random.seed(0) target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]} @@ -128,7 +129,7 @@ def my_func(a: T.handle): tvm.testing.assert_allclose(A_nd.numpy()[0], ref) -@tvm.testing.requires_aarch64_sve +@pytest.mark.skipif(not env.has_aarch64_sve(), reason="need aarch64 sve") def test_scalable_buffer_load_store(sve_device_vector_length): np.random.seed(0) target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]} @@ -153,7 +154,7 @@ def my_func(a: T.handle, b: T.handle): tvm.testing.assert_allclose(B_nd.numpy(), A_np) -@tvm.testing.requires_aarch64_sve +@pytest.mark.skipif(not env.has_aarch64_sve(), reason="need aarch64 sve") def test_scalable_loop_bound(sve_device_vector_length): np.random.seed(0) @@ -181,7 +182,7 @@ def my_func(a: T.handle, b: T.handle): tvm.testing.assert_allclose(B_nd.numpy(), A_np) -@tvm.testing.requires_aarch64_sve +@pytest.mark.skipif(not env.has_aarch64_sve(), reason="need aarch64 sve") def test_scalable_broadcast(sve_device_vector_length): target = {"kind": "llvm", "mtriple": "aarch64-linux-gnu", "mattr": ["+sve"]} num_elements = sve_device_vector_length // 32 diff --git a/tests/python/target/test_target_target.py b/tests/python/target/test_target_target.py index fc6cf209d33a..2236505d6050 100644 --- a/tests/python/target/test_target_target.py +++ b/tests/python/target/test_target_target.py @@ -22,6 +22,7 @@ import tvm import tvm.testing from tvm.target import Target +from tvm.testing import env def test_all_targets_device_type_verify(): @@ -325,7 +326,8 @@ def test_target_features(): assert not target_with_features.features.is_missing -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("input_device", ["cuda", tvm.cuda()]) def test_target_from_device_cuda(input_device): target = Target.from_device(input_device) @@ -338,7 +340,8 @@ def test_target_from_device_cuda(input_device): assert str(target.attrs.get("arch", "")) == "sm_" + dev.compute_version.replace(".", "") -@tvm.testing.requires_rocm +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_rocm(), reason="need rocm") @pytest.mark.parametrize("input_device", ["rocm", tvm.rocm()]) def test_target_from_device_rocm(input_device): target = Target.from_device(input_device) @@ -351,7 +354,8 @@ def test_target_from_device_rocm(input_device): assert int(target.attrs["thread_warp_size"]) == dev.warp_size -@tvm.testing.requires_vulkan +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_vulkan(), reason="need vulkan") @pytest.mark.parametrize("input_device", ["vulkan", tvm.vulkan()]) def test_target_from_device_vulkan(input_device): target = Target.from_device(input_device) @@ -370,7 +374,8 @@ def test_target_from_device_vulkan(input_device): ) -@tvm.testing.requires_opencl +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_opencl(), reason="need opencl") @pytest.mark.parametrize("input_device", ["opencl", tvm.opencl()]) def test_target_from_device_opencl(input_device): target = Target.from_device(input_device) diff --git a/tests/python/testing/test_env.py b/tests/python/testing/test_env.py new file mode 100644 index 000000000000..bba6d983c7a1 --- /dev/null +++ b/tests/python/testing/test_env.py @@ -0,0 +1,205 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for the thin ``tvm.testing.env`` capability probes.""" + +import pytest + +import tvm +import tvm.testing +from tvm.testing import env + +# Probes that take no arguments and must return a plain bool without raising. +_BOOL_PROBES = [ + # runtime device + env.has_cuda, + env.has_rocm, + env.has_vulkan, + env.has_metal, + env.has_opencl, + env.has_nvptx, + env.has_llvm, + env.has_gpu, + # build support + env.has_cudnn, + env.has_cublas, + env.has_nccl, + env.has_hipblas, + env.has_cutlass, + env.has_rpc, + env.has_nnapi, + env.has_openclml, + env.has_mrvl, + env.has_nvshmem, + # version / capability + env.has_tensorcore, + env.has_matrixcore, + env.has_cudagraph, + # toolchain / environment + env.has_hexagon, + env.has_hexagon_toolchain, + env.has_adreno_opencl, + env.has_aprofile_aem_fvp, + # cpu features + env.has_arm_dot, + env.has_arm_fp16, + env.has_aarch64_sve, + env.has_aarch64_sme, + env.has_x86_vnni, + env.has_x86_avx512, + env.has_x86_amx, + # host architecture + env.is_x86, + env.is_aarch64, +] + + +@pytest.mark.parametrize("probe", _BOOL_PROBES, ids=lambda p: p.__name__) +def test_probe_returns_bool(probe): + """Every probe returns a real bool and never raises during collection/run.""" + assert isinstance(probe(), bool) + + +def test_has_cuda_implies_device(): + """has_cuda() requires a device (it also requires the kind to be enabled).""" + if env.has_cuda(): + assert tvm.cuda().exist + + +def test_has_gpu_is_raw_any_device(): + """has_gpu() is the disjunction of the raw device checks (no target gating).""" + any_device = ( + env._device_exists("cuda") # pylint: disable=protected-access + or env._device_exists("rocm") # pylint: disable=protected-access + or env._device_exists("opencl") # pylint: disable=protected-access + or env._device_exists("metal") # pylint: disable=protected-access + or env._device_exists("vulkan") # pylint: disable=protected-access + ) + assert env.has_gpu() == any_device + + +def test_target_enabled_respects_tvm_test_targets(monkeypatch): + """A device kind excluded from TVM_TEST_TARGETS is reported as not enabled.""" + env._target_enabled.cache_clear() # pylint: disable=protected-access + monkeypatch.setenv("TVM_TEST_TARGETS", "cuda;llvm") + try: + assert env._target_enabled("cuda") # pylint: disable=protected-access + assert env._target_enabled("llvm") # pylint: disable=protected-access + assert not env._target_enabled("opencl") # pylint: disable=protected-access + assert not env._target_enabled("metal") # pylint: disable=protected-access + finally: + env._target_enabled.cache_clear() # pylint: disable=protected-access + + +def test_tensorcore_implies_cuda(): + """Tensor Core support cannot be reported without a CUDA device.""" + if env.has_tensorcore(): + assert env.has_cuda() + + +def test_cudagraph_implies_cuda(): + """CUDA Graph support cannot be reported without a CUDA device.""" + if env.has_cudagraph(): + assert env.has_cuda() + + +def test_cuda_compute_is_monotonic(): + """has_cuda_compute is monotone in the requested version.""" + if not env.has_cuda(): + # Without a CUDA device every query is False, including the (0, 0) floor. + assert not env.has_cuda_compute(1, 0) + assert not env.has_cuda_compute(0, 0) + return + # A device that satisfies (major, minor) also satisfies anything lower. + assert env.has_cuda_compute(1, 0) + assert env.has_cuda_compute(0, 0) + + +def test_has_multi_gpu_is_bool(): + assert isinstance(env.has_multi_gpu(), bool) + assert isinstance(env.has_multi_gpu(1), bool) + # Requiring a single device is at least as permissive as requiring two. + assert env.has_multi_gpu(1) or not env.has_multi_gpu(2) + + +@pytest.mark.parametrize( + "probe,flag", + [ + (env.has_cutlass, "USE_CUTLASS"), + (env.has_rpc, "USE_RPC"), + (env.has_nnapi, "USE_NNAPI_CODEGEN"), + (env.has_openclml, "USE_CLML"), + (env.has_mrvl, "USE_MRVL"), + ], + ids=lambda v: getattr(v, "__name__", v), +) +def test_build_flag_probe_matches_libinfo(probe, flag): + """Pure build-flag probes agree with the build-info flag they wrap.""" + assert probe() == env._build_flag_enabled(flag) # pylint: disable=protected-access + + +@pytest.mark.parametrize( + "probe,parent", + [ + (env.has_cudnn, env.has_cuda), + (env.has_cublas, env.has_cuda), + (env.has_nccl, env.has_cuda), + (env.has_hipblas, env.has_rocm), + ], + ids=lambda v: v.__name__, +) +def test_library_probe_implies_parent_device(probe, parent): + """A CUDA/ROCm library cannot be reported without its parent device.""" + if probe(): + assert parent() + + +def test_llvm_min_version_is_monotone(): + if not env.has_llvm(): + assert not env.has_llvm_min_version(1) + return + # An LLVM that satisfies a higher floor also satisfies a lower one. + assert env.has_llvm_min_version(1) + + +def test_hexagon_run_implies_toolchain(): + """Full Hexagon support implies the compile-time toolchain is present.""" + if env.has_hexagon(): + assert env.has_hexagon_toolchain() + + +def test_probes_are_memoized(): + """Probes are cached so the driver/subprocess is hit once per process.""" + env.has_cuda() + info = env._device_exists.cache_info() # pylint: disable=protected-access + assert info.hits + info.misses >= 1 + + +# --- demonstration of the target idiom ------------------------------------- +# +# The standard gating idiom: a plain registered pytest marker (for ``-m`` +# selection) plus a skipif backed by a thin env probe (for runtime gating). + + +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") +def test_thin_cuda_idiom(): + dev = tvm.cuda() + assert dev.exist + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tirx-base/test_tir_imm_values.py b/tests/python/tirx-base/test_tir_imm_values.py index bf0002bea4a7..a7985920bd7c 100644 --- a/tests/python/tirx-base/test_tir_imm_values.py +++ b/tests/python/tirx-base/test_tir_imm_values.py @@ -25,6 +25,7 @@ import tvm.testing from tvm import tirx from tvm.script import tirx as T +from tvm.testing import env @pytest.mark.parametrize( @@ -146,7 +147,7 @@ def test_tir_special_floatimms(dtype, literal): compare_float_value(x.value, literal, "imm value should match feed value") -@tvm.testing.requires_llvm() +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_tir_too_large_literal_f64(): # Behavior check: if literal f64 value is out of dtype range, the # object is still constructed, and eval to infinity. @@ -256,7 +257,7 @@ def check_tir_const_fold( assert expect == calc_res, flaky_msg -@tvm.testing.requires_llvm() +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_tir_floatimm_const_fold(): """Behavior check: folding fp32 match platform f32 arithmetic""" @@ -314,7 +315,7 @@ def _func(x, y): ) -@tvm.testing.requires_llvm() +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_tir_int8_const_fold(): """Behavior check: folding i8 operation match platform i8 arithmetic""" @@ -370,7 +371,7 @@ def imm_floordiv(x: T.int8, y: T.int8) -> T.int8: ) -@tvm.testing.requires_llvm() +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_tir_uint8_const_fold(): """Behavior check: folding u8 operation match platform u8 arithmetic""" @@ -433,7 +434,7 @@ def imm_floordiv(x: T.uint8, y: T.uint8) -> T.uint8: ) -@tvm.testing.requires_llvm() +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_tir_int32_const_fold(): """Behavior check: folding i32 operation match platform i32 arithmetic""" @@ -521,7 +522,7 @@ def imm_floormod(x: T.int32, y: T.int32) -> T.int32: ) -@tvm.testing.requires_llvm() +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_tir_uint32_const_fold(): """Behavior check: folding u32 operation match platform u32 arithmetic""" diff --git a/tests/python/tirx-base/test_tir_ptx_cp_async.py b/tests/python/tirx-base/test_tir_ptx_cp_async.py index 4585329daeb1..a2a4453a57c0 100644 --- a/tests/python/tirx-base/test_tir_ptx_cp_async.py +++ b/tests/python/tirx-base/test_tir_ptx_cp_async.py @@ -16,10 +16,12 @@ # under the License. import numpy as np +import pytest import tvm import tvm.testing from tvm.script import tirx as T +from tvm.testing import env @T.prim_func(s_tir=True) @@ -49,7 +51,8 @@ def ptx_cp_async(A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), "floa B[tx, i] = A_shared[tx, i] -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_ptx_cp_async(): f = ptx_cp_async diff --git a/tests/python/tirx-base/test_tir_ptx_griddepcontrol.py b/tests/python/tirx-base/test_tir_ptx_griddepcontrol.py index 59d9d460e519..11c418721983 100644 --- a/tests/python/tirx-base/test_tir_ptx_griddepcontrol.py +++ b/tests/python/tirx-base/test_tir_ptx_griddepcontrol.py @@ -16,10 +16,12 @@ # under the License. import numpy as np +import pytest import tvm import tvm.testing from tvm.script import tirx as T +from tvm.testing import env @T.prim_func(s_tir=True) @@ -37,7 +39,8 @@ def ptx_griddepcontrol(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float3 T.evaluate(T.ptx.griddepcontrol.launch_dependents(dtype="")) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_ptx_griddepcontrol(): f = ptx_griddepcontrol mod = tvm.compile(f, target="cuda") diff --git a/tests/python/tirx-base/test_tir_ptx_ldmatrix.py b/tests/python/tirx-base/test_tir_ptx_ldmatrix.py index 2f4cf58832e6..4f0351a17767 100644 --- a/tests/python/tirx-base/test_tir_ptx_ldmatrix.py +++ b/tests/python/tirx-base/test_tir_ptx_ldmatrix.py @@ -16,10 +16,12 @@ # under the License. import numpy as np +import pytest import tvm import tvm.testing from tvm.script import tirx as T +from tvm.testing import env @T.prim_func(s_tir=True) @@ -57,7 +59,8 @@ def ptx_ldmatrix( B[8 * j + tx // 4, 8 * k + (tx % 4) * 2 + i] = A_local[4 * k + 2 * j + i] -@tvm.testing.requires_cuda_compute_version(7, 5) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(7, 5), reason="need cuda compute >= 7.5") def test_ptx_ldmatrix(): f = ptx_ldmatrix _, _, param_num, param_trans = f.params diff --git a/tests/python/tirx-base/test_tir_ptx_mma.py b/tests/python/tirx-base/test_tir_ptx_mma.py index 9c1a83224172..475632cad91f 100644 --- a/tests/python/tirx-base/test_tir_ptx_mma.py +++ b/tests/python/tirx-base/test_tir_ptx_mma.py @@ -16,10 +16,12 @@ # under the License. import numpy as np +import pytest import tvm import tvm.testing from tvm.script import tirx as T +from tvm.testing import env @T.prim_func(s_tir=True) @@ -64,7 +66,8 @@ def gemm_mma_m8n8k4_row_col_fp64pf64fp64(a: T.handle, b: T.handle, c: T.handle): C[(tx % 32) // 4, (tx % 32) % 4 * 2 + mma_accum_c_id] = Accum[mma_accum_c_id] -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_gemm_mma_m8n8k4_row_col_fp64pf64fp64(): sch = tvm.s_tir.Schedule(gemm_mma_m8n8k4_row_col_fp64pf64fp64) cuda_mod = tvm.compile(sch.mod, target="cuda") @@ -140,7 +143,8 @@ def gemm_mma_m8n8k4_row_row_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle): ] = Accum[mma_accum_c_id] -@tvm.testing.requires_cuda_compute_version(7) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(7), reason="need cuda compute >= 7.0") def test_gemm_mma_m8n8k4_row_row_fp16fp16fp16(): sch = tvm.s_tir.Schedule(gemm_mma_m8n8k4_row_row_fp16fp16fp16) cuda_mod = tvm.compile(sch.mod, target="cuda") @@ -223,7 +227,8 @@ def gemm_mma_m8n8k4_row_row_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): ] = Accum[mma_accum_c_id] -@tvm.testing.requires_cuda_compute_version(7) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(7), reason="need cuda compute >= 7.0") def test_gemm_mma_m8n8k4_row_row_fp16fp16fp32(): sch = tvm.s_tir.Schedule(gemm_mma_m8n8k4_row_row_fp16fp16fp32) cuda_mod = tvm.compile(sch.mod, target="cuda") @@ -293,8 +298,9 @@ def gemm_mma_m8n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): # This test uses mma instructions that are not available on NVCC 10.1. # Failure occurs during the external call to nvcc, when attempting to # generate the .fatbin file. -@tvm.testing.requires_nvcc_version(11) -@tvm.testing.requires_cuda_compute_version(7, 5) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_nvcc_version(11), reason="need nvcc >= 11") +@pytest.mark.skipif(not env.has_cuda_compute(7, 5), reason="need cuda compute >= 7.5") def test_gemm_mma_m8n8k16_row_col_s8s8s32(): sch = tvm.s_tir.Schedule(gemm_mma_m8n8k16_row_col_s8s8s32) cuda_mod = tvm.compile(sch.mod, target="cuda") @@ -364,8 +370,9 @@ def gemm_mma_m8n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): # This test uses mma instructions that are not available on NVCC 10.1. # Failure occurs during the external call to nvcc, when attempting to # generate the .fatbin file. -@tvm.testing.requires_nvcc_version(11) -@tvm.testing.requires_cuda_compute_version(7, 5) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_nvcc_version(11), reason="need nvcc >= 11") +@pytest.mark.skipif(not env.has_cuda_compute(7, 5), reason="need cuda compute >= 7.5") def test_gemm_mma_m8n8k16_row_col_s8u8s32(): sch = tvm.s_tir.Schedule(gemm_mma_m8n8k16_row_col_s8u8s32) cuda_mod = tvm.compile(sch.mod, target="cuda") @@ -435,8 +442,9 @@ def gemm_mma_m8n8k32_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): # This test uses mma instructions that are not available on NVCC 10.1. # Failure occurs during the external call to nvcc, when attempting to # generate the .fatbin file. -@tvm.testing.requires_nvcc_version(11) -@tvm.testing.requires_cuda_compute_version(7, 5) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_nvcc_version(11), reason="need nvcc >= 11") +@pytest.mark.skipif(not env.has_cuda_compute(7, 5), reason="need cuda compute >= 7.5") def test_gemm_mma_m8n8k32_row_col_s4s4s32(): sch = tvm.s_tir.Schedule(gemm_mma_m8n8k32_row_col_s4s4s32) cuda_mod = tvm.compile(sch.mod, target="cuda") @@ -498,8 +506,9 @@ def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): # This test uses mma instructions that are not available on NVCC 10.1. # Failure occurs during the external call to nvcc, when attempting to # generate the .fatbin file. -@tvm.testing.requires_nvcc_version(11) -@tvm.testing.requires_cuda_compute_version(7, 5) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_nvcc_version(11), reason="need nvcc >= 11") +@pytest.mark.skipif(not env.has_cuda_compute(7, 5), reason="need cuda compute >= 7.5") def test_gemm_mma_m8n8k32_row_col_s4u4s32(): sch = tvm.s_tir.Schedule(gemm_mma_m8n8k32_row_col_s4u4s32) cuda_mod = tvm.compile(sch.mod, target="cuda") @@ -564,7 +573,8 @@ def gemm_mma_m16n8k8_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle) ] -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_gemm_mma_m16n8k8_row_col_fp16fp16fp32(): sch = tvm.s_tir.Schedule(gemm_mma_m16n8k8_row_col_fp16fp16fp32) cuda_mod = tvm.compile(sch.mod, target="cuda") @@ -640,7 +650,8 @@ def gemm_mma_m16n8k16_row_col_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle ] = Accum[mma_accum_c_id] -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_gemm_mma_m16n8k16_row_col_fp16fp16fp16(): sch = tvm.s_tir.Schedule(gemm_mma_m16n8k16_row_col_fp16fp16fp16) cuda_mod = tvm.compile(sch.mod, target="cuda") @@ -716,7 +727,8 @@ def gemm_mma_m16n8k16_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle ] = Accum[mma_accum_c_id] -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_gemm_mma_m16n8k16_row_col_fp16fp16fp32(): sch = tvm.s_tir.Schedule(gemm_mma_m16n8k16_row_col_fp16fp16fp32) cuda_mod = tvm.compile(sch.mod, target="cuda") @@ -792,7 +804,8 @@ def gemm_mma_m16n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): ] = Accum[mma_accum_c_id] -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_gemm_mma_m16n8k16_row_col_s8s8s32(): sch = tvm.s_tir.Schedule(gemm_mma_m16n8k16_row_col_s8s8s32) cuda_mod = tvm.compile(sch.mod, target="cuda") @@ -868,7 +881,8 @@ def gemm_mma_m16n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): ] = Accum[mma_accum_c_id] -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_gemm_mma_m16n8k16_row_col_s8u8s32(): sch = tvm.s_tir.Schedule(gemm_mma_m16n8k16_row_col_s8u8s32) cuda_mod = tvm.compile(sch.mod, target="cuda") @@ -944,7 +958,8 @@ def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): ] = Accum[mma_accum_c_id] -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_gemm_mma_m16n8k32_row_col_s8s8s32(): sch = tvm.s_tir.Schedule(gemm_mma_m16n8k32_row_col_s8s8s32) cuda_mod = tvm.compile(sch.mod, target="cuda") @@ -1020,7 +1035,8 @@ def gemm_mma_m16n8k32_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): ] = Accum[mma_accum_c_id] -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_gemm_mma_m16n8k32_row_col_s8u8s32(): sch = tvm.s_tir.Schedule(gemm_mma_m16n8k32_row_col_s8u8s32) cuda_mod = tvm.compile(sch.mod, target="cuda") @@ -1096,7 +1112,8 @@ def gemm_mma_m16n8k64_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): ] = Accum[mma_accum_c_id] -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_gemm_mma_m16n8k64_row_col_s4s4s32(): sch = tvm.s_tir.Schedule(gemm_mma_m16n8k64_row_col_s4s4s32) cuda_mod = tvm.compile(sch.mod, target="cuda") @@ -1164,7 +1181,8 @@ def gemm_mma_m16n8k64_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): ] = Accum[mma_accum_c_id] -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_gemm_mma_m16n8k64_row_col_s4u4s32(): sch = tvm.s_tir.Schedule(gemm_mma_m16n8k64_row_col_s4u4s32) cuda_mod = tvm.compile(sch.mod, target="cuda") @@ -1233,7 +1251,8 @@ def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle): ] = Accum[mma_accum_c_id] -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_gemm_mma_m16n8k256_row_col_b1b1s32(): sch = tvm.s_tir.Schedule(gemm_mma_m16n8k256_row_col_b1b1s32) cuda_mod = tvm.compile(sch.mod, target="cuda") diff --git a/tests/python/tirx-base/test_tir_ptx_mma_sp.py b/tests/python/tirx-base/test_tir_ptx_mma_sp.py index 9286d76155a2..e924702efd9f 100644 --- a/tests/python/tirx-base/test_tir_ptx_mma_sp.py +++ b/tests/python/tirx-base/test_tir_ptx_mma_sp.py @@ -16,10 +16,12 @@ # under the License. import numpy as np +import pytest import tvm import tvm.testing from tvm.script import tirx as T +from tvm.testing import env def gen_2in4_mask(m: int, n: int): @@ -256,7 +258,8 @@ def mma_sp_m16n8k32_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata: C[i // 2 * 8 + tx // 4, tx % 4 * 2 + i % 2] = accum[i] -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_mma_sp_m16n8k16_f16(): def get_meta_m16n8k16_half(mask): assert mask.shape == (16, 4, 2) @@ -293,7 +296,8 @@ def get_meta_m16n8k16_half(mask): tvm.testing.assert_allclose(C_tvm.numpy(), C_np, atol=1e-3, rtol=1e-3) -@tvm.testing.requires_cuda_compute_version(8) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(8), reason="need cuda compute >= 8.0") def test_mma_sp_m16n8k32_f16(): def get_meta_m16n8k32_half(mask): assert mask.shape == (16, 8, 2) diff --git a/tests/python/tirx-base/test_tir_ptx_scalar_f32_math.py b/tests/python/tirx-base/test_tir_ptx_scalar_f32_math.py index a667b213b17a..98e582d874db 100644 --- a/tests/python/tirx-base/test_tir_ptx_scalar_f32_math.py +++ b/tests/python/tirx-base/test_tir_ptx_scalar_f32_math.py @@ -16,10 +16,12 @@ # under the License. import numpy as np +import pytest import tvm import tvm.testing from tvm.script import tirx as T +from tvm.testing import env @T.prim_func(s_tir=True) @@ -43,7 +45,8 @@ def ptx_scalar_f32_math( C_max[tx] = T.ptx.max_f32(A[tx], B[tx]) -@tvm.testing.requires_cuda_compute_version(7) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(7), reason="need cuda compute >= 7.0") def test_ptx_scalar_f32_math(): f = ptx_scalar_f32_math mod = tvm.compile(f, target="cuda") diff --git a/tests/python/tirx-transform/test_tir_transform_lower_intrin.py b/tests/python/tirx-transform/test_tir_transform_lower_intrin.py index 30ead37c841b..cfb271de9e2d 100644 --- a/tests/python/tirx-transform/test_tir_transform_lower_intrin.py +++ b/tests/python/tirx-transform/test_tir_transform_lower_intrin.py @@ -16,9 +16,11 @@ # under the License. # ruff: noqa: RUF005 import numpy as np +import pytest import tvm import tvm.testing +from tvm.testing import env def lower_intrin(params, stmt): @@ -94,7 +96,7 @@ def get_ref_data(): return list(itertools.product(x, y)) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_lower_floordiv(): data = get_ref_data() for dtype in ["int32", "int64", "int16"]: @@ -128,7 +130,7 @@ def test_lower_floordiv(): check_value(res, [x, y], [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) // b) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_lower_floormod(): data = get_ref_data() for dtype in ["int32", "int64", "int16"]: @@ -157,7 +159,7 @@ def test_lower_floormod(): check_value(res, [x, y], [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) % b) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_lower_floordiv_overflow_checks(): """ Regression tests for overflow checks in TryFindShiftCoefficientForPositiveRange. diff --git a/tests/python/tirx-transform/test_tir_transform_lower_tvm_builtin.py b/tests/python/tirx-transform/test_tir_transform_lower_tvm_builtin.py index bbd6df01eff4..22a144175c8f 100644 --- a/tests/python/tirx-transform/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/tirx-transform/test_tir_transform_lower_tvm_builtin.py @@ -22,6 +22,7 @@ import tvm.testing from tvm.script import ir as I from tvm.script import tirx as T +from tvm.testing import env @tvm.register_global_func("tvm.test_matmul") @@ -112,7 +113,7 @@ def main( tvm.ir.assert_structural_equal(After, Expected) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_call_packed_return_non_i32(): # This call packed that return non i32 types expected_value = np.array([1.2, 1.4], dtype="float32") diff --git a/tests/python/tirx/codegen/test_codegen_ampere.py b/tests/python/tirx/codegen/test_codegen_ampere.py index f0c8911cd9b4..8bb7dd79c6ce 100644 --- a/tests/python/tirx/codegen/test_codegen_ampere.py +++ b/tests/python/tirx/codegen/test_codegen_ampere.py @@ -35,6 +35,7 @@ import tvm import tvm.testing from tvm.script import tirx as T +from tvm.testing import env DEV = tvm.device("cuda") @@ -70,7 +71,8 @@ def _run_mma(mod, K, no_c_ptr, np_in): np.testing.assert_allclose(D.numpy(), ref, atol=1e-2, rtol=1e-2) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("a_type", ["float16", "bfloat16"]) @pytest.mark.parametrize("no_c_ptr", [False, True]) def test_ptx_mma_m16n8k16(a_type, no_c_ptr): @@ -140,7 +142,8 @@ def G2L(buf_local, buf_global, block_8x8, mode="row"): _run_mma(mod, 16, no_c_ptr, _np_in(a_type)) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("a_type", ["float16", "bfloat16"]) @pytest.mark.parametrize("no_c_ptr", [False, True]) def test_ptx_mma_m16n8k8(a_type, no_c_ptr): diff --git a/tests/python/tirx/codegen/test_codegen_blackwell.py b/tests/python/tirx/codegen/test_codegen_blackwell.py index f6c526a2a193..61348ca48e61 100644 --- a/tests/python/tirx/codegen/test_codegen_blackwell.py +++ b/tests/python/tirx/codegen/test_codegen_blackwell.py @@ -22,6 +22,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env def _get_source(func: tvm.tirx.PrimFunc) -> str: @@ -32,7 +33,8 @@ def _get_source(func: tvm.tirx.PrimFunc) -> str: return src, mod -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_tmem_alloc_dealloc_relinquish(): N_COLS = 512 cta_group = 1 @@ -67,7 +69,8 @@ def test_tmem(A: T.Buffer((16, 16), "float16")): assert f"tcgen05.relinquish_alloc_permit.cta_group::{cta_group}.sync.aligned" in src -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_mbarrier_try_wait_once_codegen(): # fmt: off @T.prim_func @@ -86,7 +89,8 @@ def test_try_wait_once(A: T.Buffer((16, 16), "float16")): assert "selp.u32" in src -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_fence_before_after_thread_sync(): # fmt: off @T.prim_func @@ -108,7 +112,8 @@ def test_fence(A: T.Buffer((16, 16), "float16")): assert "tcgen05.fence::before_thread_sync" in src -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_tcgen05_ld_st_roundtrip(): HEIGHT = 128 WIDTH = 256 @@ -173,7 +178,8 @@ def test_ld_st(A: T.Buffer((HEIGHT, WIDTH), "float32"), B: T.Buffer((HEIGHT, WID np.testing.assert_allclose(A.numpy(), B.numpy()) -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_tcgen05_cp_ld_roundtrip(): dtype = "float32" dtype_bits = tvm.DataType(dtype).bits @@ -254,7 +260,8 @@ def test_cp_ld(A: T.Buffer((HEIGHT, WIDTH), dtype, layout=T.TileLayout(T.S[(HEIG @pytest.mark.parametrize("swizzle", [0, 1, 2, 3]) -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_tcgen05_mma_ss_no_tma(swizzle): d_type, a_type, b_type = "float32", "float16", "float16" M, N, K = 128, 128, 64 diff --git a/tests/python/tirx/codegen/test_codegen_hopper.py b/tests/python/tirx/codegen/test_codegen_hopper.py index 8f14dfc3c22d..38e1f30cfbbc 100644 --- a/tests/python/tirx/codegen/test_codegen_hopper.py +++ b/tests/python/tirx/codegen/test_codegen_hopper.py @@ -23,6 +23,7 @@ import tvm import tvm.testing from tvm.script import tirx as T +from tvm.testing import env from tvm.tirx import Buffer @@ -57,7 +58,8 @@ def main(A_ptr: T.handle): @pytest.mark.parametrize("inc", [False, True]) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_ptx_setmaxnreg(inc): # fmt: off @T.prim_func @@ -77,7 +79,8 @@ def func(A: T.Buffer(1)): @pytest.mark.parametrize("trans", [False, True]) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_stmatrix_sync_aligned(trans): # fmt: off @T.prim_func @@ -199,7 +202,8 @@ def main(A: T.Buffer((16, 16), "float16")): @pytest.mark.parametrize("trans", [False, True]) @pytest.mark.parametrize("num", [1, 2, 4]) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_ptx_stmatrix_noncontiguous(trans, num): """Symmetric stmatrix API: ``num`` independent src handles. @@ -267,7 +271,8 @@ def main(A: T.Buffer((16, 16), "float16")): np.testing.assert_allclose(A.numpy(), A_ref) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_bar_arrive(): # fmt: off @T.prim_func @@ -283,7 +288,8 @@ def func(A: T.Buffer(1)): assert 'bar.arrive %0, %1;" : : "r"(name_bar_id), "r"(thread_count) : "memory"' in src -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_bar_sync(): # fmt: off @T.prim_func @@ -299,7 +305,8 @@ def func(A: T.Buffer(1)): assert 'bar.sync %0, %1;" : : "r"(name_bar_id), "r"(thread_count) : "memory"' in src -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_fence_mbarrier_init_release_clsuter(): # fmt: off @T.prim_func @@ -314,7 +321,8 @@ def func(A: T.Buffer(1)): assert "fence.mbarrier_init.release.cluster" in src -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_ptx_elect_sync(): # fmt: off @T.prim_func @@ -331,7 +339,8 @@ def func(A: T.Buffer(1)): assert "elect.sync %%rx|%%px, %2;" in src -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize("sem,scope", [("sc", "cta"), ("acq_rel", "gpu"), ("sc", "sys")]) def test_ptx_fence(sem, scope): # fmt: off @@ -347,7 +356,8 @@ def func(A: T.Buffer(1)): assert f"fence.{sem}.{scope};" in src -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_fence_proxy_async(): # fmt: off @T.prim_func @@ -365,7 +375,8 @@ def func(A: T.Buffer(1)): assert "fence.proxy.async.shared::cta" in src -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize("dtype", ["float16", "float32", "float8_e4m3fn", "float8_e5m2"]) @pytest.mark.parametrize( "inputs", @@ -449,7 +460,8 @@ def get_np_dtype(dtype): assert np.allclose(A.numpy().astype("float32"), B.numpy().astype("float32")) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize( ("shape", "dtype", "encode_args", "error_msg"), [ @@ -525,7 +537,8 @@ def test_tensormap_encode_tiled_runtime_validation(shape, dtype, encode_args, er @pytest.mark.parametrize("swizzle", [1, 2, 3]) @pytest.mark.parametrize("dtype", ["uint8", "float16", "float32"]) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_cp_async_bulk_tensor_global_to_shared_swizzle(swizzle, dtype): def get_ir(swizzle, dtype): dtype = tvm.DataType(dtype) @@ -623,7 +636,8 @@ def main(A_ptr: T.handle, B_ptr: T.handle): ), ], ) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_cp_async_bulk_tensor_global_to_shared_multicast1(inputs): # 1 CTA does the copy, and then multicast to all CTAs in the cluster def get_ir(shape, tma_args): @@ -697,7 +711,8 @@ def main(A_ptr: T.handle, B_ptr: T.handle): ((16, 16, 4), [16, 16, 4, 64, 64 * 16, 16, 16, 1, 1, 1, 1, 0, 0, 0, 0]), ], ) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_cp_async_bulk_tensor_global_to_shared_multicast2(inputs): # 4 CTAs in the cluster do the copy of separate chunks, and then multicast to all CTAs in the cluster # noqa: E501 def get_ir(shape, tma_args): @@ -787,7 +802,8 @@ def main(A_ptr: T.handle, B_ptr: T.handle): ((16, 16, 4), [16, 16, 4, 64, 64 * 16, 16, 16, 4, 1, 1, 1, 0, 0, 0, 0]), ], ) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_cp_async_bulk_tensor_shared_to_global(inputs): def get_ir(shape, tma_args): assert shape[0] % 4 == 0 @@ -839,7 +855,8 @@ def main(A_ptr: T.handle): np.testing.assert_allclose(A.numpy(), A_ref) -@tvm.testing.requires_cuda_compute_version(9, exact=True) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9, exact=True), reason="need cuda compute == 9.0") def test_wgmma_ss_nt(): def get_ir( shapeA, @@ -994,7 +1011,8 @@ def main(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle): tvm.testing.assert_allclose(C_tvm.numpy(), C_ref, rtol=1e-3, atol=1e-3) -@tvm.testing.requires_cuda_compute_version(9, exact=True) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9, exact=True), reason="need cuda compute == 9.0") def test_wgmma_rs_nt(): def get_ir( shapeA, shapeB, shapeC, B_tma_args, in_dtype, in_dtype_bits, out_dtype, B_encode_args @@ -1150,7 +1168,8 @@ def main(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle): tvm.testing.assert_allclose(C_tvm.numpy(), C_ref, rtol=1e-3, atol=1e-3) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_ptx_map_shared_rank(): @T.prim_func def func(A: T.Buffer(1)): diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py index 676d8d95ae5f..dc5a46a751ec 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py @@ -28,6 +28,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import ComposeLayout, S, SwizzleLayout, TileLayout @@ -509,7 +510,8 @@ def test_layout_permute_copy_preserves_smem_strides(): # recognizer accepts, and emit lowers to the # ``base_off + sum_j bit_j(f) · signed_strides[j]`` precomputed form. # ---------------------------------------------------------------------------- -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_gmem_smem_swizzle_fast_path_fires_with_var_bounds(): """Warp-scope 32x64 fp16 G2S/S2G with 128b swizzled SMEM. Fast path must fire: a 3-slot ``v_[]`` signed_strides buffer + bit-select adds diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_ld_stmatrix.py b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_ld_stmatrix.py index fc62806c9bf6..4c51c9535e5b 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_ld_stmatrix.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_ld_stmatrix.py @@ -39,6 +39,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import ComposeLayout, S, SwizzleLayout, TileLayout, laneid, tid_in_wg, tx @@ -319,7 +320,8 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: @pytest.mark.parametrize("trans", [False, True]) @pytest.mark.parametrize("direction", ["ld", "st"]) @pytest.mark.parametrize("num", [1, 2, 4]) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_ldstmatrix(scope, trans, direction, num): kernel, (M, N) = _BUILDERS[scope](num, direction, trans) compiled, src = _compile_src(kernel) @@ -349,7 +351,8 @@ def test_ldstmatrix(scope, trans, direction, num): @pytest.mark.parametrize("trans", [False, True]) @pytest.mark.parametrize("direction", ["ld", "st"]) @pytest.mark.parametrize("num", [1, 2, 4]) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_ldstmatrix_swizzle(scope, trans, direction, num): kernel, (M, N) = _BUILDERS[scope](num, direction, trans, swizzle=True) compiled, src = _compile_src(kernel) @@ -424,7 +427,8 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: return kernel, shape -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_ldstmatrix_swizzle_multi_iter_pow2(): """32x64 fp16 warp; outer m_outer split into multiple BitIters (no LinearIter). Fast path must fire with a 3-slot signed_strides buffer.""" @@ -453,7 +457,8 @@ def test_ldstmatrix_swizzle_multi_iter_pow2(): np.testing.assert_allclose(B.numpy(), A_np) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_ldstmatrix_swizzle_multi_iter_linear(): """40x64 fp16 warp; outer ext=5 is non-pow2 but stride lands on swizzle period (Case 1.D pure) so the LinearIter relaxation fires. Pattern has diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_dsmem.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_dsmem.py index 5493fe0c28e6..27bf74ed4082 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_dsmem.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_dsmem.py @@ -30,6 +30,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx import IntImm, Var from tvm.tirx.cuda.operator.tile_primitive.copy_async.dsmem import copy_dsmem_impl from tvm.tirx.exec_scope import ExecScope @@ -122,7 +123,8 @@ def _layout_physical_elements(layout): return max_offset + 1 -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize("shape,dtype,src_spec,dst_spec,expected", DSMEM_CONFIGS) def test_dsmem(shape, dtype, src_spec, dst_spec, expected): """Dispatch assertion + GPU correctness for DSMEM copy. diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_smem_tmem.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_smem_tmem.py index 3cdf31efd864..84f23cf8eaa2 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_smem_tmem.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_smem_tmem.py @@ -31,6 +31,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.cuda.operator.tile_primitive.tma_utils import SwizzleMode, mma_shared_layout from tvm.tirx.layout import R, S, TCol, TileLayout, TLane @@ -219,7 +220,8 @@ def _execute(kernel, A_init, expected): ) -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize( "name,s_full,s_full_shape,s_region", [ @@ -276,7 +278,8 @@ def test_single_cp(name, s_full, s_full_shape, s_region): _run_2d(s_full, T_LAY_BASIC, s_full_shape, s_region, "uint8", A_np, expected) -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_multi_cp_sw0_4tiles(): s_full = TileLayout(S[(4, 32, 16) : (512, 16, 1)]) t_full = TileLayout(S[(4, 32, 16) : (16 @ TCol, 1 @ TLane, 1 @ TCol)] + R[4 : 32 @ TLane]) @@ -285,7 +288,8 @@ def test_multi_cp_sw0_4tiles(): _run_3d_4tile(s_full, t_full, [4, 32, 16], "uint8", A_np, expected) -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_align_middle_2_to_1_nvfp4_sfb(): """SFB-style nvfp4 case: TMEM mid canonicalizes to single iter (16@TCol + 4@TCol merge), but SMEM mid stays as 2 iters @@ -394,7 +398,8 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle): _execute(kernel, A_np, expected) -@tvm.testing.requires_cuda_compute_version(10) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize( "bad", [ diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tma.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tma.py index 1b0455e27234..3e9cb455b039 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tma.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tma.py @@ -26,6 +26,7 @@ from tvm.ir.type import TensorMapType from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx import IntImm, StringImm, Var from tvm.tirx.cuda.operator.tile_primitive.tma_utils import ( mma_atom_layout, @@ -1046,7 +1047,8 @@ def test_copy_tma_codegen(case): # =========================================================================== -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize("swizzle_len", [3]) @pytest.mark.parametrize("dtype", ["float16"]) def test_copy_tma_symbolic_dimension(dtype, swizzle_len): @@ -1143,7 +1145,8 @@ def copy_async(A_ptr: T.handle, B_ptr: T.handle) -> None: np.testing.assert_allclose(B_ref, B.numpy()) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize("swizzle_len", [3]) @pytest.mark.parametrize("dtype", ["float16"]) def test_copy_tma_3d_with_view(dtype, swizzle_len): @@ -1248,7 +1251,8 @@ def copy_async(Q_ptr: T.handle, B_ptr: T.handle) -> None: # =========================================================================== -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize( "task", [ @@ -1423,7 +1427,8 @@ def copy_async(A_ptr: T.handle, B_ptr: T.handle) -> None: np.testing.assert_allclose(B_ref, B.numpy()) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize("dtype", ["float16"]) def test_copy_tma_gpu_smoke_s2g(dtype): """Smoke test: compile and run TMA S2G store on GPU.""" @@ -1487,7 +1492,8 @@ def copy_async(A_ptr: T.handle, B_ptr: T.handle) -> None: np.testing.assert_allclose(A_np, B.numpy()) -@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize("dtype", ["float16"]) def test_copy_tma_dynamic_cta_mask(dtype): """Regression test for B00004: dynamic cta_mask expression in TMA multicast. diff --git a/tests/python/tirx/operator/tile_primitive/cuda/gemm/test_gemm_mma_m16n8k_.py b/tests/python/tirx/operator/tile_primitive/cuda/gemm/test_gemm_mma_m16n8k_.py index 516366365f34..c15965970e15 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/gemm/test_gemm_mma_m16n8k_.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/gemm/test_gemm_mma_m16n8k_.py @@ -38,6 +38,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import S, TileLayout, laneid from tvm.tirx.operator.tile_primitive import list_registered_schedules @@ -411,7 +412,8 @@ def test_cuda_gemm_mma_rejects_fractional_beta(): _lower(_build_gemm(alpha=1.0, beta=0.5)) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) def test_cuda_gemm_mma_numerical(dtype): """End-to-end D = A @ B on a single m16n8k16 tile (one warp). @@ -503,7 +505,8 @@ def gemm(A_ptr: T.handle, B_ptr: T.handle, D_ptr: T.handle): ] -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("Mt, Nt, Kt, kinst", _TILED_SHAPES) @pytest.mark.parametrize("dtype, beta", _TILED_MODES) def test_cuda_gemm_mma_numerical_tiled(dtype, beta, Mt, Nt, Kt, kinst): @@ -537,7 +540,8 @@ def test_cuda_gemm_mma_numerical_tiled(dtype, beta, Mt, Nt, Kt, kinst): tvm.testing.assert_allclose(golden, D_dev.numpy(), atol=2e-2, rtol=2e-2) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @pytest.mark.parametrize( "transpose_A, transpose_B", @@ -596,7 +600,8 @@ def test_cuda_gemm_mma_lowers_tiled(Mt, Nt, Kt, kinst): assert f"m16n8k{kinst}" in script -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize( "Mt, Nt, Kt, kinst", [ @@ -637,7 +642,8 @@ def test_cuda_gemm_mma_lowers_transpose(transpose_A, transpose_B): assert "m16n8k16" in script -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize( "transpose_A, transpose_B", [(False, False), (True, False), (False, True), (True, True)], diff --git a/tests/python/tirx/test_bench_utils.py b/tests/python/tirx/test_bench_utils.py index 9e56d466237c..e5bf13bc845e 100644 --- a/tests/python/tirx/test_bench_utils.py +++ b/tests/python/tirx/test_bench_utils.py @@ -21,7 +21,7 @@ pytest.importorskip("triton") # tvm.tirx.bench imports triton.profiler -import tvm.testing +from tvm.testing import env from tvm.tirx.bench import _compute_group_count, _parse_proton_tree, bench, tensor_bytes # ── _parse_proton_tree ────────────────────────────────────────────────────── @@ -91,7 +91,8 @@ def test_parse_proton_tree_empty(): # ── bench ─────────────────────────────────────────────────────────────────── -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_bench_basic(): """bench returns positive times for each impl.""" M, N = 256, 256 @@ -108,7 +109,8 @@ def make_input(): assert results["impls"]["matmul"] > 0 -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_bench_multiple_impls(): """Multiple impls each get their own timing.""" M, N = 128, 128 @@ -129,7 +131,8 @@ def make_input(): assert all(v > 0 for v in results["impls"].values()) -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_bench_multiple_input_groups(): """Multiple input groups cycle correctly (L2 eviction).""" M, N = 128, 128 @@ -176,7 +179,8 @@ def test_compute_groups_moderate_tensors(): assert n == 4 -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_bench_legacy_callable_api(): """bench still accepts the existing single-callable API used by TIRx tests.""" M, N = 128, 128 @@ -189,7 +193,8 @@ def test_bench_legacy_callable_api(): assert result > 0 -@tvm.testing.requires_cuda +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_bench_callable_inputs(): """bench accepts a factory callable and auto-computes groups.""" M, N = 256, 256 diff --git a/tests/python/tvmscript/test_tvmscript_ops.py b/tests/python/tvmscript/test_tvmscript_ops.py index f053473bd7a2..7793fb72c6d9 100644 --- a/tests/python/tvmscript/test_tvmscript_ops.py +++ b/tests/python/tvmscript/test_tvmscript_ops.py @@ -21,6 +21,7 @@ import tvm import tvm.testing from tvm.script import tirx as T +from tvm.testing import env @T.prim_func(s_tir=True) @@ -174,7 +175,7 @@ def ceildiv_test(A: T.Buffer(16, "int32")): A[i] = T.ceildiv(A[i], 4) -@tvm.testing.requires_llvm +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_ceildiv(): f = tvm.compile(ceildiv_test, "llvm") a = tvm.runtime.tensor(np.arange(16).astype("int32")) diff --git a/tests/scripts/task_python_integration_gpuonly.sh b/tests/scripts/task_python_integration_gpuonly.sh index b1dd1fe104ea..f3ea1f985629 100755 --- a/tests/scripts/task_python_integration_gpuonly.sh +++ b/tests/scripts/task_python_integration_gpuonly.sh @@ -19,6 +19,7 @@ set -exo pipefail export TVM_TEST_TARGETS='cuda;opencl;metal;rocm;nvptx;{"kind":"opencl","device":"mali,adreno"}' +# Every GPU test carries the `gpu` marker; the specific backend is gated by skipif. export PYTEST_ADDOPTS="-m gpu $PYTEST_ADDOPTS" export TVM_RELAY_TEST_TARGETS="cuda" export TVM_INTEGRATION_TESTSUITE_NAME=python-integration-gpu diff --git a/tests/scripts/task_python_unittest_gpuonly.sh b/tests/scripts/task_python_unittest_gpuonly.sh index b011cb57cc00..be3ea0794761 100755 --- a/tests/scripts/task_python_unittest_gpuonly.sh +++ b/tests/scripts/task_python_unittest_gpuonly.sh @@ -18,6 +18,7 @@ set -euxo pipefail +# Every GPU test carries the `gpu` marker; the specific backend is gated by skipif. export PYTEST_ADDOPTS="-m gpu ${PYTEST_ADDOPTS:-}" # Test most of the enabled runtimes here. From 694dacb9646edb544226197918fb85365a7ed8b1 Mon Sep 17 00:00:00 2001 From: Bohan Hou Date: Tue, 16 Jun 2026 03:52:14 -0700 Subject: [PATCH 08/23] [TIRX][CUDA] Framework support for FA4, CLC intrinsics, and nvfp4 tcgen05 GEMM (#19785) --- python/tvm/backend/cuda/lang/pipeline.py | 11 +- .../tvm/backend/cuda/lang/tile_scheduler.py | 135 +++++++++++++- python/tvm/backend/cuda/op.py | 79 +++++++- .../backend/cuda/operator/intrinsics/sync.py | 100 ++++++++++- .../tile_primitive/copy_async/tcgen05_ldst.py | 35 ++-- .../tile_primitive/elementwise/reg.py | 67 +++++++ python/tvm/backend/cuda/script.py | 6 + python/tvm/support/nvcc.py | 76 ++++++-- .../tirx/script/builder/external_kernel.py | 2 +- src/backend/cuda/op/target_builtin.cc | 6 + src/target/llvm/codegen_llvm.cc | 17 ++ src/target/llvm/codegen_llvm.h | 3 + src/tirx/ir/layout/tile_slice.cc | 6 +- .../codegen/test_target_codegen_llvm.py | 39 ++++ .../python/tirx/codegen/test_codegen_cuda.py | 11 ++ .../tirx/codegen/test_codegen_nvshmem.py | 3 + tests/python/tirx/codegen/test_cuda_copy.py | 11 ++ .../tirx/codegen/test_cuda_cta_reduce.py | 13 ++ .../tirx/codegen/test_cuda_warp_reduce.py | 13 ++ tests/python/tirx/conftest.py | 40 +++++ .../tile_primitive/cuda/copy/test_fallback.py | 5 + .../cuda/copy/test_gmem_smem.py | 4 + .../tile_primitive/cuda/copy/test_reg.py | 5 + .../cuda/copy_async/test_ldgsts.py | 3 + .../cuda/copy_async/test_tmem.py | 7 + .../cuda/copy_async/test_tmem_16xnb.py | 144 +++++++++++++++ .../cuda/elementwise/test_binary.py | 13 ++ .../cuda/elementwise/test_fma.py | 15 ++ .../cuda/elementwise/test_unary.py | 168 +++++++++++++++++- .../cuda/gemm_async/test_gemm_async.py | 23 +++ .../permute_layout/test_permute_layout.py | 7 + .../cuda/reduction/test_reduction.py | 23 +++ tests/python/tirx/test_buffer_print.py | 4 + tests/python/tirx/test_control_flow.py | 8 + tests/python/tirx/test_layout.py | 35 ++++ tests/scripts/task_python_unittest.sh | 1 + 36 files changed, 1096 insertions(+), 42 deletions(-) create mode 100644 tests/python/tirx/conftest.py diff --git a/python/tvm/backend/cuda/lang/pipeline.py b/python/tvm/backend/cuda/lang/pipeline.py index ee86090398e9..40fd40c3fac6 100644 --- a/python/tvm/backend/cuda/lang/pipeline.py +++ b/python/tvm/backend/cuda/lang/pipeline.py @@ -110,7 +110,7 @@ def wait(self, stage, phase): T.ptx.mbarrier.try_wait(self.buf.ptr_to([stage]), phase ^ self.phase_offset) @T.inline - def arrive(self, stage, cta_id=None, pred=None): + def arrive(self, stage, cta_id=None, pred=None, count=None): # Default: local-CTA arrive — emits the simple # ``mbarrier.arrive.shared.b64`` form. To arrive on a remote # CTA's mbarrier in a cluster kernel, callers must pass @@ -119,11 +119,18 @@ def arrive(self, stage, cta_id=None, pred=None): # the cross-CTA path was both surprising (``bar.arrive(stage)`` # silently ``mapa`` ed across the cluster) and a per-call cost # of ~3 PTX ops on every single-CTA kernel. + # + # ``count`` (cross-CTA path only) emits the explicit arrival-count + # operand, i.e. ``mbarrier.arrive.shared::cluster.b64 _, [addr], count``. + # When ``None`` the implicit count-of-1 form is emitted. Passing + # ``count=1`` is semantically identical but spells the count explicitly. if cta_id is None: T.ptx.mbarrier.arrive(self.buf.ptr_to([stage])) else: actual_pred = True if pred is None else pred - T.ptx.mbarrier.arrive(self.buf.ptr_to([stage]), cta_id=cta_id, pred=actual_pred) + T.ptx.mbarrier.arrive( + self.buf.ptr_to([stage]), cta_id=cta_id, pred=actual_pred, count=count + ) def ptr_to(self, idx): return self.buf.ptr_to(idx) diff --git a/python/tvm/backend/cuda/lang/tile_scheduler.py b/python/tvm/backend/cuda/lang/tile_scheduler.py index 3fd27f25ee5f..c6154f2462f6 100644 --- a/python/tvm/backend/cuda/lang/tile_scheduler.py +++ b/python/tvm/backend/cuda/lang/tile_scheduler.py @@ -20,6 +20,7 @@ instances are automatically treated as meta values inside @T.prim_func. """ +from tvm.backend.cuda.lang.pipeline import Pipeline, PipelineState from tvm.script import tirx as T @@ -753,13 +754,20 @@ class FlashAttentionLPTScheduler(BaseTileScheduler): """ def __init__( - self, prefix: str, num_batches: int, num_heads: int, num_m_blocks: int, l2_swizzle: int + self, + prefix: str, + num_batches: int, + num_heads: int, + num_m_blocks: int, + l2_swizzle: int, + num_ctas: int | None = None, ): super().__init__(prefix) self._num_batches = num_batches self._num_heads = num_heads self._num_m_blocks = num_m_blocks self._l2_swizzle = l2_swizzle + self._num_ctas = num_ctas self._total_tasks = num_batches * num_heads * num_m_blocks # Derived constants for L2 swizzle @@ -807,10 +815,131 @@ def init(self, cta_id): @T.inline def next_tile(self): - """Advance to next tile by striding by num_ctas.""" - self.linear_idx = self._total_tasks + """Advance to the next tile. + + Single-tile mode (``num_ctas=None``, the default): each CTA owns one + task; terminate. Persistent mode (``num_ctas=N``): stride by N, like + :class:`FlashAttentionLinearScheduler`, while keeping the LPT + L2 + swizzle index mapping. + """ + if self._num_ctas is None: + self.linear_idx = self._total_tasks + else: + self.linear_idx = self.linear_idx + self._num_ctas + self.update_current_m_n_idx(self.linear_idx) # fmt: on def valid(self): """Check if there are more tiles to process.""" return self.linear_idx < self._total_tasks + + +class _CLCWorker(ClusterPersistentScheduler2D): + """Per-role CLC handle: IS-A ClusterPersistentScheduler2D (so m_idx / n_idx work as + usual) plus the role-local barrier phase and handshake. A coord-free role (e.g. an + MMA warp consuming whatever a loader staged) arms the loop with reset() not init(). + """ + + def __init__(self, clc, prefix): + super().__init__( + prefix, + num_m_tiles=clc._num_m_tiles, + num_n_tiles=clc._num_n_tiles, + num_clusters=clc._num_m_tiles * clc._num_n_tiles, + l2_group_size=clc._l2_group_size, + ) + self._clc = clc + self._sa = PipelineState(1, 0) + self._done = T.local_scalar("int32") + self._nxt = T.local_scalar("uint32") + + @T.inline + def reset(self): + self._done = 0 + + @T.inline + def init(self, cluster_id): + # Explicit base call: TVMScript's parser has no zero-arg super(). + ClusterPersistentScheduler2D.init(self, cluster_id) + self._done = 0 + + def valid(self): + return self._done == 0 + + @T.inline + def consume(self): + # Single-elected-thread scope: wait for the handle, decode, release the slot. + self._clc.sched_arr.full.wait(0, self._sa.phase) + self._sa.advance() + self._nxt = T.ptx.clc_query_cancel(T.address_of(self._clc.clc_handle[0])) + self._clc.sched_fin.empty.arrive(0, cta_id=0, pred=True) + + @T.inline + def consume_wg(self, wg_id, warp_id, lane_id): + # Warpgroup scope: all threads decode; one elected lane releases the slot. + self._clc.sched_arr.full.wait(0, self._sa.phase) + self._sa.advance() + self._nxt = T.ptx.clc_query_cancel(T.address_of(self._clc.clc_handle[0])) + T.cuda.warpgroup_sync(wg_id + 1) + if (warp_id == 0) & (lane_id == 0): + self._clc.sched_fin.empty.arrive(0, cta_id=0, pred=True) + + @T.inline + def advance_coords(self): + if self._nxt != 0xFFFFFFFF: + self.update_current_m_n_idx(self._nxt // self._clc._cta_group) + + @T.inline + def mark_done_if_drained(self): + if self._nxt == 0xFFFFFFFF: + self._done = 1 + + +@T.meta_class +class ClusterLaunchControlScheduler: + """Blackwell Cluster Launch Control (CLC) tile scheduler. + + A scheduler warp runs ``run_scheduler`` (issues ``try_cancel`` to steal the next + cluster); worker roles each take a ``worker()`` handle and pull the stolen tile + through the shared smem handshake. Owns the CLC smem: the 16B response handle, the + arrival barrier (handle ready), and the finished barrier (slot consumed; + ``finish_arrivals`` arrivals per round). Tile-coord mapping is delegated to + ``ClusterPersistentScheduler2D`` (group-major L2 ordering). + """ + + def __init__(self, pool, num_m_tiles, num_n_tiles, l2_group_size, cta_group, finish_arrivals): + self._num_m_tiles = num_m_tiles + self._num_n_tiles = num_n_tiles + self._l2_group_size = l2_group_size + self._cta_group = cta_group + self.sched_arr = Pipeline(pool, 1, full="tma", empty="mbar", init_empty=1) + self.sched_fin = Pipeline(pool, 1, full="mbar", empty="mbar", init_empty=finish_arrivals) + self.clc_handle = pool.alloc((4,), "uint32", align=16) + self._s_done = T.local_scalar("int32") + self._s_nxt = T.local_scalar("uint32") + + def worker(self, prefix): + return _CLCWorker(self, prefix) + + @T.inline + def run_scheduler(self, cbx): + # cta0 drives try_cancel; both CTAs expect_bytes + consume the handle so the + # finished-barrier count is met and the slot can be reissued. + if T.ptx.elect_sync(): + sa = PipelineState(1, 0) + sf = PipelineState(1, 1) + self._s_done = 0 + while self._s_done == 0: + if cbx == 0: + self.sched_fin.empty.wait(0, sf.phase) + sf.advance() + T.ptx.clc_try_cancel( + T.address_of(self.clc_handle[0]), T.address_of(self.sched_arr.full.buf[0]) + ) + self.sched_arr.full.arrive(0, 16) # expect_bytes for the 16B handle + self.sched_arr.full.wait(0, sa.phase) + sa.advance() + self._s_nxt = T.ptx.clc_query_cancel(T.address_of(self.clc_handle[0])) + self.sched_fin.empty.arrive(0, cta_id=0, pred=True) + if self._s_nxt == 0xFFFFFFFF: + self._s_done = 1 diff --git a/python/tvm/backend/cuda/op.py b/python/tvm/backend/cuda/op.py index e76d5fbe2452..9570e266623c 100644 --- a/python/tvm/backend/cuda/op.py +++ b/python/tvm/backend/cuda/op.py @@ -653,12 +653,12 @@ def ptx_mbarrier_init(bar, thread_count): return call_intrin("", "tirx.ptx_mbarrier_init", bar, thread_count) -def ptx_mbarrier_arrive(bar, cta_id=None, pred=None): +def ptx_mbarrier_arrive(bar, cta_id=None, pred=None, count=None): """TVM intrinsic to call mbarrier.arrive.shared::cta.b64 or @p mapa.shared::cluster.u32 - @p mbarrier.arrive.shared::cluster.b64 + @p mbarrier.arrive.shared::cluster.b64 [, count] Parameters ---------- @@ -670,11 +670,29 @@ def ptx_mbarrier_arrive(bar, cta_id=None, pred=None): pred : Optional[PrimExpr] The predicate to guard the operation. + + count : Optional[PrimExpr] + Explicit arrival count operand for the cross-CTA (cluster) form. When + ``None`` the implicit count-of-1 form is emitted; when given, emits + ``mbarrier.arrive.shared::cluster.b64 _, [addr], count``. """ if cta_id is None and pred is None: return call_intrin("", "tirx.ptx_mbarrier_arrive", bar) assert cta_id is not None and pred is not None - return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred) + if count is None: + return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred) + return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred, count) + + +def ptx_mbarrier_arrive_cluster_count(bar, cta_id, count): + """Cross-CTA ``mbarrier.arrive`` on CTA ``cta_id`` with an explicit count. + + Convenience for an already-elected thread: emits + ``@p mapa.shared::cluster.u32`` + ``@p mbarrier.arrive.shared::cluster.b64 _, + [addr], count`` with the guard defaulted to 1. + """ + return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, True, count) + def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None): @@ -706,7 +724,11 @@ def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None): """ if cta_id is None and pred is None: return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar, byte_count) - assert cta_id is not None and pred is not None + assert cta_id is not None + # Cross-CTA expect_tx from an already-elected thread: default the guard to 1 + # (the caller has elected a single lane), so callers can pass cta_id alone. + if pred is None: + pred = True return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar, byte_count, cta_id, pred) @@ -729,6 +751,23 @@ def ptx_mbarrier_try_wait(bar, phase): return call_intrin("", "tirx.ptx_mbarrier_try_wait", bar, phase) +def ptx_mbarrier_try_wait_acquire_cluster(bar, phase): + """``mbarrier.try_wait.parity.acquire.cluster`` retry loop. + + Cluster-scope acquire wait — used to wait on a barrier that a remote CTA in + the cluster arrives on (a group cluster wait). + + Parameters + ---------- + bar : Var + The pointer to barrier variable. + + phase : int + The phase of the barrier. + """ + return call_intrin("", "tirx.ptx_mbarrier_try_wait_acquire_cluster", bar, phase) + + def ptx_mbarrier_try_wait_once(bar, phase, ticks): """TVM intrinsic for one-shot non-blocking ``mbarrier.try_wait.parity``. @@ -1261,6 +1300,38 @@ def ptx_barrier_cluster_wait(acquire=False, aligned=True): return call_intrin("", "tirx.ptx_barrier_cluster_wait", acquire, aligned) +def ptx_clc_try_cancel(handle, mbar): + """TVM intrinsic to call clusterlaunchcontrol.try_cancel. + + Async-requests cancelling the next cluster's launch (work-stealing): writes the + 16B response handle to smem and signals ``mbar`` (complete_tx, multicast to both + cluster CTAs). + + Parameters + ---------- + handle : PrimExpr + Pointer to the 16B (uint4) smem response handle. + + mbar : PrimExpr + Pointer to the mbarrier signalled when the handle lands. + """ + return call_intrin("", "tirx.ptx_clc_try_cancel", handle, mbar) + + +def ptx_clc_query_cancel(handle): + """TVM intrinsic to call clusterlaunchcontrol.query_cancel. + + Decodes the response handle written by :func:`ptx_clc_try_cancel`. Returns the + cancelled cluster's first ``ctaid.x``, or ``0xFFFFFFFF`` when no work was stolen. + + Parameters + ---------- + handle : PrimExpr + Pointer to the 16B (uint4) smem response handle. + """ + return call_intrin("uint32", "tirx.ptx_clc_query_cancel", handle) + + def ptx_elect_sync(): """TVM intrinsic to call elect.sync""" return call_intrin("uint32", "tirx.ptx_elect_sync") diff --git a/python/tvm/backend/cuda/operator/intrinsics/sync.py b/python/tvm/backend/cuda/operator/intrinsics/sync.py index 0fcdb31a46f1..791d9cc981fc 100644 --- a/python/tvm/backend/cuda/operator/intrinsics/sync.py +++ b/python/tvm/backend/cuda/operator/intrinsics/sync.py @@ -168,6 +168,54 @@ def _ptx_barrier_cluster_wait(acquire, aligned): ) +# ============================================================================= +# clusterlaunchcontrol.try_cancel / query_cancel — Blackwell Cluster Launch +# Control (CLC) work-stealing, written from the PTX ISA spec (section +# "clusterlaunchcontrol", PTX ISA 8.6). try_cancel async-requests cancelling the +# next cluster's launch, writing a 16B response to smem + signalling mbar. query +# decodes the response: on success it extracts the cancelled cluster's first +# ctaid.x (via the get_first_ctaid::x form); a single uint32 is returned, with +# 0xFFFFFFFF as the "no work stolen" sentinel (a device helper returns one scalar). +# ============================================================================= +device_intrinsic( + "ptx_clc_try_cancel", + c_signature="(void* handle, void* mbar)", + body=( + " unsigned int addr = (unsigned int)__cvta_generic_to_shared(handle);\n" + " unsigned int bar = (unsigned int)__cvta_generic_to_shared(mbar);\n" + " asm volatile(\n" + ' "clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes"\n' + ' ".multicast::cluster::all.b128 [%0], [%1];\\n"\n' + ' :: "r"(addr), "r"(bar) : "memory");' + ), +) + + +device_intrinsic( + "ptx_clc_query_cancel", + c_signature="(void* handle)", + return_type="uint32_t", + tvm_return_type="uint32", + body=( + " unsigned int addr = (unsigned int)__cvta_generic_to_shared(handle);\n" + " unsigned int first_ctaid_x;\n" + " asm volatile(\n" + ' "{\\n"\n' + ' ".reg .pred canceled;\\n"\n' + ' ".reg .b128 response;\\n"\n' + ' "ld.shared.b128 response, [%1];\\n"\n' + ' "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 canceled, response;\\n"\n' + ' "mov.u32 %0, 0xffffffff;\\n"\n' + ' "@canceled clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128"\n' + ' " %0, response;\\n"\n' + ' "}\\n"\n' + ' : "=r"(first_ctaid_x) : "r"(addr) : "memory");\n' + ' asm volatile("fence.proxy.async.shared::cta;\\n" ::: "memory");\n' + " return first_ctaid_x;" + ), +) + + # ============================================================================= # mbarrier.init.shared.b64 [addr], count ; — 1 form. # ============================================================================= @@ -208,7 +256,7 @@ def _ptx_barrier_cluster_wait(acquire, aligned): ' "{\\n"\n' ' ".reg .pred p;\\n"\n' ' ".reg .b32 remAddr32;\\n"\n' - ' "setp.eq.u32 p, %2, 1;\\n"\n' + ' "setp.ne.s32 p, %2, 0;\\n"\n' ' "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\\n"\n' ' "@p mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\\n"\n' ' "}\\n"\n' @@ -217,15 +265,38 @@ def _ptx_barrier_cluster_wait(acquire, aligned): ) +# Same cross-CTA arrive, but with an explicit arrival-count operand +# (``..., [remAddr32], count``). Matches the ``tma::cluster::arrive`` spelling. +device_intrinsic( + "_ptx_mbarrier_arrive_remote_count", + helper_name="tvm_builtin_ptx_mbarrier_arrive_remote_count", + c_signature="(void* barrier, int cta_id, int pred, int count)", + body=( + " unsigned int barrier_addr = __cvta_generic_to_shared(barrier);\n" + " asm volatile(\n" + ' "{\\n"\n' + ' ".reg .pred p;\\n"\n' + ' ".reg .b32 remAddr32;\\n"\n' + ' "setp.ne.s32 p, %2, 0;\\n"\n' + ' "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\\n"\n' + ' "@p mbarrier.arrive.shared::cluster.b64 _, [remAddr32], %3;\\n"\n' + ' "}\\n"\n' + ' :: "r"(barrier_addr), "r"(cta_id), "r"(pred), "r"(count) : "memory");' + ), +) + + @register_codegen("ptx_mbarrier_arrive") def _codegen_mbarrier_arrive(*args): - """Dispatch by arg count: 1 -> local, 3 -> remote (cluster-mapped).""" + """Dispatch by arg count: 1 -> local, 3 -> remote, 4 -> remote+count.""" if len(args) == 1: result = CODEGEN_REGISTRY["tirx._ptx_mbarrier_arrive_local"](list(args)) elif len(args) == 3: result = CODEGEN_REGISTRY["tirx._ptx_mbarrier_arrive_remote"](list(args)) + elif len(args) == 4: + result = CODEGEN_REGISTRY["tirx._ptx_mbarrier_arrive_remote_count"](list(args)) else: - raise ValueError(f"ptx_mbarrier_arrive expects 1 or 3 args, got {len(args)}") + raise ValueError(f"ptx_mbarrier_arrive expects 1, 3, or 4 args, got {len(args)}") return result[0] if isinstance(result, tuple) else result @@ -252,7 +323,7 @@ def _codegen_mbarrier_arrive(*args): ' "{\\n"\n' ' ".reg .pred p;\\n"\n' ' ".reg .b32 remAddr32;\\n"\n' - ' "setp.eq.u32 p, %2, 1;\\n"\n' + ' "setp.ne.s32 p, %2, 0;\\n"\n' ' "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\\n"\n' ' "@p mbarrier.arrive.expect_tx.shared::cluster.b64 _, [remAddr32], %3;\\n"\n' ' "}\\n"\n' @@ -303,6 +374,27 @@ def _codegen_mbarrier_arrive_expect_tx(*args): ) +# mbarrier.try_wait.parity.acquire.cluster — cluster-scope acquire wait used for +# cross-CTA barrier handshakes (e.g. the tmem-finished handoff). +device_intrinsic( + "ptx_mbarrier_try_wait_acquire_cluster", + c_signature="(void* barrier, int phase)", + body=( + " unsigned int barrier_addr_int = __cvta_generic_to_shared(barrier);\n" + " asm volatile(\n" + ' "{\\n"\n' + ' ".reg .pred P1;\\n"\n' + ' "LAB_WAIT_AC:\\n"\n' + ' "mbarrier.try_wait.parity.acquire.cluster.shared::cta.b64 P1, [%0], %1;\\n"\n' + ' "@P1 bra.uni DONE_AC;\\n"\n' + ' "bra.uni LAB_WAIT_AC;\\n"\n' + ' "DONE_AC:\\n"\n' + ' "}\\n"\n' + ' :: "r"(barrier_addr_int), "r"(phase) : "memory");' + ), +) + + # ============================================================================= # mbarrier.try_wait.parity — ONE-SHOT non-blocking variant. Returns true # if the requested parity has already been reached, false otherwise. diff --git a/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py index ffd5e18a3a5c..081ea5a772d3 100644 --- a/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py @@ -369,20 +369,24 @@ def _emit_16xnb_path( tmem_st, tmem_extent = get_st_extent(tmem_region) local_st, local_extent = get_st_extent(local_region) - # Local slice must be the full (frag_rows, K_cols) view. + # Rows must span the full frag. The COLUMN extent may be a sub-multiple of + # the atom's full width ``width_elems`` — i.e. a per-chunk column slice of a + # wider frag (e.g. an epilogue that loads one big (128, MMA_N) frag in + # EPI_TILE-wide chunks). The atom layout maps consecutive columns to + # consecutive registers within each slab, so a column slice occupies a + # contiguous register window; we emit ``num_eff`` (the slice's atom rep) at + # the slab base + the column's register offset. When the slice IS the full + # atom (the common case), num_eff == num and reg offset == 0 (no change). assert analyzer.can_prove_equal(local_st[0], 0) assert analyzer.can_prove_equal(local_extent[0], frag_rows) - assert analyzer.can_prove_equal(local_extent[1], width_elems) - - # TMEM slice must start at row 0 and span ``frag_rows`` rows. For Layout - # F the buffer is already (64, W) so frag_rows=64 covers the full slice; - # for Layout D + frag_rows=64 the slice reads the *first* half-slab and - # the rest of the buffer's 128 rows is invisible to this atom. For - # Layout D + frag_rows=128 the slice covers all 128 physical lanes via - # two PTX issues (row=0 + row=16). assert analyzer.can_prove_equal(tmem_st[0], 0) assert analyzer.can_prove_equal(tmem_extent[0], frag_rows) - assert analyzer.can_prove_equal(tmem_extent[1], width_elems) + # local and tmem column slices must match and divide the atom's full width. + assert analyzer.can_prove_equal(local_extent[1], tmem_extent[1]) + slice_w = int(local_extent[1]) + assert width_elems % slice_w == 0, f"slice width {slice_w} must divide atom width {width_elems}" + num_eff = num * slice_w // width_elems + regs_eff = regs_per_thread_per_slab * slice_w // width_elems del tmem_rows # only used for the structural check above col_off = tmem_st[1] @@ -410,13 +414,18 @@ def impl(): # for the register-pointer arguments of the PTX builtin. local_storage = local_buf.view(per_thread_elems, layout=TileLayout(S[per_thread_elems])) local_32b = local_storage.view("uint32") - local_reg_base = local_col_off_elems // elem_per_32b + # Register offset of the column slice within each slab. The old + # ``local_col_off // elem_per_32b`` is only correct when the slice IS the + # full atom; in general consecutive columns advance registers at the rate + # (regs_per_thread_per_slab / width_elems). For a full-atom load the + # offset is 0 either way, so existing callers are unaffected. + local_reg_base = local_col_off_elems * regs_per_thread_per_slab // width_elems for slab in range(n_slabs): reg_base = slab * regs_per_thread_per_slab op( tmem_buf.allocated_addr[0], - *[local_32b[local_reg_base + reg_base + i] for i in range(regs_per_thread_per_slab)], # noqa: E501 - shape=shape, num=num, row=slab * 16, col=col_off_32b, + *[local_32b[local_reg_base + reg_base + i] for i in range(regs_eff)], + shape=shape, num=num_eff, row=slab * 16, col=col_off_32b, ) # fmt: on return impl diff --git a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py index eddf9f3d8eac..64d77a21cf69 100644 --- a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py @@ -45,8 +45,10 @@ from ..copy._common import _carve_tail, _verify_s_tail_contig from ..layout_utils import get_sublayout_from_region, layout_signature from ._common import ( + _TID_AXIS_FOR_SCOPE, _all_threads_active, _tensor_shape_of, + _thread_cnt, align_operands_to_anchor, buffer_regions, compute_dtype_of, @@ -67,6 +69,68 @@ def _validate_anchor_layout(anchor_br) -> tuple[bool, str | None]: return True, None +def _validate_scope_level_anchor(anchor_br, sctx: DispatchContext) -> tuple[bool, str | None]: + """For warp/warpgroup/cta scope, require dst to be scope-level: after + canonicalizing with the target its thread axes are the scope's intra-thread + axis (laneid/tid_in_wg/tx) and, sorted by stride, tile a complete ``T:1`` + chain over all ``T`` threads of the scope. Rejects thread-local ``.local()`` + views; thread scope is exempt. + """ + scope = sctx.scope_kind + if scope == "thread": + return True, None + expected_axis = _TID_AXIS_FOR_SCOPE.get(scope) + if expected_axis is None: + return True, None + expected_cnt = _thread_cnt(sctx) + + # Canonicalize the sliced anchor with the target so warp/lane axes fuse. + st, ext = get_st_extent(anchor_br) + sliced = get_sublayout_from_region(anchor_br.buffer.layout, anchor_br.buffer.shape, st, ext) + with sctx.target: + canon = sliced.canonicalize() if hasattr(sliced, "canonicalize") else sliced + shard = getattr(canon, "shard", None) + if shard is None: + return False, f"{scope}-scope op operand layout is not a TileLayout after slicing" + + thread_iters = [it for it in shard if it.axis.is_thread()] + if not thread_iters: + return ( + False, + f"{scope}-scope op needs a {scope}-level operand whose layout carries " + f"thread axes ({expected_axis} composing to {expected_cnt}:1); got a " + f"thread-local view with no thread axes — pass the {scope}-level tensor, " + f"not its `.local()` (per-thread) view", + ) + bad = sorted({it.axis.name for it in thread_iters if it.axis.name != expected_axis}) + if bad: + return ( + False, + f"{scope}-scope op operand carries thread axes {bad}; after " + f"canonicalization a {scope}-level layout must use only {expected_axis!r}", + ) + # Sorted by stride the thread iters must tile a complete chain 1, e0, + # e0*e1, ... up to the scope thread count — i.e. cover all T threads with + # no gap or overlap (extents alone would miss gaps/overlaps). + running = 1 + for it in sorted(thread_iters, key=lambda i: int(i.stride)): + stride, extent = int(it.stride), int(it.extent) + if stride != running: + return ( + False, + f"{scope}-scope op operand thread axes do not tile a complete " + f"{expected_cnt}:1 (sorted by stride: expected {running}, got {stride})", + ) + running *= extent + if running != expected_cnt: + return ( + False, + f"{scope}-scope op operand thread axes span {running} threads, not the " + f"full {expected_cnt} of the {scope}", + ) + return True, None + + def _check_layout_operands_agree(plan) -> tuple[bool, str | None]: """Replica sigs must match across non-trivial-layout operands. @@ -133,6 +197,9 @@ def check(op_call: TilePrimitiveCall, sctx: DispatchContext) -> tuple[bool, str ok3, reason3 = _validate_anchor_layout(anchor) if not ok3: return False, reason3 + ok_scope, reason_scope = _validate_scope_level_anchor(anchor, sctx) + if not ok_scope: + return False, reason_scope # Shape compat (NumPy-style broadcast): anchor's tensor shape is the # result shape; every operand must broadcast TO anchor. anchor_tshape = _tensor_shape_of(anchor.region) diff --git a/python/tvm/backend/cuda/script.py b/python/tvm/backend/cuda/script.py index a1148f9b67ee..a46aa7e7e472 100644 --- a/python/tvm/backend/cuda/script.py +++ b/python/tvm/backend/cuda/script.py @@ -53,6 +53,8 @@ def __init__(self): self.stmatrix = _op_wrapper(_cuda_op.ptx_stmatrix) self.setmaxnreg: Callable[..., Any] = _op_wrapper(_cuda_op.ptx_setmaxnreg) self.elect_sync: Callable[..., Any] = _op_wrapper(_cuda_op.ptx_elect_sync) + self.clc_try_cancel = _op_wrapper(_cuda_op.ptx_clc_try_cancel) + self.clc_query_cancel = _op_wrapper(_cuda_op.ptx_clc_query_cancel) self.fetch_register: Callable[..., Any] = _op_wrapper(_cuda_op.ptx_fetch_register) self.ld = _op_wrapper(_cuda_op.ptx_ld) self.ld_acquire = _op_wrapper(_cuda_op.ptx_ld_acquire) @@ -276,6 +278,9 @@ def __init__(self): self.init = _op_wrapper(_cuda_op.ptx_mbarrier_init) self.try_wait = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait) self.try_wait_once = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait_once) + self.try_wait_acquire_cluster = _op_wrapper( + _cuda_op.ptx_mbarrier_try_wait_acquire_cluster + ) self.arrive = MbarrierArriveNamespace() @@ -284,6 +289,7 @@ class MbarrierArriveNamespace: def __init__(self): self.expect_tx = _op_wrapper(_cuda_op.ptx_mbarrier_arrive_expect_tx) + self.cluster_count = _op_wrapper(_cuda_op.ptx_mbarrier_arrive_cluster_count) def __call__(self, *args, **kwds): return _op_wrapper(_cuda_op.ptx_mbarrier_arrive)(*args, **kwds) diff --git a/python/tvm/support/nvcc.py b/python/tvm/support/nvcc.py index ea5939fceffc..b421042fb30b 100644 --- a/python/tvm/support/nvcc.py +++ b/python/tvm/support/nvcc.py @@ -32,7 +32,7 @@ def compile_cuda( - code, target_format=None, arch=None, options=None, path_target=None, compiler="nvcc" + code, target_format=None, arch=None, options=None, path_target=None, compiler="nvrtc" ): """Compile CUDA code with NVCC or NVRTC. @@ -54,7 +54,7 @@ def compile_cuda( Output file. compiler : str, optional - Compiler backend: "nvcc" or "nvrtc". + Compiler backend: "nvrtc" (default) or "nvcc". This can be set by the TVM_CUDA_COMPILE_MODE environment variable. Returns @@ -191,7 +191,7 @@ def _compile_cuda_nvcc( "--expt-extended-lambda", "--use_fast_math", "--ptxas-options=-v", # printing out number of registers - "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers # noqa: E501 + f"--ptxas-options=--verbose,--register-usage-level={os.environ.get('TVM_CUDA_PTXAS_REG_LEVEL', '10')},--warn-on-local-memory-usage", # noqa: E501 ] major, _ = parse_compute_version(get_target_compute_version(Target.current(allow_none=True))) @@ -342,14 +342,23 @@ def _compile_cuda_nvrtc( line for line in code.splitlines() if line.strip() not in headers_to_strip ) - # NVRTC compiles device code and does not include the host-side cuda.h. - # CUtensorMap is a host-side structure, to reference and use it in device code, - # we must forward-declare it for NVRTC. + # NVRTC compiles device code and does not include the host-side cuda.h + # (it is guarded behind ``#ifndef __CUDACC_RTC__`` in generated code and is + # stripped above), so the complete ``CUtensorMap_st`` layout that cuda.h + # normally provides is missing. TMA kernels take ``CUtensorMap`` by value as + # ``__grid_constant__`` params, which requires the complete type. Define the + # ``CUtensorMap_st`` tag with cuda.h's layout (64-byte aligned, 128 bytes) + # plus the typedef alias. This is compatible with cccl's ````, + # which only forward-declares ``struct CUtensorMap_st;`` and re-typedefs the + # alias (a redundant typedef to the same type is legal in C++); defining the + # tag rather than ``struct CUtensorMap`` avoids the previous redefinition + # clash with that header. if "CUtensorMap" in code_filtered: code_filtered = ( - "struct __align__(128) CUtensorMap {\n" + "struct alignas(64) CUtensorMap_st {\n" " unsigned long long opaque[16];\n" - "};\n\n" + code_filtered + "};\n" + "typedef struct CUtensorMap_st CUtensorMap;\n\n" + code_filtered ) # Add standard type definitions and compatibility macros that NVRTC doesn't provide. @@ -371,6 +380,13 @@ def _compile_cuda_nvrtc( #define __volatile__ volatile #endif +// NVRTC does not pull in the host , so INFINITY is undefined. Provide it +// from libcu++ (same float +inf value nvcc's yields). +#include +#ifndef INFINITY +#define INFINITY (::cuda::std::numeric_limits::infinity()) +#endif + """ code_filtered = nvrtc_preamble + code_filtered @@ -406,6 +422,9 @@ def _compile_cuda_nvrtc( compile_opts = [ f"--gpu-architecture={arch}".encode(), b"-default-device", + # nvcc enables 128-bit integers by default on Linux; NVRTC requires the + # flag to be passed explicitly for kernels that use __int128_t. + b"--device-int128", ] if use_nvshmem: @@ -469,6 +488,21 @@ def _compile_cuda_nvrtc( ] ) + # Define the vector-deprecation silencing macros as no-ops for every NVRTC + # compile. These live in vector_types.h, which the fp4/fp6/fp8 headers use + # but do not include; depending on the include chain NVRTC pulls in, the + # macro can be left undefined and trigger a bogus "declaration has no storage + # class" error. Defining them empty is harmless (they only gate host-side + # deprecation warnings) and matches what the NVSHMEM path already did. + compile_opts.extend( + [ + b"-D__NV_SILENCE_DEPRECATION_BEGIN=", + b"-D__NV_SILENCE_DEPRECATION_END=", + b"-D__NV_SILENCE_HOST_DEPRECATION_BEGIN=", + b"-D__NV_SILENCE_HOST_DEPRECATION_END=", + ] + ) + compile_opts.extend( [ b"-U__CUDA_NO_HALF_OPERATORS__", @@ -481,6 +515,24 @@ def _compile_cuda_nvrtc( ] ) + # Mirror the nvcc path's ptxas options. register-usage-level drives ptxas + # register allocation / instruction scheduling and is perf-relevant (FA4 was + # tuned around it, hence the env-driven default); -v and + # --warn-on-local-memory-usage are diagnostic. NVRTC rejects -O3 and + # --register-usage-level as top-level flags but forwards them to its internal + # ptxas via --ptxas-options (ptxas already defaults to -O3). NB: unlike nvcc, + # NVRTC does not comma-split --ptxas-options, so each ptxas flag must be its + # own entry. The nvcc-only --expt-relaxed-constexpr / --expt-extended-lambda + # have no NVRTC equivalent and are intentionally not mirrored. + reg_level = os.environ.get("TVM_CUDA_PTXAS_REG_LEVEL", "10") + compile_opts.extend( + [ + b"--ptxas-options=-v", + f"--ptxas-options=--register-usage-level={reg_level}".encode(), + b"--ptxas-options=--warn-on-local-memory-usage", + ] + ) + # Add user-provided options, filtering out nvcc-specific flags that nvrtc doesn't support if options: nvcc_only_prefixes = ( @@ -802,7 +854,7 @@ def tvm_callback_cuda_compile(code): Compile CUDA code using the configured backend (nvcc or nvrtc). This callback is invoked by TVM's C++ backend during CUDA module compilation. - By default, uses nvcc to generate fatbin. The current target is fetched + By default, uses nvrtc to generate cubin. The current target is fetched inside the callback (via ``tvm.target.Target.current(allow_none=True)``) so the caller does not need to push/pop a target scope around the invocation. @@ -810,9 +862,9 @@ def tvm_callback_cuda_compile(code): Environment Variables --------------------- TVM_CUDA_COMPILE_MODE : str - Compiler backend: "nvcc" (default) or "nvrtc" - - "nvcc": Use nvcc subprocess, generates fatbin + Compiler backend: "nvrtc" (default) or "nvcc" - "nvrtc": Use NVRTC via cuda-bindings for faster JIT, generates cubin + - "nvcc": Use nvcc subprocess, generates fatbin TVM_KERNEL_DUMP : str If set, dump generated CUDA/intermediate files and append "-lineinfo" so profilers can correlate SASS back to the dumped source. @@ -830,7 +882,7 @@ def tvm_callback_cuda_compile(code): # The current Target is fetched inside compile_cuda via # tvm.target.Target.current(allow_none=True) when arch is unset; the # caller no longer needs to push/pop a target scope. - compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvcc").lower() + compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvrtc").lower() if compiler == "nvrtc": return compile_cuda(code, target_format="cubin", compiler="nvrtc") diff --git a/python/tvm/tirx/script/builder/external_kernel.py b/python/tvm/tirx/script/builder/external_kernel.py index c1f5d5871655..d56ed9ea0384 100644 --- a/python/tvm/tirx/script/builder/external_kernel.py +++ b/python/tvm/tirx/script/builder/external_kernel.py @@ -159,7 +159,7 @@ def compile_to_device_module( # pylint: disable=arguments-differ target_format = "cubin" if use_nvshmem else "ptx" output_path = f"{temp_dir}/{kernel_name}.{target_format}" - compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvcc") + compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvrtc") nvcc.compile_cuda( source_code, target_format=target_format, diff --git a/src/backend/cuda/op/target_builtin.cc b/src/backend/cuda/op/target_builtin.cc index 005fe5b32263..353c04b501ec 100644 --- a/src/backend/cuda/op/target_builtin.cc +++ b/src/backend/cuda/op/target_builtin.cc @@ -152,6 +152,9 @@ TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_arrive_expect_tx) TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_try_wait) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); +TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_try_wait_acquire_cluster) + .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); + TIRX_DEFINE_BUILTIN_FUNC(ptx_bar_arrive) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); @@ -497,6 +500,8 @@ const DeviceIntrinsicRegistration kDeviceIntrinsics[] = { TIRX_DEVICE_INTRIN_ALIAS(ptx_bar_sync, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_barrier_cluster_arrive, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_barrier_cluster_wait, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_clc_query_cancel, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_clc_try_cancel, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk_commit_group, ptx, kOpaque), @@ -540,6 +545,7 @@ const DeviceIntrinsicRegistration kDeviceIntrinsics[] = { TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_init, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_test_wait_parity, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_try_wait, ptx, kOpaque), + TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_try_wait_acquire_cluster, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_try_wait_once, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_mma, ptx, kOpaque), TIRX_DEVICE_INTRIN_ALIAS(ptx_mma_legacy, ptx, kOpaque), diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 88a28ebccb5f..f32dcdde11fd 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -133,6 +133,8 @@ void CodeGenLLVM::Init(const std::string& module_name, LLVMTarget* llvm_target, builder_.reset(new IRBuilder(*ctx)); module_.reset(new llvm::Module(module_name, *ctx)); md_builder_.reset(new llvm::MDBuilder(*ctx)); + functions_.clear(); + function_symbol_owners_.clear(); // types t_void_ = llvm::Type::getVoidTy(*ctx); t_void_p_ = llvmGetPointerTo(llvm::Type::getInt8Ty(*ctx), GetGlobalAddressSpace()); @@ -260,6 +262,21 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, cons llvm::FunctionType::get(GetLLVMType(func->ret_type), param_types, false); auto [symbol_name, linkage_type] = GetLinkage(gvar, func); + if (auto it = function_symbol_owners_.find(symbol_name); it != function_symbol_owners_.end()) { + constexpr const char* kFFISymbolPrefix = "__tvm_ffi_"; + std::string user_symbol = symbol_name; + if (user_symbol.rfind(kFFISymbolPrefix, 0) == 0) { + user_symbol = user_symbol.substr(std::char_traits::length(kFFISymbolPrefix)); + } + TVM_FFI_THROW(InternalError) << "Duplicate PrimFunc global_symbol '" << user_symbol + << "' in LLVM codegen: IRModule keys '" << it->second + << "' and '" << gvar->name_hint + << "' both lower to the same exported symbol '" << symbol_name + << "'. " + << "Each exposed PrimFunc in one IRModule must have a unique " + "global_symbol."; + } + function_symbol_owners_[symbol_name] = gvar->name_hint; auto function = module_->getFunction(MakeStringRef(symbol_name)); if (function == nullptr) { diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 8526b3f642df..08396d596daa 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -547,6 +547,9 @@ class CodeGenLLVM : public ExprFunctor, // that function. std::unordered_map functions_; + // Map from the generated LLVM function symbol to the GlobalVar that owns it. + std::unordered_map function_symbol_owners_; + // Whether current function is restricted bool is_restricted_{true}; // The analyzer information diff --git a/src/tirx/ir/layout/tile_slice.cc b/src/tirx/ir/layout/tile_slice.cc index 3f4db4837964..ce1809ae9907 100644 --- a/src/tirx/ir/layout/tile_slice.cc +++ b/src/tirx/ir/layout/tile_slice.cc @@ -144,7 +144,11 @@ ffi::Optional SlicePerGroup(TileLayout layout, PrimExpr begin, PrimE ffi::Optional TileLayoutNode::Slice(const Array& shape, const Region& region) const { arith::Analyzer analyzer; - auto [grouped_layout, seps] = Group(ffi::GetRef(this), shape); + // Canonicalize the whole layout first so scope fusion (e.g. wid_in_wg+laneid + // -> tid_in_wg) runs globally; otherwise grouping can split sibling thread + // axes and SlicePerGroup's per-group fusion leaves an ill-formed mix. + TileLayout canon = this->Canonicalize().as().value(); + auto [grouped_layout, seps] = Group(canon, shape); std::vector new_shard; ffi::Map new_offset; for (size_t i = 0; i < seps.size() - 1; ++i) { diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index 7c093f9be27b..624d587b825f 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -30,6 +30,45 @@ from tvm.testing import env +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") +def test_duplicate_primfunc_global_symbol_diagnostic(): + @I.ir_module(s_tir=True) + class Module: + @T.prim_func(s_tir=True) + def first_unique_key(A: T.Buffer((1,), "float32")): + T.func_attr({"global_symbol": "dup_symbol", "tirx.noalias": True}) + A[0] = T.float32(1) + + @T.prim_func(s_tir=True) + def second_unique_key(A: T.Buffer((1,), "float32")): + T.func_attr({"global_symbol": "dup_symbol", "tirx.noalias": True}) + A[0] = T.float32(2) + + with pytest.raises( + tvm.error.InternalError, match="Duplicate PrimFunc global_symbol 'dup_symbol'" + ) as err: + tvm.compile(Module, target="llvm") + assert "first_unique_key" in str(err.value) + assert "second_unique_key" in str(err.value) + + +@pytest.mark.skipif(not env.has_llvm(), reason="need llvm") +def test_unique_primfunc_global_symbols_compile(): + @I.ir_module(s_tir=True) + class Module: + @T.prim_func(s_tir=True) + def first_unique_key(A: T.Buffer((1,), "float32")): + T.func_attr({"global_symbol": "dup_symbol_a", "tirx.noalias": True}) + A[0] = T.float32(1) + + @T.prim_func(s_tir=True) + def second_unique_key(A: T.Buffer((1,), "float32")): + T.func_attr({"global_symbol": "dup_symbol_b", "tirx.noalias": True}) + A[0] = T.float32(2) + + tvm.compile(Module, target="llvm") + + @pytest.mark.skipif(not env.has_llvm(), reason="need llvm") def test_llvm_intrin(): @I.ir_module(s_tir=True) diff --git a/tests/python/tirx/codegen/test_codegen_cuda.py b/tests/python/tirx/codegen/test_codegen_cuda.py index f253d6d375c6..521a72f6d732 100644 --- a/tests/python/tirx/codegen/test_codegen_cuda.py +++ b/tests/python/tirx/codegen/test_codegen_cuda.py @@ -21,6 +21,7 @@ import tvm import tvm.testing from tvm.script import tirx as T +from tvm.testing import env DEV = tvm.device("cuda") @@ -118,6 +119,8 @@ def main(A: T.Buffer((1,), "uint64")): assert "*(void* *)" not in src +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_atomic_add(): @T.prim_func def main(A: T.Buffer((1,), "int32"), B: T.Buffer((1,), "float32")): @@ -442,6 +445,8 @@ def main(A: T.Buffer((16, 16), "int32")): assert "tvm_builtin_cuda_atomic_cas" in src +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cuda_func_call(): def test_add_one(): add_one = """ @@ -497,6 +502,8 @@ def main(a: T.Buffer((16, 16), "int32")): test_print() +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_warp_shuffle_xor_sync(): # fmt: off @T.prim_func @@ -532,6 +539,8 @@ def func(A_ptr: T.handle): np.testing.assert_allclose(A.numpy(), A_ref) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("cp_size", [4, 8, 16]) @pytest.mark.parametrize("cache_hint", ["", "evict_last"]) @pytest.mark.parametrize("prefetch_size", [-1, 64, 128, 256]) @@ -575,6 +584,8 @@ def main(A: T.Buffer((N), "float16")): print(src) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("trans", [False, True]) @pytest.mark.parametrize("num", [1, 2, 4]) def test_ptx_ldmatrix(trans, num): diff --git a/tests/python/tirx/codegen/test_codegen_nvshmem.py b/tests/python/tirx/codegen/test_codegen_nvshmem.py index ff9f17170ddd..d3869077428e 100644 --- a/tests/python/tirx/codegen/test_codegen_nvshmem.py +++ b/tests/python/tirx/codegen/test_codegen_nvshmem.py @@ -28,6 +28,7 @@ from tvm.runtime import disco as di from tvm.script import tirx as T from tvm.support.popen_pool import PopenWorker +from tvm.testing import env NUM_WORKERS = 4 @@ -61,6 +62,8 @@ def create_nvshmem_array(sess, shape, dtype, init_data_fn=None, zero_out=True): return arr +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.skip(reason="nvshmem doesn't work with pytest") def test_codegen_nvshmem(): def _test_func(): diff --git a/tests/python/tirx/codegen/test_cuda_copy.py b/tests/python/tirx/codegen/test_cuda_copy.py index cb08f4247318..047eb1f12ca3 100644 --- a/tests/python/tirx/codegen/test_cuda_copy.py +++ b/tests/python/tirx/codegen/test_cuda_copy.py @@ -21,6 +21,7 @@ import tvm from tvm.script import tirx as T +from tvm.testing import env DEV = tvm.cuda(0) TARGET = tvm.target.Target("cuda") @@ -34,6 +35,8 @@ def _build_and_run(func, *np_args): return (*tuple(a.numpy() for a in rt_args), mod) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_copy_128b(): """copy_128b: copies 16 bytes (4 float32 elements) via uint4 load/store.""" @@ -63,6 +66,8 @@ def func(out_ptr: T.handle): assert "tvm_builtin_copy_128b" in mod.mod.imports[0].inspect_source() +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_copy_64b(): """copy_64b: copies 8 bytes (2 float32 elements) via uint2 load/store.""" @@ -92,6 +97,8 @@ def func(out_ptr: T.handle): assert "tvm_builtin_copy_64b" in mod.mod.imports[0].inspect_source() +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_copy_32b(): """copy_32b: copies 4 bytes (1 float32 element) via unsigned int load/store.""" @@ -121,6 +128,8 @@ def func(out_ptr: T.handle): assert "tvm_builtin_copy_32b" in mod.mod.imports[0].inspect_source() +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_copy_16b(): """copy_16b: copies 2 bytes (1 float16 element) via unsigned short load/store.""" @@ -150,6 +159,8 @@ def func(out_ptr: T.handle): assert "tvm_builtin_copy_16b" in mod.mod.imports[0].inspect_source() +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_copy_8b(): """copy_8b: copies 1 byte (1 uint8 element) via unsigned char load/store.""" diff --git a/tests/python/tirx/codegen/test_cuda_cta_reduce.py b/tests/python/tirx/codegen/test_cuda_cta_reduce.py index 51b8f1099a91..bf07da1b6798 100644 --- a/tests/python/tirx/codegen/test_cuda_cta_reduce.py +++ b/tests/python/tirx/codegen/test_cuda_cta_reduce.py @@ -21,6 +21,7 @@ import tvm from tvm.script import tirx as T +from tvm.testing import env DEV = tvm.cuda(0) TARGET = tvm.target.Target("cuda") @@ -35,6 +36,8 @@ def _build_and_run(func, n): return out.numpy(), mod +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cta_sum_4_warps(): """CTA sum with 4 warps (128 threads): all threads get the same sum.""" NUM_WARPS = 4 @@ -61,6 +64,8 @@ def func(out_ptr: T.handle): assert "cta_reduce_sum_4" in mod.mod.imports[0].inspect_source() +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cta_sum_8_warps(): """CTA sum with 8 warps (256 threads).""" NUM_WARPS = 8 @@ -86,6 +91,8 @@ def func(out_ptr: T.handle): np.testing.assert_allclose(result, np.full(N, expected)) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cta_max_4_warps(): """CTA max with 4 warps: all threads get the maximum value.""" NUM_WARPS = 4 @@ -110,6 +117,8 @@ def func(out_ptr: T.handle): np.testing.assert_allclose(result, np.full(N, float(N))) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cta_min_4_warps(): """CTA min with 4 warps: all threads get the minimum value.""" NUM_WARPS = 4 @@ -134,6 +143,8 @@ def func(out_ptr: T.handle): np.testing.assert_allclose(result, np.full(N, 1.0)) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_cta_sum_1_warp(): """CTA sum with 1 warp: degenerates to a pure warp reduce.""" NUM_WARPS = 1 @@ -159,6 +170,8 @@ def func(out_ptr: T.handle): np.testing.assert_allclose(result, np.full(N, expected)) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("num_warps", [1, 2, 4, 8, 16]) def test_cta_sum_all_warp_counts(num_warps): """Parametric test: cta_sum with various warp counts.""" diff --git a/tests/python/tirx/codegen/test_cuda_warp_reduce.py b/tests/python/tirx/codegen/test_cuda_warp_reduce.py index df568a95e483..e5167a055c9a 100644 --- a/tests/python/tirx/codegen/test_cuda_warp_reduce.py +++ b/tests/python/tirx/codegen/test_cuda_warp_reduce.py @@ -21,6 +21,7 @@ import tvm from tvm.script import tirx as T +from tvm.testing import env DEV = tvm.cuda(0) TARGET = tvm.target.Target("cuda") @@ -35,6 +36,8 @@ def _build_and_run(func, n=32): return out.numpy(), mod +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_warp_sum_full(): """Full warp sum (width=32): each lane gets the sum of all 32 values.""" @@ -57,6 +60,8 @@ def func(out_ptr: T.handle): assert "warp_reduce_sum_32" in mod.mod.imports[0].inspect_source() +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_warp_sum_partial_8(): """Partial warp sum (width=8): 4 groups of 8 lanes, each group sums independently.""" @@ -85,6 +90,8 @@ def func(out_ptr: T.handle): np.testing.assert_allclose(result, expected) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_warp_max_partial_4(): """Partial warp max (width=4): 8 groups of 4 lanes.""" @@ -109,6 +116,8 @@ def func(out_ptr: T.handle): np.testing.assert_allclose(result, expected) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_warp_min_full(): """Full warp min (width=32).""" @@ -129,6 +138,8 @@ def func(out_ptr: T.handle): np.testing.assert_allclose(result, np.full(32, 1.0)) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_warp_sum_partial_2(): """Smallest partial warp sum (width=2): 16 pairs of adjacent lanes.""" @@ -155,6 +166,8 @@ def func(out_ptr: T.handle): np.testing.assert_allclose(result, expected) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("width", [2, 4, 8, 16, 32]) def test_warp_sum_all_widths(width): """Parametric test: warp_sum with every valid width.""" diff --git a/tests/python/tirx/conftest.py b/tests/python/tirx/conftest.py new file mode 100644 index 000000000000..fb8ba62f4f41 --- /dev/null +++ b/tests/python/tirx/conftest.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Suite-level hardware gate for the tirx tests. + +The tirx kernels and codegen paths target Blackwell (sm_100a) — they emit +PTX/SASS (tcgen05, tmem, cp.async ``.async`` modifiers, fp8 conversions, ...) +that ptxas/NVRTC reject for older targets, and many tests execute on the +device. Running the suite on a CPU-only node or a pre-sm_100 GPU therefore +fails at compile/run time rather than skipping. Gate the whole directory on a +real sm_100a device so it skips cleanly where the hardware is absent and runs +in full where it is present. +""" + +import pytest + +from tvm.testing import env + + +def pytest_collection_modifyitems(config, items): + if env.has_cuda_compute(10): + return + skip = pytest.mark.skip( + reason="tirx suite requires a CUDA compute capability 10.0 (sm_100a) device" + ) + for item in items: + item.add_marker(skip) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py index 75faf61366fe..1824b41eae43 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py @@ -32,6 +32,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env # Force the fallback dispatch to register before any test compiles a kernel. # Without this import, in fresh pytest workers the `copy/fallback` variant @@ -128,6 +129,8 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: return kernel +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize( "scope,n_threads,shape,why", [ @@ -158,6 +161,8 @@ def test_fallback_round_trip(scope, n_threads, shape, why): np.testing.assert_array_equal(B.numpy(), A_np) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") def test_fallback_thread_scope(): """``T.thread()`` — single thread, no gate. Either ``gmem_smem`` picks it up (n_elements % 1 == 0) or ``fallback`` does — both end up emitting diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py index dc5a46a751ec..c31ca79db918 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py @@ -103,6 +103,8 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: ] +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize( "scope,n_threads,shape", [pytest.param(*t, id=f"{t[0]}-{t[1]}-{'x'.join(map(str, t[2]))}") for t in TASKS], @@ -194,6 +196,8 @@ def test_gmem_smem_roundtrip(scope, n_threads, shape, dtype): ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize( "dtype", ["int8", "float8_e4m3fn", "float8_e5m2", "float16", "bfloat16", "float32"] ) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py index 451622530318..26c4d5de9b18 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py @@ -35,6 +35,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import S, TileLayout, laneid, tid_in_wg, tx @@ -228,6 +229,8 @@ def _expected(shape, dtype): return out +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize("non_r_scope", ["shared", "global"]) @pytest.mark.parametrize( "scope,n_threads,k", @@ -287,6 +290,8 @@ def test_reg_roundtrip(scope, n_threads, k, dtype, non_r_scope): ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(9), reason="need cuda compute >= 9.0") @pytest.mark.parametrize( "dtype", ["int8", "float8_e4m3fn", "float8_e5m2", "float16", "bfloat16", "float32"] ) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py index b4d54d2b4109..96f92832532a 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py @@ -24,6 +24,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import S, TileLayout @@ -65,6 +66,8 @@ ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize( "dtype", ["int8", "float8_e4m3fn", "float8_e5m2", "float16", "bfloat16", "float32"] ) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py index 0f910a43766d..55e32339c72d 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py @@ -24,10 +24,13 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import S, TCol, TileLayout, TLane from tvm.tirx.layout import tid_in_wg as axis_tid_in_wg +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("dtype", ["float16", "float32"]) @pytest.mark.parametrize("width_32b", [4, 8, 16, 32]) def test_copy_tmem2reg_async(dtype, width_32b): @@ -132,6 +135,8 @@ def copy_async_test(A_ptr: T.handle, B_ptr: T.handle) -> None: # ---------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("dtype", ["uint8", "float16", "float32"]) @pytest.mark.parametrize("width_32b", [2, 4, 8, 16, 32, 64, 128]) @pytest.mark.parametrize("offset_32b", [0, 3, 10]) @@ -224,6 +229,8 @@ def copy_sync(A_ptr: T.handle, B_ptr: T.handle) -> None: np.testing.assert_allclose(B.numpy(), A_np) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("dtype", ["float16", "float32"]) @pytest.mark.parametrize("width_32b", [4, 8, 16, 32]) @pytest.mark.parametrize("local_offset_32b", [0, 2, 4]) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py index 420935946028..aac93c0252c7 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py @@ -43,6 +43,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import ( S, TCol, @@ -152,6 +153,8 @@ def _expected_reg_value_16b( # -------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("shape", list(_SHAPE_REPS)) @pytest.mark.parametrize("rep", [1, 2, 4, 8, 16, 32]) # subset; full reps below @pytest.mark.parametrize("dtype", ["float32"]) @@ -162,6 +165,8 @@ def test_tcgen05_ld_16xnb_load_fp32(shape, rep, dtype): _run_load_test(shape, rep, dtype) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize( "shape, rep", [ @@ -175,6 +180,8 @@ def test_tcgen05_ld_16xnb_load_fp32_large_rep(shape, rep): _run_load_test(shape, rep, "float32") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("shape", list(_SHAPE_REPS)) @pytest.mark.parametrize("rep", [1, 2, 4, 8, 16, 32]) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @@ -201,6 +208,8 @@ def test_tcgen05_16xnb_roundtrip_16b(shape, rep, dtype): # We only need to spot-check that the dispatch fires correctly and the per- # thread reg ↔ TMEM mapping round-trips bit-exactly — the M=64 sweep above # already covers the (lane, reg) decomposition, so a sparse rep set suffices. +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("shape", ["16x64b", "16x128b", "16x256b"]) @pytest.mark.parametrize("rep", [1, 2, 4]) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @@ -214,6 +223,8 @@ def test_tcgen05_16xnb_roundtrip_16b_M128(shape, rep, dtype): # with the scatter-encoded TileLayout that ``tmem_datapath_layout("F", ...)`` # produces. ``.16x*b`` M=64 PTX has the matching scatter built in, so the # round-trip is bit-exact in the same way as Layout D + M=64. +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("shape", ["16x64b", "16x128b", "16x256b"]) @pytest.mark.parametrize("rep", [1, 2, 4]) @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) @@ -639,6 +650,8 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: # -------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("shape", list(_SHAPE_REPS)) @pytest.mark.parametrize("rep", [1, 4, 16]) @pytest.mark.parametrize("dtype", ["float32"]) @@ -853,5 +866,136 @@ def kernel(A_ptr: T.handle) -> None: ) +# -------------------------------------------------------------------------- +# Test 3: column-slice loads of a wider frag +# +# An epilogue may allocate one wide ``(128, K)`` frag and load it from TMEM in +# EPI_TILE-wide column chunks (``frag[:, c:c+w]``) so all loads are in flight +# before a single ``wait.ld``. The ``.16x*b`` dispatch must emit each slice as +# its own atom (``num_eff`` derived from the slice width) at the correct +# per-slab register offset. We verify this is *bit-exact identical* to one +# full-width load of the same frag — which the sweeps above already validate +# against the layout-derived expectation. M=128 here exercises the 2-slab path +# (the slice's two slabs live ``regs_per_thread_per_slab`` apart, not adjacent). +# -------------------------------------------------------------------------- + + +def _run_sliced_vs_full_load(shape, full_rep, n_chunks): + dtype = "float32" + K_cols_fp32 = _COL_FACTOR_FP32[shape] * full_rep + assert K_cols_fp32 % n_chunks == 0 + chunk_elem = K_cols_fp32 // n_chunks # fp32: elem == fp32 col + frag_rows = 128 # M=128 => 2 slabs + per_thread_elems = _REGS_FACTOR[shape] * full_rep * 2 # *2 for the second slab + + tmem_col_width_32b = max(32, _next_pow2(K_cols_fp32)) + stage_width_elem = tmem_col_width_32b + CHUNK_FP32 = 128 + n_stage = tmem_col_width_32b // CHUNK_FP32 if tmem_col_width_32b > CHUNK_FP32 else 1 + stage_w = tmem_col_width_32b if n_stage == 1 else CHUNK_FP32 + VEC_LEN = 4 # 128-bit / fp32 + + atom_view = tcgen05_atom_layout(shape, (frag_rows, K_cols_fp32), dtype) + stage_view = TileLayout(S[(128, stage_w) : (1 @ axis_tid_in_wg, 1)]) + + @T.prim_func + def kernel(A_ptr: T.handle, Bf_ptr: T.handle, Bs_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (128, stage_width_elem), dtype) + Bf = T.match_buffer(Bf_ptr, (128, per_thread_elems), dtype) # full-load dump + Bs = T.match_buffer(Bs_ptr, (128, per_thread_elems), dtype) # sliced-load dump + A_flat = A.view(-1) + + T.device_entry() + warp_id = T.warp_id([4]) + T.cta_id([2]) + wg_id = T.warpgroup_id([1]) + T.warp_id_in_wg([4]) + T.lane_id([32]) + tid_in_wg = T.thread_id([128]) + + tmem_addr = T.alloc_shared([1], "uint32") + if wg_id == 0: + if warp_id == 0: + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=tmem_col_width_32b, cta_group=1) + T.tvm_storage_sync("shared") + tmem = T.decl_buffer( + (128, stage_width_elem), + dtype, + scope="tmem", + allocated_addr=tmem_addr[0], + layout=TileLayout(S[(128, stage_width_elem) : (1 @ TLane, 1 @ TCol)]), + ) + # Stage A -> TMEM via the standard .32x32b path. + stage_reg = T.alloc_local((stage_w,), dtype) + stage_local = stage_reg.view(128, stage_w, layout=stage_view) + for ci in range(n_stage): + coff = ci * stage_w + for i in range(stage_w // VEC_LEN): + g = T.meta_var(tid_in_wg * stage_width_elem + coff + i * VEC_LEN) + Tx.copy(stage_reg[i * VEC_LEN : i * VEC_LEN + VEC_LEN], A_flat[g : g + VEC_LEN]) + T.cuda.cta_sync() + Tx.wg.copy_async(tmem[:, coff : coff + stage_w], stage_local[:, :]) + T.ptx.tcgen05.wait.st() + T.cuda.cta_sync() + + # (a) one full-width load + ff = T.alloc_local((per_thread_elems,), dtype) + ffl = ff.view(frag_rows, K_cols_fp32, layout=atom_view) + Tx.wg.copy_async(ffl[:, :], tmem[0:frag_rows, 0:K_cols_fp32]) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + for i in range(per_thread_elems): + Bf[tid_in_wg, i] = ff[i] + + # (b) the same frag loaded in n_chunks column slices + sf = T.alloc_local((per_thread_elems,), dtype) + sfl = sf.view(frag_rows, K_cols_fp32, layout=atom_view) + for ck in range(n_chunks): + lo = T.meta_var(ck * chunk_elem) + Tx.wg.copy_async( + sfl[:, lo : lo + chunk_elem], tmem[0:frag_rows, lo : lo + chunk_elem] + ) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + for i in range(per_thread_elems): + Bs[tid_in_wg, i] = sf[i] + + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=tmem_col_width_32b, cta_group=1) + + target = tvm.target.Target("cuda") + with target: + mod = tvm.IRModule({"main": kernel}) + mod = tvm.compile(mod, target=target, tir_pipeline="tirx") + A_np = tvm.testing.generate_random_array(dtype, (128, stage_width_elem)) + Bf_np = np.zeros((128, per_thread_elems), dtype=dtype) + Bs_np = np.zeros((128, per_thread_elems), dtype=dtype) + DEV = tvm.cuda(0) + A = tvm.runtime.tensor(A_np, DEV) + Bf = tvm.runtime.tensor(Bf_np, DEV) + Bs = tvm.runtime.tensor(Bs_np, DEV) + mod(A, Bf, Bs) + # Sliced load must reproduce the full-width load bit-for-bit. + np.testing.assert_array_equal(Bs.numpy().view(np.uint32), Bf.numpy().view(np.uint32)) + + +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") +@pytest.mark.parametrize( + "full_rep, n_chunks", + [ + (32, 8), # 16x256b.x32 (256 fp32 cols) loaded in 8 chunks of 32 cols (nvfp4 EPI_TILE=32) + (32, 16), # ...in 16 chunks of 16 cols (nvfp4 EPI_TILE=16) + (32, 4), # ...in 4 chunks of 64 cols + (16, 8), # 16x256b.x16 (128 fp32 cols) in 8 chunks of 16 cols + (16, 2), # ...in 2 chunks of 64 cols + ], +) +def test_tcgen05_ld_16x256b_sliced_matches_full_M128(full_rep, n_chunks): + """Per-chunk column-slice load of a wide M=128 frag == full-width load.""" + _run_sliced_vs_full_load("16x256b", full_rep, n_chunks) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py index 1ce0d34ea6e0..8d39ba355633 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py @@ -23,6 +23,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import S, TileLayout, wg_local_layout @@ -67,6 +68,8 @@ ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["add", "sub", "mul", "fdiv"]) @pytest.mark.parametrize("operands_type", ["region_region", "region_const", "const_region"]) @pytest.mark.parametrize("dtype", ["float16"]) @@ -223,6 +226,8 @@ def bad_kernel() -> None: tvm.compile(mod, target=target, tir_pipeline="tirx") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("exec_scope", ["warp", "warpgroup"]) @pytest.mark.parametrize("op_type", ["add", "mul"]) def test_binary_op_shared_subcta_scope(exec_scope, op_type): @@ -276,6 +281,8 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(A_ref, A.numpy(), atol=1e-3) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("exec_scope", ["cta", "warpgroup", "warp"]) @pytest.mark.parametrize("rhs_kind", ["region", "broadcast", "const"]) @pytest.mark.parametrize("op_type", ["add", "sub", "mul", "fdiv"]) @@ -392,6 +399,8 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("storage_scope", ["shared", "local"]) @pytest.mark.parametrize("exec_scope", ["cta", "thread"]) @pytest.mark.parametrize("op_type", ["add", "sub", "mul", "fdiv"]) @@ -495,6 +504,8 @@ def get_prim_func(): tvm.testing.assert_allclose(A_ref, A.numpy(), atol=atol) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["add", "sub", "mul"]) def test_binary_op_packed_f32x2_auto_dispatch(op_type): target = tvm.target.Target("cuda") @@ -568,6 +579,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(A_ref, A.numpy(), atol=1e-3) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_name", ["add", "sub", "mul"]) def test_binary_op_warpgroup_wg_local_layout(op_name): dtype = "float32" diff --git a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py index aa0f5ced8f58..02352638e4d6 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py @@ -26,6 +26,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import S, TileLayout, wg_local_layout @@ -41,6 +42,8 @@ def _get_sm_version(): # --------------------------------------------------------------------------- # FMA op: scalar scale + scalar bias # --------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_fma_scalar_scalar(): sm = _get_sm_version() if sm < 100: @@ -78,6 +81,8 @@ def test_func(A_ptr: T.handle) -> None: # --------------------------------------------------------------------------- # FMA op: buffer scale + scalar bias (Horner pattern) # --------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_fma_buffer_scale_scalar_bias(): sm = _get_sm_version() if sm < 100: @@ -119,6 +124,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: # --------------------------------------------------------------------------- # Binary op with scalar broadcast (PrimExpr scalar, e.g. BufferLoad) # --------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_mul_scalar_broadcast(): sm = _get_sm_version() if sm < 100: @@ -158,6 +165,8 @@ def test_func(A_ptr: T.handle, S_ptr: T.handle) -> None: # --------------------------------------------------------------------------- # Binary add with rounding mode # --------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_add_rounding_mode(): sm = _get_sm_version() if sm < 100: @@ -199,6 +208,8 @@ def test_func(A_ptr: T.handle) -> None: # --------------------------------------------------------------------------- # FMA op: layout=None local buffer (no TileLayout) # --------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_fma_no_layout(): sm = _get_sm_version() if sm < 100: @@ -238,6 +249,8 @@ def test_func(A_ptr: T.handle) -> None: # --------------------------------------------------------------------------- # Binary sub with rounding mode (buffer-buffer) # --------------------------------------------------------------------------- +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_sub_buffer_buffer_rounding(): sm = _get_sm_version() if sm < 100: @@ -278,6 +291,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(expected, A_dev.numpy(), atol=1e-6) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_fma_warpgroup_wg_local_layout(): rows, cols = 128, 8 dtype = "float32" diff --git a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py index c20df63bebf0..fb70b3754123 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py @@ -23,6 +23,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.cuda.operator.tile_primitive.layout_utils import ( cast_layout_supported_for_local as _cast_layout_supported_for_local, ) @@ -54,6 +55,8 @@ ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["zero", "sqrt"]) @pytest.mark.parametrize( "src_dtype,dst_dtype", [("float16", "float16"), ("float32", "float16"), ("float32", "bfloat16")] @@ -145,6 +148,8 @@ def get_ref(A_np): tvm.testing.assert_allclose(B_ref, B.numpy(), atol=1e-2, rtol=1e-2) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("exec_scope", ["warp", "warpgroup"]) def test_unary_op_shared_subcta_scope(exec_scope): dtype = "float16" @@ -209,6 +214,8 @@ def unary_op_subcta(A_ptr: T.handle) -> None: ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["sqrt", "exp"]) @pytest.mark.parametrize("bias_type", ["const", "region"]) @pytest.mark.parametrize( @@ -432,6 +439,8 @@ def get_ref(A_np, bias_np): ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["reciprocal", "exp", "exp2"]) @pytest.mark.parametrize( "src_dtype,dst_dtype", [("float16", "float16"), ("float32", "float16"), ("float32", "bfloat16")] @@ -554,6 +563,8 @@ def test_unary(A_ptr: T.handle, B_ptr: T.handle) -> None: ), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["sqrt", "exp"]) @pytest.mark.parametrize("bias_type", ["const", "region"]) @pytest.mark.parametrize( @@ -682,6 +693,8 @@ def get_ref(A_np, bias_np): tvm.testing.assert_allclose(B_ref, B.numpy(), atol=atol) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("shape", [(128, 8), (128, 4, 16), (128, 5, 5)]) @pytest.mark.parametrize("op_type", ["fill"]) @pytest.mark.parametrize("exec_scope", ["thread", "cta"]) @@ -740,6 +753,8 @@ def test_unary_cta(A_ptr: T.handle) -> None: tvm.testing.assert_allclose(A.numpy(), np.full(shape, value.value), atol=1e-2) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["zero", "sqrt", "reciprocal", "exp", "silu"]) @pytest.mark.parametrize("dtype", ["float16"]) def test_unary_op_local_thread_wise(op_type, dtype): @@ -791,6 +806,8 @@ def kernel(A_ptr: T.handle) -> None: tvm.testing.assert_allclose(A_ref, A.numpy(), atol=1e-2, rtol=1e-2) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("shape", [(8,), (16, 16), (5, 5)]) @pytest.mark.parametrize("A_dtype", ["float16", "float32"]) @pytest.mark.parametrize("B_dtype", ["float16", "float32"]) @@ -831,6 +848,8 @@ def test_cast(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B.numpy(), B_ref, atol=1e-2) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"), ("float32", "bfloat16")]) def test_cast_warpgroup_local_view(A_dtype, B_dtype): """T.cast in warpgroup scope with offset (tid_in_wg + layout offset). Covers offset/tid_in_wg/warpgroup scope.""" # noqa: E501 @@ -884,6 +903,8 @@ def test_cast(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B.numpy(), B_ref, atol=1e-2) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"), ("float32", "bfloat16")]) def test_cast_warpgroup_src_layout_to_flat_uses_vec2_intrinsic(A_dtype, B_dtype): """Regression: GEMM-epilogue cast pattern must emit the packed vec2 cuda intrinsic. @@ -944,6 +965,8 @@ def test_cast(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B.numpy(), B_ref, atol=1e-2) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"), ("float32", "bfloat16")]) def test_cast_cta_local_view(A_dtype, B_dtype): """T.cast with view+layout in CTA scope (128 threads, register->register).""" @@ -988,6 +1011,8 @@ def test_cast(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B.numpy(), B_ref, atol=1e-2) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"), ("float32", "bfloat16")]) @pytest.mark.parametrize("slice_start,slice_end", [(0, 4), (2, 6), (4, 8)]) def test_cast_local_view_sliced(A_dtype, B_dtype, slice_start, slice_end): @@ -1087,6 +1112,8 @@ def test_cast_layout_partition_and_validation(): check(part) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("slice_start,slice_end", [(0, 2), (2, 4)]) def test_cast_mixed_axes_and_subregion(slice_start, slice_end): """Test cast with mixed axes and subregion.""" @@ -1095,7 +1122,7 @@ def test_cast_mixed_axes_and_subregion(slice_start, slice_end): LOCAL_LEN = 4 full_shape = (8, N_WARPS, 4, LOCAL_LEN) g_layout = TileLayout(S[full_shape]) - cast_layout = TileLayout(S[full_shape : (4 @ laneid, 2 @ warpid, 1 @ laneid, 1)]) + cast_layout = TileLayout(S[full_shape : (4 @ laneid, 1 @ warpid, 1 @ laneid, 1)]) A_ref = np.zeros(full_shape, dtype="float32") for j in range(full_shape[0]): @@ -1207,8 +1234,12 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: target = tvm.target.Target("cuda") with target: mod = tvm.IRModule({"main": kernel}) + # The mismatched dst also fails the scope-level check (thread axes don't + # span the full CTA), which fires first — either rejection is fine. with pytest.raises( - Exception, match="tile_local_valid|layout signature mismatch|thread part mismatch" + Exception, + match="tile_local_valid|layout signature mismatch|thread part mismatch" + "|do not tile a complete|not the full", ): tvm.compile(mod, target=target, tir_pipeline="tirx") @@ -1277,5 +1308,138 @@ def k(A_ptr: T.handle, B_ptr: T.handle) -> None: ), f"expected packed vec2 cast {intrinsic}; got:\n{src[:2000]}" +# ----------------------------------------------------------------------------- +# Scope-level operand check: a warp/wg/cta reg op needs a scope-level layout +# (thread axes spanning all the scope's threads), not a thread-local .local(). +# ----------------------------------------------------------------------------- +_SL_ROWS, _SL_COLS = 128, 8 + + +def _sl_compile(fn): + target = tvm.target.Target("cuda") + with target: + tvm.compile(tvm.IRModule({"main": fn}), target=target, tir_pipeline="tirx") + + +def test_cast_wg_rejects_thread_local_view(): + """Tx.wg.cast on a .local() (thread-axis-stripped) view is rejected.""" + + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + T.device_entry() + _bx = T.cta_id([1]) + _wg = T.warpgroup_id([1]) + tid = T.thread_id_in_wg([_SL_ROWS]) + src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)])) + dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)])) + src_row = src.local(_SL_COLS) + for i in T.serial(_SL_COLS): + src_row[i] = A[tid, i] + Tx.wg.cast(dst.local(), src.local()) + dst_row = dst.local(_SL_COLS) + for i in T.serial(_SL_COLS): + B[tid, i] = dst_row[i] + + with pytest.raises(Exception, match="thread-local view"): + _sl_compile(kernel) + + +def test_cast_cta_rejects_thread_local_view(): + """Tx.cta.cast on a .local() view is rejected (cta -> tx).""" + + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + T.device_entry() + _bx = T.cta_id([1]) + tx_var = T.thread_id([_SL_ROWS]) + src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)])) + dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)])) + src_row = src.local(_SL_COLS) + for i in T.serial(_SL_COLS): + src_row[i] = A[tx_var, i] + Tx.cta.cast(dst.local(), src.local()) + dst_row = dst.local(_SL_COLS) + for i in T.serial(_SL_COLS): + B[tx_var, i] = dst_row[i] + + with pytest.raises(Exception, match="thread-local view"): + _sl_compile(kernel) + + +def test_cast_wg_rejects_partial_thread_coverage(): + """A tid_in_wg layout covering only 64 of the 128 wg threads is rejected.""" + half = 64 + + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (half, _SL_COLS), "float32", layout=TileLayout(S[(half, _SL_COLS)])) + B = T.match_buffer(B_ptr, (half, _SL_COLS), "float16", layout=TileLayout(S[(half, _SL_COLS)])) + T.device_entry() + _bx = T.cta_id([1]) + _wg = T.warpgroup_id([1]) + tid = T.thread_id_in_wg([_SL_ROWS]) + src = T.alloc_buffer((half, _SL_COLS), "float32", scope="local", layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)])) + dst = T.alloc_buffer((half, _SL_COLS), "float16", scope="local", layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)])) + src_row = src.local(_SL_COLS) + for i in T.serial(_SL_COLS): + src_row[i] = A[tid, i] + Tx.wg.cast(dst, src) + dst_row = dst.local(_SL_COLS) + for i in T.serial(_SL_COLS): + B[tid, i] = dst_row[i] + + with pytest.raises(Exception, match="not the full 128"): + _sl_compile(kernel) + + +def test_cast_wg_accepts_wg_level_layout(): + """Tx.wg.cast on a wg-level (tid_in_wg-distributed) layout compiles.""" + + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + T.device_entry() + _bx = T.cta_id([1]) + _wg = T.warpgroup_id([1]) + tid = T.thread_id_in_wg([_SL_ROWS]) + src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)])) + dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)])) + src_row = src.local(_SL_COLS) + for i in T.serial(_SL_COLS): + src_row[i] = A[tid, i] + Tx.wg.cast(dst, src) + dst_row = dst.local(_SL_COLS) + for i in T.serial(_SL_COLS): + B[tid, i] = dst_row[i] + + _sl_compile(kernel) + + +def test_cast_thread_accepts_local_view(): + """thread scope is exempt: a thread-axis-free local tile still compiles.""" + + @T.prim_func + def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + T.device_entry() + _bx = T.cta_id([1]) + tx_var = T.thread_id([_SL_ROWS]) + src = T.alloc_buffer((_SL_COLS,), "float32", scope="local", layout=TileLayout(S[(_SL_COLS,)])) + dst = T.alloc_buffer((_SL_COLS,), "float16", scope="local", layout=TileLayout(S[(_SL_COLS,)])) + for i in T.serial(_SL_COLS): + src[i] = A[tx_var, i] + Tx.cast(dst, src) + for i in T.serial(_SL_COLS): + B[tx_var, i] = dst[i] + + _sl_compile(kernel) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py index e0a270e7091a..32ac00e39d5f 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py @@ -32,6 +32,7 @@ from tvm.ir.type import PointerType, PrimType from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.cuda.operator.tile_primitive.gemm_async import sf_tmem_layout from tvm.tirx.cuda.operator.tile_primitive.tma_utils import ( mma_atom_layout, @@ -167,6 +168,8 @@ def pack_sf_fp8_uint32(sf_uint8, n_total=128): return packed +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize( "task", [ @@ -293,6 +296,8 @@ def gemm_async(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_gemm_tcgen05_cta_group_1_layout_f_m64(): """M=64 MMA with C operand allocated as Layout F (datapath="F"). @@ -405,6 +410,8 @@ def gemm_layout_f(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-2, rtol=1e-2) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize( "task", [ @@ -545,6 +552,8 @@ def gemm_async(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") def test_gemm_tcgen05_cta_group_2_layout_b(): """Test cta_group=2 with Layout B (2x2 datapath, M=128 total, 64 per CTA). @@ -675,6 +684,8 @@ def gemm_async(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes") @pytest.mark.parametrize( "task", @@ -864,6 +875,8 @@ def gemm_async_fn(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle, SFA_ptr: T. np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1.0, rtol=0.15) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes") @pytest.mark.parametrize( "task", @@ -1089,6 +1102,8 @@ def gemm_async_fn(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle, SFA_ptr: T. np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1.0, rtol=0.15) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes") def test_gemm_block_scaled_nvfp4_cta_group_1(): """Test block-scaled nvfp4 GEMM with cta_group=1. @@ -1258,6 +1273,8 @@ def gemm_async_fn(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle, SFA_ptr: T. np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1.0, rtol=0.15) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes") def test_gemm_block_scaled_nvfp4_cta_group_2(): """Test block-scaled nvfp4 GEMM with cta_group=2. @@ -1462,6 +1479,8 @@ def gemm_async_fn(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle, SFA_ptr: T. np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1.0, rtol=0.15) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes") def test_gemm_block_scaled_fp8_sf_id(): """Test sf_id auto-derivation from layout for fp8 block-scaled MMA. @@ -1681,6 +1700,8 @@ def per_block_quantize_fp8(mat, block_size=32): ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize( "task", [ @@ -1960,6 +1981,8 @@ def gemm_async(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") @pytest.mark.parametrize("k_lo,k_hi", [(0, 16), (0, 32), (16, 32), (16, 48), (32, 64)]) def test_gemm_tcgen05_contiguous_kslice_partial_k(k_lo, k_hi): """A slice on the *contiguous* (K) axis of a swizzled gemm_async operand must diff --git a/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py b/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py index 67cc1e0bd6fa..0402719ba1e5 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py @@ -43,6 +43,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env # Helpers exposed by the dispatcher module for direct algorithm tests. from tvm.tirx.cuda.operator.tile_primitive.permute_layout.warp_xor_swizzle import ( @@ -167,6 +168,8 @@ def _compile_and_run(prim_func, np_inputs): return [t.numpy() for t in tensors], mod.mod.imports[0].inspect_source() +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @needs_cuda @pytest.mark.parametrize( "name, pipe, blk, dtype", @@ -231,6 +234,8 @@ def f(A: T.handle, B: T.handle): np.testing.assert_array_equal(B_flat, ref, err_msg=f"{name} stage {s}") +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @needs_cuda def test_identity_passes_through_as_copy(): """L_src == L_dst should still compile and produce a correct (identity) copy.""" @@ -255,6 +260,8 @@ def f(A: T.handle, B: T.handle): np.testing.assert_array_equal(B_out, A_np) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @needs_cuda @pytest.mark.parametrize("dtype", ["uint32", "int32", "float32"]) @pytest.mark.parametrize( diff --git a/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py b/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py index 0474ad2dc46a..9031aa4f487f 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py @@ -21,6 +21,7 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.testing import env from tvm.tirx.layout import R, S, TileLayout, laneid, wg_local_layout @@ -41,6 +42,8 @@ ((32, 32), (32,), (-1,), (1, 1), (2,), (5, 8), (5,)), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["sum", "max", "min"]) @pytest.mark.parametrize("dtype", ["float32", "float16"]) @pytest.mark.parametrize("accum", [False, True]) @@ -129,6 +132,8 @@ def test_reduction(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(ref, B.numpy()[tuple(reduce_slice_dst)], atol=atol) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("exec_scope", ["warp", "warpgroup", "thread"]) @pytest.mark.parametrize("op_type", ["sum", "max", "min"]) @pytest.mark.parametrize("accum", [False, True]) @@ -264,6 +269,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: ((2, 3, 4), (3, 4), (0,)), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["sum", "max", "min"]) @pytest.mark.parametrize("accum", [False, True]) def test_reduction_local_thread_wise(src_shape, dst_shape, axes, op_type, accum): @@ -367,6 +374,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: ((4, 8), (1, 8), (1,), False, None), ], ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["sum", "max", "min"]) def test_reduction_local_view_basic(inner_dims, dst_dims, axes, accum, slice_end, op_type): """Test view-based local reduction with simple purely-local layouts.""" @@ -484,6 +493,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(ref, B.numpy(), atol=1e-5) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("n_groups, n_warps", [(1, 1), (1, 4), (2, 8)]) @pytest.mark.parametrize("op_type", ["sum", "max", "min"]) @pytest.mark.parametrize("dtype", ["float32", "float16"]) @@ -616,6 +627,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B_ref, B.numpy(), atol=atol) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("reduction_len", [8, 16, 64, 128, 256, 7, 10, 15, 100]) @pytest.mark.parametrize("op_type", ["max", "min"]) @pytest.mark.parametrize("accum", [False, True]) @@ -685,6 +698,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B_ref, B.numpy()[0], atol=1e-5) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("reduction_len", [8, 16, 64, 128, 256, 9, 17, 63, 65, 100]) @pytest.mark.parametrize("accum", [False, True]) def test_reduction_local_optimized_packed_add_sum(reduction_len, accum): @@ -746,6 +761,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B_ref, B.numpy()[0], atol=1e-4) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["sum", "max"]) @pytest.mark.parametrize("dtype", ["float32", "float16"]) def test_reduction_op_warp_shuffle(op_type, dtype): @@ -807,6 +824,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B_ref, B.numpy(), atol=atol) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_type", ["sum", "max"]) @pytest.mark.parametrize("dtype", ["float32", "float16"]) def test_reduction_op_warp_shuffle_multi_elem(op_type, dtype): @@ -875,6 +894,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B_ref, B.numpy(), atol=atol) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_reduction_warp_shuffle_multi_warp_loop(): """Test intra-warp + cross-warp reduction via T.sum in a for loop with multiple warps. @@ -951,6 +972,8 @@ def test_func(A_ptr: T.handle, B_ptr: T.handle) -> None: tvm.testing.assert_allclose(B_ref, B_dev.numpy(), atol=1e-3) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") @pytest.mark.parametrize("op_name", ["sum", "max"]) def test_reduction_warpgroup_wg_local_layout(op_name): rows, cols = 128, 16 diff --git a/tests/python/tirx/test_buffer_print.py b/tests/python/tirx/test_buffer_print.py index 211f4d390313..dbd0da8f849a 100644 --- a/tests/python/tirx/test_buffer_print.py +++ b/tests/python/tirx/test_buffer_print.py @@ -18,10 +18,12 @@ import re import numpy as np +import pytest import tvm import tvm.testing from tvm.script import tirx as T +from tvm.testing import env def generate_random_data(shape, dtype): @@ -181,6 +183,8 @@ def verify_cuda_code_string(func, expected_var_name, expected_string_literal): ) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_print(): DEV = tvm.cuda() target = tvm.target.Target("cuda") diff --git a/tests/python/tirx/test_control_flow.py b/tests/python/tirx/test_control_flow.py index 1f905bd03cc9..9085c2b0213b 100644 --- a/tests/python/tirx/test_control_flow.py +++ b/tests/python/tirx/test_control_flow.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm from tvm.script import tirx as T +from tvm.testing import env def run_test_break_continue(func, shape, expected): @@ -32,6 +34,8 @@ def run_test_break_continue(func, shape, expected): np.testing.assert_allclose(arr.numpy(), expected) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_break_continue1(): # fmt: off @T.prim_func @@ -53,6 +57,8 @@ def func(A_ptr: T.handle): run_test_break_continue(func, (10,), expected) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_break_continue2(): # fmt: off @T.prim_func @@ -79,6 +85,8 @@ def func(A_ptr: T.handle): run_test_break_continue(func, (9,), expected) +@pytest.mark.gpu +@pytest.mark.skipif(not env.has_cuda(), reason="need cuda") def test_break_continue3(): # fmt: off @T.prim_func diff --git a/tests/python/tirx/test_layout.py b/tests/python/tirx/test_layout.py index e3711cb00cd2..0dcf212ce271 100644 --- a/tests/python/tirx/test_layout.py +++ b/tests/python/tirx/test_layout.py @@ -1733,5 +1733,40 @@ def test_slice_single_shard_skips_defensive_floormod(): # we just assert offset is non-empty and structurally sane (not None). +def test_slice_tcgen05_frag_layout_scope_consistent(): + """Slicing a wid_in_wg+laneid frag layout (tcgen05 16x256b) must stay + scope-consistent: the sliced result canonicalizes to a single tid_in_wg + chain over the full 128 threads (regression for the per-group-fusion bug). + """ + frag = TileLayout( + S[(4, 2, 2, 8, 4, 4, 2) : (1 @ wid_in_wg, 16, 2, 4 @ laneid, 4, 1 @ laneid, 1)] + ) + + def thread_chain(layout): + canon = layout.canonicalize() + names = {it.axis.name for it in canon.shard if it.axis.is_thread()} + titers = sorted( + ((int(it.stride), int(it.extent)) for it in canon.shard if it.axis.is_thread()), + ) + running = 1 + for stride, extent in titers: + assert stride == running, f"non-contiguous thread chain: {titers}" + running *= extent + return names, running + + with tvm.target.Target("cuda"): + # Full-region slice and a column sub-slice must both canonicalize to a + # single tid_in_wg chain covering all 128 warpgroup threads. + full = frag.slice([128, 32], [(0, 128), (0, 32)]) + names, total = thread_chain(full) + assert names == {"tid_in_wg"}, names + assert total == 128, total + + col = frag.slice([128, 32], [(0, 128), (16, 32)]) + names_c, total_c = thread_chain(col) + assert names_c == {"tid_in_wg"}, names_c + assert total_c == 128, total_c + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/scripts/task_python_unittest.sh b/tests/scripts/task_python_unittest.sh index ec052281ad11..15bb51bdf73d 100755 --- a/tests/scripts/task_python_unittest.sh +++ b/tests/scripts/task_python_unittest.sh @@ -55,6 +55,7 @@ TEST_FILES=( "tirx-analysis" "tirx-base" "tirx-transform" + "tirx" "tvmscript" "relax" ) From bce6ebcde4ceee6ad3cff22a5d95a91c30ed6e95 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 16 Jun 2026 07:30:13 -0400 Subject: [PATCH 09/23] [Relax][TensorRT] Update TensorRT runtime to 10 (#19789) This pr fixes #19609. TensorRT 10 removed a large set of APIs that the Relax TensorRT BYOC integration relied on, so it failed to compile against TRT >= 10. Port the runtime and codegen to the TRT10 API and require TensorRT >= 10: - Lifetime: obj->destroy() -> delete (destroy() removed in TRT10). - Builder: drop implicit-batch mode (networks are always explicit-batch via createNetworkV2(0); setMaxBatchSize removed); setMaxWorkspaceSize -> setMemoryPoolLimit(kWORKSPACE); buildEngineWithConfig -> buildSerializedNetwork + deserializeCudaEngine, keeping the IRuntime alive alongside the engine. - Execution: the binding-index model (getNbBindings / getBindingIndex / setBindingDimensions / execute / executeV2) -> the named-tensor model (getNbIOTensors / setInputShape / setTensorAddress / enqueueV3); deserializeCudaEngine drops the trailing IPluginFactory* argument. - Layers: addConvolution / addPooling / addDeconvolution / addPadding -> the *Nd variants; set{Stride,Dilation} -> *Nd; IFullyConnectedLayer / addFullyConnected removed -> dense rebuilt with addConstant + addMatrixMultiply. - Add a build-time guard that emits a clear error on TensorRT < 10. Also fix pre-existing issues that prevented this path from running end-to-end: the runtime had drifted from the current tvm-ffi API (TVMTensorCopyToBytes / TVMGetLastError, VectorToTrtDims over ffi::Array, a stale `override` on the destructor), and the conv converters read a Relay-era "channels" attribute that Relax does not emit (output channels are now derived from the kernel shape). All tests are verified correct locally. This pr barely includes api updates and there is no new parts added --- src/relax/backend/contrib/tensorrt/codegen.cc | 3 +- .../contrib/tensorrt/tensorrt_builder.cc | 129 +++++------ .../extra/contrib/tensorrt/tensorrt_builder.h | 18 +- .../contrib/tensorrt/tensorrt_calibrator.h | 5 +- .../extra/contrib/tensorrt/tensorrt_ops.cc | 161 ++++++-------- .../extra/contrib/tensorrt/tensorrt_ops.h | 9 +- .../contrib/tensorrt/tensorrt_runtime.cc | 138 +++++++----- .../extra/contrib/tensorrt/tensorrt_utils.h | 21 +- tests/python/relax/test_codegen_tensorrt.py | 204 +++++++++++++++++- 9 files changed, 436 insertions(+), 252 deletions(-) diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc b/src/relax/backend/contrib/tensorrt/codegen.cc index 7fa6d48bdc24..07ba1c81e653 100644 --- a/src/relax/backend/contrib/tensorrt/codegen.cc +++ b/src/relax/backend/contrib/tensorrt/codegen.cc @@ -61,7 +61,8 @@ struct TensorRTCompilerConfigNode : public ffi::Object { "TensorRT version as (major, minor, patch).", refl::DefaultValue(ffi::Array({6, 0, 1}))) .def_ro("use_implicit_batch", &TensorRTCompilerConfigNode::use_implicit_batch, - "Use implicit batch", refl::DefaultValue(true)) + "Use implicit batch (removed in TensorRT 10; networks are always explicit-batch)", + refl::DefaultValue(false)) .def_ro("max_workspace_size", &TensorRTCompilerConfigNode::max_workspace_size, "Max workspace size", refl::DefaultValue(size_t(1) << 30)) .def_ro("remove_no_mac_subgraphs", &TensorRTCompilerConfigNode::remove_no_mac_subgraphs, diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc index 4caa8e383e15..f0c2a26b2e66 100644 --- a/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc @@ -40,36 +40,24 @@ namespace contrib { TensorRTBuilder::TensorRTBuilder(TensorRTLogger* logger, const std::vector& data_entry, - size_t max_workspace_size, bool use_implicit_batch, bool use_fp16, - int batch_size, nvinfer1::IInt8Calibrator* calibrator) - : data_entry_(data_entry), + size_t max_workspace_size, bool use_fp16, + nvinfer1::IInt8Calibrator* calibrator) + : trt_logger_(logger), + data_entry_(data_entry), max_workspace_size_(max_workspace_size), - use_implicit_batch_(use_implicit_batch), use_fp16_(use_fp16), use_int8_(false), - batch_size_(batch_size), calibrator_(calibrator) { // Create TRT builder and network. - builder_ = nvinfer1::createInferBuilder(*logger); + builder_ = nvinfer1::createInferBuilder(*trt_logger_); -#if TRT_VERSION_GE(6, 0, 1) - // Use INetworkV2. - auto flags = - 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); - if (use_implicit_batch_) { - flags = 0U; - builder_->setMaxBatchSize(batch_size_); - } + // TensorRT 10 removed implicit-batch mode and the kEXPLICIT_BATCH creation flag; every network is + // explicit-batch, so the batch dimension is simply dimension 0 of each binding and is varied + // through optimization profiles rather than IBuilder::setMaxBatchSize. if (calibrator_ != nullptr) { use_int8_ = true; } - network_ = builder_->createNetworkV2(flags); -#else - builder_->setMaxBatchSize(batch_size_); - builder_->setMaxWorkspaceSize(max_workspace_size_); - builder_->setFp16Mode(use_fp16_); - network_ = builder_->createNetwork(); -#endif + network_ = builder_->createNetworkV2(0U); } nvinfer1::DataType DLDataType2NVDataType(DLDataType data_type) { @@ -87,10 +75,7 @@ void TensorRTBuilder::AddInput(int nid, uint32_t entry_id, const JSONGraphNode& for (size_t i = 0; i < shapes.size(); ++i) { const std::string name = node_name + "_" + std::to_string(i); auto shape = shapes[i]; - // Remove batch dim when not in explicit batch mode. - if (use_implicit_batch_ && shape.size() > 1) { - shape.erase(shape.begin()); - } + // TensorRT 10 is always explicit-batch: keep the full shape including the batch dimension. nvinfer1::Dims dims = VectorToTrtDims(shape); auto input_tensor = network_->addInput(name.c_str(), DLDataType2NVDataType(dtypes[i]), dims); node_output_map_[nid].push_back(TensorRTOpInput(input_tensor)); @@ -168,11 +153,10 @@ void TensorRTBuilder::AddLayer(int nid, const JSONGraphNode& node) { } TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { - // Process graph to create INetworkDefinition. -// Build engine. -#if TRT_VERSION_GE(6, 0, 1) + // Build engine. config_ = builder_->createBuilderConfig(); - config_->setMaxWorkspaceSize(max_workspace_size_); + // TensorRT 10 replaced IBuilderConfig::setMaxWorkspaceSize with a tunable memory pool. + config_->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); if (use_fp16_) { config_->setFlag(nvinfer1::BuilderFlag::kFP16); } @@ -184,40 +168,48 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { LOG(INFO) << "config finishes setting up calibrator as INT8 mode ... "; } - // Add profiles. - if (!use_implicit_batch_) { - auto profile = builder_->createOptimizationProfile(); - for (int i = 0; i < network_->getNbInputs(); ++i) { - auto name = network_->getInput(i)->getName(); - const uint32_t entry_id = entry_id_map_[name]; - std::vector shape(data_entry_[entry_id]->shape, - data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim); - auto dims = VectorToTrtDims(shape); + // Every network is explicit-batch in TRT10, so always add an optimization profile that pins each + // input to its concrete shape (with a minimum batch of 1 for dynamic batch dimensions). + auto profile = builder_->createOptimizationProfile(); + for (int i = 0; i < network_->getNbInputs(); ++i) { + auto name = network_->getInput(i)->getName(); + const uint32_t entry_id = entry_id_map_[name]; + std::vector shape(data_entry_[entry_id]->shape, + data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim); + auto dims = VectorToTrtDims(shape); - profile->setDimensions(name, nvinfer1::OptProfileSelector::kOPT, dims); - profile->setDimensions(name, nvinfer1::OptProfileSelector::kMAX, dims); - // Set minimum batch size to 1 when dynamic batching is used. - if (network_->getInput(i)->getDimensions().nbDims >= 1 && - network_->getInput(i)->getDimensions().d[0] == -1) { - dims.d[0] = 1; - } - profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, dims); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kOPT, dims); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kMAX, dims); + // The network inputs are built with static shapes, so the profile must match them exactly; only + // lower kMIN for a genuinely dynamic (-1) leading dimension. + if (network_->getInput(i)->getDimensions().nbDims >= 1 && + network_->getInput(i)->getDimensions().d[0] == -1) { + dims.d[0] = 1; } - config_->addOptimizationProfile(profile); + profile->setDimensions(name, nvinfer1::OptProfileSelector::kMIN, dims); } - nvinfer1::ICudaEngine* engine = builder_->buildEngineWithConfig(*network_, *config_); -#else - nvinfer1::ICudaEngine* engine = builder_->buildCudaEngine(*network_); -#endif - TVM_FFI_ICHECK_EQ(engine->getNbBindings(), - network_input_names_.size() + network_output_names_.size()); + config_->addOptimizationProfile(profile); + + // TensorRT 10 removed buildEngineWithConfig; build a serialized engine and deserialize it through + // an IRuntime that is kept alive alongside the engine (TensorRTEngineAndContext::runtime). + nvinfer1::IHostMemory* plan = builder_->buildSerializedNetwork(*network_, *config_); + TVM_FFI_ICHECK(plan) << "Failed to build TensorRT serialized network."; + nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(*trt_logger_); + nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(plan->data(), plan->size()); + delete plan; + if (engine == nullptr) { + delete runtime; + TVM_FFI_THROW(InternalError) << "Failed to deserialize the TensorRT engine."; + } + TVM_FFI_ICHECK_EQ( + engine->getNbIOTensors(), + static_cast(network_input_names_.size() + network_output_names_.size())); nvinfer1::IExecutionContext* context = engine->createExecutionContext(); CleanUp(); - TVM_FFI_ICHECK(engine); TVM_FFI_ICHECK(context); - return {engine, context, network_input_names_, network_output_names_}; + return {runtime, engine, context, network_input_names_, network_output_names_}; } nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr, @@ -236,10 +228,9 @@ nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr, } weight.count = count; weight.values = new float[count]; - TVM_FFI_ICHECK_EQ(TVMTensorCopyToBytes(const_cast(dptr), - const_cast(weight.values), weight_bytes), - 0) - << TVMGetLastError(); + // Tensor::CopyToBytes throws on failure (the old C API TVMTensorCopyToBytes/TVMGetLastError + // were removed during the tvm-ffi refactor). + Tensor::CopyToBytes(dptr, const_cast(weight.values), weight_bytes); trt_weights_.push_back(weight); return weight; } @@ -247,35 +238,25 @@ nvinfer1::Weights TensorRTBuilder::GetDLTensorAsWeights(const DLTensor* dptr, nvinfer1::ITensor* TensorRTBuilder::GetInputAsTensor(const TensorRTOpInput& input) { if (input.type == kTensor) return input.tensor; auto shape = input.weight_shape; - // Remove batch dim when not in explicit batch mode. - // Example: - // x = dims (1, 32, 224, 224) which becomes TRT Dims (32, 224, 224) - // y = dims (1, 32) - // z = add(x, y) - // y needs to have TRT dims (32,), otherwise broadcasting will result in z having - // TRT Dims(1, 32, 224, 224) when it should be (32, 224, 224). - if (use_implicit_batch_ && shape.size() > 1 && shape[0] == 1) { - shape.erase(shape.begin()); - } + // TensorRT 10 is always explicit-batch, so the constant keeps its full shape. return network_->addConstant(VectorToTrtDims(shape), input.weight)->getOutput(0); } void TensorRTBuilder::CleanUp() { + // TensorRT 10 removed obj->destroy(); objects are released with the delete operator. VLOG(1) << "Destroying TensorRT network"; TVM_FFI_ICHECK(network_); - network_->destroy(); + delete network_; network_ = nullptr; -#if TRT_VERSION_GE(6, 0, 1) VLOG(1) << "Destroying TensorRT config"; TVM_FFI_ICHECK(config_); - config_->destroy(); + delete config_; config_ = nullptr; -#endif VLOG(1) << "Destroying TensorRT builder"; TVM_FFI_ICHECK(builder_); - builder_->destroy(); + delete builder_; builder_ = nullptr; VLOG(1) << "Destroying TensorRT weights"; diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_builder.h b/src/runtime/extra/contrib/tensorrt/tensorrt_builder.h index 96905598737c..108f56b9f32f 100644 --- a/src/runtime/extra/contrib/tensorrt/tensorrt_builder.h +++ b/src/runtime/extra/contrib/tensorrt/tensorrt_builder.h @@ -48,6 +48,9 @@ using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry; * perform inference. */ struct TensorRTEngineAndContext { + // TensorRT 10 builds a serialized engine which is then deserialized through an IRuntime. The + // runtime must outlive the engine it produced, so it is owned alongside the engine/context. + nvinfer1::IRuntime* runtime = nullptr; nvinfer1::ICudaEngine* engine = nullptr; nvinfer1::IExecutionContext* context = nullptr; std::vector inputs; @@ -67,12 +70,10 @@ class TensorRTBuilder { * \brief Create TensorRT builder. * \param logger TensorRT logger to use for errors and warnings. * \param max_workspace_size Workspace size parameter for TensorRT engine build phase. - * \param use_implicit_batch Whether to use implicit batch mode (default) * \param use_fp16 Whether to automatically convert a model to fp16 - * \param batch_size If use_implicit_batch, */ TensorRTBuilder(TensorRTLogger* logger, const std::vector& data_entry, - size_t max_workspace_size, bool use_implicit_batch, bool use_fp16, int batch_size, + size_t max_workspace_size, bool use_fp16, nvinfer1::IInt8Calibrator* calibrator = nullptr); /*! @@ -124,13 +125,14 @@ class TensorRTBuilder { /*! \brief Maps a node to its outputs. */ std::unordered_map> node_output_map_; + /*! \brief TensorRT logger, used to create the builder and the deserialization runtime. */ + TensorRTLogger* trt_logger_ = nullptr; + /*! \brief TensorRT builder. */ nvinfer1::IBuilder* builder_ = nullptr; -#if TRT_VERSION_GE(6, 0, 1) /*! \brief TensorRT builder config. */ nvinfer1::IBuilderConfig* config_ = nullptr; -#endif /*! \brief TensorRT network definition. */ nvinfer1::INetworkDefinition* network_ = nullptr; @@ -147,18 +149,12 @@ class TensorRTBuilder { /*! \brief Max workspace size in bytes for TRT. */ size_t max_workspace_size_; - /*! \brief Whether to use implicit batch mode. */ - bool use_implicit_batch_; - /*! \brief Whether to automatically convert model to 16-bit floating point precision. */ bool use_fp16_; /*! \brief whether to automatically convert model to int8 precision */ bool use_int8_; - /*! \brief Batch size to optimize for. */ - int batch_size_; - /*! \brief Input names. */ std::vector network_input_names_; diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_calibrator.h b/src/runtime/extra/contrib/tensorrt/tensorrt_calibrator.h index 408d50cc7e08..aa10d8f0d9df 100755 --- a/src/runtime/extra/contrib/tensorrt/tensorrt_calibrator.h +++ b/src/runtime/extra/contrib/tensorrt/tensorrt_calibrator.h @@ -123,7 +123,10 @@ class TensorRTCalibrator : public nvinfer1::IInt8EntropyCalibrator2 { const int num_inputs = data_sizes_[0].size(); buffers_.assign(num_inputs, nullptr); for (int i = 0; i < num_inputs; ++i) { - TVM_FFI_CHECK_CUDA_ERROR(cudaMalloc(&buffers_[i], data_sizes_[0][i] * sizeof(float))); + // data_sizes_ holds the per-sample element count; getBatch() copies a full batch + // (batch_size_ * per-sample) into each buffer, so the device buffer must be sized to match. + TVM_FFI_CHECK_CUDA_ERROR( + cudaMalloc(&buffers_[i], batch_size_ * data_sizes_[0][i] * sizeof(float))); } } }; diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc index f8463cb50e65..d3e68778fde9 100644 --- a/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc @@ -252,11 +252,16 @@ class Conv1DOpConverter : public TensorRTOpConverter { auto dilation = params->node.GetAttr>("dilation"); auto padding = params->node.GetAttr>("padding"); int groups = static_cast(params->node.GetAttr("groups")); + // Relax conv attrs carry no "channels" field (unlike Relay); the number of output channels is + // the first dimension of the OIHW/OIW kernel. int channels = weight_shape[0]; - channels = static_cast(params->node.GetAttr("channels")); auto shuffle_layer = params->network->addShuffle(*input_tensor); - std::vector new_shape = {input_dims[0], input_dims[1], 1}; + // Emulate a 1D convolution with a 2D convolution by appending a trailing unit spatial + // dimension (NCW -> NCW1). In explicit-batch mode (TensorRT 10) input_dims already includes the + // batch dimension, so derive the reshape from the full input rank instead of hard-coding it. + std::vector new_shape(input_dims); + new_shape.push_back(1); shuffle_layer->setReshapeDimensions(VectorToTrtDims(new_shape)); input_tensor = shuffle_layer->getOutput(0); @@ -265,21 +270,22 @@ class Conv1DOpConverter : public TensorRTOpConverter { nvinfer1::Weights bias{weight_type, nullptr, 0}; - auto conv_layer = params->network->addConvolution(*input_tensor, channels, kernel_size, - params->inputs.at(1).weight, bias); + auto conv_layer = params->network->addConvolutionNd(*input_tensor, channels, kernel_size, + params->inputs.at(1).weight, bias); TVM_FFI_ICHECK(conv_layer != nullptr); - conv_layer->setPadding(nvinfer1::DimsHW(static_cast(padding[0]), 0)); + conv_layer->setPaddingNd(nvinfer1::DimsHW(static_cast(padding[0]), 0)); TVM_FFI_ICHECK_EQ(strides.size(), 1); const auto trt_strides = nvinfer1::DimsHW(static_cast(strides[0]), 1); - conv_layer->setStride(trt_strides); + conv_layer->setStrideNd(trt_strides); TVM_FFI_ICHECK_EQ(dilation.size(), 1); const auto trt_dilation = nvinfer1::DimsHW(static_cast(dilation[0]), 1); - conv_layer->setDilation(trt_dilation); + conv_layer->setDilationNd(trt_dilation); conv_layer->setNbGroups(groups); input_tensor = conv_layer->getOutput(0); - auto conv_output_dims = TrtDimsToVector(input_tensor->getDimensions()); - std::vector back_shape = {0, 0}; + // Drop the trailing unit dimension (NOW1 -> NOW); 0 copies the corresponding input dimension, + // so the number of leading dims to keep matches the original input rank. + std::vector back_shape(input_dims.size(), 0); auto shuffle_back_layer = params->network->addShuffle(*input_tensor); shuffle_back_layer->setReshapeDimensions(VectorToTrtDims(back_shape)); params->outputs.push_back(shuffle_back_layer->getOutput(0)); @@ -304,47 +310,36 @@ class Conv2DOpConverter : public TensorRTOpConverter { auto dilation = params->node.GetAttr>("dilation"); auto padding = params->node.GetAttr>("padding"); int groups = static_cast(params->node.GetAttr("groups")); + // Relax conv attrs carry no "channels" field (unlike Relay); the number of output channels is + // the first dimension of the OIHW/OIW kernel. int channels = weight_shape[0]; - channels = static_cast(params->node.GetAttr("channels")); // TRT conv2d op doesn't support asymmetric padding before 5.1, so we // workaround by adding a padding layer before the pooling op. nvinfer1::DimsHW prepadding, postpadding; bool use_asymmetric_padding; GetPadding(padding, &use_asymmetric_padding, &prepadding, &postpadding); -#if !TRT_VERSION_GE(5, 1, 5) - if (use_asymmetric_padding) { - auto pad_layer = params->network->addPadding(*input_tensor, prepadding, postpadding); - TVM_FFI_ICHECK(pad_layer != nullptr); - input_tensor = pad_layer->getOutput(0); - // No need for conv op to do any padding. - use_asymmetric_padding = false; - prepadding = nvinfer1::DimsHW(0, 0); - } -#endif const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], weight_shape[3]); const nvinfer1::DataType weight_type = params->inputs.at(1).weight.type; nvinfer1::Weights bias{weight_type, nullptr, 0}; - auto conv_layer = params->network->addConvolution(*input_tensor, channels, kernel_size, - params->inputs.at(1).weight, bias); + auto conv_layer = params->network->addConvolutionNd(*input_tensor, channels, kernel_size, + params->inputs.at(1).weight, bias); TVM_FFI_ICHECK(conv_layer != nullptr); conv_layer->setName(params->LayerName().c_str()); if (use_asymmetric_padding) { -#if TRT_VERSION_GE(5, 1, 5) conv_layer->setPrePadding(prepadding); conv_layer->setPostPadding(postpadding); -#endif } else { - conv_layer->setPadding(prepadding); + conv_layer->setPaddingNd(prepadding); } TVM_FFI_ICHECK_EQ(strides.size(), 2); const auto trt_strides = nvinfer1::DimsHW(static_cast(strides[0]), static_cast(strides[1])); - conv_layer->setStride(trt_strides); + conv_layer->setStrideNd(trt_strides); TVM_FFI_ICHECK_EQ(dilation.size(), 2); const auto trt_dilation = nvinfer1::DimsHW(static_cast(dilation[0]), static_cast(dilation[1])); - conv_layer->setDilation(trt_dilation); + conv_layer->setDilationNd(trt_dilation); conv_layer->setNbGroups(groups); params->outputs.push_back(conv_layer->getOutput(0)); } @@ -374,7 +369,8 @@ class Conv3DOpConverter : public TensorRTOpConverter { bool use_asymmetric_padding; GetPadding3D(padding, &use_asymmetric_padding, &prepadding, &postpadding); - const int num_outputs = static_cast(params->node.GetAttr("channels")); + // Relax conv3d has no "channels" attr; output channels = weight_shape[0] (OIDHW kernel). + const int num_outputs = static_cast(weight_shape[0]); const auto kernel_size = nvinfer1::Dims3(weight_shape[2], weight_shape[3], weight_shape[4]); const nvinfer1::DataType weight_type = params->inputs.at(1).weight.type; nvinfer1::Weights bias{weight_type, nullptr, 0}; @@ -410,31 +406,27 @@ class DenseOpConverter : public TensorRTOpConverter { void Convert(TensorRTOpConverterParams* params) const { auto input_tensor = params->inputs.at(0).tensor; - auto input_dims = TrtDimsToVector(input_tensor->getDimensions()); - TVM_FFI_ICHECK(input_dims.size() > 0 && input_dims.size() <= 3); - const size_t required_rank = TRT_HAS_IMPLICIT_BATCH(params) ? 3 : 4; - const bool need_reshape_on_input = input_dims.size() != required_rank; - if (need_reshape_on_input) { - // Add dims of size 1 until rank is required_rank. - std::vector new_shape(input_dims); - while (new_shape.size() < required_rank) new_shape.insert(new_shape.end(), 1); - input_tensor = Reshape(params, input_tensor, new_shape); - } - // Weights are in KC format. + // Weights are in KC (out_units x in_features) format. TVM_FFI_ICHECK_EQ(params->inputs.at(1).weight_shape.size(), 2); - const int num_units = params->inputs.at(1).weight_shape[0]; - const nvinfer1::DataType weight_type = params->inputs.at(1).weight.type; - nvinfer1::Weights bias{weight_type, nullptr, 0}; - nvinfer1::IFullyConnectedLayer* fc_layer = params->network->addFullyConnected( - *input_tensor, num_units, params->inputs.at(1).weight, bias); - TVM_FFI_ICHECK(fc_layer != nullptr); - auto output_tensor = fc_layer->getOutput(0); - if (need_reshape_on_input) { - // Remove added dims. - input_dims[input_dims.size() - 1] = num_units; - output_tensor = Reshape(params, output_tensor, input_dims); - } - params->outputs.push_back(output_tensor); + // addMatrixMultiply requires the input to have at least 2 dimensions (rows x K); the old + // FullyConnected path padded the rank, so guard explicitly now that it is gone. + TVM_FFI_ICHECK_GE(input_tensor->getDimensions().nbDims, 2) + << "TensorRT dense expects an input of rank >= 2 (got " + << input_tensor->getDimensions().nbDims << ")"; + // TensorRT 10 removed IFullyConnectedLayer/addFullyConnected. Implement dense as a matrix + // multiply: out[.., O] = in[.., K] * weightᵀ, with weight a constant of shape [O, K]. + // IMatrixMultiplyLayer contracts the last dim of `input` (K) with the last dim of the + // transposed weight (also K) and broadcasts the remaining leading dimensions, which matches + // nn.dense semantics for any input rank >= 2 without the rank-padding reshape FC required. + auto* weight_tensor = params->network + ->addConstant(VectorToTrtDims(params->inputs.at(1).weight_shape), + params->inputs.at(1).weight) + ->getOutput(0); + auto* matmul_layer = params->network->addMatrixMultiply( + *input_tensor, nvinfer1::MatrixOperation::kNONE, *weight_tensor, + nvinfer1::MatrixOperation::kTRANSPOSE); + TVM_FFI_ICHECK(matmul_layer != nullptr); + params->outputs.push_back(matmul_layer->getOutput(0)); } }; @@ -666,33 +658,18 @@ class PoolingOpConverter : public TensorRTOpConverter { GetPadding(padding, &use_asymmetric_padding, &prepadding, &postpadding); bool ceil_mode = static_cast(params->node.GetAttr("ceil_mode")); -// TRT pooling op doesn't support asymmetric padding before 5.1, so we -// workaround by adding a padding layer before the pooling op. -#if !TRT_VERSION_GE(5, 1, 5) - if (use_asymmetric_padding) { - auto pad_layer = params->network->addPadding(*input, prepadding, postpadding); - TVM_FFI_ICHECK(pad_layer != nullptr); - input = pad_layer->getOutput(0); - // No need for pooling op to do any padding. - use_asymmetric_padding = false; - prepadding = nvinfer1::DimsHW(0, 0); - } -#endif - nvinfer1::DimsHW window_size = nvinfer1::DimsHW(static_cast(pool_size[0]), static_cast(pool_size[1])); - auto pool_layer = params->network->addPooling(*input, it->second, window_size); + auto pool_layer = params->network->addPoolingNd(*input, it->second, window_size); TVM_FFI_ICHECK(pool_layer != nullptr); nvinfer1::DimsHW trt_strides = nvinfer1::DimsHW(static_cast(strides[0]), static_cast(strides[1])); - pool_layer->setStride(trt_strides); + pool_layer->setStrideNd(trt_strides); if (use_asymmetric_padding) { -#if TRT_VERSION_GE(5, 1, 5) pool_layer->setPrePadding(prepadding); pool_layer->setPostPadding(postpadding); -#endif } else { - pool_layer->setPadding(prepadding); + pool_layer->setPaddingNd(prepadding); } if (op_name == "nn.avg_pool2d") { bool count_include_pad = static_cast(params->node.GetAttr("count_include_pad")); @@ -783,7 +760,7 @@ class GlobalPoolingOpConverter : public TensorRTOpConverter { const int h = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[1] : input_dims[2]; const int w = TRT_HAS_IMPLICIT_BATCH(params) ? input_dims[2] : input_dims[3]; auto pool_layer = - params->network->addPooling(*input_tensor, it->second, nvinfer1::DimsHW(h, w)); + params->network->addPoolingNd(*input_tensor, it->second, nvinfer1::DimsHW(h, w)); TVM_FFI_ICHECK(pool_layer != nullptr); params->outputs.push_back(pool_layer->getOutput(0)); } @@ -993,7 +970,7 @@ class Conv2DTransposeOpConverter : public TensorRTOpConverter { TVM_FFI_ICHECK_EQ(params->node.GetAttr("data_layout"), "NCHW"); TVM_FFI_ICHECK(params->node.GetAttr("out_layout") == "" || params->node.GetAttr("out_layout") == "NCHW"); - TVM_FFI_ICHECK_EQ(params->node.GetAttr("kernel_layout"), "OIHW"); + TVM_FFI_ICHECK_EQ(params->node.GetAttr("kernel_layout"), "IOHW"); auto dilation = params->node.GetAttr>("dilation"); TVM_FFI_ICHECK(static_cast(dilation[0]) == 1 && static_cast(dilation[1]) == 1); auto strides = params->node.GetAttr>("strides"); @@ -1006,35 +983,26 @@ class Conv2DTransposeOpConverter : public TensorRTOpConverter { nvinfer1::DimsHW prepadding, postpadding; bool use_asymmetric_padding; GetPadding(padding, &use_asymmetric_padding, &prepadding, &postpadding); -#if !TRT_VERSION_GE(5, 1, 5) - if (use_asymmetric_padding) { - auto pad_layer = params->network->addPadding(*input_tensor, prepadding, postpadding); - TVM_FFI_ICHECK(pad_layer != nullptr); - input_tensor = pad_layer->getOutput(0); - // No need for conv op to do any padding. - use_asymmetric_padding = false; - prepadding = nvinfer1::DimsHW(0, 0); - } -#endif - const int num_outputs = static_cast(params->node.GetAttr("channels")); + // Relax conv2d_transpose uses an IOHW kernel ([in, out, h, w]) by default, which is also the + // layout TensorRT's deconvolution expects, so the weight is passed through unchanged and the + // output channel count is the second kernel dimension. + const int num_outputs = static_cast(weight_shape[1]); const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], weight_shape[3]); const nvinfer1::DataType weight_type = params->inputs.at(1).weight.type; nvinfer1::Weights bias{weight_type, nullptr, 0}; - auto deconv_layer = params->network->addDeconvolution(*input_tensor, num_outputs, kernel_size, - params->inputs.at(1).weight, bias); + auto deconv_layer = params->network->addDeconvolutionNd(*input_tensor, num_outputs, kernel_size, + params->inputs.at(1).weight, bias); TVM_FFI_ICHECK(deconv_layer != nullptr); if (use_asymmetric_padding) { -#if TRT_VERSION_GE(5, 1, 5) deconv_layer->setPrePadding(prepadding); deconv_layer->setPostPadding(postpadding); -#endif } else { - deconv_layer->setPadding(prepadding); + deconv_layer->setPaddingNd(prepadding); } const auto trt_strides = nvinfer1::DimsHW(static_cast(strides[0]), static_cast(strides[1])); - deconv_layer->setStride(trt_strides); + deconv_layer->setStrideNd(trt_strides); deconv_layer->setNbGroups(groups); nvinfer1::ITensor* output = deconv_layer->getOutput(0); // Output padding. @@ -1044,7 +1012,7 @@ class Conv2DTransposeOpConverter : public TensorRTOpConverter { postpadding.w() != 0) { // Output padding for Conv2D transpose is always asymmetric and applied to post only. prepadding = nvinfer1::DimsHW(0, 0); - auto pad_layer = params->network->addPadding(*output, prepadding, postpadding); + auto pad_layer = params->network->addPaddingNd(*output, prepadding, postpadding); output = pad_layer->getOutput(0); } } @@ -1065,7 +1033,7 @@ class Conv3DTransposeOpConverter : public TensorRTOpConverter { TVM_FFI_ICHECK_EQ(params->node.GetAttr("data_layout"), "NCDHW"); TVM_FFI_ICHECK(params->node.GetAttr("out_layout") == "" || params->node.GetAttr("out_layout") == "NCDHW"); - TVM_FFI_ICHECK_EQ(params->node.GetAttr("kernel_layout"), "OIDHW"); + TVM_FFI_ICHECK_EQ(params->node.GetAttr("kernel_layout"), "IODHW"); auto dilation = params->node.GetAttr>("dilation"); TVM_FFI_ICHECK_EQ(dilation.size(), 3); TVM_FFI_ICHECK(static_cast(dilation[0]) == 1 && static_cast(dilation[1]) == 1 && @@ -1078,7 +1046,10 @@ class Conv3DTransposeOpConverter : public TensorRTOpConverter { bool use_asymmetric_padding; GetPadding3D(padding, &use_asymmetric_padding, &prepadding, &postpadding); - const int num_outputs = static_cast(params->node.GetAttr("channels")); + // Relax conv3d_transpose uses an IODHW kernel ([in, out, d, h, w]) by default, matching the + // layout TensorRT's deconvolution expects, so the weight passes through unchanged and the + // output channel count is the second kernel dimension. + const int num_outputs = static_cast(weight_shape[1]); const auto kernel_size = nvinfer1::Dims3(weight_shape[2], weight_shape[3], weight_shape[4]); const nvinfer1::DataType weight_type = params->inputs.at(1).weight.type; nvinfer1::Weights bias{weight_type, nullptr, 0}; @@ -1186,7 +1157,7 @@ class PadOpConverter : public TensorRTOpConverter { nvinfer1::DimsHW(static_cast(padding_arr[0]), static_cast(padding_arr[1])); nvinfer1::DimsHW postpadding = nvinfer1::DimsHW(static_cast(padding_arr[2]), static_cast(padding_arr[3])); - auto pad_layer = params->network->addPadding(*input, prepadding, postpadding); + auto pad_layer = params->network->addPaddingNd(*input, prepadding, postpadding); params->outputs.push_back(pad_layer->getOutput(0)); } }; @@ -1282,9 +1253,9 @@ class AdaptivePoolingOpConverter : public TensorRTOpConverter { const auto stride = nvinfer1::DimsHW(h / output_size.h(), w / output_size.w()); const auto window_size = nvinfer1::DimsHW(h - (output_size.h() - 1) * stride.h(), w - (output_size.w() - 1) * stride.w()); - auto pool_layer = params->network->addPooling(*input_tensor, it->second, window_size); + auto pool_layer = params->network->addPoolingNd(*input_tensor, it->second, window_size); TVM_FFI_ICHECK(pool_layer != nullptr); - pool_layer->setStride(stride); + pool_layer->setStrideNd(stride); params->outputs.push_back(pool_layer->getOutput(0)); } }; diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_ops.h b/src/runtime/extra/contrib/tensorrt/tensorrt_ops.h index 26ea40075458..5e4c30ed7f30 100644 --- a/src/runtime/extra/contrib/tensorrt/tensorrt_ops.h +++ b/src/runtime/extra/contrib/tensorrt/tensorrt_ops.h @@ -35,11 +35,10 @@ #include "NvInfer.h" #include "tensorrt_utils.h" -#if TRT_VERSION_GE(6, 0, 1) -#define TRT_HAS_IMPLICIT_BATCH(params) (params->network->hasImplicitBatchDimension()) -#else -#define TRT_HAS_IMPLICIT_BATCH(params) (true) -#endif +// TensorRT 10 removed implicit-batch mode; every network is explicit-batch. Keep the macro so the +// converters' batch-aware branches read clearly, but it is unconditionally false (and no longer +// calls the deprecated INetworkDefinition::hasImplicitBatchDimension()). +#define TRT_HAS_IMPLICIT_BATCH(params) (false) namespace tvm { namespace runtime { diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/extra/contrib/tensorrt/tensorrt_runtime.cc index 40ca760d96f2..932c52b394dc 100644 --- a/src/runtime/extra/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/extra/contrib/tensorrt/tensorrt_runtime.cc @@ -40,6 +40,9 @@ #include "../json/json_runtime.h" #ifdef TVM_GRAPH_EXECUTOR_TENSORRT +#include +#include + #include "NvInfer.h" #include "tensorrt_builder.h" #include "tensorrt_calibrator.h" @@ -125,6 +128,10 @@ class TensorRTRuntime : public JSONRuntimeBase { for (size_t i = 0; i < nodes_.size(); ++i) { if (nodes_[i].HasAttr("use_implicit_batch") && nodes_[i].HasAttr("max_workspace_size")) { use_implicit_batch_ = static_cast(nodes_[i].GetAttr("use_implicit_batch")); + if (use_implicit_batch_) { + LOG(WARNING) << "use_implicit_batch=True is ignored: TensorRT 10 removed implicit-batch " + "mode, so the engine is always built and run in explicit-batch mode."; + } // Allow max_workspace_size to be overridden at runtime. size_t runtime_max_workspace_size = support::GetEnv("TVM_TENSORRT_MAX_WORKSPACE_SIZE", size_t(0)); @@ -145,17 +152,20 @@ class TensorRTRuntime : public JSONRuntimeBase { /*! \brief Destroy engines and contexts. */ void DestroyEngines() { for (auto& it : trt_engine_cache_) { + // TensorRT 10 removed obj->destroy(); release with delete. The deserialization runtime must + // outlive the engine it produced, so delete the context, then the engine, then the runtime. VLOG(1) << "Destroying TensorRT context for function '" << it.first.first << "' (batch size " << it.first.second << ")"; - it.second.context->destroy(); + delete it.second.context; VLOG(1) << "Destroying TensorRT engine for function '" << it.first.first << "' (batch size " << it.first.second << ")"; - it.second.engine->destroy(); + delete it.second.engine; + delete it.second.runtime; } trt_engine_cache_.clear(); } - ~TensorRTRuntime() override { + ~TensorRTRuntime() { VLOG(1) << "Destroying TensorRT runtime"; DestroyEngines(); VLOG(1) << "Destroyed TensorRT runtime"; @@ -166,11 +176,13 @@ class TensorRTRuntime : public JSONRuntimeBase { auto& engine_and_context = GetOrBuildEngine(); int batch_size = GetBatchSize(); if (batch_size == 0) return; - auto engine = engine_and_context.engine; auto context = engine_and_context.context; - const int num_bindings = engine->getNbBindings(); - std::vector bindings(num_bindings, nullptr); - std::vector binding_sizes(num_bindings, 0); + + // TensorRT 10 uses named-tensor I/O (setInputShape/setTensorAddress/enqueueV3, no binding + // indices). Track input device pointers and per-sample element counts for the INT8 calibrator. + std::vector input_bindings; + std::vector input_binding_sizes; + // Setup input bindings. for (size_t i = 0; i < input_nodes_.size(); ++i) { auto nid = input_nodes_[i]; @@ -178,28 +190,28 @@ class TensorRTRuntime : public JSONRuntimeBase { for (size_t j = 0; j < nodes_[nid].GetOpShape().size(); ++j) { uint32_t eid = EntryID(nid, j); const std::string name = nodes_[nid].GetOpName() + "_" + std::to_string(j); - int binding_index = engine->getBindingIndex(name.c_str()); - TVM_FFI_ICHECK_NE(binding_index, -1); -#if TRT_VERSION_GE(6, 0, 1) - if (!use_implicit_batch_) { - std::vector shape(data_entry_[eid]->shape, - data_entry_[eid]->shape + data_entry_[eid]->ndim); - auto dims = VectorToTrtDims(shape); - TVM_FFI_ICHECK(context->setBindingDimensions(binding_index, dims)); - } -#endif + std::vector shape(data_entry_[eid]->shape, + data_entry_[eid]->shape + data_entry_[eid]->ndim); + auto dims = VectorToTrtDims(shape); + TVM_FFI_ICHECK(context->setInputShape(name.c_str(), dims)); + + void* device_ptr = nullptr; if (data_entry_[eid]->device.device_type == kDLCUDA) { - bindings[binding_index] = data_entry_[eid]->data; + device_ptr = data_entry_[eid]->data; } else { - auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); + auto device_buffer = GetOrAllocateDeviceBuffer(name, eid); device_buffer.CopyFrom(data_entry_[eid]); - bindings[binding_index] = device_buffer->data; + device_ptr = device_buffer->data; } + TVM_FFI_ICHECK(context->setTensorAddress(name.c_str(), device_ptr)); - auto dims = engine->getBindingDimensions(binding_index); + // Per-sample element count (exclude the batch dimension d[0]); the INT8 calibrator + // multiplies by the batch size itself when copying calibration data, so including the + // batch dim here would over-read the device buffer by a factor of batch_size. int num_elements = 1; - for (int i = 0; i < dims.nbDims; ++i) num_elements *= dims.d[i]; - binding_sizes[binding_index] = num_elements; + for (int k = 1; k < dims.nbDims; ++k) num_elements *= dims.d[k]; + input_bindings.push_back(device_ptr); + input_binding_sizes.push_back(static_cast(num_elements)); } } } @@ -209,7 +221,7 @@ class TensorRTRuntime : public JSONRuntimeBase { if (calibrator_ != nullptr) { LOG(INFO) << "Starting adding last " << num_calibration_batches_remaining_ << "-th batch data to the calibrator"; - calibrator_->AddBatchData(bindings, binding_sizes); + calibrator_->AddBatchData(input_bindings, input_binding_sizes); num_calibration_batches_remaining_--; } return; @@ -219,34 +231,31 @@ class TensorRTRuntime : public JSONRuntimeBase { for (size_t i = 0; i < outputs_.size(); ++i) { uint32_t eid = EntryID(outputs_[i]); const std::string& name = engine_and_context.outputs[i]; - int binding_index = engine->getBindingIndex(name.c_str()); - TVM_FFI_ICHECK_NE(binding_index, -1); + void* device_ptr = nullptr; if (data_entry_[eid]->device.device_type == kDLCUDA) { - bindings[binding_index] = data_entry_[eid]->data; + device_ptr = data_entry_[eid]->data; } else { - auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); - bindings[binding_index] = device_buffer->data; + auto device_buffer = GetOrAllocateDeviceBuffer(name, eid); + device_ptr = device_buffer->data; } + TVM_FFI_ICHECK(context->setTensorAddress(name.c_str(), device_ptr)); } -#if TRT_VERSION_GE(6, 0, 1) - if (use_implicit_batch_) { - TVM_FFI_ICHECK(context->execute(batch_size, bindings.data())) << "Running TensorRT failed."; - } else { - TVM_FFI_ICHECK(context->executeV2(bindings.data())) << "Running TensorRT failed."; - } -#else - TVM_FFI_ICHECK(context->execute(batch_size, bindings.data())) << "Running TensorRT failed."; -#endif + // Run on TVM's current CUDA stream so the engine is ordered after the inputs produced upstream + // (and to avoid TensorRT's default-stream synchronization warning). enqueueV3 is async-only in + // TRT10, so synchronize afterwards to preserve Run()'s blocking semantics. + const DLDevice& dev = data_entry_[input_var_eid_[0]]->device; + const int device_id = dev.device_type == kDLCUDA ? dev.device_id : 0; + cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); + TVM_FFI_ICHECK(context->enqueueV3(stream)) << "Running TensorRT failed."; + TVM_FFI_CHECK_CUDA_ERROR(cudaStreamSynchronize(stream)); // Copy outputs from GPU buffers if needed. for (size_t i = 0; i < outputs_.size(); ++i) { uint32_t eid = EntryID(outputs_[i]); const std::string& name = engine_and_context.outputs[i]; - int binding_index = engine->getBindingIndex(name.c_str()); - TVM_FFI_ICHECK_NE(binding_index, -1); if (data_entry_[eid]->device.device_type != kDLCUDA) { - auto device_buffer = GetOrAllocateDeviceBuffer(eid, binding_index); + auto device_buffer = GetOrAllocateDeviceBuffer(name, eid); device_buffer.CopyTo(const_cast(data_entry_[eid])); } } @@ -269,8 +278,11 @@ class TensorRTRuntime : public JSONRuntimeBase { } return false; } - // Check for engine with compatible max_batch_size. - if (batch_size <= max_batch_size_) { + // Single-engine mode: TensorRT 10 engines are explicit-batch and their optimization profile + // pins the built batch size, so a cached engine can only serve that exact batch. Require an + // exact match (otherwise a smaller batch would be rejected by setInputShape) and rebuild on any + // change. This replaces the implicit-batch "any batch <= max" reuse that TRT10 removed. + if (batch_size == max_batch_size_) { *compatible_engine_batch_size = max_batch_size_; return true; } @@ -325,8 +337,8 @@ class TensorRTRuntime : public JSONRuntimeBase { void BuildEngineFromJson(int batch_size) { const bool use_fp16 = support::GetEnv("TVM_TENSORRT_USE_FP16", false) || use_fp16_; - TensorRTBuilder builder(&logger_, data_entry_, max_workspace_size_, use_implicit_batch_, - use_fp16, batch_size, calibrator_.get()); + TensorRTBuilder builder(&logger_, data_entry_, max_workspace_size_, use_fp16, + calibrator_.get()); for (size_t i = 0; i < input_nodes_.size(); ++i) { auto nid = input_nodes_[i]; const auto& node = nodes_[nid]; @@ -372,11 +384,20 @@ class TensorRTRuntime : public JSONRuntimeBase { infile.close(); std::string serialized_engine; LoadBinaryFromFile(path, &serialized_engine); - // Deserialize engine + // Deserialize engine. TensorRT 10 dropped the trailing IPluginFactory* argument and the runtime + // must outlive the engine, so it is owned by the cached TensorRTEngineAndContext. nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger_); TensorRTEngineAndContext engine_and_context; + engine_and_context.runtime = runtime; engine_and_context.engine = - runtime->deserializeCudaEngine(&serialized_engine[0], serialized_engine.size(), nullptr); + runtime->deserializeCudaEngine(&serialized_engine[0], serialized_engine.size()); + if (engine_and_context.engine == nullptr) { + // A stale or incompatible (e.g. different TensorRT version) .plan file. Drop it and rebuild. + delete runtime; + LOG(WARNING) << "Failed to deserialize cached TensorRT engine from " << path + << "; it will be rebuilt."; + return false; + } engine_and_context.context = engine_and_context.engine->createExecutionContext(); // Load metadata namespace json = ::tvm::ffi::json; @@ -424,7 +445,7 @@ class TensorRTRuntime : public JSONRuntimeBase { trt_engine_cache_[std::make_pair(symbol_name_, batch_size)].engine->serialize(); SaveBinaryToFile(path, std::string(static_cast(serialized_engine->data()), serialized_engine->size())); - serialized_engine->destroy(); + delete serialized_engine; // Serialize metadata namespace json = ::tvm::ffi::json; json::Object meta_obj; @@ -454,26 +475,27 @@ class TensorRTRuntime : public JSONRuntimeBase { return symbol_name_ + (support::GetEnv("TVM_TENSORRT_USE_FP16", false) ? "_fp16" : "_fp32"); } - /*! \brief Retreive a GPU buffer for input or output or allocate if needed. */ - Tensor GetOrAllocateDeviceBuffer(int entry_id, int binding_index) { + /*! \brief Retreive a GPU buffer for input or output or allocate if needed. Keyed by TensorRT IO + * tensor name (TRT10 has no binding indices). */ + Tensor GetOrAllocateDeviceBuffer(const std::string& name, int entry_id) { std::vector shape(data_entry_[entry_id]->shape, data_entry_[entry_id]->shape + data_entry_[entry_id]->ndim); - if (device_buffers_.count(binding_index)) { + if (device_buffers_.count(name)) { // Buffer is already initialized. - if (shape[0] > device_buffers_[binding_index]->shape[0]) { + if (shape[0] > device_buffers_[name]->shape[0]) { // Buffer is too small. Need to allocate bigger buffer. - device_buffers_[binding_index] = + device_buffers_[name] = runtime::Tensor::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); - } else if (shape[0] < device_buffers_[binding_index]->shape[0]) { + } else if (shape[0] < device_buffers_[name]->shape[0]) { // Buffer is too large. Create view. - return device_buffers_[binding_index].CreateView(shape, data_entry_[entry_id]->dtype); + return device_buffers_[name].CreateView(shape, data_entry_[entry_id]->dtype); } } else { // Buffer not initialized yet. - device_buffers_[binding_index] = + device_buffers_[name] = runtime::Tensor::Empty(shape, data_entry_[entry_id]->dtype, {kDLCUDA, 0}); } - return device_buffers_.at(binding_index); + return device_buffers_.at(name); } void CreateInt8Calibrator(const TensorRTEngineAndContext& engine_and_context) { @@ -498,7 +520,7 @@ class TensorRTRuntime : public JSONRuntimeBase { * is not "cuda". Since TensorRT execution can only read data from GPU, we need to copy data from * the runtime device to these buffers first. These will be allocated for the highest batch size * used by all engines. */ - std::unordered_map device_buffers_; + std::unordered_map device_buffers_; /*! \brief TensorRT logger. */ TensorRTLogger logger_; diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_utils.h b/src/runtime/extra/contrib/tensorrt/tensorrt_utils.h index ab9b169f26d6..e0c06f018be4 100644 --- a/src/runtime/extra/contrib/tensorrt/tensorrt_utils.h +++ b/src/runtime/extra/contrib/tensorrt/tensorrt_utils.h @@ -30,6 +30,15 @@ #include "NvInfer.h" +// This integration targets the TensorRT 10 API. TensorRT 10 removed a large set of APIs the +// pre-TRT10 code relied on (implicit batch, binding indices, addConvolution/addPooling/addPadding, +// IFullyConnectedLayer, IBuilder::setMaxBatchSize, IBuilderConfig::setMaxWorkspaceSize, +// IExecutionContext::execute, obj->destroy(), ...). Emit a clear error instead of a flood of +// "has no member" diagnostics on older releases. +#if !defined(NV_TENSORRT_MAJOR) || NV_TENSORRT_MAJOR < 10 +#error "TVM's TensorRT runtime requires TensorRT 10.0 or newer (or set USE_TENSORRT_RUNTIME=OFF)." +#endif + // There is a conflict between cpplint and clang-format-10. // clang-format off #define TRT_VERSION_GE(major, minor, patch) \ @@ -42,18 +51,18 @@ namespace runtime { namespace contrib { /*! - * \brief Helper function to convert an vector to TRT Dims. - * \param vec Vector. + * \brief Helper function to convert a vector-like container to TRT Dims. + * \param vec A container supporting size() and operator[] (e.g. std::vector or ffi::Array). * \return TRT Dims. */ -template -inline nvinfer1::Dims VectorToTrtDims(const std::vector& vec) { +template +inline nvinfer1::Dims VectorToTrtDims(const Container& vec) { nvinfer1::Dims dims; // Dims(nbDims=0, d[0]=1) is used to represent a scalar in TRT. dims.d[0] = 1; - dims.nbDims = vec.size(); + dims.nbDims = static_cast(vec.size()); for (size_t i = 0; i < vec.size(); ++i) { - dims.d[i] = vec[i]; + dims.d[i] = static_cast(vec[i]); } return dims; } diff --git a/tests/python/relax/test_codegen_tensorrt.py b/tests/python/relax/test_codegen_tensorrt.py index 57390515d72d..5f90f826ddf9 100644 --- a/tests/python/relax/test_codegen_tensorrt.py +++ b/tests/python/relax/test_codegen_tensorrt.py @@ -114,5 +114,207 @@ def get_ref(): tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) +def _offload_and_compare(mod, params_np, patterns, data_np, rtol=1e-2, atol=1e-2): + """Offload a single-op module to TensorRT and compare against the LLVM reference. + + Each module here contains a single instance of the op under test, which both exercises the + individual converter and avoids the structurally-identical-composite deduplication that would + otherwise collapse repeated ops. + """ + ref = build_and_run(mod, [data_np, *params_np.values()], "llvm", legalize=True) + offloaded = tvm.transform.Sequential( + [ + relax.transform.BindParams("main", params_np), + relax.transform.FuseOpsByPattern(patterns), + relax.transform.MergeCompositeFunctions(), + relax.transform.RunCodegen(), + ] + )(mod) + out = build_and_run(offloaded, [data_np], "cuda") + tvm.testing.assert_allclose(out, ref, rtol=rtol, atol=atol) + + +def test_tensorrt_conv1d(): + # Regression test: explicit-batch (batch > 1) 1D convolution. The pre-TRT10 converter assumed an + # implicit batch dimension and dropped the spatial dimension under explicit batch. + @tvm.script.ir_module + class Conv1d: + @R.function + def main(data: R.Tensor((2, 8, 16), "float32"), weight: R.Tensor((4, 8, 3), "float32")): + with R.dataflow(): + out = relax.op.nn.conv1d(data, weight, padding=1) + R.output(out) + return out + + data = np.random.randn(2, 8, 16).astype("float32") + weight = np.random.randn(4, 8, 3).astype("float32") + patterns = [("tensorrt.nn.conv1d", is_op("relax.nn.conv1d")(wildcard(), wildcard()))] + _offload_and_compare(Conv1d, {"weight": weight}, patterns, data) + + +def test_tensorrt_max_pool2d(): + @tvm.script.ir_module + class MaxPool: + @R.function + def main(data: R.Tensor((2, 8, 16, 16), "float32")): + with R.dataflow(): + out = relax.op.nn.max_pool2d(data, pool_size=(2, 2), strides=(2, 2)) + R.output(out) + return out + + data = np.random.randn(2, 8, 16, 16).astype("float32") + patterns = [("tensorrt.nn.max_pool2d", is_op("relax.nn.max_pool2d")(wildcard()))] + _offload_and_compare(MaxPool, {}, patterns, data) + + +def test_tensorrt_avg_pool2d(): + @tvm.script.ir_module + class AvgPool: + @R.function + def main(data: R.Tensor((2, 8, 16, 16), "float32")): + with R.dataflow(): + out = relax.op.nn.avg_pool2d(data, pool_size=(2, 2), strides=(2, 2)) + R.output(out) + return out + + data = np.random.randn(2, 8, 16, 16).astype("float32") + patterns = [("tensorrt.nn.avg_pool2d", is_op("relax.nn.avg_pool2d")(wildcard()))] + _offload_and_compare(AvgPool, {}, patterns, data) + + +def test_tensorrt_softmax(): + @tvm.script.ir_module + class Softmax: + @R.function + def main(data: R.Tensor((2, 8, 16, 16), "float32")): + with R.dataflow(): + out = relax.op.nn.softmax(data, axis=1) + R.output(out) + return out + + data = np.random.randn(2, 8, 16, 16).astype("float32") + patterns = [("tensorrt.nn.softmax", is_op("relax.nn.softmax")(wildcard()))] + _offload_and_compare(Softmax, {}, patterns, data) + + +def test_tensorrt_sigmoid(): + @tvm.script.ir_module + class Sigmoid: + @R.function + def main(data: R.Tensor((2, 8, 16, 16), "float32")): + with R.dataflow(): + out = relax.op.sigmoid(data) + R.output(out) + return out + + data = np.random.randn(2, 8, 16, 16).astype("float32") + patterns = [("tensorrt.sigmoid", is_op("relax.sigmoid")(wildcard()))] + _offload_and_compare(Sigmoid, {}, patterns, data) + + +def test_tensorrt_tanh(): + @tvm.script.ir_module + class Tanh: + @R.function + def main(data: R.Tensor((2, 8, 16, 16), "float32")): + with R.dataflow(): + out = relax.op.tanh(data) + R.output(out) + return out + + data = np.random.randn(2, 8, 16, 16).astype("float32") + patterns = [("tensorrt.tanh", is_op("relax.tanh")(wildcard()))] + _offload_and_compare(Tanh, {}, patterns, data) + + +def test_tensorrt_conv2d_transpose(): + # Default IOHW kernel layout ([in, out, h, w]); output channels are weight_shape[1]. + @tvm.script.ir_module + class ConvTranspose: + @R.function + def main( + data: R.Tensor((2, 8, 16, 16), "float32"), weight: R.Tensor((8, 4, 3, 3), "float32") + ): + with R.dataflow(): + out = relax.op.nn.conv2d_transpose(data, weight, padding=1) + R.output(out) + return out + + data = np.random.randn(2, 8, 16, 16).astype("float32") + weight = np.random.randn(8, 4, 3, 3).astype("float32") + patterns = [ + ("tensorrt.nn.conv2d_transpose", is_op("relax.nn.conv2d_transpose")(wildcard(), wildcard())) + ] + _offload_and_compare(ConvTranspose, {"weight": weight}, patterns, data) + + +def test_tensorrt_conv3d_transpose(): + # Default IODHW kernel layout ([in, out, d, h, w]); output channels are weight_shape[1]. + @tvm.script.ir_module + class ConvTranspose3d: + @R.function + def main( + data: R.Tensor((2, 4, 8, 8, 8), "float32"), weight: R.Tensor((4, 2, 3, 3, 3), "float32") + ): + with R.dataflow(): + out = relax.op.nn.conv3d_transpose(data, weight, padding=1) + R.output(out) + return out + + data = np.random.randn(2, 4, 8, 8, 8).astype("float32") + weight = np.random.randn(4, 2, 3, 3, 3).astype("float32") + patterns = [ + ("tensorrt.nn.conv3d_transpose", is_op("relax.nn.conv3d_transpose")(wildcard(), wildcard())) + ] + _offload_and_compare(ConvTranspose3d, {"weight": weight}, patterns, data) + + +def test_tensorrt_int8_calibration(monkeypatch): + # INT8 calibration path: the first N runs feed calibration batches, then the INT8 engine is + # built and run. Validates that the calibrator copies a full batch (batch_size * per-sample + # elements) without over-reading the input or over-writing the device buffers, which previously + # crashed for batch > 1. + @tvm.script.ir_module + class Conv2dInt8: + @R.function + def main( + data: R.Tensor((2, 8, 16, 16), "float32"), weight: R.Tensor((4, 8, 3, 3), "float32") + ): + with R.dataflow(): + out = relax.op.nn.conv2d(data, weight, padding=1) + R.output(out) + return out + + data = np.random.randn(2, 8, 16, 16).astype("float32") + weight = np.random.randn(4, 8, 3, 3).astype("float32") + ref = build_and_run(Conv2dInt8, [data, weight], "llvm", legalize=True) + + patterns = [("tensorrt.nn.conv2d", is_op("relax.nn.conv2d")(wildcard(), wildcard()))] + offloaded = tvm.transform.Sequential( + [ + relax.transform.BindParams("main", {"weight": weight}), + relax.transform.FuseOpsByPattern(patterns), + relax.transform.MergeCompositeFunctions(), + relax.transform.RunCodegen(), + ] + )(Conv2dInt8) + + num_calibration_batches = 2 + monkeypatch.setenv("TVM_TENSORRT_USE_INT8", "1") + monkeypatch.setenv("TENSORRT_NUM_CALI_INT8", str(num_calibration_batches)) + + dev = tvm.device("cuda", 0) + vm = relax.VirtualMachine(tvm.compile(offloaded, "cuda"), dev) + data_trt = tvm.runtime.tensor(data, dev) + out = None + for _ in range(num_calibration_batches + 1): + out = vm["main"](data_trt).numpy() + + assert np.isfinite(out).all() + # INT8 is lossy, so use a generous tolerance; the key assertion is that calibration completed + # without a CUDA error. + tvm.testing.assert_allclose(out, ref, rtol=0.2, atol=0.1 * float(np.abs(ref).max())) + + if __name__ == "__main__": - test_tensorrt_offload() + tvm.testing.main() From e8e94786ea6a040928dc76a29076550fb5422456 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 16 Jun 2026 07:31:04 -0400 Subject: [PATCH 10/23] [Tests] Make TargetCreation.DeduplicateKeys host-agnostic on AArch64 (#19786) This pr fixes #19718. The test asserted target->attrs.size()==2, which is host-specific: LLVM target canonicalization legitimately adds host attrs (feature.has_sve / has_asimd / is_aarch64 / mtriple on AArch64), so the target ends up with 9 attrs there and the assertion fails, while it happens to be 2 on x86. The test only means to verify that duplicate keys are deduplicated, so assert that the "keys" entry did not leak into the generic attrs map instead of pinning the host-specific attr count. --- tests/cpp/target_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index e6e73e84d626..d842c868d50b 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -379,7 +379,7 @@ TEST(TargetCreation, DeduplicateKeys) { TVM_FFI_ICHECK_EQ(target->keys.size(), 2U); TVM_FFI_ICHECK_EQ(target->keys[0], "cpu"); TVM_FFI_ICHECK_EQ(target->keys[1], "arm_cpu"); - TVM_FFI_ICHECK_EQ(target->attrs.size(), 2U); + TVM_FFI_ICHECK_EQ(target->attrs.count("keys"), 0U); TVM_FFI_ICHECK_EQ(target->GetAttr("device"), "arm_cpu"); } From 00813d6b14392a919da8137787c8e0e61a3aa992 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 16 Jun 2026 07:32:37 -0400 Subject: [PATCH 11/23] [Tests] Replace remaining requires_* helpers with standard pytest (#19787) This pr is the Follow-up to #19777. This pr removes the last `requires_*` decorators so test gating is plain pytest everywhere, with no custom indirection left. --- python/tvm/contrib/hexagon/pytest_plugin.py | 17 ---- python/tvm/testing/utils.py | 43 ---------- .../relax/backend/adreno/test_clml_ops.py | 38 ++++++--- .../relax/backend/adreno/test_texture_ops.py | 83 ++++++++++++------- tests/python/relax/backend/adreno/utils.py | 46 ++-------- tests/python/runtime/test_runtime_dlpack.py | 7 +- 6 files changed, 93 insertions(+), 141 deletions(-) diff --git a/python/tvm/contrib/hexagon/pytest_plugin.py b/python/tvm/contrib/hexagon/pytest_plugin.py index 97a644400b51..ac1bc7af99e7 100644 --- a/python/tvm/contrib/hexagon/pytest_plugin.py +++ b/python/tvm/contrib/hexagon/pytest_plugin.py @@ -59,23 +59,6 @@ def shape_nhwc(batch, in_channel, in_size): return (batch, in_size, in_size, in_channel) -def _compose(args, decs): - """Helper to apply multiple markers""" - if len(args) > 0: - func = args[0] - for dec in reversed(decs): - func = dec(func) - return func - return decs - - -def requires_hexagon_toolchain(func): - """Skip a test unless the Hexagon toolchain is available (compile-only).""" - return pytest.mark.skipif( - not tvm.testing.env.has_hexagon_toolchain(), reason="need hexagon toolchain" - )(func) - - def android_serial_number() -> str | None: """Return the android serial number""" serial = os.getenv(ANDROID_SERIAL_NUMBER, default="") diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index c90e610af4d6..51c862919552 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -589,49 +589,6 @@ def skip_if_no_reference_system(func): return skip_if_32bit(reason="Reference system unavailable in i386 container")(func) -def requires_package(*packages): - """Mark a test as requiring python packages to run. - - If the packages listed are not available, tests marked with - `requires_package` will appear in the pytest results as being skipped. - This is equivalent to using ``foo = pytest.importorskip('foo')`` inside - the test body. - - Parameters - ---------- - packages : List[str] - - The python packages that should be available for the test to - run. - - Returns - ------- - mark: pytest mark - - The pytest mark to be applied to unit tests that require this - - """ - - def has_package(package): - try: - __import__(package) - return True - except ImportError: - return False - - marks = [ - pytest.mark.skipif(not has_package(package), reason=f"Cannot import '{package}'") - for package in packages - ] - - def wrapper(func): - for mark in marks: - func = mark(func) - return func - - return wrapper - - def parametrize_targets(*args): """Parametrize a test over a specific set of targets. diff --git a/tests/python/relax/backend/adreno/test_clml_ops.py b/tests/python/relax/backend/adreno/test_clml_ops.py index 69b437bd0df7..428f21abdf25 100644 --- a/tests/python/relax/backend/adreno/test_clml_ops.py +++ b/tests/python/relax/backend/adreno/test_clml_ops.py @@ -44,7 +44,7 @@ get_relax_reshape_mod, get_unary_op_mod, ) -from utils import requires_adreno_clml, verify_results +from utils import skip_unless_adreno_clml, verify_results import tvm import tvm.testing @@ -105,7 +105,8 @@ def verify( verify_results(clml_mod, target=clml_target, ref_target=ref_target) -@requires_adreno_clml +@pytest.mark.gpu +@skip_unless_adreno_clml @pytest.mark.parametrize("dtype", ["float32"]) @pytest.mark.parametrize( "kernel_h, kernel_w, padding, stride, dilation, out_channels, shape, has_bias, has_bn, has_activation, has_pad, is_depthwise", @@ -199,7 +200,8 @@ def test_conv2d_offload( verify(mod, clml_codegen, inputs_np, params_np) -@requires_adreno_clml +@pytest.mark.gpu +@skip_unless_adreno_clml @pytest.mark.parametrize("dtype", ["float32"]) @pytest.mark.parametrize( "dshape, kshape, channels, kernel_size, strides, padding, out_shape", @@ -244,7 +246,8 @@ def test_conv2d_transpose( verify(mod, clml_codegen, inputs_np, params_np, target_test=False) -@requires_adreno_clml +@pytest.mark.gpu +@skip_unless_adreno_clml @pytest.mark.skipif( CLML_VERSION < 3, reason="Requires compiler supporting CLML v5 or above", @@ -314,7 +317,8 @@ def _get_axis_tuple(axis): verify(mod, clml_codegen, inputs_np, params_np) -@requires_adreno_clml +@pytest.mark.gpu +@skip_unless_adreno_clml @pytest.mark.parametrize("dtype", ["float32"]) @pytest.mark.parametrize( "a_shape, b_shape, op", @@ -333,7 +337,8 @@ def _get_axis_tuple(axis): ((1, 256), (1, 256), R.maximum), ], ) -@requires_adreno_clml +@pytest.mark.gpu +@skip_unless_adreno_clml def test_binary_ops(a_shape, b_shape, op, dtype): (mod, inputs_np) = get_binary_op_mod(a_shape, b_shape, op, dtype) clml_codegen = [ @@ -368,7 +373,8 @@ def test_binary_ops(a_shape, b_shape, op, dtype): verify(mod, clml_codegen, inputs_np, {}) -@requires_adreno_clml +@pytest.mark.gpu +@skip_unless_adreno_clml @pytest.mark.parametrize( "dtype", [ @@ -384,7 +390,8 @@ def test_binary_ops(a_shape, b_shape, op, dtype): ((1, 14, 14, 256), R.nn.relu), ], ) -@requires_adreno_clml +@pytest.mark.gpu +@skip_unless_adreno_clml def test_unary_ops(a_shape, op, dtype): (mod, inputs_np) = get_unary_op_mod(a_shape, op, dtype) clml_codegen = [ @@ -412,7 +419,8 @@ def test_unary_ops(a_shape, op, dtype): verify(mod, clml_codegen, inputs_np, {}) -@requires_adreno_clml +@pytest.mark.gpu +@skip_unless_adreno_clml @pytest.mark.parametrize("dtype", ["float32"]) @pytest.mark.parametrize( "trials", @@ -439,7 +447,8 @@ def test_max_pool(dtype, trials): verify(mod, clml_codegen, inputs_np, {}) -@requires_adreno_clml +@pytest.mark.gpu +@skip_unless_adreno_clml @pytest.mark.parametrize("dtype", ["float32"]) @pytest.mark.parametrize( "trials", @@ -467,7 +476,8 @@ def test_avg_pool(dtype, trials): verify(mod, clml_codegen, inputs_np, {}) -@requires_adreno_clml +@pytest.mark.gpu +@skip_unless_adreno_clml @pytest.mark.parametrize("dtype", ["float32"]) @pytest.mark.parametrize( "trials", @@ -488,7 +498,8 @@ def test_reshape(dtype, trials): @pytest.mark.skip(reason="Codegen Comparision Failing") -@requires_adreno_clml +@pytest.mark.gpu +@skip_unless_adreno_clml @pytest.mark.parametrize("dtype", ["float32"]) @pytest.mark.parametrize( "trials", @@ -514,7 +525,8 @@ def test_global_avg_pool(dtype, trials): verify(mod, clml_codegen, inputs_np, {}) -@requires_adreno_clml +@pytest.mark.gpu +@skip_unless_adreno_clml @pytest.mark.parametrize("dtype", ["float32"]) @pytest.mark.parametrize( "trials", diff --git a/tests/python/relax/backend/adreno/test_texture_ops.py b/tests/python/relax/backend/adreno/test_texture_ops.py index cfd2f358dffa..292f7d6ad5da 100644 --- a/tests/python/relax/backend/adreno/test_texture_ops.py +++ b/tests/python/relax/backend/adreno/test_texture_ops.py @@ -16,7 +16,7 @@ # under the License. import pytest -from utils import requires_adreno_opencl_vulkan, verify_results +from utils import skip_unless_adreno_opencl_vulkan, verify_results import tvm import tvm.testing @@ -30,7 +30,8 @@ ref_target = tvm.target.Target("llvm") -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d(target): @I.ir_module @@ -47,7 +48,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_relu(target): @I.ir_module @@ -65,7 +67,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_relu_conv2d_relu(target): @I.ir_module @@ -84,7 +87,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_relu_tanh(target): @I.ir_module @@ -103,7 +107,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_add(target): @I.ir_module @@ -123,7 +128,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_sum(target): @I.ir_module @@ -141,7 +147,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_sum_keepdims(target): @I.ir_module @@ -159,7 +166,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_sum_reduce(target): @I.ir_module @@ -177,7 +185,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_transpose(target): @I.ir_module @@ -195,7 +204,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_expand_dims(target): @I.ir_module @@ -213,7 +223,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_squeeze(target): @I.ir_module @@ -231,7 +242,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_strided_slice(target): @I.ir_module @@ -251,7 +263,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_relu_concat(target): @I.ir_module @@ -270,7 +283,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_relu_concat_split(target): @I.ir_module @@ -290,7 +304,8 @@ def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "fl verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_relu_concat_split_transpose_concat(target): @I.ir_module @@ -312,7 +327,8 @@ def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "fl @pytest.mark.skip(reason="Known failure: numerical mismatch in texture lowering") -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_maxpool2d(target): @I.ir_module @@ -338,7 +354,8 @@ def main( @pytest.mark.skip(reason="Known failure: numerical mismatch in texture lowering") -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_avgpool2d(target): @I.ir_module @@ -356,7 +373,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_softmax(target): @I.ir_module @@ -374,7 +392,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_layernorm(target): @I.ir_module @@ -397,7 +416,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_binary_broadcast(target): @I.ir_module @@ -417,7 +437,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_binary_ewise_scalar(target): @I.ir_module @@ -435,7 +456,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_residual_block(target): r""" @@ -483,7 +505,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_conv2d_fallback_to_buffer_conv2d(target): r""" @@ -522,7 +545,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_conv2d_conv2d_conv2d_concat(target): r""" @@ -562,7 +586,8 @@ def main( @pytest.mark.skip(reason="Known failure: numerical mismatch in texture lowering") -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_pooling_branching_texture_params(target): r""" @@ -613,7 +638,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_injective_inputs1(target): r""" @@ -662,7 +688,8 @@ def main( verify_results(Input, target, ref_target) -@requires_adreno_opencl_vulkan +@pytest.mark.gpu +@skip_unless_adreno_opencl_vulkan @tvm.testing.parametrize_targets(*TARGETS) def test_injective_nwo_inputs2(target): r""" diff --git a/tests/python/relax/backend/adreno/utils.py b/tests/python/relax/backend/adreno/utils.py index 608530c32590..f576c202cd65 100644 --- a/tests/python/relax/backend/adreno/utils.py +++ b/tests/python/relax/backend/adreno/utils.py @@ -56,49 +56,19 @@ def __call__(self): return self.check -def _adreno_requires(predicate, reason): - """Tag a GPU test with the ``gpu`` marker plus an eager runtime skip. - - The predicate is evaluated when the decorator is applied (at collection - time), so the skip condition is resolved eagerly. - """ - - def decorator(func): - func = pytest.mark.skipif(not predicate(), reason=reason)(func) - return pytest.mark.gpu(func) - - return decorator - +# Eager skips for Adreno GPU tests, resolved at import time. Pair each with +# ``@pytest.mark.gpu`` at the test site so CI's ``-m gpu`` filter selects it. # OpenCL or Vulkan -requires_adreno_opencl_vulkan = _adreno_requires( - run_time_check("any").check, - "need adreno opencl or vulkan", -) - -# Any Vulkan -requires_adreno_vulkan = _adreno_requires( - lambda: tvm.runtime.enabled("vulkan") and run_time_check("vulkan").check(), - "need adreno vulkan", -) - -# Any OpenCL -requires_adreno_opencl = _adreno_requires( - lambda: tvm.runtime.enabled("opencl") and run_time_check("opencl").check(), - "need adreno opencl", -) - -# Real Adreno GPU OpenCL Target -requires_adreno_opencl_real = _adreno_requires( - lambda: tvm.runtime.enabled("opencl") and run_time_check("real").check(), - "need real adreno opencl", +skip_unless_adreno_opencl_vulkan = pytest.mark.skipif( + not run_time_check("any").check(), + reason="need adreno opencl or vulkan", ) # CLML Codegen -requires_adreno_clml = _adreno_requires( - lambda: tvm.get_global_func("relax.is_openclml_runtime_enabled", allow_missing=True) - is not None, - "need adreno openclml", +skip_unless_adreno_clml = pytest.mark.skipif( + tvm.get_global_func("relax.is_openclml_runtime_enabled", allow_missing=True) is None, + reason="need adreno openclml", ) diff --git a/tests/python/runtime/test_runtime_dlpack.py b/tests/python/runtime/test_runtime_dlpack.py index 886ee85bf78d..b1fcc83dcf1b 100644 --- a/tests/python/runtime/test_runtime_dlpack.py +++ b/tests/python/runtime/test_runtime_dlpack.py @@ -15,13 +15,17 @@ # specific language governing permissions and limitations # under the License. import numpy as np +import pytest import tvm import tvm.testing from tvm import te +# These tests exercise the PyTorch DLPack interop path; skip the whole module +# when torch is unavailable. +pytest.importorskip("torch") + -@tvm.testing.requires_package("torch") def test_from_dlpack_shape_one(): # A test case for the issue https://github.com/pytorch/pytorch/issues/99803 import torch @@ -47,7 +51,6 @@ def test_from_dlpack_shape_one(): tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) -@tvm.testing.requires_package("torch") def test_from_dlpack_strided(): import torch from torch.utils.dlpack import to_dlpack From 8ede60c9be51d8ba3e6d0dd47226237f6a42ed54 Mon Sep 17 00:00:00 2001 From: Zephyr <114734429+ZephyrLi-pro@users.noreply.github.com> Date: Tue, 16 Jun 2026 19:34:44 +0800 Subject: [PATCH 12/23] [TIRx][RISC-V] Use scalable RVV loops for fixed vectorize (#19776) This PR improves TIRx vectorization for RISC-V RVV targets. Fixed-width `T.vectorized` loops can be lowered to fixed LLVM vectors such as `<16 x float>`, which LLVM/RVV may scalarize into repeated scalar `flw/fsub.s/fsw` instructions. This PR rewrites fixed-width vectorized loops on RVV targets into scalable `T.vscale() * 4` chunks with lane masks, allowing LLVM to generate RVV load/store instructions instead. The change is limited to RISC-V RVV and does not enable the same automatic rewrite for Arm SVE. Tested on a RISC-V K3 board: Before: flw/fsub.s/fsw = 16/16/16, vle32/vse32 = 0/0 After: flw/fsub.s/fsw = 0/0/0, vle32/vse32 = 1/1 Also added a RISC-V LLVM codegen regression test. --- src/tirx/transform/vectorize_loop.cc | 97 +++++++++++++++++-- .../codegen/test_target_codegen_riscv.py | 43 ++++++++ 2 files changed, 130 insertions(+), 10 deletions(-) diff --git a/src/tirx/transform/vectorize_loop.cc b/src/tirx/transform/vectorize_loop.cc index a1e954f95184..fe6734863bb8 100644 --- a/src/tirx/transform/vectorize_loop.cc +++ b/src/tirx/transform/vectorize_loop.cc @@ -54,16 +54,27 @@ bool IsVScaleCall(const PrimExpr& expr) { return false; } +bool TargetHasRVV(Target target) { + if (!target.defined()) return false; + static auto target_has_feature_fn = + tvm::ffi::Function::GetGlobalRequired("target.target_has_feature"); + return target_has_feature_fn("v", target).cast(); +} + // File-local helper: true if the target supports Variable-Length Array extensions // (AArch64 SVE or RISC-V V). bool TargetHasVLA(Target target) { if (!target.defined()) return false; bool has_vla = target->GetAttr("feature.has_sve").value_or(false); - static auto target_has_feature_fn = - tvm::ffi::Function::GetGlobalRequired("target.target_has_feature"); - has_vla |= target_has_feature_fn("v", target).cast(); + has_vla |= TargetHasRVV(target); return has_vla; } + +bool ContainsCallNode(const Stmt& stmt) { + return CheckContains::StmtContains(stmt, [](const PrimExpr& expr) { + return expr.as() != nullptr; + }); +} } // namespace inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) { @@ -132,7 +143,8 @@ bool EnableBufferLevelPredication(Target target) { */ class TryPredicateBufferAccesses : public StmtExprMutator { public: - TryPredicateBufferAccesses() {} + explicit TryPredicateBufferAccesses(bool allow_offset_predication) + : allow_offset_predication_(allow_offset_predication) {} /*! * \brief Run the pass to try to exact predicates. @@ -157,7 +169,10 @@ class TryPredicateBufferAccesses : public StmtExprMutator { return {false, stmt}; } - base_ = Downcast(lt->a)->base; + Ramp pred_ramp = Downcast(lt->a); + base_ = pred_ramp->base; + stride_ = pred_ramp->stride; + lanes_ = pred_ramp->lanes; limit_ = Downcast(lt->b)->value; // Now we can try to predicate @@ -190,11 +205,21 @@ class TryPredicateBufferAccesses : public StmtExprMutator { } Ramp ramp = Downcast(node->indices[0]); - // The vectorized access pattern must match the base of the predicate - if (!ffi::StructuralEqual()(ramp->base, base_)) { + if (!ffi::StructuralEqual()(ramp->stride, stride_) || + !ffi::StructuralEqual()(ramp->lanes, lanes_)) { return node; } + bool same_base = ffi::StructuralEqual()(ramp->base, base_); + if (!same_base) { + // The lane mask describes which lanes are active, independent of the + // memory base. This covers accesses such as A[offset + i] guarded by + // a predicate over i. + if (!allow_offset_predication_) { + return node; + } + } + DataType buf_predicate_dtype = DataType(DataType::kUInt, 1, ramp->dtype.get_lanes_or_vscale_factor(), ramp->dtype.is_scalable_vector()); @@ -202,15 +227,27 @@ class TryPredicateBufferAccesses : public StmtExprMutator { num_accesses_rewritten_ += 1; auto writer = node.CopyOnWrite(); - writer->predicate = lane_mask; + if (node->predicate.defined() && allow_offset_predication_) { + // Buffer predicates are uint1 lane masks, so mask merging uses bitwise + // and rather than logical &&. + writer->predicate = node->predicate.value() & lane_mask; + } else { + writer->predicate = lane_mask; + } return node; } /*! \brief The variable base expr of the predicate. */ PrimExpr base_; + /*! \brief The lane stride of the predicate. */ + PrimExpr stride_; + /*! \brief The lane count of the predicate. */ + PrimExpr lanes_; /*! \brief The limit of the predicate. The expr specifies the upper bound of the base's * evaluated value. */ PrimExpr limit_; + /*! \brief Whether to predicate offset buffer accesses that use the same lane layout. */ + bool allow_offset_predication_; /*! \brief The number of buffer accesses in the stmt we will analyze. */ size_t num_accesses_analyzed_ = 0; /*! \brief The number of buffer accesses rewritten with predicates. */ @@ -819,7 +856,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor success_stmt_pair = - TryPredicateBufferAccesses().Run(then_case, condition); + TryPredicateBufferAccesses(TargetHasRVV(target_)).Run(then_case, condition); bool can_remove_if_then_else = success_stmt_pair.first; if (can_remove_if_then_else) { return success_stmt_pair.second; @@ -975,12 +1012,19 @@ class LoopVectorizer : public StmtMutator { if (op->kind == ForKind::kVectorized) { auto* extent_as_int = op->extent.as(); + TVM_FFI_ICHECK(is_zero(op->min)); + // General calls still have vectorization paths that query a compile-time + // lane count, so keep them on the existing fixed-width path for now. + if (extent_as_int && extent_as_int->value > 1 && TargetHasRVV(target_) && + !ContainsCallNode(op->body)) { + return VectorizeFixedLoopForRVV(op, extent_as_int->value); + } + if (!extent_as_int || extent_as_int->value < 1) { bool is_scalable_expr = CheckContains::ExprContains(op->extent, IsVScaleCall); TVM_FFI_ICHECK(is_scalable_expr && TargetHasVLA(target_)) << "Failed to vectorize loop with extent " << op->extent << " for target " << target_; } - TVM_FFI_ICHECK(is_zero(op->min)); return Vectorizer(op->loop_var, op->extent, target_)(op->body); } else { return StmtMutator::VisitStmt_(op); @@ -999,6 +1043,39 @@ class LoopVectorizer : public StmtMutator { } private: + Stmt VectorizeFixedLoopForRVV(const ForNode* op, int64_t extent) { + // Match the existing TIRx scalable-vector convention. LLVM/RVV still + // selects the runtime vector length with vsetvli. + static constexpr int kDefaultVScaleFactor = 4; + DataType index_dtype = op->loop_var->dtype; + PrimExpr zero = make_const(index_dtype, 0); + PrimExpr fixed_extent = make_const(index_dtype, extent); + PrimExpr scalable_lanes = CreateNewLanes(/*is_scalable=*/true, kDefaultVScaleFactor); + DataType lane_dtype = scalable_lanes.dtype(); + PrimExpr scalable_lanes_index = scalable_lanes; + if (scalable_lanes_index.dtype() != index_dtype) { + scalable_lanes_index = Cast(index_dtype, scalable_lanes_index); + } + PrimExpr num_chunks = ceildiv(fixed_extent, scalable_lanes_index); + + Var outer(op->loop_var->name_hint + ".vla.o", index_dtype); + Var inner(op->loop_var->name_hint + ".vla.i", lane_dtype); + PrimExpr inner_index = inner; + if (inner_index.dtype() != index_dtype) { + inner_index = Cast(index_dtype, inner_index); + } + PrimExpr index = outer * scalable_lanes_index + inner_index; + Stmt body = Substitute(op->body, {{op->loop_var, index}}); + Stmt guarded_body = IfThenElse(index < fixed_extent, body, std::nullopt, op->span); + Stmt vector_loop = + For(inner, make_const(lane_dtype, 0), scalable_lanes, ForKind::kVectorized, guarded_body, + std::nullopt, op->annotations, std::nullopt, op->span); + Stmt loop = For(outer, zero, num_chunks, ForKind::kSerial, vector_loop, std::nullopt, {}, + std::nullopt, op->span); + + return this->VisitStmt(loop); + } + Target target_ = Target::Current(); }; diff --git a/tests/python/codegen/test_target_codegen_riscv.py b/tests/python/codegen/test_target_codegen_riscv.py index 5b9b1ecd7707..3ac75dc33745 100644 --- a/tests/python/codegen/test_target_codegen_riscv.py +++ b/tests/python/codegen/test_target_codegen_riscv.py @@ -16,6 +16,7 @@ # under the License. # ruff: noqa: E501, F841 +import re import pytest import tvm @@ -113,5 +114,47 @@ def rvv_with_vscale(A_handle: T.handle, B_handle: T.handle, C_handle: T.handle): f = tvm.tirx.build(rvv_with_vscale, target) +@pytest.mark.skipif(not env.has_llvm_min_version(14), reason="need llvm >= 14") +def test_rvv_fixed_width_vectorized_loop_uses_scalable_chunks(): + @T.prim_func(s_tir=True) + def fixed16_negative( + A: T.Buffer((14, 23, 67, 99), "float32"), + B: T.Buffer((14, 23, 67, 99), "float32"), + ): + for n, c, h, wo in T.grid(14, 23, 67, 7): + for wi in T.vectorized(0, 16): + if wo * 16 + wi < 99: + B[n, c, h, wo * 16 + wi] = T.float32(0) - A[n, c, h, wo * 16 + wi] + + @T.prim_func(s_tir=True) + def fixed16_negative_int64(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): + for wi in T.vectorized(T.int64(0), T.int64(16)): + B[wi] = T.float32(0) - A[wi] + + target = tvm.target.Target( + { + "kind": "llvm", + "device": "riscv_cpu", + "mtriple": "riscv64-linux-gnu", + "mcpu": "generic-rv64", + "mattr": ["+64bit", "+a", "+c", "+d", "+f", "+m", "+v"], + } + ) + + def check_codegen(func): + with target: + f = tvm.tirx.build(func, target) + + assembly = f.inspect_source("asm") + assert "vle32.v" in assembly + assert "vse32.v" in assembly + assert not re.search(r"\bflw\b", assembly) + assert not re.search(r"\bfsub\.s\b", assembly) + assert not re.search(r"\bfsw\b", assembly) + + check_codegen(fixed16_negative) + check_codegen(fixed16_negative_int64) + + if __name__ == "__main__": tvm.testing.main() From c4c737a08f771866589462b818423ea9ca3ab7c3 Mon Sep 17 00:00:00 2001 From: Shushi Hong <820958424@qq.com> Date: Tue, 16 Jun 2026 09:09:31 -0400 Subject: [PATCH 13/23] [Docs] Modernize test-gating documentation (#19788) This pr updates the contributor guide and tvm.testing docstrings/comments to describe the current gating API --------- Co-authored-by: Tianqi Chen Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- docs/contribute/code_guide.rst | 11 +++++- docs/contribute/testing.rst | 62 +++++++++++++++++----------------- python/tvm/testing/env.py | 22 ++++++------ python/tvm/testing/plugin.py | 8 +++-- python/tvm/testing/utils.py | 58 +++++++++++++++---------------- 5 files changed, 87 insertions(+), 74 deletions(-) diff --git a/docs/contribute/code_guide.rst b/docs/contribute/code_guide.rst index fd40cec579cf..6419b7f9d77a 100644 --- a/docs/contribute/code_guide.rst +++ b/docs/contribute/code_guide.rst @@ -139,7 +139,16 @@ If you want your test to run over a variety of targets, use the :py:func:`tvm.te def test_mytest(target, dev): ... -will run ``test_mytest`` with ``target="llvm"``, ``target="cuda"``, and few others. This also ensures that your test is run on the correct hardware by the CI. If you only want to test against a couple targets use ``@tvm.testing.parametrize_targets("target_1", "target_2")``. If you want to test on a single target, use the associated decorator from :py:func:`tvm.testing`. For example, CUDA tests use the ``@tvm.testing.requires_cuda`` decorator. +will run ``test_mytest`` with ``target="llvm"``, ``target="cuda"``, and few others. This also ensures that your test is run on the correct hardware by the CI. If you only want to test against a couple targets use ``@tvm.testing.parametrize_targets("target_1", "target_2")``. If you want to test on a single target, gate the test on the corresponding capability probe instead of using a per-target decorator. Mark GPU tests with ``@pytest.mark.gpu`` so the CI can select them, and skip when the required feature is unavailable with ``@pytest.mark.skipif``. For example, CUDA tests use: + +.. code:: python + + @pytest.mark.gpu + @pytest.mark.skipif(not tvm.testing.env.has_cuda(), reason="need cuda") + def test_mycudatest(): + ... + +The ``tvm.testing.env`` module exposes a ``has_*()`` probe for each runtime and hardware feature (e.g. ``has_cuda()``, ``has_rocm()``, ``has_vulkan()``, ``has_llvm()``). To skip a test when an optional Python package is missing, use ``pytest.importorskip("package_name")``. Network Resources diff --git a/docs/contribute/testing.rst b/docs/contribute/testing.rst index c2f502503099..c5777bad38a0 100644 --- a/docs/contribute/testing.rst +++ b/docs/contribute/testing.rst @@ -111,9 +111,9 @@ parameters. For instance, there may be target-specific implementations that should be tested, where some targets have more than one implementation. These can be done by explicitly parametrizing over tuples of arguments, such as shown below. In these -cases, only the explicitly listed targets will run, but they will -still have the appropriate ``@tvm.testing.requires_RUNTIME`` mark -applied to them. +cases, only the explicitly listed targets will run, and each target is +automatically gated on whether it can run on the current machine (a GPU +target gets ``@pytest.mark.gpu`` plus a skip when no device is present). .. code-block:: python @@ -134,34 +134,34 @@ marks are as follows. - ``@pytest.mark.gpu`` - Tags a function as using GPU capabilities. This has no effect on its own, but can be paired with - command-line arguments ``-m gpu`` or ``-m 'not gpu'`` to restrict - which tests pytest will execute. This should not be called on its - own, but is part of other marks used in unit-tests. - -- ``@tvm.testing.uses_gpu`` - Applies ``@pytest.mark.gpu``. This - should be used to mark unit tests that may use the GPU, if one is - present. This decorator is only needed for tests that explicitly - loop over ``tvm.testing.enabled_targets()``, but that is no longer - the preferred style of writing unit tests (see below). When using - ``tvm.testing.parametrize_targets()``, this decorator is implicit - for GPU targets, and does not need to be explicitly applied. - -- ``@tvm.testing.requires_gpu`` - Applies ``@tvm.testing.uses_gpu``, - and additionally marks that the test should be skipped - (``@pytest.mark.skipif``) entirely if no GPU is present. - -- ``@tvm.testing.requires_RUNTIME`` - Several decorators - (e.g. ``@tvm.testing.requires_cuda``), each of which skips a test if - the specified runtime cannot be used. A runtime cannot be used if it - is disabled in the ``config.cmake``, or if a compatible device is - not present. For runtimes that use the GPU, this includes - ``@tvm.testing.requires_gpu``. - -When using parametrized targets, each test run is decorated with the -``@tvm.testing.requires_RUNTIME`` that corresponds to the target -being used. As a result, if a target is disabled in ``config.cmake`` -or does not have appropriate hardware to run, it will be explicitly -listed as skipped. + the command-line arguments ``-m gpu`` or ``-m 'not gpu'`` to restrict + which tests pytest will execute. Apply it to any test that needs a + GPU so that the CI runs it only on GPU nodes. + +- ``@pytest.mark.skipif(not tvm.testing.env.has_X(), reason=...)`` - + Skips a test when a required runtime or hardware feature is not + available. The :py:mod:`tvm.testing.env` module exposes one memoized + probe per capability (e.g. ``has_cuda()``, ``has_rocm()``, + ``has_vulkan()``, ``has_gpu()``, ``has_llvm()``), each of which + returns ``False`` when the runtime is disabled in ``config.cmake`` or + no compatible device is present. Pair it with ``@pytest.mark.gpu`` + for tests that use the GPU:: + + @pytest.mark.gpu + @pytest.mark.skipif(not tvm.testing.env.has_cuda(), reason="need cuda") + def test_cuda_vectorize_add(): + # Test code goes here + +- ``pytest.importorskip("package_name")`` - Skips a test (or the whole + module, when called at import time) if an optional Python package is + not installed. Use this instead of a ``skipif`` for package + dependencies. + +When using ``tvm.testing.parametrize_targets()``, each parametrized run +is gated automatically on whether its target can run on the current +machine. As a result, if a target is disabled in ``config.cmake`` or +does not have appropriate hardware to run, it will be explicitly listed +as skipped, and GPU targets are tagged with ``@pytest.mark.gpu`` for you. There also exists a ``tvm.testing.enabled_targets()`` that returns all targets that are enabled and runnable on the current machine, diff --git a/python/tvm/testing/env.py b/python/tvm/testing/env.py index 6b39b5f2f674..0c9b48e5c16a 100644 --- a/python/tvm/testing/env.py +++ b/python/tvm/testing/env.py @@ -115,9 +115,9 @@ def _device_exists(kind: str, index: int = 0) -> bool: def _build_flag_enabled(flag: str) -> bool: """Return whether an optional build flag (e.g. ``USE_CUTLASS``) is on. - Mirrors the historical ``Feature`` check: a flag counts as enabled - unless it is explicitly disabled, so library flags carrying a path - still register as present. + A flag counts as enabled unless it is explicitly disabled, so library + flags carrying a path (rather than a boolean) still register as present. + Callers gate on this via ``@pytest.mark.skipif(not tvm.testing.env.has_cutlass(), ...)``. """ try: value = tvm.support.libinfo().get(flag, "OFF") @@ -130,8 +130,8 @@ def _build_flag_enabled(flag: str) -> bool: def _target_enabled(kind: str) -> bool: """True if ``kind`` is selected by ``TVM_TEST_TARGETS`` (or the default set). - Restores the historical ``target_kind_enabled`` opt-out, so CI can exclude a - flaky backend (e.g. opencl) via ``TVM_TEST_TARGETS`` and have its tests skip + Honors the ``TVM_TEST_TARGETS`` opt-out, so CI can exclude a flaky + backend (e.g. opencl) via ``TVM_TEST_TARGETS`` and have its tests skip even when a device is physically present. """ try: @@ -343,8 +343,9 @@ def _nvcc_version() -> tuple: def has_nvcc_version(major: int, minor: int = 0, release: int = 0) -> bool: """True if a CUDA device is present and nvcc is at least ``(major, minor, release)``. - Implies :func:`has_cuda`, matching the historical ``requires_nvcc_version`` - decorator which also required the CUDA runtime. + Returns False when no CUDA device is present, so it implies :func:`has_cuda`. + Gate a test with ``@pytest.mark.skipif(not tvm.testing.env.has_nvcc_version(11, 4), + reason="need nvcc >= 11.4")`` (add ``@pytest.mark.gpu`` for GPU selection). """ return has_cuda() and _nvcc_version() >= (major, minor, release) @@ -389,9 +390,10 @@ def has_matrixcore() -> bool: def has_cudagraph() -> bool: """True if a CUDA device is present and the toolkit supports CUDA Graphs. - Implies :func:`has_cuda`, matching the historical ``requires_cudagraph`` - decorator (``parent_features="cuda"``): ``nvcc.have_cudagraph()`` only - checks the toolkit version, so the device guard must be explicit. + Implies :func:`has_cuda`: ``nvcc.have_cudagraph()`` only checks the + toolkit version, so the device guard must be explicit. Gate a test with + ``@pytest.mark.skipif(not tvm.testing.env.has_cudagraph(), reason=...)`` + (add ``@pytest.mark.gpu`` for CI selection). """ try: from tvm.support import nvcc # pylint: disable=import-outside-toplevel diff --git a/python/tvm/testing/plugin.py b/python/tvm/testing/plugin.py index bba2da6aee0d..91aeb6374f34 100644 --- a/python/tvm/testing/plugin.py +++ b/python/tvm/testing/plugin.py @@ -210,9 +210,11 @@ def update_parametrize_target_arg( raise TypeError(msg) from err if "target" in metafunc.fixturenames: - # Update any explicit use of @pytest.mark.parmaetrize to - # parametrize over targets. This adds the appropriate - # @tvm.testing.requires_* markers for each target. + # Update any explicit use of @pytest.mark.parametrize to + # parametrize over targets. This attaches the appropriate + # per-target gating markers (pytest.mark.gpu for GPU-family + # targets, plus a pytest.mark.skipif guarded by the relevant + # tvm.testing.env.has_*() probe) via _target_to_requirement. for mark in metafunc.definition.iter_markers("parametrize"): update_parametrize_target_arg(mark, *mark.args, **mark.kwargs) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 51c862919552..9adeba689b3b 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -29,38 +29,38 @@ Testing Markers *************** -We use pytest markers to specify the requirements of test functions. Currently -there is a single distinction that matters for our testing environment: does -the test require a gpu. For tests that require just a gpu or just a cpu, we -have the decorator :py:func:`requires_gpu` that enables the test when a gpu is -available. To avoid running tests that don't require a gpu on gpu nodes, this -decorator also sets the pytest marker `gpu` so we can use select the gpu subset -of tests (using `pytest -m gpu`). - -Unfortunately, many tests are written like this: +We use pytest markers to specify the requirements of test functions. +Currently there is a single distinction that matters for our testing +environment: does the test require a gpu. Tests that require a gpu are +tagged with the ``gpu`` pytest marker -- the only registered marker (see +the ``markers`` entry in ``pyproject.toml``). This lets us select the +gpu subset of tests with ``pytest -m gpu`` (and exclude them on cpu-only +nodes with ``pytest -m "not gpu"``). + +The ``gpu`` marker only controls which testing node a test runs on; it +does not check whether the required hardware or libraries are actually +present. To gate a test on a specific capability, combine the marker +with a ``skipif`` that consults the memoized environment probes in +:py:mod:`tvm.testing.env`: .. code-block:: python - def test_something(): - for target in all_targets(): - do_something() - -The test uses both gpu and cpu targets, so the test needs to be run on both cpu -and gpu nodes. But we still want to only run the cpu targets on the cpu testing -node. The solution is to mark these tests with the gpu marker so they will be -run on the gpu nodes. But we also modify all_targets (renamed to -enabled_targets) so that it only returns gpu targets on gpu nodes and cpu -targets on cpu nodes (using an environment variable). - -Instead of using the all_targets function, future tests that would like to -test against a variety of targets should use the -:py:func:`tvm.testing.parametrize_targets` functionality. This allows us -greater control over which targets are run on which testing nodes. - -If in the future we want to add a new type of testing node (for example -fpgas), we need to add a new marker in `tests/python/pytest.ini` and a new -function in this module. Then targets using this node should be added to the -`TVM_TEST_TARGETS` environment variable in the CI. + @pytest.mark.gpu + @pytest.mark.skipif(not tvm.testing.env.has_cuda(), reason="need cuda") + def test_cuda_vectorize_add(): + ... + +There is one ``has_*`` (or ``is_*``) probe per capability -- for example +:py:func:`tvm.testing.env.has_gpu`, :py:func:`tvm.testing.env.has_cuda`, +and :py:func:`tvm.testing.env.has_vulkan`. For optional Python packages, +prefer ``pytest.importorskip("pkg_name")`` instead of a ``skipif``. + +To run a test against a variety of targets, use +:py:func:`tvm.testing.parametrize_targets`; it parametrizes the test over +the enabled targets and applies the appropriate ``gpu`` tag and skip +conditions per target automatically. The set of enabled targets is +controlled by the ``TVM_TEST_TARGETS`` environment variable, so the CI +can run different targets on different testing nodes. """ From beb65113832bf48164490e446fc0c52a1379cfea Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Tue, 16 Jun 2026 22:44:45 +0800 Subject: [PATCH 14/23] [Web] Destroy GPUDevice once on buffer creation error (#19790) ## Why In `tryCreateBuffer`, each of the three popped error scopes independently called `device.destroy()` and `console.error`, so a buffer that triggers more than one error type destroyed the device repeatedly and logged duplicate errors. ## How - Collect all three `popErrorScope()` results via `Promise.all` and call `device.destroy()` at most once - Log every captured error instead of relying on per-scope handlers --------- Signed-off-by: Guan-Ming (Wesley) Chiu <105915352+guan404ming@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- web/src/webgpu.ts | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/web/src/webgpu.ts b/web/src/webgpu.ts index 199fa14235ee..5a25833d1a8f 100644 --- a/web/src/webgpu.ts +++ b/web/src/webgpu.ts @@ -169,9 +169,20 @@ function tryCreateBuffer(device: GPUDevice, descriptor: GPUBufferDescriptor) { const buffer = device.createBuffer(descriptor); - device.popErrorScope().then((error) => {if (error) {device.destroy(); console.error(error);}}); - device.popErrorScope().then((error) => {if (error) {device.destroy(); console.error(error);}}); - device.popErrorScope().then((error) => {if (error) {device.destroy(); console.error(error);}}); + // Destroy at most once even if multiple error types fire. + Promise.all([ + device.popErrorScope(), + device.popErrorScope(), + device.popErrorScope(), + ]).then((errors) => { + const captured = errors.filter((error): error is GPUError => error !== null); + if (captured.length > 0) { + device.destroy(); + captured.forEach((error) => console.error(error)); + } + }).catch((err) => { + console.error("Failed to pop error scopes:", err); + }); return buffer; } From a7864af9454a97fe294f49d76fef240c31a00753 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 16 Jun 2026 11:32:08 -0400 Subject: [PATCH 15/23] [REFACTOR] Phase out unused queue and rang license entries (#19794) This PR removes the obsolete queue and rang license files and drops the leftover rang include-directory hook from the CMake setup. --- CMakeLists.txt | 1 - licenses/LICENSE.blockingconcurrentqueue.txt | 26 -------------------- licenses/LICENSE.concurrentqueue.txt | 22 ----------------- licenses/LICENSE.rang.txt | 24 ------------------ 4 files changed, 73 deletions(-) delete mode 100644 licenses/LICENSE.blockingconcurrentqueue.txt delete mode 100644 licenses/LICENSE.concurrentqueue.txt delete mode 100644 licenses/LICENSE.rang.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 99a2569b142a..ffcd3ab7ff2e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -120,7 +120,6 @@ tvm_option(TVM_BUILD_PYTHON_MODULE "Build Python module with scikit-build-core" # include directories include_directories(${CMAKE_INCLUDE_PATH}) include_directories("include") -include_directories(SYSTEM ${RANG_PATH}) include_directories(SYSTEM ${COMPILER_RT_PATH}) # initial variables diff --git a/licenses/LICENSE.blockingconcurrentqueue.txt b/licenses/LICENSE.blockingconcurrentqueue.txt deleted file mode 100644 index d08e53a3c518..000000000000 --- a/licenses/LICENSE.blockingconcurrentqueue.txt +++ /dev/null @@ -1,26 +0,0 @@ -©2015-2016 Cameron Desrochers. Distributed under the terms of the simplified -BSD license, available at the top of concurrentqueue.h. - -Uses Jeff Preshing's semaphore implementation (under the terms of its -separate zlib license, embedded below). - - -zlib license ------------- -Copyright (c) 2015 Jeff Preshing - -This software is provided 'as-is', without any express or implied -warranty. In no event will the authors be held liable for any damages -arising from the use of this software. - -Permission is granted to anyone to use this software for any purpose, -including commercial applications, and to alter it and redistribute it -freely, subject to the following restrictions: - -1. The origin of this software must not be misrepresented; you must not - claim that you wrote the original software. If you use this software - in a product, an acknowledgement in the product documentation would be - appreciated but is not required. -2. Altered source versions must be plainly marked as such, and must not be - misrepresented as being the original software. -3. This notice may not be removed or altered from any source distribution. diff --git a/licenses/LICENSE.concurrentqueue.txt b/licenses/LICENSE.concurrentqueue.txt deleted file mode 100644 index b36f9eadc9f9..000000000000 --- a/licenses/LICENSE.concurrentqueue.txt +++ /dev/null @@ -1,22 +0,0 @@ -Simplified BSD license: -Copyright (c) 2013-2016, Cameron Desrochers. -All rights reserved. - -Redistribution and use in source and binary forms, with or without modification, -are permitted provided that the following conditions are met: - -- Redistributions of source code must retain the above copyright notice, this list of -conditions and the following disclaimer. -- Redistributions in binary form must reproduce the above copyright notice, this list of -conditions and the following disclaimer in the documentation and/or other materials -provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY -EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF -MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL -THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT -OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR -TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, -EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/licenses/LICENSE.rang.txt b/licenses/LICENSE.rang.txt deleted file mode 100644 index cf1ab25da034..000000000000 --- a/licenses/LICENSE.rang.txt +++ /dev/null @@ -1,24 +0,0 @@ -This is free and unencumbered software released into the public domain. - -Anyone is free to copy, modify, publish, use, compile, sell, or -distribute this software, either in source code form or as a compiled -binary, for any purpose, commercial or non-commercial, and by any -means. - -In jurisdictions that recognize copyright laws, the author or authors -of this software dedicate any and all copyright interest in the -software to the public domain. We make this dedication for the benefit -of the public at large and to the detriment of our heirs and -successors. We intend this dedication to be an overt act of -relinquishment in perpetuity of all present and future rights to this -software under copyright law. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. -IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR -OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, -ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR -OTHER DEALINGS IN THE SOFTWARE. - -For more information, please refer to From 5388ea33a1e484b8379e0ff754bcdced20a07938 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 16 Jun 2026 12:04:07 -0400 Subject: [PATCH 16/23] [REFACTOR][HEXAGON] Phase out Hexagon app and test wrappers (#19796) ## Summary The old Hexagon app and test wrappers depend on RPC helper artifacts that are no longer part of the supported app flow. This PR removes those wrappers and related helper references while keeping the core Hexagon target, codegen, and runtime implementation in place. ## Changes - Remove the obsolete Hexagon app wrapper directories. - Remove the Hexagon contrib test directory and its dedicated pytest/RPC launcher helpers. - Drop stale CI/docs references to the removed app and test helper paths. --- apps/hexagon_api/CMakeLists.txt | 166 --- apps/hexagon_api/README.md | 58 -- apps/hexagon_launcher/CMakeLists.txt | 81 -- apps/hexagon_launcher/README.md | 145 --- .../cmake/HexagonLauncher.cmake | 75 -- .../cmake/android/CMakeLists.txt | 100 -- .../cmake/hexagon/CMakeLists.txt | 105 -- apps/hexagon_launcher/launcher_android.cc | 170 --- apps/hexagon_launcher/launcher_core.cc | 231 ---- apps/hexagon_launcher/launcher_core.h | 133 --- apps/hexagon_launcher/launcher_hexagon.cc | 237 ----- apps/hexagon_launcher/launcher_main.cc | 159 --- apps/hexagon_launcher/launcher_rpc.idl | 33 - apps/hexagon_launcher/launcher_util.cc | 68 -- apps/hexagon_launcher/launcher_util.h | 34 - ci/jenkins/data.py | 4 - python/tvm/contrib/hexagon/_ci_env_check.py | 4 +- python/tvm/contrib/hexagon/build.py | 830 --------------- .../tvm/contrib/hexagon/hexagon_profiler.py | 125 --- python/tvm/contrib/hexagon/meta_schedule.py | 195 ---- .../hexagon/profiling/process_lwp_data.py | 388 ------- python/tvm/contrib/hexagon/pytest_plugin.py | 384 ------- python/tvm/contrib/hexagon/session.py | 287 ----- .../hexagon/runtime/profiler/README.md | 85 -- .../runtime/rpc/android_bash.sh.template | 31 - tests/python/contrib/test_hexagon/README.md | 130 --- .../python/contrib/test_hexagon/README_RPC.md | 371 ------- tests/python/contrib/test_hexagon/__init__.py | 19 - .../contrib/test_hexagon/benchmark_util.py | 277 ----- tests/python/contrib/test_hexagon/conftest.py | 27 - .../contrib/test_hexagon/conv2d/README.md | 37 - .../contrib/test_hexagon/conv2d/__init__.py | 19 - .../conv2d/test_conv2d_blocked.md | 494 --------- .../test_hexagon/conv2d/test_conv2d_conv2d.md | 986 ------------------ .../contrib/test_hexagon/infrastructure.py | 376 ------- .../contrib/test_hexagon/pytest_util.py | 176 ---- .../test_hexagon/test_async_dma_pipeline.py | 889 ---------------- .../test_benchmark_elemwise_add.py | 425 -------- .../test_hexagon/test_benchmark_maxpool2d.py | 358 ------- .../contrib/test_hexagon/test_dma_builtin.py | 190 ---- .../contrib/test_hexagon/test_memory_alloc.py | 84 -- .../test_hexagon/test_meta_schedule.py | 370 ------- .../contrib/test_hexagon/test_parallel_hvx.py | 240 ----- .../test_parallel_hvx_load_vtcm.py | 562 ---------- .../test_hexagon/test_parallel_scalar.py | 177 ---- .../test_relax_2d_buffer_allocation.py | 93 -- .../test_hexagon/test_relax_integration.py | 115 -- .../test_hexagon/test_run_unit_tests.py | 183 ---- .../contrib/test_hexagon/test_sigmoid.py | 120 --- .../test_software_pipeline_async.py | 206 ---- .../python/contrib/test_hexagon/test_take.py | 397 ------- .../contrib/test_hexagon/test_thread_pool.py | 108 -- .../python/contrib/test_hexagon/test_vtcm.py | 94 -- .../test_hexagon/test_vtcm_bandwidth.py | 194 ---- tests/python/testing/test_env.py | 2 +- 55 files changed, 2 insertions(+), 11845 deletions(-) delete mode 100644 apps/hexagon_api/CMakeLists.txt delete mode 100644 apps/hexagon_api/README.md delete mode 100644 apps/hexagon_launcher/CMakeLists.txt delete mode 100644 apps/hexagon_launcher/README.md delete mode 100644 apps/hexagon_launcher/cmake/HexagonLauncher.cmake delete mode 100644 apps/hexagon_launcher/cmake/android/CMakeLists.txt delete mode 100644 apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt delete mode 100644 apps/hexagon_launcher/launcher_android.cc delete mode 100644 apps/hexagon_launcher/launcher_core.cc delete mode 100644 apps/hexagon_launcher/launcher_core.h delete mode 100644 apps/hexagon_launcher/launcher_hexagon.cc delete mode 100644 apps/hexagon_launcher/launcher_main.cc delete mode 100644 apps/hexagon_launcher/launcher_rpc.idl delete mode 100644 apps/hexagon_launcher/launcher_util.cc delete mode 100644 apps/hexagon_launcher/launcher_util.h delete mode 100644 python/tvm/contrib/hexagon/build.py delete mode 100644 python/tvm/contrib/hexagon/hexagon_profiler.py delete mode 100644 python/tvm/contrib/hexagon/meta_schedule.py delete mode 100644 python/tvm/contrib/hexagon/profiling/process_lwp_data.py delete mode 100644 python/tvm/contrib/hexagon/pytest_plugin.py delete mode 100644 python/tvm/contrib/hexagon/session.py delete mode 100644 src/backend/hexagon/runtime/profiler/README.md delete mode 100644 src/backend/hexagon/runtime/rpc/android_bash.sh.template delete mode 100644 tests/python/contrib/test_hexagon/README.md delete mode 100644 tests/python/contrib/test_hexagon/README_RPC.md delete mode 100644 tests/python/contrib/test_hexagon/__init__.py delete mode 100644 tests/python/contrib/test_hexagon/benchmark_util.py delete mode 100644 tests/python/contrib/test_hexagon/conftest.py delete mode 100644 tests/python/contrib/test_hexagon/conv2d/README.md delete mode 100644 tests/python/contrib/test_hexagon/conv2d/__init__.py delete mode 100644 tests/python/contrib/test_hexagon/conv2d/test_conv2d_blocked.md delete mode 100644 tests/python/contrib/test_hexagon/conv2d/test_conv2d_conv2d.md delete mode 100644 tests/python/contrib/test_hexagon/infrastructure.py delete mode 100644 tests/python/contrib/test_hexagon/pytest_util.py delete mode 100644 tests/python/contrib/test_hexagon/test_async_dma_pipeline.py delete mode 100644 tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py delete mode 100644 tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py delete mode 100644 tests/python/contrib/test_hexagon/test_dma_builtin.py delete mode 100644 tests/python/contrib/test_hexagon/test_memory_alloc.py delete mode 100644 tests/python/contrib/test_hexagon/test_meta_schedule.py delete mode 100644 tests/python/contrib/test_hexagon/test_parallel_hvx.py delete mode 100644 tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py delete mode 100644 tests/python/contrib/test_hexagon/test_parallel_scalar.py delete mode 100644 tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py delete mode 100644 tests/python/contrib/test_hexagon/test_relax_integration.py delete mode 100644 tests/python/contrib/test_hexagon/test_run_unit_tests.py delete mode 100644 tests/python/contrib/test_hexagon/test_sigmoid.py delete mode 100644 tests/python/contrib/test_hexagon/test_software_pipeline_async.py delete mode 100644 tests/python/contrib/test_hexagon/test_take.py delete mode 100644 tests/python/contrib/test_hexagon/test_thread_pool.py delete mode 100644 tests/python/contrib/test_hexagon/test_vtcm.py delete mode 100644 tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py diff --git a/apps/hexagon_api/CMakeLists.txt b/apps/hexagon_api/CMakeLists.txt deleted file mode 100644 index 62dca9d4e644..000000000000 --- a/apps/hexagon_api/CMakeLists.txt +++ /dev/null @@ -1,166 +0,0 @@ -cmake_minimum_required(VERSION 3.2) - -project(hexagon_api) - -include(ExternalProject) - -# Required variables: -# ANDROID_ABI -# ANDROID_PLATFORM -# USE_ANDROID_TOOLCHAIN (Android toolchain .cmake file) -# USE_HEXAGON_ARCH -# USE_HEXAGON_SDK -# USE_HEXAGON_TOOLCHAIN (Path to Hexagon toolchain ending with "Tools") -# Optional variable: -# USE_OUTPUT_BINARY_DIR (Path to copy the output binaries to) -# USE_HEXAGON_GTEST (Path to Hexagon specific gtest version) - -set(TVM_SOURCE_DIR "${CMAKE_SOURCE_DIR}/../..") - -if(DEFINED USE_OUTPUT_BINARY_DIR) - set(HEXAGON_API_BINARY_DIR "${USE_OUTPUT_BINARY_DIR}") -else() - set(HEXAGON_API_BINARY_DIR "${CMAKE_BINARY_DIR}/hexagon_rpc") -endif() -file(MAKE_DIRECTORY ${HEXAGON_API_BINARY_DIR}) - -if(DEFINED USE_HEXAGON_GTEST) - if(EXISTS ${USE_HEXAGON_GTEST}) - message(STATUS "Found Hexagon gtest at ${USE_HEXAGON_GTEST}") - else() - message(WARNING "Could not find Hexagon gtest at ${USE_HEXAGON_GTEST}. Disabling Hexagon gtest support.") - unset(USE_HEXAGON_GTEST) - endif() -endif() - -# Build X86 binaries: -# - tvm_rpc_x86 - -ExternalProject_Add(x86_tvm_runtime_rpc - SOURCE_DIR "${TVM_SOURCE_DIR}" - BUILD_COMMAND $(MAKE) runtime tvm_rpc - CMAKE_ARGS - "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" - "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" - "-DCMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER}" - "-DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER}" - "-DUSE_HEXAGON_TOOLCHAIN=${USE_HEXAGON_TOOLCHAIN}" - "-DCMAKE_CXX_STANDARD=17" - "-DTVM_FFI_USE_LIBBACKTRACE=OFF" - "-DTVM_FFI_USE_THREADS=OFF" - "-DTVM_FFI_USE_DL_LIBS=OFF" - "-DUSE_RPC=ON" - "-DUSE_CPP_RPC=ON" - "-DUSE_HEXAGON=ON" - "-DUSE_HEXAGON_RPC=ON" - "-DBUILD_STATIC_RUNTIME=ON" - "-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}" - INSTALL_COMMAND "" - BUILD_ALWAYS ON -) -ExternalProject_Get_Property(x86_tvm_runtime_rpc BINARY_DIR) -ExternalProject_Add_Step(x86_tvm_runtime_rpc copy_rpc_server - COMMAND ${CMAKE_COMMAND} -E copy_if_different - ${BINARY_DIR}/tvm_rpc - ${HEXAGON_API_BINARY_DIR}/tvm_rpc_x86 - DEPENDEES install -) - -# Build Android binaries: -# - libtvm_runtime.so -# - tvm_rpc_android - -ExternalProject_Add(android_tvm_runtime_rpc - SOURCE_DIR "${TVM_SOURCE_DIR}" - BUILD_COMMAND $(MAKE) runtime tvm_rpc - CMAKE_ARGS - "-DCMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER}" - "-DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER}" - "-DCMAKE_TOOLCHAIN_FILE=${USE_ANDROID_TOOLCHAIN}" - "-DANDROID_PLATFORM=${ANDROID_PLATFORM}" - "-DANDROID_ABI=${ANDROID_ABI}" - "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" - "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" - "-DCMAKE_CXX_STANDARD=17" - "-DTVM_FFI_USE_LIBBACKTRACE=OFF" - "-DTVM_FFI_USE_THREADS=OFF" - "-DTVM_FFI_USE_DL_LIBS=OFF" - "-DUSE_RPC=ON" - "-DUSE_CPP_RPC=ON" - "-DUSE_HEXAGON=ON" - "-DUSE_HEXAGON_RPC=ON" - "-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}" - "-DUSE_ALTERNATIVE_LINKER=OFF" - "-DUSE_RANDOM=ON" - INSTALL_COMMAND "" - BUILD_ALWAYS ON -) - -ExternalProject_Get_Property(android_tvm_runtime_rpc BINARY_DIR) -ExternalProject_Add_Step(android_tvm_runtime_rpc copy_runtime - COMMAND ${CMAKE_COMMAND} -E copy_if_different - ${BINARY_DIR}/lib/libtvm_runtime.so - ${HEXAGON_API_BINARY_DIR} - DEPENDEES install -) -ExternalProject_Add_Step(android_tvm_runtime_rpc copy_rpc_server - COMMAND ${CMAKE_COMMAND} -E copy_if_different - ${BINARY_DIR}/tvm_rpc - ${HEXAGON_API_BINARY_DIR}/tvm_rpc_android - DEPENDEES install -) - - -# Build Hexagon binaries: -# - libhexagon_rpc_skel.so -# - libtvm_runtime.a -if(DEFINED USE_HEXAGON_GTEST) - set(GTEST_FLAG "-DUSE_HEXAGON_GTEST=${USE_HEXAGON_GTEST}") -endif() - -if(NOT DEFINED USE_HEXAGON_QHL) - # USE_HEXAGON_QHL defaults to ON for rpc runtime if not explicitly set to OFF - set(USE_HEXAGON_QHL ON) -endif() - -ExternalProject_Add(hexagon_tvm_runtime_rpc - SOURCE_DIR "${TVM_SOURCE_DIR}" - BUILD_COMMAND $(MAKE) runtime hexagon_rpc_sim - CMAKE_ARGS - "-DCMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER}" - "-DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER}" - "-DCMAKE_C_COMPILER=${USE_HEXAGON_TOOLCHAIN}/bin/hexagon-clang" - "-DCMAKE_CXX_COMPILER=${USE_HEXAGON_TOOLCHAIN}/bin/hexagon-clang++" - "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" - "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" - "-DUSE_HEXAGON_EXTERNAL_LIBS=${USE_HEXAGON_EXTERNAL_LIBS}" - "-DHEXAGON_EXTERNAL_LIBS_SHA=${HEXAGON_EXTERNAL_LIBS_SHA}" - "-DCMAKE_CXX_STANDARD=17" - "-DTVM_FFI_USE_LIBBACKTRACE=OFF" - "-DTVM_FFI_USE_THREADS=OFF" - "-DTVM_FFI_USE_DL_LIBS=OFF" - "-DUSE_RPC=OFF" - "-DUSE_HEXAGON=ON" - "-DUSE_HEXAGON_RPC=ON" - "-DBUILD_STATIC_RUNTIME=ON" - "-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}" - "-DUSE_ALTERNATIVE_LINKER=OFF" - "-DUSE_CUSTOM_LOGGING=ON" - "-DUSE_HEXAGON_QHL=${USE_HEXAGON_QHL}" - "-DUSE_RANDOM=ON" - "${GTEST_FLAG}" - INSTALL_COMMAND "" - BUILD_ALWAYS ON -) -ExternalProject_Get_Property(hexagon_tvm_runtime_rpc BINARY_DIR) -ExternalProject_Add_Step(hexagon_tvm_runtime_rpc copy_binaries - COMMAND ${CMAKE_COMMAND} -E copy_if_different - ${BINARY_DIR}/lib/libtvm_runtime.a - ${BINARY_DIR}/libhexagon_rpc_skel.so - ${BINARY_DIR}/libhexagon_rpc_sim.so - ${HEXAGON_API_BINARY_DIR} - DEPENDEES install -) - -configure_file("${TVM_SOURCE_DIR}/src/runtime/hexagon/rpc/android_bash.sh.template" - ${HEXAGON_API_BINARY_DIR} COPYONLY) diff --git a/apps/hexagon_api/README.md b/apps/hexagon_api/README.md deleted file mode 100644 index f8f5f792af87..000000000000 --- a/apps/hexagon_api/README.md +++ /dev/null @@ -1,58 +0,0 @@ - - - - - - - - - - - - - - - - - -# Hexagon API app - -This is a meta-app that build the necessary binaries for use with -the `HexagonLauncher` utility from `tvm.contrib.hexagon`. - -It will build the TVM runtime for Android, the RPC server application -for Android, and the RPC library for Hexagon with the TVM runtime for -Hexagon built into it. - -## Configuration - -There is a set of configuration variables that are required for CMake: -- `ANDROID_ABI`: Set this to `arm64-v8a`. -- `ANDROID_PLATFORM`: This can be `android-28`. -- `USE_ANDROID_TOOLCHAIN`: The path to the Android toolchain file, i.e. -`android.toolchain.cmake`. This file is a part of the Android NDK. -- `USE_HEXAGON_ARCH`: The version string of the Hexagon architecture -to use, i.e. vNN. The typical setting would be `v68` or later. -- `USE_HEXAGON_SDK`: The path to the Hexagon SDK. Set this path in such -a way that `${USE_HEXAGON_SDK}/setup_sdk_env.source` exists. -- `USE_HEXAGON_TOOLCHAIN`: Path to Hexagon toolchain. It can be the -Hexagon toolchain included in the SDK, for example -`${USE_HEXAGON_TOOLCHAIN}/tools/HEXAGON_Tools/x.y.z/Tools`. The `x.y.z` -in the path is the toolchain version number, which is specific to the -version of the SDK. - -Additionally, the variable `USE_OUTPUT_BINARY_DIR` can be set to indicate -the location where the generated binaries will be placed. If not set, it -defaults to `hexagon_rpc` subdirectory in the current build directory. - - -## Build - -The build will generate the following binaries: -- `tvm_runtime.so`: TVM runtime for Android (shared library). -- `tvm_rpc_android`: RPC server for Android. -- `libhexagon_rpc_skel.so`: RPC library for Hexagon. -- `libtvm_runtime.a`: TVM runtime for Hexagon (static library). - -The RPC library for Hexagon contains the TVM runtime, so the static -TVM runtime for Hexagon is not strictly necessary. diff --git a/apps/hexagon_launcher/CMakeLists.txt b/apps/hexagon_launcher/CMakeLists.txt deleted file mode 100644 index c08e743a2592..000000000000 --- a/apps/hexagon_launcher/CMakeLists.txt +++ /dev/null @@ -1,81 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -cmake_minimum_required(VERSION 3.2) -project(HexagonLauncher C CXX) - -include(ExternalProject) - -set(LAUNCHER_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}") - -set(VARS_NEEDED - ANDROID_ABI - ANDROID_PLATFORM - USE_ANDROID_TOOLCHAIN - USE_HEXAGON_ARCH - USE_HEXAGON_SDK - USE_HEXAGON_TOOLCHAIN -) -foreach(V IN LISTS VARS_NEEDED) - if(NOT ${V}) - message(SEND_ERROR "Please set ${V}") - endif() -endforeach() - - -ExternalProject_Add(android_launcher_binaries - SOURCE_DIR "${LAUNCHER_SOURCE_DIR}/cmake/android" - BUILD_COMMAND $(MAKE) - CMAKE_ARGS - "-DCMAKE_TOOLCHAIN_FILE=${USE_ANDROID_TOOLCHAIN}" - "-DANDROID_PLATFORM=${ANDROID_PLATFORM}" - "-DANDROID_ABI=${ANDROID_ABI}" - "-DCMAKE_CXX_STANDARD=17" - "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" - "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" - INSTALL_COMMAND "" - BUILD_ALWAYS ON -) -ExternalProject_Get_Property(android_launcher_binaries BINARY_DIR) -ExternalProject_Add_Step(android_launcher_binaries copy_binaries - COMMAND ${CMAKE_COMMAND} -E copy_if_different - ${BINARY_DIR}/launcher_android - ${BINARY_DIR}/lib/libtvm_runtime.so - ${CMAKE_CURRENT_BINARY_DIR} - DEPENDEES install -) - -ExternalProject_Add(hexagon_launcher_binaries - SOURCE_DIR "${LAUNCHER_SOURCE_DIR}/cmake/hexagon" - BUILD_COMMAND $(MAKE) - CMAKE_ARGS - "-DCMAKE_C_COMPILER=${USE_HEXAGON_TOOLCHAIN}/bin/hexagon-clang" - "-DCMAKE_CXX_COMPILER=${USE_HEXAGON_TOOLCHAIN}/bin/hexagon-clang++" - "-DCMAKE_CXX_STANDARD=17" - "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" - "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" - "-DUSE_CUSTOM_LOGGING=ON" - INSTALL_COMMAND "" - BUILD_ALWAYS ON -) -ExternalProject_Get_Property(hexagon_launcher_binaries BINARY_DIR) -ExternalProject_Add_Step(hexagon_launcher_binaries copy_binaries - COMMAND ${CMAKE_COMMAND} -E copy_if_different - ${BINARY_DIR}/liblauncher_rpc_skel.so - ${CMAKE_CURRENT_BINARY_DIR} - DEPENDEES install -) diff --git a/apps/hexagon_launcher/README.md b/apps/hexagon_launcher/README.md deleted file mode 100644 index f3c3519ea7d9..000000000000 --- a/apps/hexagon_launcher/README.md +++ /dev/null @@ -1,145 +0,0 @@ - - - - - - - - - - - - - - - - -# Hexagon Graph Launcher - -## Compilation - -The launcher consists of two parts: part running on Hexagon, and part running -on Android. Each component must be compiled separately. - -The supported Snapdragon architectures are 855, 865, and 888. - -### Prerequisites - -1. Android NDK version r19c or later. -2. Hexagon SDK version 4.0.0 or later. - -Android NDK can be downloaded from https://developer.android.com/ndk. -Hexagon SDK is available at https://developer.qualcomm.com/software/hexagon-dsp-sdk. - -### Manual compilation - -Since some source files are shared between the Hexagon and Android builds, -make sure to delete all object files between compilations. Compile the Hexagon -code first. - -#### Compilation of the Hexagon part - -Create a subdirectory for the build files, and run `cmake` with the -following variables set: - -``` -cmake -DCMAKE_C_COMPILER=/path/to/hexagon-clang \ - -DCMAKE_CXX_COMPILER=/path/to/hexagon-clang++ \ - -DUSE_HEXAGON_ARCH=v65|v66|v68|v69|v73|v75 \ - -DUSE_HEXAGON_SDK=/path/to/hexagon/SDK \ - /path/to/apps/hexagon_launcher/cmake/hexagon -``` - -Run `make`. This will create `liblauncher_rpc_skel.so`. The static version of -the TVM runtime for Hexagon will be built as a part of the process. - -#### Compilation of the Android part - -2. Create a subdirectory for the build files (different from the one used for - Hexagon files), and run `cmake` with the following variables set: - -``` -cmake -DCMAKE_TOOLCHAIN_FILE=/path/to/android-ndk/build/cmake/android.toolchain.cmake \ - -DANDROID_ABI=arm64-v8a \ - -DANDROID_PLATFORM=android-28 \ - -DUSE_HEXAGON_SDK=/p/Hexagon_SDK/4.3.0.0 \ - -DUSE_HEXAGON_ARCH=v65|v66|v68|v69|v73|v75 \ - /path/to/apps/hexagon_launcher/cmake/android -``` - -Run `make`. This will create `launcher_android`. The TVM runtime for Android will -be built as a part of the process. Depending on the version of CMake that you are -using, you may see the following warnings---they can be ignored. - -``` -An old version of CMake is being used that cannot automatically detect -compiler attributes. Compiler identification is being bypassed. Some -values may be wrong or missing. Update to CMake 3.19 or newer to use -CMake's built-in compiler identification. -``` - -## Execution - -From the Android shell, do -``` -./launcher_android --in_config input.json --out_config output.json -``` - -You may need to add the location of `libtvm_runtime.so` to `LD_LIBRARY_PATH`. -See below for more information about the setup and launcher's inputs. - -### Preparation steps - -Copy the following binaries to the device: -- `liblauncher_rpc_skel.so`: created by the compilation step for Hexagon, -- `libgcc.so`: take this one from the Hexagon toolchain, -- `launcher_android`: created by the compilation step for Android, -- `libtvm_runtime.so`: built for Android. - -These are only the binaries related to the launcher itself. To run a model -copy the shared object with the model and the model JSON file over to the -device (both are obtained from relay). Also, copy all input files for the -model as well. - -## Profiling using hexagon launcher - -### Enabling lightweight profiling (LWP) instrumentation - -This profiling option can be used to get function and loop level processor cycles. -This needs to be enabled explicitly while compiling a model. - -Here, `instrument_lwp` is used to enable the tir pass which instruments the code with the builtin calls. - -During codegen, profiling builtin calls can be replaced with a target specific handler to record runtime -information into a buffer. This buffer is written into a JSON file which is processed to construct -function and loop level profiling information. - -To generate LWP JSON file, add `--gen_lwp_json` flag to launcher_android: - -``` -./launcher_android --in_config input.json --out_config output.json --gen_lwp_json -``` - -Please note that `--gen_lwp_json` flag by itself doesn't enable profiling and is only used to dump -the profiling data into a json file called lwp.json. This file will be created at the same location -on the device where launcher_android is executed from. To generate the data, profiling instrumentation -must be enabled while compiling a model as mentioned above. - -Use this command to pull `lwp.json` from the device: - -``` -adb -s pull /path/to/lwp.json -``` - -**Note:** Please refer to src/runtime/hexagon/profiler/README.md for information on how -to enable profiling using Hexagon RPC launcher and also to learn about additional profiling related -config options. - -# Disclaimer - -The launcher does not perform any correctness verification. In order to verify -correctness, the user needs to copy the output files from the device and -verify their contents. - -This launcher is intended for use with prototyping and does not utilize any -performance acceleration, as such the measured performance may be very poor. diff --git a/apps/hexagon_launcher/cmake/HexagonLauncher.cmake b/apps/hexagon_launcher/cmake/HexagonLauncher.cmake deleted file mode 100644 index 52e2cff1f895..000000000000 --- a/apps/hexagon_launcher/cmake/HexagonLauncher.cmake +++ /dev/null @@ -1,75 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# On successful execution, sets -# SDK_INCLUDE_DIRS -# QAIC_EXE_PATH -# and -# QAIC_FLAGS -# LAUNCHER_SRC -# LAUNCHER_RPC_IDL -# LAUNCHER_RPC_H -# LAUNCHER_RPC_SKEL_C -# LAUNCHER_RPC_STUB_C - -if(USE_CUSTOM_LOGGING) - add_definitions(-DTVM_LOG_CUSTOMIZE=1) -endif() -if(NOT DEFINED USE_HEXAGON_SDK) - message(SEND_ERROR "Please set USE_HEXAGON_SDK to the location of Hexagon SDK") -endif() -if (NOT DEFINED USE_HEXAGON_ARCH) - message(SEND_ERROR "Please set USE_HEXAGON_ARCH to the Hexagon architecture version") -endif() - -set(TVM_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../../") - -include(ExternalProject) -include("${TVM_SOURCE_DIR}/cmake/utils/Utils.cmake") -include("${TVM_SOURCE_DIR}/cmake/modules/HexagonSDK.cmake") - -get_hexagon_sdk_property("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}" - SDK_INCLUDE SDK_INCLUDE_DIRS - QAIC_EXE QAIC_EXE_PATH -) -if(NOT SDK_INCLUDE_DIRS OR NOT QAIC_EXE_PATH) - message(WARNING "Could not locate some Hexagon SDK components") -endif() - -include_directories(SYSTEM ${SDK_INCLUDE_DIRS}) - -foreach(INCDIR IN LISTS SDK_INCLUDE_DIRS) - list(APPEND QAIC_FLAGS "-I${INCDIR}") -endforeach() - -set(LAUNCHER_SRC "${CMAKE_CURRENT_SOURCE_DIR}/../../") -set(CMAKE_SKIP_RPATH TRUE) - -# Qaic for the domain header. -# -# Don't add paths to these filenames, or otherwise CMake may spontaneously -# add -o option to the qaic invocation (with an undesirable path). -set(LAUNCHER_RPC_IDL "launcher_rpc.idl") -set(LAUNCHER_RPC_H "launcher_rpc.h") -set(LAUNCHER_RPC_SKEL_C "launcher_rpc_skel.c") -set(LAUNCHER_RPC_STUB_C "launcher_rpc_stub.c") - -include_directories( - "${LAUNCHER_SRC}" - "${TVM_SOURCE_DIR}/include" - "${TVM_SOURCE_DIR}/3rdparty/dlpack/include" -) diff --git a/apps/hexagon_launcher/cmake/android/CMakeLists.txt b/apps/hexagon_launcher/cmake/android/CMakeLists.txt deleted file mode 100644 index 0846ce786909..000000000000 --- a/apps/hexagon_launcher/cmake/android/CMakeLists.txt +++ /dev/null @@ -1,100 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -cmake_minimum_required(VERSION 3.2) -project(HexagonAndroidLauncher C CXX) - -include("${CMAKE_CURRENT_SOURCE_DIR}/../HexagonLauncher.cmake") -# From the include above, get -# SDK_INCLUDE_DIRS -# QAIC_EXE_PATH -# and -# QAIC_FLAGS -# LAUNCHER_SRC -# LAUNCHER_RPC_IDL -# LAUNCHER_RPC_H -# LAUNCHER_RPC_SKEL_C -# LAUNCHER_RPC_STUB_C - -add_custom_command( - OUTPUT ${LAUNCHER_RPC_STUB_C} ${LAUNCHER_RPC_H} - COMMAND ${QAIC_EXE_PATH} ${QAIC_FLAGS} "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" - MAIN_DEPENDENCY "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" -) - -get_hexagon_sdk_property("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}" - RPCMEM_ROOT RPCMEM_ROOT_DIR - DSPRPC_LIB DSPRPC_LIB_DIRS -) -if(NOT RPCMEM_ROOT_DIR) - message(WARNING "Could not locate some Hexagon SDK components") -endif() - -include_directories(SYSTEM - "${SDK_INCLUDE_DIRS}" - "${RPCMEM_ROOT_DIR}/inc" - "${CMAKE_CURRENT_BINARY_DIR}" # Output of qaic will go here -) - -link_directories(${DSPRPC_LIB_DIRS}) - - -set(STUB_SRCS - "${LAUNCHER_SRC}/launcher_android.cc" - "${LAUNCHER_SRC}/launcher_core.cc" - "${LAUNCHER_SRC}/launcher_main.cc" - "${LAUNCHER_SRC}/launcher_util.cc" -) - -add_executable(launcher_android - "${LAUNCHER_RPC_H}" - "${LAUNCHER_RPC_STUB_C}" - "${STUB_SRCS}" -) - -ExternalProject_Add(android_tvm_runtime - SOURCE_DIR "${TVM_SOURCE_DIR}" - BUILD_COMMAND $(MAKE) runtime - CMAKE_ARGS - "-DANDROID_ABI=${ANDROID_ABI}" - "-DANDROID_PLATFORM=${ANDROID_PLATFORM}" - "-DCMAKE_CXX_STANDARD=17" - "-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE}" - "-DUSE_HEXAGON=ON" - "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" - "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" - "-DTVM_FFI_USE_LIBBACKTRACE=OFF" - "-DTVM_FFI_USE_THREADS=OFF" - "-DTVM_FFI_USE_DL_LIBS=OFF" - "-DUSE_LLVM=OFF" - "-DUSE_RPC=OFF" - INSTALL_COMMAND "" - BUILD_ALWAYS ON -) -ExternalProject_Get_Property(android_tvm_runtime BINARY_DIR) -ExternalProject_Add_Step(android_tvm_runtime copy_binaries - COMMAND ${CMAKE_COMMAND} -E copy_if_different - ${BINARY_DIR}/lib/libtvm_runtime.so - ${CMAKE_CURRENT_BINARY_DIR} - DEPENDEES install -) - -add_dependencies(launcher_android android_tvm_runtime) -add_library(a_tvm_runtime SHARED IMPORTED) -set_target_properties(a_tvm_runtime PROPERTIES IMPORTED_LOCATION "${BINARY_DIR}/lib/libtvm_runtime.so") - -target_link_libraries(launcher_android cdsprpc log a_tvm_runtime) diff --git a/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt deleted file mode 100644 index a0557307ba50..000000000000 --- a/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt +++ /dev/null @@ -1,105 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -cmake_minimum_required(VERSION 3.2) -project(HexagonLauncherRPCSkel C CXX ASM) - -include("${CMAKE_CURRENT_SOURCE_DIR}/../HexagonLauncher.cmake") -# From the include above get -# SDK_INCLUDE_DIRS -# QAIC_EXE_PATH -# and -# QAIC_FLAGS -# LAUNCHER_SRC -# LAUNCHER_RPC_IDL -# LAUNCHER_RPC_H -# LAUNCHER_RPC_SKEL_C -# LAUNCHER_RPC_STUB_C - -add_custom_command( - OUTPUT ${LAUNCHER_RPC_SKEL_C} ${LAUNCHER_RPC_H} - COMMAND ${QAIC_EXE_PATH} ${QAIC_FLAGS} "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" - MAIN_DEPENDENCY "${LAUNCHER_SRC}/${LAUNCHER_RPC_IDL}" -) - -get_hexagon_sdk_property("${USE_HEXAGON_SDK}" "${USE_HEXAGON_ARCH}" - QURT_INCLUDE QURT_INCLUDE_DIRS - QURT_LIB QURT_LIB_DIRS -) -if(NOT QURT_INCLUDE_DIRS OR NOT QURT_LIB_DIRS) - message(WARNING "Could not locate some Hexagon SDK components") -endif() - -include_directories(SYSTEM - ${QURT_INCLUDE_DIRS} - ${CMAKE_CURRENT_BINARY_DIR} # Output of qaic will go here -) - -link_directories(${QURT_LIB_DIRS}) - -add_definitions(-D_MACH_I32=int) - -# Extra compile flags (both C and C++). -set(EXTRA_COMP_FLAGS - "-O3" - "-m${USE_HEXAGON_ARCH}" -) -string(REGEX REPLACE ";" " " EXTRA_COMP_FLAGS_STR "${EXTRA_COMP_FLAGS}") -set(CMAKE_C_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_C_FLAGS}") -set(CMAKE_CXX_FLAGS "${EXTRA_COMP_FLAGS_STR} ${CMAKE_CXX_FLAGS}") - -set(SKEL_SRCS - "${LAUNCHER_SRC}/launcher_core.cc" - "${LAUNCHER_SRC}/launcher_hexagon.cc" -) -set(PROFILER_DIR "${TVM_SOURCE_DIR}/src/runtime/hexagon/profiler") - -add_library(launcher_rpc_skel SHARED - "${LAUNCHER_RPC_H}" - "${LAUNCHER_RPC_SKEL_C}" - "${SKEL_SRCS}" - "${PROFILER_DIR}/prof_utils.cc" - "${PROFILER_DIR}/lwp_handler.S" -) - -ExternalProject_Add(static_hexagon_tvm_runtime - SOURCE_DIR "${TVM_SOURCE_DIR}" - BUILD_COMMAND $(MAKE) runtime - CMAKE_ARGS - "-DBUILD_STATIC_RUNTIME=ON" - "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" - "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" - "-DCMAKE_CXX_STANDARD=17" - "-DUSE_HEXAGON=ON" - "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" - "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" - "-DTVM_FFI_USE_LIBBACKTRACE=OFF" - "-DTVM_FFI_USE_THREADS=OFF" - "-DTVM_FFI_USE_DL_LIBS=OFF" - "-DUSE_LLVM=OFF" - "-DUSE_RPC=OFF" - "-DUSE_CUSTOM_LOGGING=ON" - INSTALL_COMMAND "" - BUILD_ALWAYS ON -) -ExternalProject_Get_Property(static_hexagon_tvm_runtime BINARY_DIR) - -add_dependencies(launcher_rpc_skel static_hexagon_tvm_runtime) -add_library(h_tvm_runtime STATIC IMPORTED) -set_target_properties(h_tvm_runtime PROPERTIES IMPORTED_LOCATION "${BINARY_DIR}/libtvm_runtime.a") - -target_link_libraries(launcher_rpc_skel -Wl,--whole-archive h_tvm_runtime -Wl,--no-whole-archive) diff --git a/apps/hexagon_launcher/launcher_android.cc b/apps/hexagon_launcher/launcher_android.cc deleted file mode 100644 index 34db0bdacb60..000000000000 --- a/apps/hexagon_launcher/launcher_android.cc +++ /dev/null @@ -1,170 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -#include "launcher_core.h" -#include "launcher_rpc.h" - -AEEResult enable_unsigned_pd(bool enable) { - remote_rpc_control_unsigned_module data; - data.domain = CDSP_DOMAIN_ID; - data.enable = static_cast(enable); - AEEResult rc = remote_session_control(DSPRPC_CONTROL_UNSIGNED_MODULE, &data, sizeof(data)); - if (rc != AEE_SUCCESS) { - std::cout << "error " << (enable ? "enabling" : "disabling") << " unsigned PD\n"; - } - return rc; -} - -AEEResult set_remote_stack_size(int size) { - remote_rpc_thread_params data; - data.domain = CDSP_DOMAIN_ID; - data.prio = -1; - data.stack_size = size; - AEEResult rc = remote_session_control(FASTRPC_THREAD_PARAMS, &data, sizeof(data)); - if (rc != AEE_SUCCESS) { - std::cout << "error setting remote stack size: " << std::hex << rc << '\n'; - } - return rc; -} - -struct RPCChannel : public ExecutionSession { - explicit RPCChannel(const std::string& uri, bool gen_lwp_json = false) - : ExecutionSession(gen_lwp_json) { - enable_unsigned_pd(true); - set_remote_stack_size(128 * 1024); - - int rc = launcher_rpc_open(uri.c_str(), &handle); - if (rc != AEE_SUCCESS) { - handle = -1; - } - } - - ~RPCChannel() { - if (handle == -1) { - return; - } - - for (void* ptr : allocations) { - rpcmem_free(ptr); - } - if (model_loaded) { - unload_model(); - } - launcher_rpc_close(handle); - handle = -1; - } - - void* alloc_mem(size_t nbytes, size_t align) override { - void* host_ptr = rpcmem_alloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, nbytes); - if (host_ptr != nullptr) { - allocations.push_back(host_ptr); - } - return host_ptr; - } - - void free_mem(void* addr) override { - auto f = std::find(allocations.begin(), allocations.end(), addr); - if (f != allocations.end()) { - allocations.erase(f); - rpcmem_free(addr); - } - } - - bool load_model(const std::string& model_path, const std::string& model_json) override { - AEEResult rc = launcher_rpc_load(handle, model_path.c_str(), model_json.c_str()); - if (rc != AEE_SUCCESS) { - std::cout << "error loading graph module: " << std::hex << rc << '\n'; - } else { - model_loaded = true; - } - return rc == AEE_SUCCESS; - } - - bool unload_model() override { - AEEResult rc = launcher_rpc_unload(handle); - if (rc != AEE_SUCCESS) { - std::cout << "error unloading model: " << std::hex << rc << '\n'; - } - model_loaded = false; - return rc == AEE_SUCCESS; - } - - bool set_input(int input_idx, const tensor_meta* input_meta, const void* input_data) override { - AEEResult rc = launcher_rpc_set_input( - handle, input_idx, reinterpret_cast(input_meta), - input_meta->meta_size(), reinterpret_cast(input_data), - input_meta->data_size()); - if (rc != AEE_SUCCESS) { - std::cout << "error setting model input no." << input_idx << ": " << std::hex << rc << '\n'; - } - return rc == AEE_SUCCESS; - } - - bool run(uint64_t* pcycles, uint64_t* usecs) override { - AEEResult rc = launcher_rpc_run(handle, pcycles, usecs, gen_lwp_json); - if (rc != AEE_SUCCESS) { - std::cout << "error running model: " << std::hex << rc << '\n'; - } - return rc == AEE_SUCCESS; - } - - bool get_num_outputs(int* num_outputs) override { - AEEResult rc = launcher_rpc_get_num_outputs(handle, num_outputs); - if (rc != AEE_SUCCESS) { - std::cout << "error getting number of outputs: " << std::hex << rc << '\n'; - } - return rc == AEE_SUCCESS; - } - - bool get_output(int output_idx, tensor_meta* output_meta, int meta_size, void* output_data, - int data_size) override { - AEEResult rc = launcher_rpc_get_output( - handle, output_idx, reinterpret_cast(output_meta), meta_size, - reinterpret_cast(output_data), data_size); - if (rc != AEE_SUCCESS) { - std::cout << "error getting output no." << output_idx << ": " << std::hex << rc << '\n'; - } - return rc == AEE_SUCCESS; - } - - bool model_loaded = false; - remote_handle64 handle = -1; - std::vector allocations; -}; - -ExecutionSession* create_execution_session(bool gen_lwp_json) { - auto* session = new RPCChannel(launcher_rpc_URI CDSP_DOMAIN, gen_lwp_json); - if (session->handle == -1) { - delete session; - session = nullptr; - std::cout << "Error opening FastRPC channel\n"; - } - return session; -} diff --git a/apps/hexagon_launcher/launcher_core.cc b/apps/hexagon_launcher/launcher_core.cc deleted file mode 100644 index 44fc48c92701..000000000000 --- a/apps/hexagon_launcher/launcher_core.cc +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "launcher_core.h" - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -const std::string TensorConfig::file_key = "file"; // NOLINT(runtime/string) -const std::string TensorConfig::shape_key = "shape"; // NOLINT(runtime/string) -const std::string TensorConfig::dtype_key = "dtype"; // NOLINT(runtime/string) - -std::string tensor_meta::to_string() const { - std::stringstream out; - out << "ndim=" << ndim << ", dtype=" << tvm::runtime::DLDataTypeToString(dtype) << ", shape="; - for (int i = 0; i != ndim; ++i) { - out << shape[i]; - if (i + 1 < ndim) { - out << 'x'; - } - } - return out.str(); -} - -void TensorConfig::Load(tvm::ffi::json::Object obj) { - namespace json = ::tvm::ffi::json; - for (const auto& kv : obj) { - std::string key = std::string(kv.first.cast()); - if (key == file_key) { - file_name = std::string(kv.second.cast()); - } else if (key == shape_key) { - auto arr = kv.second.cast(); - shape.clear(); - shape.reserve(arr.size()); - for (const auto& elem : arr) { - shape.push_back(static_cast(elem.cast())); - } - if (shape.empty()) { - std::cout << "error: empty shape\n"; - bad = true; - } - } else if (key == dtype_key) { - dtype = std::string(kv.second.cast()); - } else { - std::cout << "unknown tensor config key: " << key << '\n'; - bad = true; - } - } -} - -tvm::ffi::json::Value TensorConfig::SaveToJSON() const { - namespace json = ::tvm::ffi::json; - json::Object obj; - obj.Set(tvm::ffi::String(file_key), tvm::ffi::String(file_name)); - json::Array shape_arr; - for (int v : shape) { - shape_arr.push_back(static_cast(v)); - } - obj.Set(tvm::ffi::String(shape_key), std::move(shape_arr)); - obj.Set(tvm::ffi::String(dtype_key), tvm::ffi::String(dtype)); - return obj; -} - -void ModelConfig::Load(tvm::ffi::json::Object obj) { - namespace json = ::tvm::ffi::json; - for (const auto& kv : obj) { - std::string key = std::string(kv.first.cast()); - if (key == "model-library") { - model_library = std::string(kv.second.cast()); - } else if (key == "model-json") { - model_json = std::string(kv.second.cast()); - } else if (key == "inputs") { - auto arr = kv.second.cast(); - inputs.clear(); - inputs.reserve(arr.size()); - for (const auto& elem : arr) { - TensorConfig tc; - tc.Load(elem.cast()); - inputs.push_back(std::move(tc)); - } - bad = std::any_of(inputs.begin(), inputs.end(), [](auto t) { return t.bad; }); - } else { - std::cout << "unknown model config key: " << key << '\n'; - bad = true; - } - } -} - -tvm::ffi::json::Value OutputConfig::SaveToJSON() const { - namespace json = ::tvm::ffi::json; - json::Object obj; - obj.Set(tvm::ffi::String("pcycles"), static_cast(pcycles)); - obj.Set(tvm::ffi::String("usecs"), static_cast(usecs)); - json::Array outputs_arr; - for (const auto& tc : outputs) { - outputs_arr.push_back(tc.SaveToJSON()); - } - obj.Set(tvm::ffi::String("outputs"), std::move(outputs_arr)); - return obj; -} - -bool read_model_config(const std::string& file_name, ModelConfig* model_config) { - namespace json = ::tvm::ffi::json; - if (model_config == nullptr) { - return false; - } - std::ifstream mfc(file_name); - if (!mfc.is_open()) { - return false; - } - std::string content((std::istreambuf_iterator(mfc)), std::istreambuf_iterator()); - auto parsed = json::Parse(content); - model_config->Load(parsed.cast()); - if (model_config->bad) { - return false; - } - return true; -} - -bool write_output_config(const std::string& file_name, OutputConfig* output_config) { - namespace json = ::tvm::ffi::json; - std::ofstream ofc(file_name); - if (!ofc.is_open()) { - return false; - } - ofc << std::string(json::Stringify(output_config->SaveToJSON(), 2)); - if (!ofc) { - return false; - } - return true; -} - -Model::Model(tvm::runtime::Module executor, tvm::runtime::Module module, std::string json) - : model_executor(executor), graph_module(module), graph_json(json) { - // Lookup "run" ahead of time to reduce overhead in the model execution. - run = get_module_func(model_executor, "run"); -} - -const tvm::ffi::Function get_runtime_func(const std::string& name) { - if (auto pf = tvm::ffi::Function::GetGlobal(name)) { - return *pf; - } - return tvm::ffi::Function(); -} - -const tvm::ffi::Function get_module_func(tvm::runtime::Module module, const std::string& name) { - return module->GetFunction(name, false).value_or(tvm::ffi::Function()); -} - -void reset_device_api() { - const tvm::ffi::Function api = get_runtime_func("device_api.hexagon"); - tvm::ffi::Function::SetGlobal("device_api.cpu", api, true); -} - -tvm::runtime::Module load_module(const std::string& file_name) { - static const tvm::ffi::Function loader = get_runtime_func("ffi.Module.load_from_file.hexagon"); - tvm::ffi::Any rv = loader(file_name); - if (rv.type_code() == kTVMModuleHandle) { - TVM_FFI_ICHECK_EQ(rv.type_code(), kTVMModuleHandle) - << __func__ << ": loaded " << file_name << ", but did not get module handle"; - return rv.operator tvm::runtime::Module(); - } - return tvm::runtime::Module(); -} - -std::ostream& operator<<(std::ostream& os, const tvm::ffi::Array& strings) { - os << '['; - for (int i = 0, e = strings.size(); i != e; ++i) { - if (i != 0) os << ','; - os << static_cast(strings[i]); - } - os << ']'; - return os; -} - -tvm::runtime::Module create_graph_executor(const std::string& graph_json, - tvm::runtime::Module graph_module, tvm::Device device) { - std::string launcher_name = "tvm.graph_executor.create"; - - const tvm::ffi::Function create_executor = get_runtime_func(launcher_name); - uint64_t device_type = device.device_type; - uint64_t device_id = device.device_id; - - if (graph_json.empty()) { - LOG(ERROR) << __func__ << ": graph executor requires graph JSON"; - return tvm::runtime::Module(); - } - tvm::ffi::Any rv = create_executor(graph_json, graph_module, device_type, device_id); - return rv.operator tvm::runtime::Module(); -} - -tvm::runtime::Module create_aot_executor(tvm::runtime::Module factory_module, tvm::Device device) { - tvm::ffi::Function list_modules = get_module_func(factory_module, "list_module_names"); - tvm::ffi::Array module_names = list_modules(); - if (module_names.size() != 1) { - LOG(WARNING) << __func__ << ": expecting single module, got: " << module_names << ", using " - << module_names[0]; - } - tvm::ffi::Function f = get_module_func(factory_module, module_names[0]); - if (f.get() == nullptr) { - LOG(ERROR) << __func__ << ": failed to obtain function " << module_names[0]; - return tvm::runtime::Module(); - } - return f(device); -} diff --git a/apps/hexagon_launcher/launcher_core.h b/apps/hexagon_launcher/launcher_core.h deleted file mode 100644 index 1af8a1fbf8f1..000000000000 --- a/apps/hexagon_launcher/launcher_core.h +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_HEXAGON_LAUNCHER_LAUNCHER_CORE_H_ -#define TVM_RUNTIME_HEXAGON_LAUNCHER_LAUNCHER_CORE_H_ - -#include -#include -#include -#include -#include -#include - -#include -#include - -struct tensor_meta { - int ndim; - DLDataType dtype; - int64_t shape[]; - - int meta_size() const { return meta_size(ndim); } - int data_size() const { - int size = tvm::runtime::DataType(dtype).bytes(); - for (int d = 0; d != ndim; ++d) { - size *= shape[d]; - } - return size; - } - - static int meta_size(int ndim) { return sizeof(tensor_meta) + ndim * sizeof(int64_t); } - - std::string to_string() const; -}; - -struct TensorConfig { - static const std::string file_key; - static const std::string shape_key; - static const std::string dtype_key; - - std::string file_name; - std::vector shape; - std::string dtype; - bool bad = false; - - void Load(tvm::ffi::json::Object obj); - tvm::ffi::json::Value SaveToJSON() const; -}; - -struct ModelConfig { - std::string model_library; - std::string model_json; - std::vector inputs; - bool bad = false; - - void Load(tvm::ffi::json::Object obj); -}; - -struct OutputConfig { - uint64_t pcycles; - uint64_t usecs; - std::vector outputs; - - tvm::ffi::json::Value SaveToJSON() const; -}; - -struct Model { - Model(tvm::ffi::Module executor, tvm::ffi::Module module, std::string json); - - tvm::ffi::Module model_executor; - tvm::ffi::Module graph_module; - std::string graph_json; - - static tvm::Device device() { return tvm::Device{static_cast(kDLHexagon), 0}; } - static tvm::Device external() { return tvm::Device{static_cast(kDLCPU), 0}; } - - tvm::ffi::Function run; -}; - -struct ExecutionSession { - explicit ExecutionSession(bool lwp_json = false) : gen_lwp_json(lwp_json) {} - - template - T* alloc(size_t bytes, size_t align = 1) { - return reinterpret_cast(alloc_mem(bytes, align)); - } - void free(void* ptr) { free_mem(ptr); } - - virtual void* alloc_mem(size_t bytes, size_t align) = 0; - virtual void free_mem(void* ptr) = 0; - - virtual bool load_model(const std::string& model_path, const std::string& model_json) = 0; - virtual bool unload_model() = 0; - - virtual bool set_input(int input_idx, const tensor_meta* input_meta, const void* input_data) = 0; - virtual bool run(uint64_t* pcycles, uint64_t* usecs) = 0; - virtual bool get_num_outputs(int* num_outputs) = 0; - virtual bool get_output(int output_idx, tensor_meta* output_meta, int meta_size, - void* output_data, int data_size) = 0; - bool gen_lwp_json = false; -}; - -bool read_model_config(const std::string& file_name, ModelConfig* model_config); -bool write_output_config(const std::string& file_name, OutputConfig* output_config); - -void reset_device_api(); - -tvm::ffi::Module load_module(const std::string& file_name); - -const tvm::ffi::Function get_runtime_func(const std::string& name); -const tvm::ffi::Function get_module_func(tvm::ffi::Module module, const std::string& name); - -tvm::ffi::Module create_aot_executor(tvm::ffi::Module factory_module, tvm::Device device); -tvm::ffi::Module create_graph_executor(const std::string& graph_json, tvm::ffi::Module graph_module, - tvm::Device device); - -#endif // TVM_RUNTIME_HEXAGON_LAUNCHER_LAUNCHER_CORE_H_ diff --git a/apps/hexagon_launcher/launcher_hexagon.cc b/apps/hexagon_launcher/launcher_hexagon.cc deleted file mode 100644 index 32f1e621986d..000000000000 --- a/apps/hexagon_launcher/launcher_hexagon.cc +++ /dev/null @@ -1,237 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -extern "C" { -#include -#include -#include -#include -#include -} - -#include -#include - -#include -#include -#include - -#include "launcher_core.h" -#include "launcher_rpc.h" - -static std::unique_ptr TheModel; -bool WriteLWPOutput(const std::string&); - -static AEEResult error_too_small(const std::string& func_name, const std::string& value_name, - int given, int needed) { - LOG(ERROR) << func_name.c_str() << ": " << value_name.c_str() << " value too small (" << given - << "), need at least " << needed; - return AEE_EBADPARM; -} - -int __QAIC_HEADER(launcher_rpc_open)(const char* uri, remote_handle64* handle) { - *handle = 0; // Just use any value. - reset_device_api(); - static const tvm::ffi::Function acq_res = - get_runtime_func("device_api.hexagon.acquire_resources"); - acq_res(); - return AEE_SUCCESS; -} - -int __QAIC_HEADER(launcher_rpc_close)(remote_handle64 handle) { - // Comment to stop clang-format from single-lining this function. - static const tvm::ffi::Function rel_res = - get_runtime_func("device_api.hexagon.release_resources"); - rel_res(); - return AEE_SUCCESS; -} - -AEEResult __QAIC_HEADER(launcher_rpc_load)(remote_handle64 handle, const char* module_path, - const char* graph_json) { - if (TheModel) { - // Need to unload first. - LOG(ERROR) << __func__ << ": model already loaded, unload first"; - return AEE_EUNABLETOLOAD; - } - - tvm::runtime::Module module = load_module(module_path); - std::string module_type = module->type_key(); - tvm::runtime::Module executor; - if (module_type == "AotExecutorFactory") { - executor = create_aot_executor(module, Model::external()); - } else if (module_type == "library") { - // We're not expecting "GraphExecutorFactory" here. - executor = create_graph_executor(graph_json, module, Model::device()); - } else { - LOG(ERROR) << __func__ << ": unexpected module type: " << module_type; - // Fall through. - } - - if (executor.get() == nullptr) { - LOG(ERROR) << __func__ << ": failed to create executor for module" << module_path; - return AEE_EUNABLETOLOAD; - } - - TheModel = std::make_unique(executor, module, graph_json); - return AEE_SUCCESS; -} - -AEEResult __QAIC_HEADER(launcher_rpc_unload)(remote_handle64 handle) { - if (TheModel) { - TheModel.reset(); - } - return AEE_SUCCESS; -} - -AEEResult __QAIC_HEADER(launcher_rpc_get_num_inputs)(remote_handle64 handle, int* num_inputs) { - if (!TheModel) { - // No model created. - return AEE_EBADSTATE; - } - - tvm::ffi::Function get_num_inputs = get_module_func(TheModel->model_executor, "get_num_inputs"); - *num_inputs = get_num_inputs(); - return AEE_SUCCESS; -} - -AEEResult __QAIC_HEADER(launcher_rpc_set_input)(remote_handle64 handle, int input_idx, - const unsigned char* input_meta, int meta_size, - const unsigned char* input_value, int value_size) { - if (!TheModel) { - // No model created. - LOG(ERROR) << __func__ << ": no model created"; - return AEE_EBADSTATE; - } - - const auto* meta = reinterpret_cast(input_meta); - if (meta_size < meta->meta_size()) { - return error_too_small(__func__, "meta_size", meta_size, meta->meta_size()); - } - if (value_size < meta->data_size()) { - return error_too_small(__func__, "value_size", value_size, meta->data_size()); - } - - DLTensor tensor{ - const_cast(input_value), - Model::external(), - meta->ndim, - meta->dtype, - const_cast(meta->shape), - /*strides*/ nullptr, - /*byte_offset*/ 0, - }; - DLManagedTensor managed{tensor, /*manager_ctx*/ nullptr, /*deleter*/ nullptr}; - - auto input = tvm::runtime::Tensor::FromDLPack(&managed); - - tvm::ffi::Function set_input = get_module_func(TheModel->model_executor, "set_input"); - set_input(input_idx, input); - - return AEE_SUCCESS; -} - -AEEResult __QAIC_HEADER(launcher_rpc_get_num_outputs)(remote_handle64 handle, int* num_outputs) { - if (!TheModel) { - // No model created. - return AEE_EBADSTATE; - } - - tvm::ffi::Function get_num_outputs = get_module_func(TheModel->model_executor, "get_num_outputs"); - *num_outputs = get_num_outputs(); - return AEE_SUCCESS; -} - -AEEResult __QAIC_HEADER(launcher_rpc_get_output)(remote_handle64 handle, int output_idx, - unsigned char* output_meta, int meta_size, - unsigned char* output_value, int value_size) { - if (!TheModel) { - // No model created. - return AEE_EBADSTATE; - } - if (meta_size < 0 || value_size < 0) { - return AEE_EBADPARM; - } - if ((output_meta == nullptr && meta_size != 0) || (output_value == nullptr && value_size != 0)) { - // If the pointer is null, the size must be 0. - return AEE_EBADPARM; - } - - tvm::ffi::Function get_output = get_module_func(TheModel->model_executor, "get_output"); - tvm::runtime::Tensor output = get_output(output_idx); - - std::vector shape_vec{output->shape, output->shape + output->ndim}; - - auto* container = new tvm::runtime::Tensor::Container(static_cast(output_value), shape_vec, - output->dtype, Model::external()); - container->SetDeleter([](tvm::ffi::Object* container) { - delete static_cast(container); - }); - - tvm::runtime::Tensor host_output(tvm::ffi::GetObjectPtr(container)); - - if (meta_size != 0) { - auto* meta = reinterpret_cast(output_meta); - if (meta_size < meta->meta_size(output->ndim)) { - return error_too_small(__func__, "meta_size", meta_size, meta->meta_size(output->ndim)); - } - - meta->ndim = output->ndim; - meta->dtype = output->dtype; - std::copy(&output->shape[0], &output->shape[output->ndim], meta->shape); - } - - if (value_size != 0) { - size_t data_size = tvm::runtime::GetDataSize(*output.operator->()); - if (value_size < data_size) { - return error_too_small(__func__, "value_size", value_size, data_size); - } - - host_output.CopyFrom(output); - } - - return AEE_SUCCESS; -} - -AEEResult __QAIC_HEADER(launcher_rpc_run)(remote_handle64 handle, uint64_t* pcycles, - uint64_t* usecs, int gen_lwp_json) { - if (!TheModel) { - // No model created. - LOG(ERROR) << __func__ << ": no model created"; - return AEE_EBADSTATE; - } - - uint64_t us_begin = HAP_perf_get_time_us(); - uint64_t pc_begin = HAP_perf_get_pcycles(); - - TheModel->run(); - - uint64_t pc_end = HAP_perf_get_pcycles(); - uint64_t us_end = HAP_perf_get_time_us(); - *pcycles = pc_end - pc_begin; - *usecs = us_end - us_begin; - - if (gen_lwp_json) { - if (!WriteLWPOutput("lwp.json")) { - LOG(ERROR) << "ERROR: failed to generate lwp json file"; - return AEE_EFAILED; - } - } - - return AEE_SUCCESS; -} diff --git a/apps/hexagon_launcher/launcher_main.cc b/apps/hexagon_launcher/launcher_main.cc deleted file mode 100644 index 8690996684b2..000000000000 --- a/apps/hexagon_launcher/launcher_main.cc +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include - -#include -#include -#include -#include - -#include "launcher_core.h" -#include "launcher_util.h" - -ExecutionSession* create_execution_session(bool gen_lwp_json); - -int parse_command_line(int argc, char* argv[], std::string* in_path, std::string* out_path, - bool* gen_lwp_json) { - static option long_options[] = { - {"in_config", required_argument, nullptr, 0}, - {"out_config", required_argument, nullptr, 0}, - {"gen_lwp_json", optional_argument, nullptr, 0}, - }; - - bool show_usage = false; - int opt, long_index = 0; - while ((opt = getopt_long(argc, argv, "i:o:u:", long_options, &long_index)) != -1) { - if (opt != 0) { - show_usage = true; - continue; - } - switch (long_index) { - case 0: - *in_path = std::string(optarg); - break; - case 1: - *out_path = std::string(optarg); - break; - case 2: - *gen_lwp_json = true; - break; - } - } - if (in_path->empty() || out_path->empty() || show_usage) { - std::cout << "Usage: " << argv[0] << " --" << long_options[0].name << " input.json --" - << long_options[1].name << " output.json\n"; - return 1; - } - return 0; -} - -int main(int argc, char* argv[]) { - std::string in_path, out_path; - bool gen_lwp_json; - if (parse_command_line(argc, argv, &in_path, &out_path, &gen_lwp_json) != 0) { - return 1; - } - - ModelConfig config; - if (!read_model_config(in_path, &config)) { - return 1; - } - - ExecutionSession* session_ptr = create_execution_session(gen_lwp_json); - if (session_ptr == nullptr) { - return 1; - } - ExecutionSession& session = *session_ptr; - - std::cout << "loading model files: "; - if (!config.model_json.empty()) { - std::cout << config.model_json << ", "; - } - std::cout << config.model_library << '\n'; - - std::string json = !config.model_json.empty() ? load_text_file(config.model_json) : ""; - if (!session.load_model(config.model_library, json.c_str())) { - return 1; - } - - int max_ndim = 0; - for (const TensorConfig& tc : config.inputs) { - max_ndim = std::max(max_ndim, tc.shape.size()); - } - auto* input_meta = session.alloc(tensor_meta::meta_size(max_ndim)); - - for (int i = 0, e = config.inputs.size(); i != e; ++i) { - const TensorConfig& tc = config.inputs[i]; - input_meta->ndim = tc.shape.size(); - input_meta->dtype = tvm::ffi::StringToDLDataType(tc.dtype); - std::copy(tc.shape.begin(), tc.shape.end(), input_meta->shape); - - auto* input_data = session.alloc(input_meta->data_size()); - std::cout << "loading input file #" << i << ": " << tc.file_name << '\n'; - load_binary_file(tc.file_name, input_data, input_meta->data_size()); - if (!session.set_input(i, input_meta, input_data)) { - return 1; - } - } - - OutputConfig output_config; - - std::cout << "running..." << std::flush; - if (!session.run(&output_config.pcycles, &output_config.usecs)) { - std::cout << '\n'; - return 1; - } - std::cout << '\n'; - std::cout << "Finished in " << output_config.pcycles << " pcycles, (" << output_config.usecs - << "us)\n"; - - auto* output_meta = session.alloc(128); - int num_outputs = 0; - if (!session.get_num_outputs(&num_outputs)) { - return 1; - } - - for (int i = 0; i != num_outputs; ++i) { - if (!session.get_output(i, output_meta, 128, nullptr, 0)) { - return 1; - } - int data_size = output_meta->data_size(); - auto* output_data = session.alloc(data_size); - if (!session.get_output(i, output_meta, 128, output_data, data_size)) { - return 1; - } - - TensorConfig oc; - oc.file_name = "output" + std::to_string(i) + ".dat"; - for (int i = 0, e = output_meta->ndim; i != e; ++i) { - oc.shape.push_back(output_meta->shape[i]); - } - oc.dtype = tvm::runtime::DLDataTypeToString(output_meta->dtype); - write_binary_file(oc.file_name, output_data, data_size); - output_config.outputs.push_back(std::move(oc)); - - session.free(output_data); - } - - if (!write_output_config(out_path, &output_config)) { - return 1; - } - return 0; -} diff --git a/apps/hexagon_launcher/launcher_rpc.idl b/apps/hexagon_launcher/launcher_rpc.idl deleted file mode 100644 index 27e5d1d15d68..000000000000 --- a/apps/hexagon_launcher/launcher_rpc.idl +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "remote.idl" -#include "AEEStdDef.idl" - -typedef sequence buffer; - -interface launcher_rpc : remote_handle64 { - AEEResult load(in string module_path, in string model_json); - AEEResult unload(); - AEEResult get_num_inputs(rout long num_inputs); - AEEResult set_input(in long input_idx, in buffer input_meta, in buffer input_value); - AEEResult get_num_outputs(rout long num_outputs); - AEEResult get_output(in long output_idx, rout buffer output_meta, rout buffer output_value); - AEEResult run(rout uint64_t pcycles, rout uint64_t usecs, in long gen_lwp_json); -}; diff --git a/apps/hexagon_launcher/launcher_util.cc b/apps/hexagon_launcher/launcher_util.cc deleted file mode 100644 index ddbafb3c84a9..000000000000 --- a/apps/hexagon_launcher/launcher_util.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "launcher_util.h" - -#include - -#include -#include -#include -#include -#include -#include - -size_t get_file_size(std::ifstream& in_file) { - std::ifstream::pos_type pos = in_file.tellg(); - size_t size = in_file.seekg(0, std::ios::end).tellg(); - in_file.seekg(pos, std::ios::beg); - return size; -} - -size_t get_file_size(std::ifstream&& in_file) { - return get_file_size(in_file); // calls the & version -} - -std::string load_text_file(const std::string& file_name) { - constexpr size_t block_size = 1024 * 1024; // 1MB - std::ifstream in_file(file_name); - TVM_FFI_ICHECK(in_file.is_open()) << "cannot open file " << file_name; - size_t file_size = get_file_size(in_file); - std::string buffer(file_size + 1, 0); - - in_file.read(&buffer[0], file_size); - return buffer; -} - -void* load_binary_file(const std::string& file_name, void* buffer, size_t buffer_size) { - std::ifstream in_file(file_name); - TVM_FFI_ICHECK(in_file.is_open()) << "cannot open file " << file_name; - size_t file_size = get_file_size(in_file); - - in_file.read(reinterpret_cast(buffer), - std::min(buffer_size, file_size)); - return buffer; -} - -void write_binary_file(const std::string& file_name, void* buffer, size_t buffer_size) { - std::ofstream out_file(file_name); - TVM_FFI_ICHECK(out_file.is_open()) << "cannot open file " << file_name; - - out_file.write(reinterpret_cast(buffer), buffer_size); -} diff --git a/apps/hexagon_launcher/launcher_util.h b/apps/hexagon_launcher/launcher_util.h deleted file mode 100644 index 13db89d052fb..000000000000 --- a/apps/hexagon_launcher/launcher_util.h +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_HEXAGON_LAUNCHER_LAUNCHER_UTIL_H_ -#define TVM_RUNTIME_HEXAGON_LAUNCHER_LAUNCHER_UTIL_H_ - -#include -#include -#include - -size_t get_file_size(std::ifstream& in_file); -size_t get_file_size(std::ifstream&& in_file); - -std::string load_text_file(const std::string& file_name); -void* load_binary_file(const std::string& file_name, void* buffer, size_t buffer_size); -void write_binary_file(const std::string& file_name, void* buffer, size_t buffer_size); - -#endif // TVM_RUNTIME_HEXAGON_LAUNCHER_LAUNCHER_UTIL_H_ diff --git a/ci/jenkins/data.py b/ci/jenkins/data.py index 44cdba1d02b2..0811451b53fe 100644 --- a/ci/jenkins/data.py +++ b/ci/jenkins/data.py @@ -33,10 +33,6 @@ files_to_stash = { # Executables and build files needed to run c++ tests "cpptest": ["build/cpptest", "build/build.ninja", "build/CMakeFiles/rules.ninja"], - # Folder for hexagon build - "hexagon_api": [ - "build/hexagon_api_output", - ], # runtime files "tvm_runtime": ["build/lib/libtvm_runtime.so", "build/config.cmake"], # compiler files diff --git a/python/tvm/contrib/hexagon/_ci_env_check.py b/python/tvm/contrib/hexagon/_ci_env_check.py index f7f14a23955e..ed7beb246a82 100644 --- a/python/tvm/contrib/hexagon/_ci_env_check.py +++ b/python/tvm/contrib/hexagon/_ci_env_check.py @@ -17,9 +17,7 @@ """Hexagon environment checks for CI usage -These may be required by either tvm.testing or -tvm.contrib.hexagon.pytest_plugin, and are separated here to avoid a -circular dependency. +These are required by tvm.testing and live here to avoid a circular dependency. """ import os diff --git a/python/tvm/contrib/hexagon/build.py b/python/tvm/contrib/hexagon/build.py deleted file mode 100644 index 30b9c484581a..000000000000 --- a/python/tvm/contrib/hexagon/build.py +++ /dev/null @@ -1,830 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=consider-using-with, unnecessary-ellipsis -# ruff: noqa: RUF005, RUF012 - -"""Defines top-level glue functions for building Hexagon.""" - -import abc -import datetime -import logging -import multiprocessing as mp -import os -import pathlib -import random -import signal -import socket -import stat -import string -import subprocess -import sys -import tempfile - -from tvm.contrib.hexagon.hexagon_profiler import HexagonProfiler - -from .session import Session -from .tools import HEXAGON_SIMULATOR_NAME - -HEXAGON_RPC_LIB_DIR = os.environ.get("HEXAGON_RPC_LIB_DIR") -ANDROID_BASH_FILE_NAME = "android_bash.sh" -HEXAGON_REMOTE_DEVICE_KEY = "hexagon-dev" - - -def _check_call_verbose(cmd, **kwargs) -> None: - """ - Similar to subprocess.check_call(cmd), but if the exit code is non-zero - then the raised Exception's message provides more detail, including - the stdout/stderr provided by the subprocess. - """ - try: - subprocess.run( - cmd, - check=True, - encoding="UTF-8", - capture_output=True, - **kwargs, - ) - except subprocess.CalledProcessError as err: - error_msg = f"{err}\nstdout:\n{err.stdout}\nstderr:\n{err.stderr}" - raise Exception(error_msg) - - -def _get_hexagon_rpc_lib_dir() -> pathlib.Path: - """Find the Hexagon API binaries. - - Returns - ------- - pathlib.Path : - The path to the Hexagon API directory. - """ - global HEXAGON_RPC_LIB_DIR - HEXAGON_RPC_LIB_DIR = os.environ.get("HEXAGON_RPC_LIB_DIR") - if HEXAGON_RPC_LIB_DIR is None: - raise RuntimeError("hexagon_api binaries not found, please define HEXAGON_RPC_LIB_DIR") - return pathlib.Path(HEXAGON_RPC_LIB_DIR) - - -def _get_test_directory_name() -> str: - """Generate a time-stamped name for use as a test directory name.""" - date_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - random_str = "".join(random.choice(string.ascii_lowercase) for _ in range(10)) - return f"{date_str}-{random_str}" - - -def _get_adb_path() -> str: - """Define path to adb - - Order of search: - 1. From PATH - 2. From ANDROID_SDK_ROOT - 3. From ANDROID_HOME - 3. From default android sdk installation directory (platform specific) - """ - - def check_execution(exe_path): - try: - ret_code = subprocess.call( - [exe_path, "--version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) - except FileNotFoundError: - ret_code = -1 - - return ret_code == 0 - - # Check if adb available via PATH - if check_execution("adb"): - return "adb" - - # Check if adb available via env vars or default directories - list_of_paths = [ - os.environ.get("ANDROID_SDK_ROOT", default=""), - os.environ.get("ANDROID_HOME", default=""), - ] - - if sys.platform == "darwin": - list_of_paths += [ - os.path.join(pathlib.Path.home(), "Library", "Android", "sdk", "platform-tools") - ] - if sys.platform == "win32": - list_of_paths += [ - os.path.join( - pathlib.Path.home(), "AppData", "Local", "Android", "sdk", "platform-tools" - ) - ] - if sys.platform == "linux": - list_of_paths += [os.path.join(pathlib.Path.home(), "Android", "Sdk", "platform-tools")] - - list_of_paths = [path for path in list_of_paths if path != ""] - - found_path = None - for candidate_path in list_of_paths: - adb_path = os.path.join(candidate_path, "adb") - if os.path.isfile(adb_path) and check_execution(adb_path): - found_path = adb_path - break - - if found_path is None: - raise RuntimeError( - "ADB was not found. It should be available via PATH, ANDROID_SDK_ROOT " - "or ANDROID_HOME env var." - ) - - return found_path - - -class HexagonLauncherRPC(metaclass=abc.ABCMeta): - """Base class for RPC-based launchers. - - This is an abstract class intended to be a base class for specific - implementations of RPC launchers. There are two public methods that - each launcher needs to implement: - - start_server - - stop server - and two "private" methods used in setting up the environment: - - _copy_to_remote - - _create_remote_directory - - The basic flow of interaction with the launcher is - launcher = HexagonLauncher(...) - launcher.start_server() - with launcher.create_session() as session: - # Do something with the session - launcher.stop_server() - - Parameters - ---------- - rpc_info : dict - Description of the RPC setup. Recognized keys: - "rpc_tracker_host" : str name of the host running the tracker (default "0.0.0.0") - "rpc_tracker_port" : int port number of the tracker (default: 9190) - "rpc_server_port" : int port number for the RPC server to use (default 7070) - "workspace_base" : str name of base test directory (default ".") - workspace : str or patlib.Path - The server's remote working directory. If this directory does not - exist, it will be created. If it does exist, the servermust have - write permissions to it. - If this parameter is None, a subdirectory in the `workspace_base` - directory will be created, otherwise the `workspace_base` is not - used. - """ - - def __init__( - self, - rpc_info: dict, - workspace: str | pathlib.Path | None = None, - serial_number: str | None = None, - ): - self._rpc_info = { - "rpc_tracker_host": "0.0.0.0", - "rpc_tracker_port": 9190, - "rpc_server_port": 7070, - "workspace_base": ".", - } - self._rpc_info.update(rpc_info) - self._workspace = self._create_workspace(workspace) - self._serial_number = serial_number - - @abc.abstractmethod - def start_server(self): - """Start the RPC server""" - ... - - @abc.abstractmethod - def stop_server(self): - """Stop the RPC server""" - ... - - @abc.abstractmethod - def cleanup_directory(self): - """Cleanup working directory""" - ... - - @abc.abstractmethod - def _copy_to_remote(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path): - """Copy a local file to a remote location. - - Parameters - ---------- - local_path : str or pathlib.Path - Path to the local file. - remote_path : str or pathlib.Path - Path to the remote file (to be written). - """ - ... - - @abc.abstractmethod - def _create_remote_directory(self, remote_path: str | pathlib.Path) -> pathlib.Path: - """Create a directory in the remote location. - - Parameters - ---------- - remote_path : str or pathlib.Path - Name of the directory to be created. - - Returns - ------- - pathlib.Path : - Absolute path of the remote workspace. - """ - ... - - def _create_workspace(self, workspace: str | pathlib.Path) -> pathlib.Path: - """Create a working directory for the server. - - Parameters - ---------- - workspace : str or pathlib.Path or NoneType - Name of the directory to create. If None, a new name is constructed - using workspace_base. - - Returns - ------- - pathlib.Path : - Created workspace. - """ - if not workspace: - base_dir = self._rpc_info["workspace_base"] - workspace = os.path.join(base_dir, _get_test_directory_name()) - return self._create_remote_directory(workspace) - - @abc.abstractmethod - def get_profile_output( - self, - hex_profiler: HexagonProfiler, - session: Session, - ) -> str: - """Extract profile output. - - Parameters - ---------- - hex_profiler : HexagonProfiler - HexagonProfiler object that contains the profiling related information. - session : Session - Remote session. The session must be established (via __enter__) - prior to calling this function. - - Returns - ------- - profile_data : str - Path of the profiling data file - """ - ... - - def create_session(self, session_name: str = "hexagon-rpc") -> Session: - """Create an RPC session. - - Parameters - ---------- - session_name : str - RPC session name. - - Returns - ------- - Session : - The session object. - """ - hexagon_session_kw = { - "remote_workspace": self._workspace, - "rpc_tracker": (self._rpc_info["rpc_tracker_host"], self._rpc_info["rpc_tracker_port"]), - "rpc_server_key": self._rpc_info["device_key"], - "serial_number": self._serial_number, - "session_name": session_name, - } - return Session(**hexagon_session_kw) - - def is_simulator(self): - return self._serial_number == HEXAGON_SIMULATOR_NAME - - -class HexagonLauncherAndroid(HexagonLauncherRPC): - """Hexagon Launcher for Android.""" - - ANDROID_HEXAGON_TEST_BASE_DIR = pathlib.Path("/data/local/tmp/hexagon_test") - ANDROID_HEXAGON_RPC_FILES = [ - "libhexagon_rpc_skel.so", - "libtvm_runtime.so", - "tvm_rpc_android", - ] - - def __init__( - self, - serial_number: str, - rpc_info: dict, - workspace: str | pathlib.Path | None = None, - hexagon_debug: bool = False, - clear_logcat: bool = False, - sysmon_profile: bool = False, - farf_config: str = "0x1e", - ): - """Configure a new HexagonLauncherAndroid - - Parameters - ---------- - serial_number : str - Android device serial number. - rpc_info : dict - Same as in HexagonLauncherRPC, except if the "workspace_base" - key is not present or is None, ANDROID_HEXAGON_TEST_BASE_DIR - is used as the base directory. - workspace : str or pathlib.Path, optional - Test workspace path on android. - hexagon_debug: bool, optional - Should the server run debug options. - clear_logcat: bool, optional - Should the server clear logcat before running. - sysmon_profile: bool, optional - Should the server run sysmon profiler in the background. - farf_config: str, optional - Configuration string for runtime log level filtering. - Use farf_config_from_python_log_level to generate a bitmask - string from a Python logging level (e.g., logging.INFO) - """ - if not rpc_info.get("workspace_base"): - rpc_info["workspace_base"] = self.ANDROID_HEXAGON_TEST_BASE_DIR - self._serial_number = serial_number - assert self._serial_number != "", "Android serial number is not set." - - adb_socket = rpc_info["adb_server_socket"] if rpc_info["adb_server_socket"] else "tcp:5037" - adb_exe = _get_adb_path() - self._adb_device_sub_cmd = [adb_exe, "-L", adb_socket, "-s", self._serial_number] - self.forwarded_ports_ = [] - self._hexagon_debug = hexagon_debug - self._clear_logcat = clear_logcat - self._sysmon_profile = sysmon_profile - self._sysmon_process = None - self._farf_config = farf_config - rpc_info["device_key"] = HEXAGON_REMOTE_DEVICE_KEY + "." + self._serial_number - - super().__init__(rpc_info, workspace, self._serial_number) - - def _copy_to_remote(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path): - """Abstract method implementation. See description in HexagonLauncherRPC.""" - - _check_call_verbose(self._adb_device_sub_cmd + ["push", str(local_path), str(remote_path)]) - - def _create_remote_directory(self, remote_path: str | pathlib.Path) -> pathlib.Path: - """Abstract method implementation. See description in HexagonLauncherRPC.""" - _check_call_verbose(self._adb_device_sub_cmd + ["shell", "mkdir", "-p", str(remote_path)]) - return pathlib.Path(remote_path) - - def _copy_binaries(self): - """Upload Android server binaries.""" - - # Create bash script - with open(_get_hexagon_rpc_lib_dir() / f"{ANDROID_BASH_FILE_NAME}.template") as src_f: - with tempfile.TemporaryDirectory() as temp_dir: - android_bash_script_path = pathlib.Path(temp_dir) / ANDROID_BASH_FILE_NAME - with open(android_bash_script_path, "w") as dest_f: - for line in src_f.readlines(): - if "" in line: - line = line.replace( - "", str(self._rpc_info["rpc_tracker_host"]) - ) - if "" in line: - line = line.replace( - "", str(self._rpc_info["rpc_tracker_port"]) - ) - if "" in line: - line = line.replace( - "", self._rpc_info["device_key"] - ) - if "" in line: - line = line.replace( - "", str(self._rpc_info["rpc_server_port"]) - ) - if "" in line: - line = line.replace("", str(self._farf_config)) - dest_f.write(line) - - # Make shell script executable - android_bash_stat = os.stat(android_bash_script_path) - os.chmod(android_bash_script_path, android_bash_stat.st_mode | stat.S_IEXEC) - self._copy_to_remote( - android_bash_script_path, self._workspace / android_bash_script_path.name - ) - - # Push files - lib_dir = _get_hexagon_rpc_lib_dir() - for item in self.ANDROID_HEXAGON_RPC_FILES: - self._copy_to_remote(lib_dir / item, self._workspace / item) - - def _process_forwarded_ports(self): - forwarded_ports = subprocess.check_output(self._adb_device_sub_cmd + ["forward", "--list"]) - existing_forwards = [] - for forward in str(forwarded_ports).split("\\n"): - entry = forward.split() - if len(entry) == 3: - _, local, _ = entry - existing_forwards.append(int(local.strip("tcp:"))) - return existing_forwards - - def _forward_ports(self, rpc_server_port, existing_forwards): - # Enable port forward for RPC server. We forward the first ten open ports - # starting from the rpc_server_port - port = rpc_server_port - while len(self.forwarded_ports_) < 10: - if port not in existing_forwards and not _is_port_in_use(port): - _check_call_verbose( - self._adb_device_sub_cmd + ["forward", f"tcp:{port}", f"tcp:{port}"] - ) - self.forwarded_ports_.append(port) - port += 1 - - def _reverse_ports(self, rpc_tracker_port): - _check_call_verbose( - self._adb_device_sub_cmd - + ["reverse", f"tcp:{rpc_tracker_port}", f"tcp:{rpc_tracker_port}"] - ) - - def _run_server_script(self): - """Setup the ADB connection and execute the server script.""" - - # Collect any existing adb port forwarding to avoid duplication - # with another running process - existing_forwards = self._process_forwarded_ports() - # Enable port reverse for RPC tracker - rpc_tracker_port = self._rpc_info["rpc_tracker_port"] - rpc_server_port = self._rpc_info["rpc_server_port"] - - self._reverse_ports(rpc_tracker_port) - self._forward_ports(rpc_server_port, existing_forwards) - - # Run server and connect to tracker - subprocess.Popen( - self._adb_device_sub_cmd - + ["shell", f"cd {self._workspace} && ./{ANDROID_BASH_FILE_NAME}"], - stdout=subprocess.PIPE, - stdin=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - - def _cleanup_port_forwarding(self): - # Removed pre-defined forward/reverse rules - rpc_tracker_port = self._rpc_info["rpc_tracker_port"] - _check_call_verbose( - self._adb_device_sub_cmd + ["reverse", "--remove", f"tcp:{rpc_tracker_port}"] - ) - for port in self.forwarded_ports_: - _check_call_verbose(self._adb_device_sub_cmd + ["forward", "--remove", f"tcp:{port}"]) - - def _terminate_remote(self): - # Send interupt to main and child processes - subprocess.Popen( - self._adb_device_sub_cmd - + ["shell", f"pkill -l sigint -P `cat {self._workspace}/rpc_pid.txt`"] - ) - subprocess.Popen( - self._adb_device_sub_cmd - + ["shell", f"kill -s sigint `cat {self._workspace}/rpc_pid.txt`"] - ) - # Wait for processes to destruct cleanly after receiving the intrupt - subprocess.Popen(self._adb_device_sub_cmd + ["shell", "sleep", "0.1s"]) - # Kill process children - subprocess.Popen( - self._adb_device_sub_cmd + ["shell", f"pkill -P `cat {self._workspace}/rpc_pid.txt`"] - ) - # Kill main process - subprocess.Popen( - self._adb_device_sub_cmd + ["shell", f"kill `cat {self._workspace}/rpc_pid.txt`"] - ) - - def cleanup_directory(self): - """Abstract method implementation. See description in HexagonLauncherRPC.""" - subprocess.Popen(self._adb_device_sub_cmd + ["shell", f"rm -rf {self._workspace}"]) - - def _start_sysmon(self): - hexagon_sdk_root = os.environ.get("HEXAGON_SDK_ROOT", default="") - subprocess.call( - self._adb_device_sub_cmd - + ["push", f"{hexagon_sdk_root}/tools/utils/sysmon/sysMonApp", "/data/local/tmp/"] - ) - sysmon_process = subprocess.Popen( - self._adb_device_sub_cmd - + [ - "shell", - "/data/local/tmp/sysMonApp profiler --debugLevel 0 --samplePeriod 1 --q6 cdsp", - ], - stdin=subprocess.PIPE, - ) - return sysmon_process - - def _stop_sysmon(self): - if self._sysmon_process is not None: - self._sysmon_process.communicate(input=b"\n") - self._sysmon_process = None - - def _retrieve_sysmon(self): - pathlib.Path("./sysmon_output/").mkdir(exist_ok=True) - subprocess.call( - self._adb_device_sub_cmd + ["pull", "/sdcard/sysmon_cdsp.bin", "./sysmon_output/"] - ) - subprocess.call(self._adb_device_sub_cmd + ["root"]) - hexagon_sdk_root = os.environ.get("HEXAGON_SDK_ROOT", default="") - subprocess.call( - f"{hexagon_sdk_root}/tools/utils/sysmon/parser_linux_v2/HTML_Parser/sysmon_parser " - + "./sysmon_output/sysmon_cdsp.bin --outdir ./sysmon_output/", - shell=True, - ) - - def _clear_debug_logs(self): - subprocess.call(self._adb_device_sub_cmd + ["shell", "logcat", "-c"]) - - def _retrieve_debug_logs(self): - run_start_time = subprocess.check_output( - self._adb_device_sub_cmd - + [ - "shell", - "stat", - f"{self._workspace}/android_bash.sh | grep 'Change' | grep -oe '[0-9].*'", - ] - ) - run_start_time = run_start_time[:-1].decode("UTF-8") - subprocess.call( - self._adb_device_sub_cmd - + [ - "shell", - "logcat", - "-t", - f'"{run_start_time}"', - "-f", - f"{self._workspace}/logcat.txt", - ] - ) - subprocess.call(self._adb_device_sub_cmd + ["pull", f"{self._workspace}/logcat.txt", "."]) - - def _print_cdsp_logs(self): - crash_count = 0 - context_lines = 0 - print_buffer = "" - try: - with open("./logcat.txt") as f: - for line in f: - if "Process on cDSP CRASHED" in line: - if crash_count <= 5: - print(print_buffer, "\n") - context_lines = 40 - print_buffer = "" - crash_count += 1 - if context_lines > 0 and "platform_qdi_driver" in line: - context_lines -= 1 - print_buffer += line[80:] - - if crash_count <= 5: - print(print_buffer, "\n") - - print( - f"There were {crash_count} crashes on the cDSP during execution... " - + "Crash printing is limited to the first 5." - ) - except FileNotFoundError: - print("Unable to parse logcat file.") - - def start_server(self): - """Abstract method implementation. See description in HexagonLauncherRPC.""" - self._copy_binaries() - if self._sysmon_profile: - self._sysmon_process = self._start_sysmon() - self._run_server_script() - if self._clear_logcat: - self._clear_debug_logs() - - def stop_server(self): - """Abstract method implementation. See description in HexagonLauncherRPC.""" - if self._sysmon_profile and self._sysmon_process is not None: - self._stop_sysmon() - self._retrieve_sysmon() - if self._hexagon_debug: - self._retrieve_debug_logs() - self._print_cdsp_logs() - self._cleanup_port_forwarding() - self._terminate_remote() - if not self._hexagon_debug: - self.cleanup_directory() - - def get_profile_output( - self, - hex_profiler: HexagonProfiler, - session: Session, - ): - """Abstract method implementation. See description in HexagonLauncherRPC.""" - profile_data = "" - if hex_profiler.is_lwp_enabled(): - temp_dir = hex_profiler.get_temp_dir() - remote_path = hex_profiler.get_remote_path() - if not temp_dir: - raise RuntimeError("tempdir not passed") - fname = "lwp.json" - out_path = os.path.join(remote_path, fname) - profile_data = temp_dir.relpath(fname) - ret = session.get_profile_output(hex_profiler.get_mode(), fname) - if ret: - subprocess.check_call(self._adb_device_sub_cmd + ["pull", out_path, profile_data]) - else: - raise RuntimeError("Error generating profile output") - elif hex_profiler.profiling_mode == "etm": - hex_profiler.pull_files_for_etm_processing(self._workspace) - else: - raise RuntimeError("Profiling not enabled") - return profile_data - - -class HexagonLauncherSimulator(HexagonLauncherRPC): - """Hexagon Launcher for Hexagon simulator.""" - - SIMULATOR_HEXAGON_RPC_FILES = ["tvm_rpc_x86", "libhexagon_rpc_sim.so"] - - def __init__(self, rpc_info: dict, workspace: str | pathlib.Path | None = None): - """Configure a new HexagonLauncherSimulator - - Parameters are same as for HexagonLauncherRPC. - """ - - self._toolchain = os.environ.get("HEXAGON_TOOLCHAIN") - if not self._toolchain: - raise RuntimeError("Please set HEXAGON_TOOLCHAIN env variable") - self._serial_number = HEXAGON_SIMULATOR_NAME - - super().__init__(rpc_info, workspace, self._serial_number) - - def _copy_to_remote(self, local_path: str | pathlib.Path, remote_path: str | pathlib.Path): - """Abstract method implementation. See description in HexagonLauncherRPC.""" - _check_call_verbose(["cp", str(local_path), str(remote_path)]) - - def _create_remote_directory(self, remote_path: str | pathlib.Path) -> pathlib.Path: - """Abstract method implementation. See description in HexagonLauncherRPC.""" - _check_call_verbose(["mkdir", "-p", str(remote_path)]) - return pathlib.Path(os.path.abspath(remote_path)) - - def _copy_libcxx(self, dest_dir: str | pathlib.Path): - """Copy libc++ libraries to the remote workspace.""" - # Copy the v68 versions, since we don't have target information. - # The v68 ones should work everywhere on v68+. - lib_dir = os.path.join(self._toolchain, "target/hexagon/lib/v68/G0/pic") - - libcxx_files = [] - for entry in os.scandir(lib_dir): - if entry.is_dir() or entry.name.find(".so") == -1: - continue - if entry.name.startswith("libc++"): - libcxx_files.append(entry.name) - - # Use tar to preserve the symbolic links. Libc++ libraries use the - # typical .so versioning, so that libc++.so may be a symlink to - # something else. Also, shared libraries using libc++ could be - # directly linked against some version, e.g. libc++.so.1, so make - # sure that all files are copied over. The preservation of symbolic - # links is to save disk space. - tar_in = f"tar -cf - -C {lib_dir} " + " ".join(libcxx_files) - tar_out = f"tar -xf - -C {dest_dir!s}" - _check_call_verbose(tar_in + " | " + tar_out, shell=True) - - def start_server(self): - """Abstract method implementation. See description in HexagonLauncherRPC.""" - # Copy binaries - lib_dir = _get_hexagon_rpc_lib_dir() - for item in self.SIMULATOR_HEXAGON_RPC_FILES: - self._copy_to_remote(lib_dir / item, self._workspace / item) - # Copy libc++ from the toolchain to the workspace - self._copy_libcxx(self._workspace) - self._rpc_info["device_key"] = HEXAGON_REMOTE_DEVICE_KEY + "." + str(os.getpid()) - - rpc_tracker_host = self._rpc_info["rpc_tracker_host"] - rpc_tracker_port = self._rpc_info["rpc_tracker_port"] - rpc_server_port = self._rpc_info["rpc_server_port"] - device_key = self._rpc_info["device_key"] - server_exe = os.path.join(".", "tvm_rpc_x86") - - args = [ - "server", - f"--tracker={rpc_tracker_host}:{rpc_tracker_port}", - f"--port={rpc_server_port}", - f"--key={device_key}", - "--timeout=0", - ] - - # pylint: disable=unused-argument - def _terminate_handler(self, signum, *rest): - # Terminate the Popen'ed (sub)process. - os.kill(self._subprocess_pid, signal.SIGTERM) - - def _start(self): - # This function will be running in a new process. It will start the RPC - # (x86) server as a subprocess of itself. - log_out = self._workspace / "stdout.txt" - log_err = self._workspace / "stderr.txt" - # Intercept the TERM signal so we can also terminate the subprocess. - signal.signal(signal.SIGTERM, lambda *a: _terminate_handler(self, *a)) - - with open(log_out, "w") as out, open(log_err, "w") as err: - p = subprocess.Popen( - [server_exe, *args], stdout=out, stderr=err, cwd=self._workspace - ) - # Insert the pid of the subprocess in the self object. - self._subprocess_pid = p.pid - p.wait() - - self._server_process = mp.Process(target=lambda *a: _start(self, *a)) - self._server_process.start() - - def cleanup_directory(self): - """Abstract method implementation. See description in HexagonLauncherRPC.""" - - def stop_server(self): - """Abstract method implementation. See description in HexagonLauncherRPC.""" - self._server_process.terminate() - - def get_profile_output( - self, - hex_profiler: HexagonProfiler, - session: Session, - ): - """Abstract method implementation. See description in HexagonLauncherRPC.""" - profile_data = "" - if hex_profiler.is_lwp_enabled(): - fname = "lwp.json" - profile_data = f"{self._workspace}/{fname}" - ret = session.get_profile_output(hex_profiler.get_mode(), fname) - if not ret: - raise RuntimeError("Error generating profile output") - elif hex_profiler.profiling_mode == "etm": - raise RuntimeError("ETM Profiling not supported on the simulator") - else: - raise RuntimeError("Profiling not enabled") - - return profile_data - - -# https://stackoverflow.com/a/52872579/2689797 -def _is_port_in_use(port: int) -> bool: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(("localhost", port)) == 0 - - -def farf_config_from_python_log_level(level) -> str: - """Generates a FARF configuration string enabling logging at the specified level - - Parameters - ---------- - level : str or int - Minimum level to log at. Must be a known Python logging level or string - (e.g., logging.INFO or "INFO") - """ - - # Runtime log levels can be selectively enabled by computing a bitmask - # corresponding to the levels you want to enable. These get forwarded to - # logcat by the DSP RPC daemon. The bits for each level are: - - # 0x01 - Hexagon LOW / TVM DEBUG / Python DEBUG - # 0x02 - Hexagon MEDIUM / TVM INFO / Python INFO - # 0x04 - Hexagon HIGH / TVM WARN / Python WARNING - # 0x08 - Hexagon ERROR / TVM ERROR / Python ERROR - # 0x10 - Hexagon FATAL / TVM FATAL / Python CRITICAL - - # Runtime logging can also be filtered on filenames by appending a - # comma-separated list of filenames. For more information, see - # the Hexagon SDK documentation. - - if level in (logging.DEBUG, "DEBUG"): - return "0x1F" - if level in (logging.INFO, "INFO"): - return "0x1E" - if level in (logging.WARNING, "WARNING"): - return "0x1C" - if level in (logging.ERROR, "ERROR"): - return "0x18" - if level in (logging.CRITICAL, "CRITICAL"): - return "0x10" - - raise ValueError("Argument must be a known Python logging level or string") - - -# pylint: disable=invalid-name -def HexagonLauncher( - serial_number: str, - rpc_info: dict, - workspace: str | pathlib.Path | None = None, - hexagon_debug: bool = False, - clear_logcat: bool = False, - sysmon_profile: bool = False, - farf_config: str = farf_config_from_python_log_level(logging.INFO), -): - """Creates a HexagonLauncher""" - if serial_number == HEXAGON_SIMULATOR_NAME: - return HexagonLauncherSimulator(rpc_info, workspace) - return HexagonLauncherAndroid( - serial_number, rpc_info, workspace, hexagon_debug, clear_logcat, sysmon_profile, farf_config - ) diff --git a/python/tvm/contrib/hexagon/hexagon_profiler.py b/python/tvm/contrib/hexagon/hexagon_profiler.py deleted file mode 100644 index 44a66ef7be39..000000000000 --- a/python/tvm/contrib/hexagon/hexagon_profiler.py +++ /dev/null @@ -1,125 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=consider-using-with -# ruff: noqa: RUF005 - -"""Define HexagonProfiler class to enable profiling for Hexagon""" - -import os -import subprocess - -from tvm.contrib.hexagon.profiling.process_lwp_data import process_lwp_output -from tvm.ir.transform import PassContext -from tvm.support import utils - - -class HexagonProfiler: - """Hexagon Profiler""" - - def __init__( - self, - dso_binary: str, - module, - hexagon_server_process, - enable_debug, - ): - """Configure HexagonProfiler""" - # Save test .so to process profiling data - self._temp_dir = utils.tempdir(keep_for_debug=enable_debug) - self._dso_binary_path = self._temp_dir.relpath(dso_binary) - module.save(self._dso_binary_path) - - self._android_serial_number = os.environ.get("ANDROID_SERIAL_NUMBER") - self._remote_path = "" - self._logcat_path = "" - - self._profiling_mode = None - config = PassContext.current().config - if self._android_serial_number is None: - raise RuntimeError("ANDROID_SERIAL_NUMBER must be set for profiling") - - if ("tirx.instrument_lwp", True) in config.items(): - # Set profiling mode - self._profiling_mode = "lwp" - - if self._android_serial_number != "simulator": - # Clear the logcat buffer and create a child process to redirect logcat output - # into a file. - launcher = hexagon_server_process["launcher"] - subprocess.check_call(launcher._adb_device_sub_cmd + ["logcat", "-c"]) - self._logcat_path = self._temp_dir.relpath("logcat.log") - self._fo = open(self._logcat_path, "w") - self._proc = subprocess.Popen( - launcher._adb_device_sub_cmd + ["logcat"], stdout=self._fo - ) - - # Get the remote workspace on the device from where the lwp data needs to be copied. - self._remote_path = launcher._workspace - - if self._profiling_mode is None: - raise RuntimeError("Profiling mode was not set or was not a valid one.") - - def get_mode(self): - return self._profiling_mode - - def is_lwp_enabled(self): - return self._profiling_mode == "lwp" - - def get_temp_dir(self): - return self._temp_dir - - def get_remote_path(self): - return self._remote_path - - def get_profile_output(self, hexagon_launcher, hexagon_session): - """Get runtime profiling data""" - prof_out = hexagon_launcher.get_profile_output(self, hexagon_session) - - print("lwp json can be found at -- ", prof_out) - - # Process lightweight profiling output into an easily readable csv file - # The post-processing requires following parameters: - # 1) Path of the binary file - # 2) android_serial_number - # 3) Path of the lwp json file (lwp.json) which gets created in the current directory - # 4) Path to the run log depending on the environment: - # a) For on-device runs: - # Use logcat output as the run log - # b) For simulator runs: - # Use "stdout.txt" as the run log. There is no need to specify the full path to - # "stdout.txt" as it will be inferred based on 'prof_out' location. - # 5) lwp processed output file - "lwp.csv" - # - lwp_csv = self._temp_dir.relpath("lwp.csv") - if self._android_serial_number == "simulator": - process_lwp_output( - self._dso_binary_path, self._android_serial_number, prof_out, "stdout.txt", lwp_csv - ) - else: - # For on-device run - self._proc.kill() # End the child process for logcat - self._fo.close() - if os.path.exists(self._logcat_path): - process_lwp_output( - self._dso_binary_path, - self._android_serial_number, - prof_out, - self._logcat_path, - lwp_csv, - ) - else: - raise RuntimeError("Error processing lwp output - missing logcat file") diff --git a/python/tvm/contrib/hexagon/meta_schedule.py b/python/tvm/contrib/hexagon/meta_schedule.py deleted file mode 100644 index 0084d1da7f56..000000000000 --- a/python/tvm/contrib/hexagon/meta_schedule.py +++ /dev/null @@ -1,195 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Meta schedule tuning utilities for Hexagon.""" - -import os -import tempfile -from collections.abc import Callable - -import tvm -from tvm.driver import build as tvm_build -from tvm.ir.module import IRModule -from tvm.ir.utils import derived_object -from tvm.runtime import Module, Tensor -from tvm.s_tir.meta_schedule.builder import LocalBuilder -from tvm.s_tir.meta_schedule.runner import ( - EvaluatorConfig, - PyRunner, - RunnerFuture, - RunnerInput, -) -from tvm.s_tir.meta_schedule.runner.rpc_runner import ( - RPCRunnerFuture, - default_alloc_argument, - default_run_evaluator, -) -from tvm.s_tir.meta_schedule.utils import cpu_count -from tvm.s_tir.transform import RemoveWeightLayoutRewriteBlock -from tvm.support.popen_pool import PopenPoolExecutor -from tvm.target import Target - -from .build import HexagonLauncherRPC -from .tools import export_module - - -@derived_object -class HexagonRPCRunner(PyRunner): - """RPCRunner for Hexagon. See the documentation of RPCRunner for more details.""" - - def __init__( - self, - hexagon_launcher: HexagonLauncherRPC, - evaluator_config: EvaluatorConfig | None = None, - cooldown_sec: float = 0.0, - alloc_repeat: int = 1, - max_workers: int | None = None, - initializer: Callable[[], None] | None = None, - ): - """ - Parameters - ---------- - hexagon_launcher : HexagonLauncherRPC - The RPC launcher for Hexagon. It is needed for creating hexagon.Session - object inside the worker function. - evaluator_config: EvaluatorConfig - The evaluator configuration. - cooldown_sec: float - The cooldown in seconds. - alloc_repeat: int - The number of times to random fill the allocation. - max_workers: Optional[int] = None - The maximum number of connections. Defaults to number of logical CPU cores. - initializer: Optional[Callable[[], None]] - The initializer function. - """ - - super().__init__() - self.hexagon_launcher = hexagon_launcher - self.evaluator_config = EvaluatorConfig._normalized(evaluator_config) - self.cooldown_sec = cooldown_sec - self.alloc_repeat = alloc_repeat - if max_workers is None: - max_workers = cpu_count(logical=True) - self.pool = PopenPoolExecutor( - max_workers=max_workers, - timeout=100, - initializer=initializer, - ) - - def run(self, runner_inputs: list[RunnerInput]) -> list[RunnerFuture]: - results = [] - for runner_input in runner_inputs: - future = RPCRunnerFuture( - future=self.pool.submit( - _worker_func, - self.hexagon_launcher, - self.evaluator_config, - self.alloc_repeat, - str(runner_input.artifact_path), - tuple(arg_info.as_json() for arg_info in runner_input.args_info), - ), - timeout_sec=100, - ) - results.append(future) - return results - - -def _worker_func(hexagon_launcher, evaluator_config, alloc_repeat, artifact_path, args_info): - with hexagon_launcher.create_session() as session: - device = session.device - _, remote_path = os.path.split(artifact_path) - uploaded = session.upload(artifact_path, remote_path) - rt_mod = session.load_module(uploaded) - repeated_args = default_alloc_argument( - session, - device, - args_info, - alloc_repeat, - ) - costs = default_run_evaluator( - session, - rt_mod, - device, - evaluator_config, - repeated_args, - ) - return costs - - -def get_hexagon_local_builder( - pass_context: tvm.transform.PassContext = None, - max_workers: int | None = None, - timeout_sec: float = 30.0, -): - """Return Hexagon-compatible Builder for meta schedule.""" - - def export_func(mod): - binary_path = export_module(mod, tempfile.mkdtemp()) - return str(binary_path) - - def default_build_with_context( - mod: IRModule, target: Target, _params: dict[str, Tensor] | None - ) -> Module: - with pass_context: - mod = RemoveWeightLayoutRewriteBlock(skip_tensor_rewrite=True)(mod) - return tvm_build(mod, target=target) - - if pass_context is not None: - return LocalBuilder( - f_build=default_build_with_context, - f_export=export_func, - max_workers=max_workers, - timeout_sec=timeout_sec, - ) - else: - return LocalBuilder(f_export=export_func, max_workers=max_workers, timeout_sec=timeout_sec) - - -def get_hexagon_rpc_runner( - hexagon_launcher: HexagonLauncherRPC, - number=3, - repeat=1, - min_repeat_ms=100, - max_workers: int | None = None, -): - """Return Hexagon-compatible RPC Runner for meta schedule. - - Parameters - ---------- - hexagon_launcher : HexagonLauncherRPC - The RPC launcher for Hexagon. - number: int - The number of times to run this function for taking average. - We call these runs as one `repeat` of measurement. - repeat: int - The number of times to repeat the measurement. - In total, the function will be invoked (1 + number x repeat) times, - where the first one is warm up and will be discarded. - The returned result contains `repeat` costs, - each of which is an average of `number` costs. - min_repeat_ms: int - Minimum repeat time in ms. if the execution latency is too short, - increase the number of runs to the given time (in ms) to reduce the measurement error. - """ - evaluator_config = EvaluatorConfig( - number=number, - repeat=repeat, - min_repeat_ms=min_repeat_ms, - enable_cpu_cache_flush=False, - ) - - return HexagonRPCRunner(hexagon_launcher, evaluator_config, max_workers=max_workers) diff --git a/python/tvm/contrib/hexagon/profiling/process_lwp_data.py b/python/tvm/contrib/hexagon/profiling/process_lwp_data.py deleted file mode 100644 index b82fbd4f88ae..000000000000 --- a/python/tvm/contrib/hexagon/profiling/process_lwp_data.py +++ /dev/null @@ -1,388 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E501, E741 - -import argparse -import csv -import json -import os -import subprocess -from collections import OrderedDict -from re import compile, search - -ENABLE_DEBUG = False -""" -Process lightweight profiling output and generate a CSV file with processor -cycles for the instrumented functions and loops. - -Please note that some assumptions have been made while processing -the lightweight profiling output. They are as follows: - -1) We don't expect profiled functions to call another profiled function. - This constraint can be relaxed if needed but it simplifies the processing - significantly without introducing any limitations for our use case. -2) For now, it's also assumed that every unique section (loop) ID has same start - and end offset which will not be true while a loop gets unrolled as it will - create multiple profiling section with the same ID. The current - implementation doesn't handle this case. - -""" - - -def get_func_info(model_so): - """Get all the .text sections along with their start and end offset values""" - hexagon_nm_path = os.environ["HEXAGON_TOOLCHAIN"] + "/bin/hexagon-nm" - out = subprocess.Popen( - [hexagon_nm_path, "--print-size", model_so], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - ) - stdo, stde = out.communicate() - stdo = stdo.decode("utf-8") - - func_info = [] - for l in stdo.split("\n"): - info = {} - if search(" (T|t) ", l): # If .text section - parts = l.split(" ") - assert len(parts) == 4 - info["start"] = int(parts[0], base=16) - info["end"] = int(parts[0], base=16) + int(parts[1], base=16) - info["name"] = parts[3] - func_info.append(info) - - # Sort the entries in the increasing order of the start offset value. - func_info = sorted(func_info, key=lambda d: d["start"]) - - if ENABLE_DEBUG: - print("func_info :\n ") - for f in func_info: - print(f) - return func_info - - -def find_func(func_info, offset): - """For a given offset, find the function it belongs to.""" - fidx = 0 - lidx = len(func_info) - 1 - while fidx <= lidx: - midx = (fidx + lidx) // 2 - ms = func_info[midx]["start"] - me = func_info[midx]["end"] - if fidx == lidx: - assert offset >= ms and offset <= me, ( - f"Couldn't find a function for this offset: {offset}" - ) - return fidx - else: - if offset > me: - fidx = midx + 1 - elif offset < ms: - lidx = midx - 1 - else: - return midx - assert False, "Possible mismatch between model .so and LWP data" - - -def accumulate_cycles(overall_cycles, func_cycles, func_name): - """Accumulate function cycles""" - acc_cycles = overall_cycles[func_name] - for id in func_cycles: - assert id in acc_cycles, f"id [{id}] missing in the existing function record" - assert acc_cycles[id]["start"] == func_cycles[id]["start"], ( - "Offset value doesn't match with the existing function record." - ) - acc_cycles[id]["cycles"] += func_cycles[id]["cycles"] - acc_cycles[id]["count"] += func_cycles[id]["count"] - overall_cycles.update({func_name: acc_cycles}) - return overall_cycles - - -def adjust_per_loop_counts(overall_cycles, data): - """ - Use execution count and the number of entries recorded for each function/loop - to compute the overall cycles spent on them. - """ - for func in overall_cycles: - func_cycles = overall_cycles[func] - for id in func_cycles: - exec_count = data["loop_counts"][id] - rec_count = func_cycles[id]["count"] - assert exec_count != 0, "Execution count should have been non-zero." - assert rec_count != 0, "Entry count should have been non-zero." - exec_cycles = ((int(func_cycles[id]["cycles"])) * exec_count) // rec_count - func_cycles[id]["cycles"] = exec_cycles - func_cycles[id]["count"] = exec_count - overall_cycles.update({func: OrderedDict(sorted(func_cycles.items()))}) - return overall_cycles - - -def create_csv_report(overall_cycles, fname): - """Create csv report""" - header = [ - "function name", - "loop/function id", - "loop depth", - "start offset", - "end offset", - "pcycles", - "parent count", - ] - with open(fname, "w") as f: - writer = csv.writer(f) - writer.writerow(header) - for func in overall_cycles: - func_cycles = overall_cycles[func] - data = [] - root = -1 - outer_most = -1 - for key, value in func_cycles.items(): - if value["parent"] == -1: - assert root == -1, "Can't have multiple root nodes." - root = key - - data.append(func) - data.append(key) - if value["parent"] == -1: - data.append("-") # Total cycles over all invocations of this function. - elif value["parent"] == root: - data.append(0) - outer_most = key - else: - if outer_most > -1: - data.append(key - outer_most) - else: - data.append(key - value["parent"]) - data.append(hex(value["start"])) - data.append(hex(value["end"])) - data.append(value["cycles"]) - data.append(value["count"]) - writer.writerow(data) - data.clear() - - -def process_data(data, func_info, so_ld_addr): - """Process data""" - # Keep an ordered list of loop IDs as they are being visited. This is used - # to match entry and exit pairs. Once the function/loop is processed, it's - # removed from the list. - ordered_visited_list = [] - # Store information regarding visited nodes as they are being processed. Once - # the function/loop is processed, it's removed from the set. - visited_set = {} - # Dictionary to store cycles for the entire model which is grouped into functions. - overall_cycles = {} - func_cycles = {} - - func_idx = -1 - func_name = "" - prev_func_name = "" - func_start = 0 - func_end = 0 - save_data = False - # Iterate over all the entries in the LWP data file and process them - # to construct a report. - for entry in data["entries"]: - id = entry["id"] - offset = entry["ret"] - so_ld_addr - - # Recorded return address should fall within the function begin and end - # offsets. If not, find the function it belongs to. - if offset < func_start or offset > func_end: - prev_func_name = func_name - if ENABLE_DEBUG: - print("offset : ", offset) - print("id : ", id) - - func_idx = find_func(func_info, offset) - func_name = func_info[func_idx]["name"] - func_start = func_info[func_idx]["start"] - func_end = func_info[func_idx]["end"] - if ENABLE_DEBUG: - print("func_name : ", func_name) - - if save_data: - # overall_cycles = save_func_cycles(prev_func_name, overall_cycles, func_cycles, ordered_visited_list) - # Done processing the previous function, copy its info into 'overall_cycles'. - if prev_func_name not in overall_cycles: - overall_cycles[prev_func_name] = func_cycles.copy() - else: - # Accumulate cycles into existing function entry. - overall_cycles = accumulate_cycles(overall_cycles, func_cycles, prev_func_name) - # We don't allow for fused operators (functions) calling another operator. - if ENABLE_DEBUG: - print("ordered_visited_list : ", ordered_visited_list) - - assert len(ordered_visited_list) == 0, ( - f"\nDone processing function [{prev_func_name}] but ordered_visited_list not empty.\n" - f"\t Possible reasons -- \n" - f"\t\t1) Mismatch between model .so and json file.\n" - f"\t\t2) LWP buffer may have overflowed resulting into missing entries!" - ) - func_cycles.clear() - - save_data = True - - if id not in visited_set: # Found 'entry' record - visited_info = {"func_idx": func_idx, "ret": offset, "cyc": entry["cyc"]} - visited_set[id] = visited_info - ordered_visited_list.append(id) - else: # Found 'exit' record - # This should be the last entry in the ordered_visited_list. If not, error out. - assert ordered_visited_list[-1] == id, ( - "Problem with LWP output - Interleaved handler calls found." - f"Loop [{ordered_visited_list[-1]}] hasn't exited yet." - ) - ordered_visited_list.pop() - entry_node = visited_set.pop(id) - assert entry_node["func_idx"] == func_idx, ( - f"Error - Found under a different function name : {entry_node['func_idx']}" - ) - cycles = entry["cyc"] - entry_node["cyc"] - parent = -1 - if ordered_visited_list: - parent = int(ordered_visited_list[-1]) - if id in func_cycles: - fcycles = func_cycles[id] - fcycles["cycles"] += cycles - fcycles["count"] += 1 - func_cycles[id] = fcycles - else: - func_cycles[id] = { - "cycles": cycles, - "start": entry_node["ret"], - "end": offset, - "parent": parent, - "count": 1, - } - - # Done processing the previous function, copy its info into 'overall_cycles'. - if func_name not in overall_cycles: - overall_cycles[func_name] = func_cycles.copy() - else: - # Accumulate cycles into existing function entry. - overall_cycles = accumulate_cycles(overall_cycles, func_cycles, func_name) - # We don't allow for fused operators (functions) calling another operator. - if ENABLE_DEBUG: - print("ordered_visited_list : ", ordered_visited_list) - - assert len(ordered_visited_list) == 0, ( - f"\nDone processing function [{prev_func_name}] but ordered_visited_list not empty.\n" - f"\t Possible reasons -- \n" - f"\t\t1) Mismatch between model .so and json file.\n" - f"\t\t2) LWP buffer may have overflowed resulting into missing entries!" - ) - - overall_cycles = adjust_per_loop_counts(overall_cycles, data) - return overall_cycles - - -def get_load_addr(serial_number: str, lwp_json: str, run_log: str): - """Get load address of the binary file""" - if serial_number == "simulator": - basedir = os.path.dirname(lwp_json) - if run_log is None: - run_log = os.path.join(basedir, "stdout.txt") - else: - # If the directory name is specified for the run_log of the - # simulator (stdout.txt) then it must be same as lwp_json. - run_log_dir = os.path.dirname(run_log) - assert run_log_dir == "" or run_log_dir == basedir, ( - f"stdout.txt and {os.path.basename(lwp_json)} must be in the same directory" - ) - run_log = os.path.join(basedir, os.path.basename(run_log)) - # To extract load address for the simulator run - pattern = compile(r"Model.*: (\w+):") - else: - # To extract load address for on-device run - pattern = compile(r"Model.*: (\w+)") - - with open(run_log) as f: - lines = f.read() - a = pattern.search(lines) - load_addr = int(a.group(1), 16) - if ENABLE_DEBUG: - print("load_addr : ", load_addr) - return load_addr - - -def process_lwp_output( - binary_path: str, - serial_number: str, - lwp_json: str, - run_log: str, - lwp_out: str, - enable_debug: bool = False, -): - """Process lightweight profiling data""" - # Enable debug messages - global ENABLE_DEBUG - ENABLE_DEBUG = enable_debug - - # Get load address for the binary - load_addr = get_load_addr(serial_number, lwp_json, run_log) - # Opening JSON file - with open(lwp_json) as f: - # Returns JSON object as a dictionary - data = json.load(f) - - # Get function names, and their start and end offsets from the model .so - func_info = get_func_info(binary_path) - - # Get the load address for model .so. - so_ld_addr = load_addr - - # Process profiling data to construct a CSV report. - overall_cycles = process_data(data, func_info, so_ld_addr) - create_csv_report(overall_cycles, lwp_out) - print("lwp processed output written to -- ", lwp_out) - print("[NOTE: Use '--hexagon-debug' to keep the temp directory]") - - -def get_args(): - """Add commandline arguments to run the script manually if needed""" - parser = argparse.ArgumentParser() - parser.add_argument("--lwp-json", help="LWP json file", required=True) - parser.add_argument("--serial-num", help="device-id/simulator", required=True) - parser.add_argument("--test-so", help="Test shared library", required=True) - parser.add_argument( - "--run-log", - help="Logcat file for on-device run and stdout.txt for simulator run", - required=True, - ) - parser.add_argument("--lwp-out", help="LWP output file name", required=True) - parser.add_argument( - "--debug", - help="Enable debug output from the script", - dest="debug", - action="store_true", - required=False, - ) - parser.set_defaults(debug=False) - args = parser.parse_args() - - global ENABLE_DEBUG - ENABLE_DEBUG = args.debug - - return args - - -if __name__ == "__main__": - args = get_args() - process_lwp_output( - args.test_so, args.serial_num, args.lwp_json, args.run_log, args.lwp_out, args.debug - ) diff --git a/python/tvm/contrib/hexagon/pytest_plugin.py b/python/tvm/contrib/hexagon/pytest_plugin.py deleted file mode 100644 index ac1bc7af99e7..000000000000 --- a/python/tvm/contrib/hexagon/pytest_plugin.py +++ /dev/null @@ -1,384 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=invalid-name,redefined-outer-name -"""Hexagon testing fixtures used to deduce testing argument -values from testing parameters""" - -from __future__ import annotations - -import os -import random -import socket -from typing import TYPE_CHECKING - -import pytest - -import tvm -import tvm.rpc -import tvm.testing - -if TYPE_CHECKING: - from tvm.contrib.hexagon.build import HexagonLauncherRPC - from tvm.contrib.hexagon.session import Session - -HEXAGON_TOOLCHAIN = "HEXAGON_TOOLCHAIN" -TVM_TRACKER_HOST = "TVM_TRACKER_HOST" -TVM_TRACKER_PORT = "TVM_TRACKER_PORT" -ANDROID_REMOTE_DIR = "ANDROID_REMOTE_DIR" -ANDROID_SERIAL_NUMBER = "ANDROID_SERIAL_NUMBER" -ADB_SERVER_SOCKET = "ADB_SERVER_SOCKET" -HEXAGON_SIMULATOR_NAME = "simulator" -RNG_SEEDED = False - -HEXAGON_AOT_LLVM_TARGET = { - "kind": "llvm", - "keys": ["hexagon", "cpu"], - "mattr": ["+hvxv68", "+hvx-length128b", "+hvx-qfloat", "-hvx-ieee-fp"], - "mcpu": "hexagonv68", - "mtriple": "hexagon", -} - - -@tvm.testing.fixture -def shape_nhwc(batch, in_channel, in_size): - return (batch, in_size, in_size, in_channel) - - -def android_serial_number() -> str | None: - """Return the android serial number""" - serial = os.getenv(ANDROID_SERIAL_NUMBER, default="") - # Setting ANDROID_SERIAL_NUMBER to an empty string should be - # equivalent to having it unset. - if not serial.strip(): - return None - - # Split android serial numbers into a list - serial = serial.split(",") - return serial - - -# NOTE on server ports: -# These tests use different port numbers for the RPC server (7070 + ...). -# The reason is that an RPC session cannot be gracefully closed without -# triggering TIME_WAIT state on the server socket. This prevents another -# server to bind to the same port until the wait time elapses. - -LISTEN_PORT_MIN = 6000 # Avoid hitting well-known Android debug ports -LISTEN_PORT_MAX = 9000 # Below the search range end (port_end=9199) of RPC server -PREVIOUS_PORT = None - - -def get_free_port() -> int: - """Return the next port that is available to listen on""" - global PREVIOUS_PORT - global RNG_SEEDED - - if tvm.testing.utils.IS_IN_CI and not RNG_SEEDED: - random.seed(0) - RNG_SEEDED = True - - if PREVIOUS_PORT is None: - port = random.randint(LISTEN_PORT_MIN, LISTEN_PORT_MAX) - else: - port = PREVIOUS_PORT + 1 - if port > LISTEN_PORT_MAX: - port = LISTEN_PORT_MIN - - while _is_port_in_use(port): - port = port + 1 if port < LISTEN_PORT_MAX else LISTEN_PORT_MIN - - PREVIOUS_PORT = port - return port - - -def _is_port_in_use(port: int) -> bool: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - return sock.connect_ex(("localhost", port)) == 0 - - -@pytest.fixture(scope="session") -def _tracker_info() -> str | int: - env_tracker_host = os.getenv(TVM_TRACKER_HOST, default="") - env_tracker_port = os.getenv(TVM_TRACKER_PORT, default="") - - if env_tracker_host or env_tracker_port: - # A tracker is already running, and we should connect to it - # when running tests. - assert env_tracker_host, "TVM_TRACKER_PORT is defined, but TVM_TRACKER_HOST is not" - assert env_tracker_port, "TVM_TRACKER_HOST is defined, but TVM_TRACKER_PORT is not" - env_tracker_port = int(env_tracker_port) - - try: - tvm.rpc.connect_tracker(env_tracker_host, env_tracker_port) - except RuntimeError as exc: - message = ( - "Could not connect to external tracker " - "specified by $TVM_TRACKER_HOST and $TVM_TRACKER_PORT " - f"({env_tracker_host}:{env_tracker_port})" - ) - raise RuntimeError(message) from exc - - yield (env_tracker_host, env_tracker_port) - - else: - # No tracker is provided to the tests, so we should start one - # for the tests to use. Import tvm.rpc.tracker lazily since it - # requires the optional tornado package. - pytest.importorskip("tornado", reason="tvm.rpc.tracker requires tornado") - from tvm.rpc.tracker import Tracker - - tracker = Tracker("127.0.0.1", get_free_port()) - try: - yield (tracker.host, tracker.port) - finally: - tracker.terminate() - - -@pytest.fixture(scope="session") -def tvm_tracker_host(_tracker_info) -> str: - host, _ = _tracker_info - return host - - -@pytest.fixture(scope="session") -def tvm_tracker_port(_tracker_info) -> int: - _, port = _tracker_info - return port - - -@pytest.fixture(scope="session") -def rpc_server_port_for_session() -> int: - return get_free_port() - - -@pytest.fixture() -def rpc_server_port() -> int: - return get_free_port() - - -@pytest.fixture(scope="session") -def adb_server_socket() -> str: - return os.getenv(ADB_SERVER_SOCKET, default="tcp:5037") - - -@pytest.fixture(scope="session") -def hexagon_server_process( - request, - rpc_server_port_for_session, - adb_server_socket, - skip_rpc, - hexagon_debug, - sysmon_profile, - clear_logcat, -) -> HexagonLauncherRPC: - """Initials and returns hexagon launcher if ANDROID_SERIAL_NUMBER is defined. - This launcher is started only once per test session. - """ - android_serial_num = android_serial_number() - - if android_serial_num is None: - pytest.skip("ANDROID_SERIAL_NUMBER is not set.") - if android_serial_num == [HEXAGON_SIMULATOR_NAME]: - yield None - else: - from tvm.contrib.hexagon.build import HexagonLauncher - - # Requesting these fixtures sets up a local tracker, if one - # hasn't been provided to us. Delaying the evaluation of - # these fixtures avoids starting a tracker unless necessary. - tvm_tracker_host = request.getfixturevalue("tvm_tracker_host") - tvm_tracker_port = request.getfixturevalue("tvm_tracker_port") - - rpc_info = { - "rpc_tracker_host": tvm_tracker_host, - "rpc_tracker_port": tvm_tracker_port, - "rpc_server_port": rpc_server_port_for_session, - "adb_server_socket": adb_server_socket, - } - workerinput = getattr(request.config, "workerinput", None) - if workerinput is None: # single-process execution - device_adr = read_device_list()[0] - else: # running in a subprocess here - device_adr = workerinput["device_adr"] - launcher = HexagonLauncher( - serial_number=device_adr, - rpc_info=rpc_info, - hexagon_debug=hexagon_debug, - sysmon_profile=sysmon_profile, - clear_logcat=clear_logcat, - ) - try: - if not skip_rpc: - launcher.start_server() - yield {"launcher": launcher, "device_adr": device_adr} - finally: - if not skip_rpc: - launcher.stop_server() - - -def read_device_list(): - return android_serial_number() - - -def pytest_configure(config): - # read device list if we are on the master - if not hasattr(config, "workerinput"): - config.iplist = read_device_list() - - -def pytest_configure_node(node): - # the master for each node fills node input dictionary - # which pytest-xdist will transfer to the subprocess - if node.config.iplist is not None: - node.workerinput["device_adr"] = node.config.iplist.pop() - - -@pytest.fixture -def hexagon_launcher( - hexagon_server_process, - rpc_server_port, - tvm_tracker_host, - tvm_tracker_port, - adb_server_socket, - hexagon_debug, - sysmon_profile, - clear_logcat, -) -> HexagonLauncherRPC: - """Initials and returns hexagon launcher which reuses RPC info and Android serial number.""" - android_serial_num = android_serial_number() - - if android_serial_num != [HEXAGON_SIMULATOR_NAME]: - rpc_info = hexagon_server_process["launcher"]._rpc_info - else: - rpc_info = { - "rpc_tracker_host": tvm_tracker_host, - "rpc_tracker_port": tvm_tracker_port, - "rpc_server_port": rpc_server_port, - "adb_server_socket": adb_server_socket, - } - from tvm.contrib.hexagon.build import HexagonLauncher - - launcher = None - try: - if android_serial_num == [HEXAGON_SIMULATOR_NAME]: - launcher = HexagonLauncher(serial_number=android_serial_num[0], rpc_info=rpc_info) - launcher.start_server() - else: - launcher = HexagonLauncher( - serial_number=hexagon_server_process["device_adr"], - rpc_info=rpc_info, - hexagon_debug=hexagon_debug, - sysmon_profile=sysmon_profile, - clear_logcat=clear_logcat, - ) - yield launcher - finally: - if launcher is not None: - if android_serial_num == [HEXAGON_SIMULATOR_NAME]: - launcher.stop_server() - elif not hexagon_debug: - launcher.cleanup_directory() - - -@pytest.fixture -def hexagon_session(hexagon_launcher: HexagonLauncherRPC) -> Session: - if hexagon_launcher is None: - yield None - else: - with hexagon_launcher.create_session() as session: - yield session - - -# If the execution aborts while an RPC server is running, the python -# code that is supposed to shut it down will never execute. This will -# keep pytest from terminating (indefinitely), so add a cleanup -# fixture to terminate any still-running servers. -@pytest.fixture(scope="session", autouse=True) -def terminate_rpc_servers(): - # Since this is a fixture that runs regardless of whether the - # execution happens on simulator or on target, make sure the - # yield happens every time. - serial = os.environ.get(ANDROID_SERIAL_NUMBER) - yield [] - if serial == [HEXAGON_SIMULATOR_NAME]: - os.system("ps ax | grep tvm_rpc_x86 | awk '{print $1}' | xargs kill") - - -aot_host_target = tvm.testing.parameter(HEXAGON_AOT_LLVM_TARGET) - - -@tvm.testing.fixture -def aot_target(aot_host_target): - if aot_host_target == "c": - yield tvm.target.Target({"kind": "hexagon", "mtriple": "hexagon", "mcpu": "hexagonv68"}) - elif isinstance(aot_host_target, dict) and aot_host_target.get("kind") == "llvm": - yield aot_host_target - elif isinstance(aot_host_target, str) and aot_host_target.startswith("llvm"): - yield aot_host_target - else: - assert False, "Incorrect AoT host target: {aot_host_target}. Options are [c, llvm]." - - -@pytest.fixture(scope="session") -def skip_rpc(request) -> bool: - return request.config.getoption("--skip-rpc") - - -@pytest.fixture(scope="session") -def hexagon_debug(request) -> bool: - return request.config.getoption("--hexagon-debug") - - -@pytest.fixture(scope="session") -def sysmon_profile(request) -> bool: - return request.config.getoption("--sysmon-profile") - - -@pytest.fixture(scope="session") -def clear_logcat(request) -> bool: - return request.config.getoption("--clear-logcat") - - -def pytest_addoption(parser): - """Add pytest options.""" - - parser.addoption( - "--skip-rpc", - action="store_true", - default=False, - help="If set true, the RPC server initialization on Android would be skipped", - ) - parser.addoption( - "--hexagon-debug", - action="store_true", - default=False, - help="If set true, it will keep the hexagon test directories on the target. " - + "Additionally logcat logs will be copied from device and cdsp errors printed out.", - ) - parser.addoption( - "--sysmon-profile", - action="store_true", - default=False, - help="If set true, it will run sysmon profiler during the tests.", - ) - parser.addoption( - "--clear-logcat", - action="store_true", - default=False, - help="If set true, it will clear logcat before execution.", - ) diff --git a/python/tvm/contrib/hexagon/session.py b/python/tvm/contrib/hexagon/session.py deleted file mode 100644 index 8bdaf33da094..000000000000 --- a/python/tvm/contrib/hexagon/session.py +++ /dev/null @@ -1,287 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=consider-using-from-import - -"""Defines a Session class for Hexagon devices.""" - -import os -import pathlib -import tempfile - -import tvm -import tvm.contrib.hexagon as hexagon -from tvm import rpc as _rpc -from tvm import runtime -from tvm.support import utils - -from .tools import HEXAGON_SIMULATOR_NAME, export_module - - -class Session: - """Hexagon Device Session - - Parameters - ---------- - remote_workspace : Union[str, pathlib.Path] - Remote workspace path - - rpc_tracker : tuple(str, int) - RPC tracker host and port number. - - rpc_server_key : str - RPC server key on remote device. - - serial_number : str - Device serial number. `simulator` used for hexagon simulator. - - session_name : str - Hexagon RPC session name. - - remote_stack_size_bytes : int - The stack size of the remote device, to be passed to - tvm.contrib.hexagon.create_hexagon_session. - - rpc_receive_buffer_size_bytes : int - RPC receive buffer size in bytes. - """ - - def __init__( - self, - remote_workspace: str | pathlib.Path, - rpc_tracker: tuple, - rpc_server_key: str, - serial_number: str, - session_name: str = "hexagon-rpc", - remote_stack_size_bytes: int = 256 * 1024, # Min size for main thread in QuRT/sim - rpc_receive_buffer_size_bytes: int = 256 * 1024 * 1024, # Size for passing hexagon tests - ): - self._workspace = str(remote_workspace) - self._rpc_tracker = rpc_tracker - self._rpc_server_key = rpc_server_key - self._serial_number = serial_number - self._session_name: str = session_name - self._remote_stack_size_bytes: int = remote_stack_size_bytes - self._rpc_receive_buffer_size_bytes: int = rpc_receive_buffer_size_bytes - self._rpc = None - self._requires_cpu_device = False - self._device = None - - def __enter__(self): - if self._rpc: - # Already initialized - return self - - tracker = _rpc.connect_tracker(self._rpc_tracker[0], self._rpc_tracker[1]) - try: - self._rpc = tracker.request( - self._rpc_server_key, - priority=0, - session_timeout=0, - session_constructor_args=[ - "tvm.contrib.hexagon.create_hexagon_session", - self._session_name, - self._remote_stack_size_bytes, - os.environ.get("HEXAGON_SIM_ARGS", ""), - self._rpc_receive_buffer_size_bytes, - ], - ) - func = self._rpc.get_function("device_api.hexagon.acquire_resources") - func() - return self - - except RuntimeError as exception: - raise exception - - def __exit__(self, exc_type, exc_value, exc_traceback): - try: - func = self._rpc.get_function("device_api.hexagon.release_resources") - func() - except RuntimeError as exception: - print( - "Exception occurred while calling release_resources() during Session __exit__: ", - exception, - ) - finally: - # close session to the tracker - shutdown_func = self._rpc._sess.get_function("CloseRPCConnection") - shutdown_func() - del self._rpc - - @property - def device(self): - """Session device.""" - - if self._device is not None: - return self._device - - if self._requires_cpu_device: - self._device = self._rpc.cpu(0) - else: - self._device = self._rpc.hexagon(0) - - return self._device - - def is_simulator(self): - return self._serial_number == HEXAGON_SIMULATOR_NAME - - def get_function(self, name): - return self._rpc.get_function(name) - - def upload(self, local_path: str | pathlib.Path, remote_filename: str) -> pathlib.Path: - """Upload a local file to the remote workspace. - - Parameters - ---------- - local_path : str or pathlib.Path - Path to the local file to be copied. - remote_filename : str - Name of the file in the remote workspace. - - Returns - ------- - pathlib.Path : - Uploaded file remote path. - """ - upload_func = self._rpc.get_function("tvm.rpc.server.upload") - remote_path = f"{self._workspace}/{remote_filename}" - with open(local_path, mode="rb") as src_f: - data = bytearray(src_f.read()) - upload_func(remote_path, data) - return remote_path - - def load_module(self, module: str | pathlib.Path | tvm.runtime.Module): - """Load TVM module. - - The session must be established (via __enter__) prior to - calling this function. - - Parameters - ---------- - module : Union[str, pathlib.Path, tvm.runtime.Module] - - The module to load. If `module` is a - `tvm.runtime.Module`, it will be uploaded to the remote - session and loaded. - - If the object passed is a string or pathlib.Path, it must - be a full path in the remote system. - - Returns - ------- - TVMModule : - TVM module object. - """ - - assert self._rpc is not None, "Hexagon session must be started using __enter__ prior to use" - - if isinstance(module, tvm.runtime.Module): - with tempfile.TemporaryDirectory() as temp_dir: - binary_name = "test_binary.so" - binary_path = export_module(module, temp_dir, binary_name) - remote_file_path = self.upload(binary_path, binary_name) - else: - remote_file_path = module - - assert isinstance(remote_file_path, str | pathlib.Path), "Invalid path type:" + str( - type(remote_file_path) - ) - return self._rpc.get_function("tvm.hexagon.load_module")(str(remote_file_path)) - - def get_executor_from_factory( - self, module: runtime.Executable | str, hexagon_arch: str = "v68" - ): - """Create a local GraphModule which consumes a remote libmod. - - Parameters - ---------- - - module : Union[runtime.Executable, str] - - The module to upload to the remote - session and load. - hexagon_arch : str - The hexagon arch to be used - """ - if isinstance(module, runtime.Executable | str): - return self._vm_executable_executor(module, hexagon_arch=hexagon_arch) - - raise TypeError(f"Unsupported executor type: {type(module)}") - - def _set_device_type(self, module: str | pathlib.Path): - """Set session device type(hexagon, cpu) based on target in module. - - Parameters - ---------- - - module: TVMModule - TVM module object. - """ - # for cases when module is a single schedule without target attribute. - if not hasattr(module, "target"): - self._requires_cpu_device = False - else: - assert len(module.target) == 1 - for target in module.target: - target_type = str(target).split()[0] - - if target_type == "llvm": - self._requires_cpu_device = True - else: - self._requires_cpu_device = False - - def _vm_executable_executor(self, executable: runtime.Executable | str, hexagon_arch: str): - """Create a local TVM module which consumes a remote vm executable. - - Parameters - ---------- - - executable : runtime.Executable - The Executable to upload to the remote and load. This will typically be the - output of `tvm.compile` or the path to an already built and exported shared library - hexagon_arch : str - The hexagon arch to be used - - Returns - ------- - TVMModule : - TVM module object - """ - assert self._rpc is not None, "Hexagon session must be started using __enter__ prior to use" - - if isinstance(executable, runtime.Executable): - temp_dir = utils.tempdir() - path_exec = temp_dir.relpath("exec.so") - - executable.export_library( - path_exec, - fcompile=hexagon.create_aot_shared, - hexagon_arch=hexagon_arch, - ) - - path = self.upload(path_exec, "exec.so") - elif isinstance(executable, str): - path_exec = executable - else: - raise TypeError(f"Unsupported executor type: {type(executable)}") - - path = self.upload(path_exec, "exec.so") - return self._rpc.get_function("tvm.hexagon.load_module")(str(path)) - - def get_profile_output(self, mode: str, path: str): - assert isinstance(mode, str), f"Invalid mode type, {type(mode)} != str" - assert isinstance(path, str), f"Invalid path type, {type(path)} != str" - return self._rpc.get_function("tvm.hexagon.get_profile_output")(mode, path) diff --git a/src/backend/hexagon/runtime/profiler/README.md b/src/backend/hexagon/runtime/profiler/README.md deleted file mode 100644 index fdcc94f69203..000000000000 --- a/src/backend/hexagon/runtime/profiler/README.md +++ /dev/null @@ -1,85 +0,0 @@ - - - - - - - - - - - - - - - - - -# Hexagon lightweight instrumentation based profiling (LWP) - -For Hexagon, LWP can be used to get function and loop level processor cycle count. -This is done by instrumenting the code with profiling builtin calls using a TIR pass. -During codegen, these builtin calls are replaced with the calls to a hexagon specific -handler which records the runtime information into a buffer. -This buffer is written into a JSON file ('lwp.json') which is processed to construct -function and loop level profiling information as a csv file. - -**Note:** During codegen, the profiling builtin calls are ignored for other targets. - -The TIR pass offers several config flags to control the level of instrumentation -as mentioned below: - -1) `lwp_disable_func_prof`: To disable function level profiling. By default, it is -set to 'False', i.e., the function level profiling is enabled. - -2) `instr_siblings`: When enabled, only loops with siblings are instrumented and rest are -ignored. The inner-most loops are always excluded from instrumentation unless overwritten -using `lwp_min_height`. This is done to minimize the adverse effect of instrumentation on -actual performance. By default, it is set to 'True'. - -3) `lwp_max_depth`: To instrument loops up to a certain depth. This flag is effective -only when `instr_siblings` is disabled. By default, it is set to 0. - -4) `lwp_min_height`: To exclude inner loops up to a certain height from instrumentation. -By default, it is set to 1. - -For additional usage information on various config flags, please refer to the tests in -`tests/python/tir-transform/test_tir_transform_profiling_instr.py` - - -## How to use lightweight profiling with RPC Launcher: - -`tests/python/contrib/test_hexagon/test_launcher.py` contains two tests, `test_lwp` and -`test_lwp_multiple_conv2d`, to demonstrate lightweight profiling usage. - -The steps involved are as follows: - -1) While building a model, set `tir.instrument_lwp` to `True`. - By default, the builtin calls will only be inserted for the loops with siblings. But it - can be altered using LWP config options as described above. -2) Create `HexagonProfiler` object - -4) Run the model and get the profiling data as a CSV file. It is done by post-processing - 'lwp.json' file generated during runtime. - -``` - graph_mod.run(**inputs) - - # Get lightweight profiling output as a CSV file - profiler.get_profile_output(hexagon_launcher, hexagon_session, hexagon_server_process) -``` -**Note:** - -- For on-device runs, 'lwp.json' is copied into a temp directory along with the test .so and the processed - CSV file -- For the simulator runs, the file is generated in the simulator test output directory. Test .so - will still be in a separate temp directory. lwp CSV file will also be in the same directory. - -**Helpful Hints:** - -- To prevent the test directories on the Hexagon device as well as temporary test directory on x86 -from being deleted for profiling related runs, pass `--hexagon-debug` to pytest. - -``` -python -m pytest --hexagon-debug tests/python/contrib/test_hexagon/test_launcher.py::test_lwp -``` diff --git a/src/backend/hexagon/runtime/rpc/android_bash.sh.template b/src/backend/hexagon/runtime/rpc/android_bash.sh.template deleted file mode 100644 index c45b03818fd3..000000000000 --- a/src/backend/hexagon/runtime/rpc/android_bash.sh.template +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/sh -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -export LD_LIBRARY_PATH=. - -# Enable FARF-based logging for Hexagon code invoked by 'tvm_rpc_android_server'. -export ADSP_LIBRARY_PATH=`pwd` - -echo > tvm_rpc_android.farf - -./tvm_rpc_android server --port= --tracker=: --key= >${PWD}/tvm_rpc_android.log 2>&1 & - -rpc_pid=$! - -rm -f rpc_pid.txt -echo $rpc_pid >> rpc_pid.txt diff --git a/tests/python/contrib/test_hexagon/README.md b/tests/python/contrib/test_hexagon/README.md deleted file mode 100644 index 7c79626f9e57..000000000000 --- a/tests/python/contrib/test_hexagon/README.md +++ /dev/null @@ -1,130 +0,0 @@ - - - - - - - - - - - - - - - - - -# Test TVM on Hexagon -This document explains various pieces that are involved in testing TVM on an Android device which includes Hexagon DSP or Hexagon simulator. - -## What is HexagonLauncherRPC? -HexagonLauncherRPC is a class to handle interactions with an Android phone which includes Hexagon DSP or Hexagon simulator to run a TVMModule(function/operation/graph) on Hexagon. HexagonLauncherRPC reuses [minRPC](https://github.com/apache/tvm/tree/main/src/runtime/minrpc) implementation to set up an RPC connection from host (your local machine) to Hexagon target, and it is passed through Android RPC server. - -## Build Required Tools/Libraries -To build TVM for Hexagon and run tests you need to run multiple steps which includes preparing required tools, setting up environment variables and building various versions of TVM. Alternatively, you can skip these instructions and use docker image which has pre-installed required tools. We highly recommend to use docker, especially if this is your first time working with Hexagon. For instructions on using docker image follow ["use hexagon docker image"](#use-hexagon-docker-image). - -- Build TVMRuntime library and C++ RPC server for Android. -- Build minRPC server along with FastRPC for Hexagon. -- Build TVM library with Hexagon support for host machine. -- Build TVMRuntime library and RPC server for host machine. - -First, ensure to export Clang libraries to `LD_LIBRARY_PATH` and Hexagon toolchain to `HEXAGON_TOOLCHAIN`: -```bash -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"Path to `llvm-clang/lib` sub-directory. Currently we use LLVM-13 in TVM CI." - -export HEXAGON_TOOLCHAIN="Path to Hexagon toolchain. It can be the Hexagon toolchain included in the SDK, for example `HEXAGON_SDK_ROOT/tools/HEXAGON_Tools/x.y.z/Tools`. The `x.y.z` in the path is the toolchain version number, which is specific to the version of the SDK." -``` - -You can find more information about downloading [Hexagon SDK](https://developer.qualcomm.com/software/hexagon-dsp-sdk). - -First build Hexagon API application under `apps/hexagon_api`. This step will generate `tvm_rpc_android` and `libtvm_runtime.so` to run on Android. Also, it generates `libtvm_runtime.a` `libtvm_runtime.so`, `libhexagon_rpc_skel.so` and `libhexagon_rpc_sim.so` to run on Hexagon device or Hexagon simulator. - -```bash -cd apps/hexagon_api -mkdir build -cd build -cmake -DANDROID_ABI=arm64-v8a \ - -DANDROID_PLATFORM=android-28 \ - -DUSE_ANDROID_TOOLCHAIN="path to `android-ndk/build/cmake/android.toolchain.cmake` file" \ - -DUSE_HEXAGON_ARCH=v65|v66|v68|v69|v73|v75 \ - -DUSE_HEXAGON_SDK="path to Hexagon SDK" \ - -DUSE_HEXAGON_TOOLCHAIN="path to Hexagon toolchain `Tools` sub-directory which explained above" \ - -DUSE_OUTPUT_BINARY_DIR="path to `build/hexagon_api_output` which is a sub-directory of `tvm`" .. - -make -j2 -``` - -Next, we need to build TVM on host with RPC and Hexagon dependencies. To do that follow these commands. - -```bash -cd tvm -mkdir build -cd build -cmake -DUSE_LLVM="path to `llvm/bin/llvm-config`" \ - -DUSE_RPC=ON \ - -DCMAKE_CXX_COMPILER="path to `clang++` executable" \ - -DUSE_HEXAGON_SDK="path to Hexagon SDK" \ - -DUSE_HEXAGON=ON .. - -make -j2 -``` - -## Use Hexagon Docker Image -To use hexagon docker image, install TVM and Hexagon API follow these steps from your TVM home directory: - -```bash -# Log in to docker image -./docker/bash.sh ci_hexagon - -# Build TVM -rm -rf build -mkdir build && cd build -cmake -DUSE_LLVM=ON \ - -DUSE_RPC=ON \ - -DUSE_HEXAGON_SDK=$HEXAGON_SDK_PATH \ - -DUSE_HEXAGON=ON .. -make -j2 -``` - -Now that you have built required tools, you can jump to [run test examples](#run-tests). - -## Run Tests -You have the options of running Hexagon test on real hardware or on Hexagon simulator. Also, depending on whether you decided to use Hexagon docker image or not we will explain both cases here. - -### Only follow these steps if running tests outside of docker -```bash -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"path to `llvm-clang/lib` sub-directory" - -export HEXAGON_TOOLCHAIN="Path to Hexagon toolchain. It can be the Hexagon toolchain included in the HexagonSDK, for example `HEXAGON_SDK_ROOT/tools/HEXAGON_Tools/x.y.z/Tools`. The `x.y.z` in the path is the toolchain version number, which is specific to the version of the SDK." - -export PYTHONPATH=$PYTHONPATH:"path to `tvm/python`" -``` - -### Now, follow these steps -**Note:** If you are using Hexagon docker image, first step is to log into the Hexagon docker image. Following these commands you will log in to the most recent version of Hexagon docker image on your TVM local branch. Since we have already built TVM for hexagon, we can just log in and use it. From your TVM home directory: - -```bash -./docker/bash.sh ci_hexagon -``` - -Now, you need to export few environment variables and execute following commands: - -```bash -# Run RPC Tracker in the background -export TVM_TRACKER_HOST="Your host IP address or 0.0.0.0" -export TVM_TRACKER_PORT="Port number of your choice." -python -m tvm.exec.rpc_tracker --host $TVM_TRACKER_HOST --port $TVM_TRACKER_PORT& - -# Only For real hardware testing -export ANDROID_SERIAL_NUMBER="You can get this number by running 'adb devices' command" - -# Only For simulator testing -export HEXAGON_SHARED_LINK_FLAGS="-Lbuild/hexagon_api_output -lhexagon_rpc_sim" -export ANDROID_SERIAL_NUMBER="simulator" -``` - -Finally, to run a Hexagon Launcher tests you can run: -```bash -pytest tests/python/contrib/test_hexagon/test_launcher.py -``` diff --git a/tests/python/contrib/test_hexagon/README_RPC.md b/tests/python/contrib/test_hexagon/README_RPC.md deleted file mode 100644 index 2bbedc95995f..000000000000 --- a/tests/python/contrib/test_hexagon/README_RPC.md +++ /dev/null @@ -1,371 +0,0 @@ - - - - - - - - - - - - - - - - - - -# A life of a Hexagon API call - -The goal is to understand what exactly is happening during `A_data.copyfrom(np.array([2, 3]))`, where `A_data` lives in Hexagon. - -## Overview -The diagram below describes the sequence of calls and components involved when memcpy over the Hexagon device is invoked. - -![Overview of RPC](https://github.com/tlc-pack/web-data/raw/main/images/design/tvm-hex-rpc.png) - -The communication between x86 and Android is done via the standard TVM RPC protocol implemented mostly in `src/runtime/rpc/rpc_endpoint.cc`. - -A packet between Android and Hexagon is proxy-ed by the Hexagon FastRPC mechanism. FastRPC depends on the auto-generated implementations of client- and server- side API. During the build time, the Android side API (”stub”) and the Hexagon side API (”skel”) is generated from `src/runtime/hexagon/rpc/hexagon_rpc.idl` (see `cmake/modules/Hexagon.cmake`). - -When TVM’s RPC server on Android, `tvm_rpc_android_server`, invokes `hexagon_rpc_send(...)`, it actually calls into the same-name function defined in the stub with the exact same arguments (which includes the URI for the `*skel.so` library to use on Hexagon, which in our case is `libhexagon_rpc_skel.so`). Similarly, on the Hexagon side, `hexagon_rpc_send(...)` call is first intercepted by the “skel” API, which in turn calls the actual implementation defined in `src/runtime/hexagon/rpc/hexagon/rpc_server.cc`. - -## Initialization: Setting up Android and establishing connection between x86 host and Android - -What’s happening during the launcher initialization at [https://github.com/apache/tvm/blob/7cfaa88e6c18edc0a41e1a984d3cb9d8659a1c2c/tests/python/contrib/test_hexagon/test_launcher.py#L71-L73](https://github.com/apache/tvm/blob/7cfaa88e6c18edc0a41e1a984d3cb9d8659a1c2c/tests/python/contrib/test_hexagon/test_launcher.py#L71-L73) ? - -```python -launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) -launcher.upload(dso_binary_path, dso_binary) -launcher.start_server() -``` - -Here, we send various files over Android via `adb`, and initialize a RPC server via `tvm_rpc_android` binary (built from [https://github.com/apache/tvm/tree/main/apps/cpp_rpc](https://github.com/apache/tvm/tree/main/apps/cpp_rpc)): - -[https://github.com/apache/tvm/blob/0c0245ae2230fa07d3e4b8be490fc9c88965730c/python/tvm/contrib/hexagon/build.py#L373-L378](https://github.com/apache/tvm/blob/0c0245ae2230fa07d3e4b8be490fc9c88965730c/python/tvm/contrib/hexagon/build.py#L373-L378) - -```python -subprocess.Popen( - self._adb_device_sub_cmd + ["shell", f"cd {self._workspace} && ./android_bash.sh"], - stdout=subprocess.PIPE, - stdin=subprocess.PIPE, - stderr=subprocess.PIPE, -) -``` - -[https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android_bash.sh.template#L20](https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android_bash.sh.template#L20) - -``` -./tvm_rpc_android server --port= --tracker=: --key=& -``` - -When we do `launcher.create_session()` , a remote RPC session between x86 and Android is established via this line: - -[https://github.com/apache/tvm/blob/0c0245ae2230fa07d3e4b8be490fc9c88965730c/python/tvm/contrib/hexagon/session.py#L57-L67](https://github.com/apache/tvm/blob/0c0245ae2230fa07d3e4b8be490fc9c88965730c/python/tvm/contrib/hexagon/session.py#L57-L67) - -```python -self._rpc = tracker.request( - ... - session_constructor_args=[ - "tvm.contrib.hexagon.create_hexagon_session", - self._session_name, - self._remote_stack_size_bytes, - ], -) -``` - -Which eventually jumps to the following line in C++, which creates a RPC client session on an x86 host and run a server initialization function `tvm.contrib.hexagon.create_hexagon_session` on Android: - -[https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L123-L129](https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L123-L129) - -```cpp -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed("rpc.Connect", [](ffi::PackedArgs args, ffi::Any* rv) { - auto url = args[0].cast(); - int port = args[1].cast(); - auto key = args[2].cast(); - *rv = RPCClientConnect(url, port, key, - ffi::PackedArgs(args.values + 3, args.type_codes + 3, args.size() - 3)); - }); -} -``` - -`tvm.contrib.hexagon.create_hexagon_session` is defined here. It establishes a link between Android and Hexagon, this code runs on Android. - -[https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android/session.cc#L106](https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android/session.cc#L106) - -```cpp - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def_packed( - "tvm.contrib.hexagon.create_hexagon_session", [](ffi::PackedArgs args, ffi::Any* rv) { - auto session_name = args[0].cast(); - int remote_stack_size_bytes = args[1].cast(); - HexagonTransportChannel* hexagon_channel = - new HexagonTransportChannel(hexagon_rpc_URI CDSP_DOMAIN, remote_stack_size_bytes); - std::unique_ptr channel(hexagon_channel); - auto ep = RPCEndpoint::Create(std::move(channel), session_name, "", NULL); - auto sess = CreateClientSession(ep); - *rv = CreateRPCSessionModule(sess); - }); -} -``` - -`HexagonTransportChannel` is the one that actually knows how to talk to Hexagon. It uses functions such as `hexagon_rpc_send`, `hexagon_rpc_receive` defined in - -[https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/hexagon/rpc_server.cc](https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/hexagon/rpc_server.cc) - -## x86 host → Android - -`A_data.copyfrom(np.array([2, 3]))` reaches this line. This is the boundary between Python and C++ land in TVM FFI: - -[https://github.com/apache/tvm/blob/b2757817af7ba3aefe16ea3ccb6d4982dd7fd531/python/tvm/runtime/ndarray.py#L183](https://github.com/apache/tvm/blob/b2757817af7ba3aefe16ea3ccb6d4982dd7fd531/python/tvm/runtime/ndarray.py#L183) - -```python -check_call(_LIB.TVMTensorCopyFromBytes(self.handle, data, nbytes)) -``` - -[https://github.com/apache/tvm/blob/37cd9837ff302e4490696ca57a9fbba6404c7046/src/runtime/tensor.cc#L322](https://github.com/apache/tvm/blob/37cd9837ff302e4490696ca57a9fbba6404c7046/src/runtime/tensor.cc#L322) - -```cpp -int TVMTensorCopyFromBytes(TVMArrayHandle handle, void* data, size_t nbytes) { - API_BEGIN(); - TensorCopyFromBytes(handle, data, nbytes); - API_END(); -} -``` - -Now we come to `TensorCopyFromBytes` function. The first non-obvious question is, which `DeviceAPI` is selected by `DeviceAPI::Get(handle->device)`? - -```cpp -void TensorCopyFromBytes(DLTensor* handle, const void* data, size_t nbytes) { - ... - DLTensor from; - ... - DeviceAPI::Get(handle->device)->CopyDataFromTo(&from, handle, nullptr); - // Synchronize in case data become unavailable later. - DeviceAPI::Get(handle->device)->StreamSync(handle->device, nullptr); -} -``` - -The answer: `RPCDeviceAPI` defined below, not `HexagonDeviceAPI`. - -[https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_device_api.cc#L34](https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_device_api.cc#L34) - -```cpp -class RPCDeviceAPI final : public DeviceAPI { - ... -``` - -This is due to the fact that `sess.device`, used in `test_launcher.py` below, encodes two pieces of information: (1) The device is RPC and (2) it wraps the underlying “real” device Hexagon. - -[https://github.com/apache/tvm/blob/2b35cfd6ddb73afecd3f550f33881e1fdc7c3267/tests/python/contrib/test_hexagon/rpc/test_launcher.py#L112](https://github.com/apache/tvm/blob/2b35cfd6ddb73afecd3f550f33881e1fdc7c3267/tests/python/contrib/test_hexagon/rpc/test_launcher.py#L112) - -See below for how `sess.device` is created during `HexagonLauncher` initialization. - - `self.device = self._rpc.hexagon(0)`. - -[https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/python/tvm/contrib/hexagon/session.py#L64](https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/python/tvm/contrib/hexagon/session.py#L64) - -`RPCDeviceAPI::CopyDataFromTo` is defined in [https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_device_api.cc#L80](https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_device_api.cc#L80) - -Here, we meet another `GetAPI` call: - -```cpp -GetSess(dev_from)->GetDeviceAPI(remote_dev)->CopyDataFromTo(&from_tensor, &to_tensor, stream); -``` - -[https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_device_api.cc#L94](https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_device_api.cc#L94) - -At first, it is not obvious where this `CopyDataFromTo` jumps to (initially I thought it would jump to `HexagonDeviceAPI`). Since `GetSess(dev_from)` returns the client RPC connection between x86 and Android, created during initialization in - -[https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L107](https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L107) - -```cpp -Module RPCClientConnect(std::string url, int port, std::string key, ffi::PackedArgs init_seq) { - auto endpt = RPCConnect(url, port, "client:" + key, init_seq); - return CreateRPCSessionModule(CreateClientSession(endpt)); -} -``` - -, this jumps to `RPCClientSession` class defined in [https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_endpoint.cc#L994](https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_endpoint.cc#L994) - -```cpp -class RPCClientSession : public RPCSession, public DeviceAPI { - ... -``` - -`rpc_endpoint.cc` is a very important file. It contains the core RPC protocol logic. `CopyDataFromTo` in `rpc_device_api.cc` jumps to - -[https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_endpoint.cc#L1060-L1062](https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_endpoint.cc#L1060-L1062) - -```cpp -void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) final { - endpoint_->SysCallRemote(RPCCode::kCopyAmongRemote, from, to, stream); -} -``` - -from which things transfer to the Android side. - -Here is where `RPCCode::kCopyAmongRemote` is handled: - -[https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_endpoint.cc#L979-L981](https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_endpoint.cc#L979-L981) - -```cpp -case RPCCode::kCopyAmongRemote: - SysCallHandler(RPCCopyAmongRemote); - break; -``` - -The handler is represented by `serving_session_`, which is initialized during server initialization at - -[https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_endpoint.cc#L541](https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_endpoint.cc#L541) - -```cpp -serving_session_ = RPCModuleGetSession(mod); -``` - -which corresponds to the Hexagon session created before in [https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android/session.cc#L106](https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android/session.cc#L106). - -The handler is passed to the following function - -[https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_endpoint.cc#L909-L922](https://github.com/apache/tvm/blob/899bc064e1bf8df915bcadc979a6f37210cdce33/src/runtime/rpc/rpc_endpoint.cc#L909-L922) - -```cpp -void RPCCopyAmongRemote(RPCSession* handler, ffi::PackedArgs args, ffi::Any* rv) { - auto from = args[0].cast(); - auto to = args[1].cast(); - ... - handler->GetDeviceAPI(dev)->CopyDataFromTo(from, to, stream); -} -``` - -This is an interesting function. Here, `handler` is again `RPCClientSession` due to the line in - -[https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android/session.cc#L114](https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android/session.cc#L114) - -```cpp -auto sess = CreateClientSession(ep); -``` - -so apparently, things might look like it is looping back to `RPCClientSession::CopyDataFromTo`: - -```cpp -void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) final { - endpoint_->SysCallRemote(RPCCode::kCopyAmongRemote, from, to, stream); - } -``` - -But this time, `endpoint_` is different. Previously, this `endpoint_` represented the connection between x86 and Android (created in [https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L99-L100](https://github.com/apache/tvm/blob/2cca934aad1635e3a83b712958ea83ff65704316/src/runtime/rpc/rpc_socket_impl.cc#L99-L100)), but this `endpoint_` belongs to the Hexagon session created in [https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android/session.cc#L113](https://github.com/apache/tvm/blob/cd2fa69677516048e165e84a88c774dfb0ee65d1/src/runtime/hexagon/rpc/android/session.cc#L113). So this is where the RPC communication between Android and Hexagon starts. - -## Android → Hexagon - -Recall that the `endpoint_` owned by the Hexagon session is created via `tvm.contrib.hexagon.create_hexagon_session` when the Android RPC server is being initialized. The `endpoint_` is represented by the following class: - -[https://github.com/apache/tvm/blob/c20cbc55c03f9f048b151a1221469b9888123608/src/runtime/hexagon/rpc/android/session.cc#L46](https://github.com/apache/tvm/blob/c20cbc55c03f9f048b151a1221469b9888123608/src/runtime/hexagon/rpc/android/session.cc#L46) - -```cpp -class HexagonTransportChannel : public RPCChannel { - public: - explicit HexagonTransportChannel(const std::string& uri, int remote_stack_size_bytes) { - ... - hexagon_rpc_open(uri.c_str(), &_handle); - ... - } - - size_t Send(const void* data, size_t size) override { - hexagon_rpc_send(_handle, static_cast(data), static_cast(size)); - ... - } -``` - -On construction, `hexagon_rpc_open` is called, which will initialize the TVM MinRPC server on Hexagon and overwrites `device_api.hexagon` registry to point to the call to `HexagonDeviceAPI`. [https://github.com/apache/tvm/blob/c20cbc55c03f9f048b151a1221469b9888123608/src/runtime/hexagon/rpc/hexagon/rpc_server.cc#L210-L213](https://github.com/apache/tvm/blob/c20cbc55c03f9f048b151a1221469b9888123608/src/runtime/hexagon/rpc/hexagon/rpc_server.cc#L210-L213) - -The endpoint routes each RPC packet by `Send` function, which in turn calls `hexagon_rpc_send(...)` defined in: - -[https://github.com/apache/tvm/blob/c20cbc55c03f9f048b151a1221469b9888123608/src/runtime/hexagon/rpc/hexagon/rpc_server.cc#L243](https://github.com/apache/tvm/blob/c20cbc55c03f9f048b151a1221469b9888123608/src/runtime/hexagon/rpc/hexagon/rpc_server.cc#L243) - -```cpp -AEEResult hexagon_rpc_send(remote_handle64 _handle, const unsigned char* data, - int dataLen) { - get_hexagon_rpc_server()->Write(reinterpret_cast(data), - static_cast(dataLen)); - ... -} -``` - -This is where FastRPC comes into play and things get very confusing. The endpoint lives in Android, so `hexagon_rpc_send` call (also `hexagon_rpc_open`) happens at Android. But the implementations of these functions in `rpc_server.cc` describe the behavior on the Hexagon side... What’s happening is that FastRPC “stub” and “skel” (see the overview at the top) API intercept those calls and play some magic behind the scene to make RPC call look transparent from the client (Android) perspective. - -So when the control comes to the point of definition of `hexagon_rpc_send` in `rpc_server.cc`, FastRPC has already finished its job and so we are really on the Hexagon side now. We come to `HexagonRPCServer::Write(...)` function, which in turn calls into TVM MinRPC server instance `rpc_server_` to process the incoming packet: - -[https://github.com/apache/tvm/blob/c20cbc55c03f9f048b151a1221469b9888123608/src/runtime/hexagon/rpc/hexagon/rpc_server.cc#L167](https://github.com/apache/tvm/blob/c20cbc55c03f9f048b151a1221469b9888123608/src/runtime/hexagon/rpc/hexagon/rpc_server.cc#L167) - -```cpp -int64_t Write(const uint8_t* data, size_t data_size_bytes) { - if (io_.SetReadBuffer(data, data_size_bytes) != AEE_SUCCESS) { - return -1; - } - rpc_server_.ProcessOnePacket(); - return (int64_t)data_size_bytes; -} -``` - -`MinRPCServer::ProcessOnePacket()` function dispatches to `HandleCopyFromRemote()` upon receiving `kCopyFromRemote` request: - -[https://github.com/apache/tvm/blob/8c125ca6090a29f38a66d26138b056b7de27cb0b/src/runtime/minrpc/minrpc_server.h#L87](https://github.com/apache/tvm/blob/8c125ca6090a29f38a66d26138b056b7de27cb0b/src/runtime/minrpc/minrpc_server.h#L87) - -```cpp -bool ProcessOnePacket() { - ... - - if (...) { - ... - } else { - switch (code) { - ... - case RPCCode::kCopyFromRemote: { - HandleCopyFromRemote(); - break; - } - ... -``` - -[https://github.com/apache/tvm/blob/8c125ca6090a29f38a66d26138b056b7de27cb0b/src/runtime/minrpc/minrpc_server.h#L178](https://github.com/apache/tvm/blob/8c125ca6090a29f38a66d26138b056b7de27cb0b/src/runtime/minrpc/minrpc_server.h#L178) - -```cpp -void HandleCopyFromRemote() { - DLTensor* arr = this->ArenaAlloc(1); - uint64_t data_handle; - this->Read(&data_handle); - arr->data = reinterpret_cast(data_handle); - ... - this->ReadArray(arr->shape, arr->ndim); - - if (...) { - ... - } else { - data_ptr = this->ArenaAlloc(num_bytes); - DLTensor temp; - ... - call_ecode = TVMDeviceCopyDataFromTo(arr, &temp, nullptr); - // need sync to make sure that the copy is completed. - if (call_ecode == 0) { - call_ecode = TVMSynchronize(arr->device.device_type, arr->device.device_id, nullptr); - } - } -``` - -And finally we see a call to `DeviceAPIManager::Get(dev)->CopyDataFromTo` which translates to `HexagonDeviceAPI::CopyDataFromTo` . - -[https://github.com/apache/tvm/blob/f929b0fc8e7a600978c9ac0418469bd70d046446/src/runtime/c_runtime_api.cc#L623-L630](https://github.com/apache/tvm/blob/f929b0fc8e7a600978c9ac0418469bd70d046446/src/runtime/c_runtime_api.cc#L623-L630) - -```cpp -int TVMDeviceCopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { - ... - DeviceAPIManager::Get(dev)->CopyDataFromTo(from, to, stream); - ... -} -``` diff --git a/tests/python/contrib/test_hexagon/__init__.py b/tests/python/contrib/test_hexagon/__init__.py deleted file mode 100644 index 07fb45c52d96..000000000000 --- a/tests/python/contrib/test_hexagon/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Testing infrastructure for Hexagon""" diff --git a/tests/python/contrib/test_hexagon/benchmark_util.py b/tests/python/contrib/test_hexagon/benchmark_util.py deleted file mode 100644 index c9be25efd3c4..000000000000 --- a/tests/python/contrib/test_hexagon/benchmark_util.py +++ /dev/null @@ -1,277 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: RUF012 - -"""Utility functions used for benchmarks""" - -import csv -import os -import tempfile - -import pytest - -from tvm.contrib.hexagon.tools import HEXAGON_SIMULATOR_NAME - - -def skip_benchmarks_flag_and_reason(): - """ - Returns one of these tuples: - (False, '') or - (True, (a string describing why the test should be skipped)) - - NOTE: This function is a temporary measure to prevent the TVM CI system - running benchmark scripts every time the CI pre-commit hook executes. - This should go away when a better system is in place to govern when various - tests / benchmarks are executed. - """ - asn = os.environ.get("ANDROID_SERIAL_NUMBER") - - if asn == HEXAGON_SIMULATOR_NAME: - return (True, "Skipping benchmarks when ANDROID_SERIAL_NUMBER='simluator'") - - return (False, "") - - -class UnsupportedException(Exception): - """ - Indicates that the specified benchmarking configuration is known to - currently be unsupported. The Exception message may provide more detail. - """ - - -class NumericalAccuracyException(Exception): - """ - Indicates that the benchmarking configuration appeared to run successfully, - but the output data didn't have the expected accuracy. - """ - - -class BenchmarksTable: - """ - Stores/reports the result of benchmark runs. - - Each line item has a status: success, fail, or skip. - - Each 'success' line item must include benchmark data, - in the form provided by TVM's `time_evaluator` mechanism. - - Each line item may also specify values for any subset of - the columns provided to the table's construstor. - """ - - BUILTIN_COLUMN_NAMES = set( - [ - "row_status", - "timings_min_usecs", - "timings_max_usecs", - "timings_median_usecs", - "timings_mean_usecs", - "timings_stddev_usecs", - ] - ) - - def __init__(self): - self._line_items = [] - - def validate_user_supplied_kwargs(self, kwarg_dict): - name_conflicts = set(kwarg_dict).intersection(self.BUILTIN_COLUMN_NAMES) - - if name_conflicts: - name_list = ", ".join(name_conflicts) - raise Exception(f"Attempting to supply values for built-in column names: {name_list}") - - def record_success(self, timings, **kwargs): - """ - `timings` : Assumed to have the structure and meaning of - the timing results provided by TVM's `time_evaluator` - mechanism. - - `kwargs` : Optional values for any of the other columns - defined for this benchmark table. - """ - self.validate_user_supplied_kwargs(kwargs) - line_item = kwargs - - line_item["row_status"] = "SUCCESS" - - line_item["timings_min_usecs"] = timings.min * 1000000 - line_item["timings_max_usecs"] = timings.max * 1000000 - line_item["timings_median_usecs"] = timings.median * 1000000 - line_item["timings_stddev_usecs"] = timings.std * 1000000 - line_item["timings_mean_usecs"] = timings.mean * 1000000 - - self._line_items.append(line_item) - - def record_skip(self, **kwargs): - self.validate_user_supplied_kwargs(kwargs) - - line_item = dict(kwargs) - line_item["row_status"] = "SKIP" - self._line_items.append(line_item) - - def record_fail(self, **kwargs): - self.validate_user_supplied_kwargs(kwargs) - - line_item = dict(kwargs) - line_item["row_status"] = "FAIL" - self._line_items.append(line_item) - - def has_fail(self): - """ - Returns True if the table contains at least one 'fail' line item, - otherwise returns False. - """ - return any(item["row_status"] == "FAIL" for item in self._line_items) - - def print_csv(self, f, column_name_order, timing_decimal_places=3): - """ - Print the benchmark results as a csv. - - `f` : The output stream. - - `column_name_order`: an iterable sequence of column names, indicating the - left-to-right ordering of columns in the CSV output. - - The CSV output will contain only those columns that are mentioned in - this list. - - `timing_decimal_places`: for the numeric timing values, this is the - number of decimal places to provide in the printed output. - For example, a value of 3 is equivalent to the Python formatting string - `'{:.3f}'` - """ - writer = csv.DictWriter( - f, column_name_order, dialect="excel-tab", restval="", extrasaction="ignore" - ) - - writer.writeheader() - - for line_item_dict in self._line_items: - # Use a copy of the line-item dictionary, because we might do some modifications - # for the sake of rendering... - csv_line_dict = dict(line_item_dict) - - for col_name in [ - "timings_min_usecs", - "timings_max_usecs", - "timings_median_usecs", - "timings_stddev_usecs", - "timings_mean_usecs", - ]: - if col_name in csv_line_dict: - old_value = csv_line_dict[col_name] - assert isinstance(old_value, float), ( - f"Formatting code assumes that column {col_name} is" - f" some col_nameind of float, but its actual type is {type(old_value)}" - ) - str_value = f"{old_value:>0.{timing_decimal_places}f}" - csv_line_dict[col_name] = str_value - - writer.writerow(csv_line_dict) - - -def get_benchmark_id(keys_dict): - """ - Given a dictionary with the distinguishing characteristics of a particular benchmark - line item, compute a string that uniquely identifies the benchmark. - - The returned string: - - is a valid directory name on the host's file systems, and - - should be easy for humans to parse - - Note that the insertion order for `keys_dict` affects the computed name. - """ - # Creat a copy, because we might be modifying it. - keys_dict_copy = dict(keys_dict) - - # Sniff for shape-like lists, because we want them in a form that's both - # readable and filesystem-friendly... - for k, v in keys_dict_copy.items(): - if isinstance(v, list | tuple): - v_str = "_".join([str(x) for x in v]) - keys_dict_copy[k] = v_str - - return "-".join([f"{k}:{v}" for k, v in keys_dict_copy.items()]) - - -def get_benchmark_decription(keys_dict): - """ - Similar to `get_benchmark_id`, but the focus is on human-readability. - - The returned string contains no line-breaks, but may contain spaces and - other characters that make it unsuitable for use as a filename. - """ - return " ".join([f"{k}={v}" for k, v in keys_dict.items()]) - - -@pytest.fixture(scope="class") -def benchmark_group(request): - """This fixture provides some initialization / finalization logic for groups of related - benchmark runs. - See the fixture implementation below for details. - - The fixture's mechanics are described here: https://stackoverflow.com/a/63047695 - - TODO: There may be cleaner ways to let each class that uses this fixture provide its - own value for `csv_column_order`. - - TODO: In the future we may wish to break this fixture up in to several smaller ones. - - The overall contract for a class (e.g. `MyTest`) using this fixture is as follows: - - https://stackoverflow.com/a/63047695 - - @pytest.mark.usefixtures("benchmark_group") - class MyTest: - - # The fixture requires that this class variable is defined before - # the fixture's finalizer-logic executes. - # - # This is used as an argument to BenchmarkTable.print_csv(...) after - # all of MyTest's unit tests have executed. - csv_column_order = [ - ... - ] - - # Before the MyTest's first unit test executes, the fixture will populate the - # following class variables: - MyTest.working_dir : str - MyTest.benchmark_table : BenchmarkTable""" - working_dir = tempfile.mkdtemp() - table = BenchmarksTable() - - request.cls.working_dir = working_dir - request.cls.benchmark_table = table - - yield - - tabular_output_filename = os.path.join(working_dir, "benchmark-results.csv") - - if not hasattr(request.cls, "csv_column_order"): - raise Exception('Classes using this fixture must have a member named "csv_column_order"') - - with open(tabular_output_filename, "w", encoding="UTF-8") as csv_file: - table.print_csv(csv_file, request.cls.csv_column_order) - - print() - print("*" * 80) - print(f"BENCHMARK RESULTS FILE: {tabular_output_filename}") - print("*" * 80) - print() - - if table.has_fail() > 0: - pytest.fail("At least one benchmark configuration failed", pytrace=False) diff --git a/tests/python/contrib/test_hexagon/conftest.py b/tests/python/contrib/test_hexagon/conftest.py deleted file mode 100644 index f187cc1c3943..000000000000 --- a/tests/python/contrib/test_hexagon/conftest.py +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Hexagon testing fixtures used to deduce testing argument -values from testing parameters""" - -# pytest 9 no longer supports declaring ``pytest_plugins`` in a -# non-rootdir conftest, so re-export the plugin's hooks and fixtures -# here instead. The explicit import covers the underscore-prefixed -# fixture that the wildcard import skips. -# ruff: noqa: F401, F403 -from tvm.contrib.hexagon.pytest_plugin import * -from tvm.contrib.hexagon.pytest_plugin import _tracker_info diff --git a/tests/python/contrib/test_hexagon/conv2d/README.md b/tests/python/contrib/test_hexagon/conv2d/README.md deleted file mode 100644 index d29d8b9c8604..000000000000 --- a/tests/python/contrib/test_hexagon/conv2d/README.md +++ /dev/null @@ -1,37 +0,0 @@ - - - - - - - - - - - - - - - - - -Documents manual TE schedule to illustrate Hexagon operator slicing. - -High Level Notes: -* Using float32 (for now) so that tests will pass on CPU -* Using global storage scope (for now) which means "cache" reads and writes from global, to global -* TIR is pending changes from the work-in-progress layout RFC - (https://github.com/apache/tvm-rfcs/pull/39) -* TIR has been hand-edited for context and clarity - * Added C-style comments - * Changed variable names - * Added spacing and line breaks -* Naming conventions - * Using input (instead of activation) - * Using filter (instead of weight, kernel) - * Using `k` to denote channel-out and `c` or `rc` (reduction channel) to denote channel-in - * Using `rh` and `rw` (reduction height / width) to denote filter height and width - -[Conv2d](test_conv2d_blocked.md) - -[Conv2d -> Conv2d](test_conv2d_conv2d.md) diff --git a/tests/python/contrib/test_hexagon/conv2d/__init__.py b/tests/python/contrib/test_hexagon/conv2d/__init__.py deleted file mode 100644 index b17055a624af..000000000000 --- a/tests/python/contrib/test_hexagon/conv2d/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# isort: skip_file -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Testing infrastructure for Hexagon/TOPI/Conv2d""" diff --git a/tests/python/contrib/test_hexagon/conv2d/test_conv2d_blocked.md b/tests/python/contrib/test_hexagon/conv2d/test_conv2d_blocked.md deleted file mode 100644 index 417ce0b12310..000000000000 --- a/tests/python/contrib/test_hexagon/conv2d/test_conv2d_blocked.md +++ /dev/null @@ -1,494 +0,0 @@ - - - - - - - - - - - - - - - - - -Hexagon conv2d schedules - -# Baseline conv2d - -This is a baseline 1x1 conv2d schedule for Hexagon. - -## Command - -pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[1-64-64-0-1-1-128-1-1-float32-llvm]" - -## Parameters - -| Parameter | Value | -| --------- | ----- | -| Batch | 1 | -| Spatial | 64x64 | -| Input Ch | 64 | -| Padding | 0 | -| Stride | 1 | -| Filter | 1x1 | -| Output Ch | 128 | - -## Assumptions - -* n/a - -## To Do - -* n/a - -## Annotated TIR - -``` -primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () - attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} - buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c - filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i - input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) - buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { - allocate(input.cache: Pointer(global float32), float32, [32768]), storage_scope = global; - allocate(filter.cache: Pointer(global float32), float32, [2048]), storage_scope = global; - allocate(output.cache: Pointer(global float32), float32, [16384]), storage_scope = global; - - for (ko.outer: int32, 0, 4) { - for (ho.outer: int32, 0, 8) { - - // input cache read - // NHWC -> NHWC8h8w32c (pending RFC) - for (wo: int32, 0, 8) { - for (co: int32, 0, 2) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ci: int32, 0, 32) { - input.cache[(((((wo*4096) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = - (float32*)input_pointer[((((((ho.outer*32768) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] - } - } - } - } - } - - // filter cache read - for (co: int32, 0, 2) { - for (ci8: int32, 0, 8) { - for (ki: int32, 0, 32) { - for (ci4: int32, 0, 4) { - filter.cache[((((co*1024) + (ci8*128)) + (ki*4)) + ci4)] = - (float32*)filter_pointer[(((((ko.outer*2048) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] - } - } - } - } - - // compute - for (wo.c: int32, 0, 8) { - - // init output cache - for (hi.c.init: int32, 0, 8) { - for (wi.c.init: int32, 0, 8) { - for (ki.c.init: int32, 0, 32) { - output.cache[((((wo.c*2048) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 - } - } - } - - // convolution - for (rc.outer: int32, 0, 2) { - for (hi.c: int32, 0, 8) { - for (wi.c: int32, 0, 8) { - for (ki.c: int32, 0, 32) { - for (rc.inner: int32, 0, 32) { - output.cache[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] = - ( - (float32*)output.cache[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] + - ( - (float32*)input.cache[(((((wo.c*4096) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] * - (float32*)filter.cache[((((rc.outer*1024) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] - ) - ) - } - } - } - } - } - } // end wo.c - - // cache write - for (wo: int32, 0, 8) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ki: int32, 0, 32) { - output_pointer[((((((ho.outer*65536) + (wo*8192)) + (ko.outer*2048)) + (hi*256)) + (wi*32)) + ki)] = - (float32*)output.cache[((((wo*2048) + (hi*256)) + (wi*32)) + ki)] - } - } - } - } - } // end ho.outer - } // end ko.outer -} -``` - -# Split on Channel Out and Height - "Full Output Slice" - -Adds new parameters `k_split` and `h_split` which creates a loop split on the outer channel out `ko` and height `ho` loops creating `outer` and `inner` loops for each split. The cache reads and writes are computed at `ho.outer` which means that cache allocation grow in proportion to `k_split` and `h_split` factors. - -The key changes in TIR versus the above are... - -1) Increased cache allocations: - -``` - // input cache grows by factor of h_split = 2 - allocate(input.cache: Pointer(global float32), float32, [65536]), storage_scope = global; - - // filter cache grows by factor of k_split = 2 - allocate(filter.cache: Pointer(global float32), float32, [4096]), storage_scope = global; - - // output cache grows by factor of h_split * k_split = 4 - allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; -``` - -2) Outer loop splits using k_split and h_split factors - -``` - // ko.outer = outer loop split on ko using k_split factor - for (ko.outer: int32, 0, 2) { - // ho.outer = outer loop split on ho using h_split factor - for (ho.outer: int32, 0, 4) { -``` - -3) Inner loop splits in both cache read / write and compute schedules. This is taken from the compute schedule e.g. -``` - for (ko.c.inner: int32, 0, 2) { - for (ho.c.inner: int32, 0, 2) { -``` - -## Command - -pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[1-64-64-0-1-1-128-2-2-float32-llvm]" - -## Parameters - -| Parameter | Value | -| --------- | ----- | -| Batch | 1 | -| Spatial | 64x64 | -| Input Ch | 64 | -| Padding | 0 | -| Stride | 1 | -| Filter | 1x1 | -| Output Ch | 128 | -| k_split | 2 | -| h_split | 2 | - -## Assumptions - -* n/a - -## To Do - -* n/a - -## Annotated TIR - -``` -primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () - attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} - buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c - filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i - input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) - buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { - - // input cache grows by factor of h_split = 2 - allocate(input.cache: Pointer(global float32), float32, [65536]), storage_scope = global; - - // filter cache grows by factor of k_split = 2 - allocate(filter.cache: Pointer(global float32), float32, [4096]), storage_scope = global; - - // output cache grows by factor of h_split * k_split = 4 - allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; - - // ko.outer = outer loop split on ko using k_split factor - for (ko.outer: int32, 0, 2) { - // ho.outer = outer loop split on ho using h_split factor - for (ho.outer: int32, 0, 4) { - - // input cache read - // NHWC -> NHWC8h8w32c (pending RFC) - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { - for (co: int32, 0, 2) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ci: int32, 0, 32) { - input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = - (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] - } - } - } - } - } - } // end ho.inner - - // filter cache read - for (ko.inner: int32, 0, 2) { - for (co: int32, 0, 2) { - for (ci8: int32, 0, 8) { - for (ki: int32, 0, 32) { - for (ci4: int32, 0, 4) { - filter.cache[(((((ko.inner*2048) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] = - (float32*)filter_pointer[((((((ko.outer*4096) + (ko.inner*2048)) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] - } - } - } - } - } // end ko.inner - - // compute - for (ko.c.inner: int32, 0, 2) { - for (ho.c.inner: int32, 0, 2) { - for (wo.c: int32, 0, 8) { - - // init output cache - for (hi.c.init: int32, 0, 8) { - for (wi.c.init: int32, 0, 8) { - for (ki.c.init: int32, 0, 32) { - output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 - } - } - } - - // convolution - for (rc.outer: int32, 0, 2) { - for (hi.c: int32, 0, 8) { - for (wi.c: int32, 0, 8) { - for (ki.c: int32, 0, 32) { - for (rc.inner: int32, 0, 32) { - output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = - ( - (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + - ( - (float32*)input.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] * - (float32*)filter.cache[(((((ko.c.inner*2048) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] - ) - ) - } - } - } - } - } - } // end wo.c - } // end ho.c.inner - } // end ko.c.inner - - // cache write - for (ko.inner: int32, 0, 2) { - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ki: int32, 0, 32) { - output_pointer[((((((((ho.outer*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = - (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] - } - } - } - } - } // end ho.inner - } // end ko.inner - } // end ho.outer - } // end ko.outer -} -``` - -# 3x3 conv2d (no padding) - -Change from a 1x1 filter to a 3x3 filter. The implication of this change is that `h_split + 1` rather than just `h_split` "full width" slices of the input are required to compute the output. This is due to the fact that the 3x3 filter will "fall off the bottom" of the input and thus the vertically adjacent "full width" slice must be prefetched into the input cache. - -The key changes in TIR versus the above are... - -1) Increased input cache size to hold the vertically adjacent slice - -``` - // input cache grows to hold vertically adjacent slice - allocate(input.cache: Pointer(global float32), float32, [98304]), storage_scope = global; -``` - -2) Loop over `ho.inner` upper bound increased from `h_split` = 2 to `h_split + 1` = 3 - -``` - for (ho.outer: int32, 0, 4) { - for (ho.inner: int32, 0, 3) { - if (((ho.outer*2) + ho.inner) < 8) { -``` - -The `if` statement above indicates NOT to prefetch the vertically adjacent slice at the "bottom" of the input since it does not exist. - - -3) Increased filter cache size to hold 3x3 filter - -``` - // filter cache grows to hold larger 3x3 filter - allocate(filter.cache: Pointer(global float32), float32, [36864]), storage_scope = global; -``` - -4) Loops over `rh` and `rw` the kernel spatial dimensions: -``` - for (rh: int32, 0, 3) { - for (rw: int32, 0, 3) { -``` - -## Command - -pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[1-64-64-0-1-3-128-2-2-float32-llvm]" - -## Parameters - -| Parameter | Value | -| --------- | ----- | -| Batch | 1 | -| Spatial | 64x64 | -| Input Ch | 64 | -| Padding | 0 | -| Stride | 1 | -| Filter | 1x1 | -| Output Ch | 128 | -| k_split | 2 | -| h_split | 2 | - -## Assumptions - -* n/a - -## To Do - -There may be some opportunity to optimize cache reuse in this case. Consider the loops over `ho.outer` and `ho.inner` and the index calculation `ho.outer * 64k + ho.inner * 32k` into the input pointer: - -| ho.outer | ho.inner | ho.outer * 64k + ho.inner * 32k | -| -------- | -------- | ------------------------------------- | -| 0 | 0 | 0 | -| 0 | 1 | 32k | -| 0 | 2 | 64k (vertical adjacent slice loop 0) | -| 1 | 0 | 64k | -| 1 | 1 | 96k | -| 1 | 2 | 128k (vertical adjacent slice loop 1) | -| 2 | 0 | 128k | -| 2 | 1 | 160k | -| 2 | 2 | 192k (vertical adjacent slice loop 2) | -| 3 | 0 | 192k | -| 3 | 1 | 224k | -| 3 | 2 | (No vertical adjacent slice loop 3) | - -Noe that the vertically adjacent slice in loop N (i.e. the loop where `ho.outer` = N) is reused in loop N + 1. - -## Annotated TIR - -``` -primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () - attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} - buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c - filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 3, 3, 8, 32, 4], []), // OIHW8i32o4i - input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) - buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { - // input cache grows to hold vertically adjacent slice - allocate(input.cache: Pointer(global float32), float32, [98304]), storage_scope = global; - // filter cache grows to hold larger 3x3 filter - allocate(filter.cache: Pointer(global float32), float32, [36864]), storage_scope = global; - allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; - for (ko.outer: int32, 0, 2) { - for (ho.outer: int32, 0, 4) { - // input cache read - // NHWC -> NHWC8h8w32c (pending RFC) - for (ho.inner: int32, 0, 3) { - if (((ho.outer*2) + ho.inner) < 8) { - for (wo: int32, 0, 8) { - for (co: int32, 0, 2) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ci: int32, 0, 32) { - input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = - (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] - } - } - } - } - } - } - } - // filter cache read - for (ko.inner: int32, 0, 2) { - for (co: int32, 0, 2) { - for (rh: int32, 0, 3) { - for (rw: int32, 0, 3) { - for (ci8: int32, 0, 8) { - for (ki: int32, 0, 32) { - for (ci4: int32, 0, 4) { - filter.cache[(((((((ko.inner*18432) + (co*9216)) + (rh*3072)) + (rw*1024)) + (ci8*128)) + (ki*4)) + ci4)] = - (float32*)filter_pointer[((((((((ko.outer*36864) + (ko.inner*18432)) + (co*9216)) + (rh*3072)) + (rw*1024)) + (ci8*128)) + (ki*4)) + ci4)] - } - } - } - } // end rw - } // end rh - } - } - for (ko.c.inner: int32, 0, 2) { - for (ho.c.inner: int32, 0, 2) { - for (wo.c: int32, 0, 8) { - for (hi.c.init: int32, 0, 8) { - for (wi.c.init: int32, 0, 8) { - for (ki.c.init: int32, 0, 32) { - output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 - } - } - } - for (rc.outer: int32, 0, 2) { - for (hi.c: int32, 0, 8) { - for (wi.c: int32, 0, 8) { - for (rh: int32, 0, 3) { - for (rw: int32, 0, 3) { - for (ki.c: int32, 0, 32) { - for (rc.inner: int32, 0, 32) { - output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = - ( - (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + - ( - (float32*)input.cache[((((((((floordiv((hi.c + rh), 8)*32768) + (ho.c.inner*32768)) + (floordiv((wi.c + rw), 8)*4096)) + (wo.c*4096)) + (rc.outer*2048)) + (floormod((hi.c + rh), 8)*256)) + (floormod((wi.c + rw), 8)*32)) + rc.inner)] * - (float32*)filter.cache[(((((((ko.c.inner*18432) + (rc.outer*9216)) + (rh*3072)) + (rw*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] - ) - ) - } - } - } // end rw - } // end rh - } - } - } - } // end wo.c - } // end ho.c.inner - } // end ko.c.inner - for (ko.inner: int32, 0, 2) { - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ki: int32, 0, 32) { - output_pointer[((((((((ho.outer*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = - (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] - } - } - } - } - } // end ho.inner - } // end ko.inner - } // end ho.outer - } // end ko.outer -}``` diff --git a/tests/python/contrib/test_hexagon/conv2d/test_conv2d_conv2d.md b/tests/python/contrib/test_hexagon/conv2d/test_conv2d_conv2d.md deleted file mode 100644 index 3671d90c2408..000000000000 --- a/tests/python/contrib/test_hexagon/conv2d/test_conv2d_conv2d.md +++ /dev/null @@ -1,986 +0,0 @@ - - - - - - - - - - - - - - - - - -Hexagon conv2d -> conv2d schedules - -# Baseline conv2d -> conv2d - -This is a baseline 1x1 conv2d -> 1x1 conv2d schedule for Hexagon. - -## Command - -pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_conv2d.py::TestConv2dConv2dPackedFilter::test_conv2d[1-64-128-0-1-1-128-1-1-128-1-1-float32-llvm]" - -## Parameters - -| Parameter | Value | -| ------------------------ | ----- | -| Batch | 1 | -| Input Size | 64x64 | -| Input Channel | 128 | -| Conv2d #1 Pad | 0 | -| Conv2d #1 Stride | 1 | -| Conv2d #1 Kernel Size | 1 | -| Conv2d #1 Output Channel | 128 | -| Conv2d #2 Stride | 1 | -| Conv2d #2 Kernel Size | 1 | -| Conv2d #2 Output Channel | 128 | -| k_split | 1 | -| h_split | 1 | - -## Constants - -| Constant | Value | -| ------------------ | ----- | -| Conv2d #2 Pad | 0 | -| Conv2d #1 Dilation | 1 | -| Conv2d #2 Dilation | 1 | - -## Shapes and Layouts - -The input is provided and padded in logical layout and then packed into its physical layout prior to compute. Logical layout / shape information is provided as a reference for physical tensors. - -| Tensor | Type | Layout | Shape | Logical Layout | Logical Shape | -| ------------ | -------- | ----------- | ---------------------- | -------------- | ---------------- | -| Input | Logical | NHWC | [1, 64, 64, 128] | | | -| Padded Input | Logical | NHWC | [1, 64, 64, 128] | | | -| Packed Input | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] | -| Filter 1 | Physical | OIHW8i32o4i | [4, 4, 1, 1, 8, 32, 4] | OIHW | [128, 128, 1, 1] | -| Temp Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] | -| Filter 2 | Physical | OIHW8i32o4i | [4, 4, 1, 1, 8, 32, 4] | OIHW | [128, 128, 1, 1] | -| Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] | - -## Schedule - -This is the conv2d compute schedule: - -``` - for (ko.outer: int32, 0, 4) { - for (ho.outer: int32, 0, 8) { - - // input cache read - - for (ko.outer_1: int32, 0, 4) { - - // filter #1 cache read - - // conv2d #1 - for (wo: int32, 0, 8) { - for (rc.outer: int32, 0, 4) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ki: int32, 0, 32) { - for (rc.inner: int32, 0, 32) { - } // end ko.outer_1 - - // filter #2 cache read - - // conv2d #2 - for (wo.c: int32, 0, 8) { - for (rc.outer_1: int32, 0, 4) { - for (hi.c: int32, 0, 8) { - for (wi.c: int32, 0, 8) { - for (ki.c: int32, 0, 32) { - for (rc.inner_1: int32, 0, 32) { - - // write back output cache - - } // end ho.outer - } // end ko.outer -``` - -Note that conv2d #1 has an independent loop over the channel out `ko.outer_1` dimension. This is because the output channels of conv2d #1 are the input channels to conv2d #2 and we compute over all input channels for each conv2d so we must compute over all output channels of conv2d #1 before we compute conv2d #2. - -``` - for (ko.outer_1: int32, 0, 2) { -``` - -## Cache Usage - -*Input Cache* - -We compute over the WC8h8w32c portion of the input so we need 8 * 4 * 8 * 8 * 32 = 64kb for the input cache. - -``` - allocate(packed_input.global: Pointer(global float32), float32, [65536]), storage_scope = global; -``` - -*Filter Cache* - -We compute over the IHW8i32o4i portion of each filter so we need 4 * 1 * 1 * 8 * 32 * 4 = 4kb filter cache. - -``` - allocate(packed_filter.global: Pointer(global float32), float32, [4096]), storage_scope = global; -``` - -Note that there is just one cache which is reused for conv2d / filter #1 and conv2d / filter #2. - -*Output Cache* - -We compute over the WK8h832k portion of the output where `k` denotes the output channel. The output cache is computed for each `ko.outer` which means it should be W * 8h * 8w * 32k = 8 * 8 * 8 * 32 = 16kb. And, in fact, this is the case for a single conv2d case. But, as already noted, for this conv2d -> conv2d case "the output channels of conv2d #1 are the input channels to conv2d #2 and we compute over all input channels for each conv2d so we must compute over all output channels of conv2d #1 before we compute conv2d #2". This means that the output cache must grow accordingly to K * W * 8h * 8w * 32k = 4 * 8 * 8 * 8 * 32 = 64kb. There is a temporary allocation to store the results of conv2d #1: - -``` - allocate(temp_output: Pointer(global float32), float32, [65536]), storage_scope = global; -``` - -Note that the input cache is reused to store the results of conv2d #2. - -## Assumptions - -* n/a - -## To Do - -* Reuse of the input cache to store the results of conv2d #2 could be problematic for async copy. e.g. - -``` -slice 0: global -> load -> cache0 -> conv2d_0 -> cache1 -> conv2d_1 -> cache0 -> store -> global -slice 1: global -> load -> cache0 -> conv2d_0 -> cache1 -> conv2d_1 -> cache0 -> store -> global -``` - -In this case the store from slice 0: cache0 -> store -> global -can potentially block the load in slice 1: global -> load -> cache0 - -StorageRewrite is responsible for planning these caches, we'll need to understand how to avoid this for the async case. - -## Annotated TIR - -``` -primfn(placeholder_3: handle, placeholder_4: handle, placeholder_5: handle, output_1: handle) -> () - attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} - buffers = {output: Buffer(output_2: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // nhw8h8w32c - placeholder_2: Buffer(placeholder_6: Pointer(float32), float32, [4, 4, 1, 1, 8, 32, 4], []), // oihw8i32o4i - placeholder_1: Buffer(placeholder_7: Pointer(float32), float32, [4, 4, 1, 1, 8, 32, 4], []), // oihw8i32o4i - placeholder: Buffer(placeholder_8: Pointer(float32), float32, [1, 64, 64, 128], [])} // nhwc - buffer_map = {placeholder_3: placeholder, placeholder_4: placeholder_1, placeholder_5: placeholder_2, output_1: output} { - allocate(packed_input.global: Pointer(global float32), float32, [65536]), storage_scope = global; - allocate(temp_output: Pointer(global float32), float32, [65536]), storage_scope = global; - allocate(packed_filter.global: Pointer(global float32), float32, [4096]), storage_scope = global; - for (ko.outer: int32, 0, 4) { - for (ho.outer: int32, 0, 8) { - - // input cache read - for (wo: int32, 0, 8) { - for (co: int32, 0, 4) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ci: int32, 0, 32) { - packed_input.global[(((((wo*8192) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = - (float32*)placeholder_8[((((((ho.outer*65536) + (hi*8192)) + (wo*1024)) + (wi*128)) + (co*32)) + ci)] - } - } - } - } - } - - // NOTE: compute over all output channels of conv2d #1 before computing conv2d #2 - for (ko.outer_1: int32, 0, 4) { - - // filter #1 cache read - for (co: int32, 0, 4) { - for (cio: int32, 0, 8) { - for (ki: int32, 0, 32) { - for (cii: int32, 0, 4) { - packed_filter.global[((((co*1024) + (cio*128)) + (ki*4)) + cii)] = - (float32*)placeholder_7[(((((ko.outer_1*4096) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] - } - } - } - } - - // conv2d #1 - for (wo: int32, 0, 8) { - - // init temp output to zero - for (hi.init: int32, 0, 8) { - for (wi.init: int32, 0, 8) { - for (ki.init: int32, 0, 32) { - temp_output[(((((wo*8192) + (ko.outer_1*2048)) + (hi.init*256)) + (wi.init*32)) + ki.init)] = 0f32 - } - } - } - - // compute - for (rc.outer: int32, 0, 4) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ki: int32, 0, 32) { - for (rc.inner: int32, 0, 32) { - temp_output[(((((wo*8192) + (ko.outer_1*2048)) + (hi*256)) + (wi*32)) + ki)] = - ( - (float32*)temp_output[(((((wo*8192) + (ko.outer_1*2048)) + (hi*256)) + (wi*32)) + ki)] + - ( - (float32*)packed_input.global[(((((wo*8192) + (rc.outer*2048)) + (hi*256)) + (wi*32)) + rc.inner)] * - (float32*)packed_filter.global[((((rc.outer*1024) + (floordiv(rc.inner, 4)*128)) + (ki*4)) + floormod(rc.inner, 4))] - ) - ) - } - } - } - } - } - } - } - - // filter #2 cache read - // NOTE: reusing same filter cache - for (co: int32, 0, 4) { - for (cio: int32, 0, 8) { - for (ki: int32, 0, 32) { - for (cii: int32, 0, 4) { - packed_filter.global[((((co*1024) + (cio*128)) + (ki*4)) + cii)] = - (float32*)placeholder_6[(((((ko.outer*4096) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] - } - } - } - } - - // conv2d #2 - for (wo.c: int32, 0, 8) { - - // init output cache to zero - // NOTE: reusing the input cache as the output cache - for (hi.c.init: int32, 0, 8) { - for (wi.c.init: int32, 0, 8) { - for (ki.c.init: int32, 0, 32) { - packed_input.global[((((wo.c*2048) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 - } - } - } - - // compute - for (rc.outer_1: int32, 0, 4) { - for (hi.c: int32, 0, 8) { - for (wi.c: int32, 0, 8) { - for (ki.c: int32, 0, 32) { - for (rc.inner_1: int32, 0, 32) { - packed_input.global[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] = - ( - (float32*)packed_input.global[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] + - ( - (float32*)temp_output[(((((wo.c*8192) + (rc.outer_1*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner_1)] * - (float32*)packed_filter.global[((((rc.outer_1*1024) + (floordiv(rc.inner_1, 4)*128)) + (ki.c*4)) + floormod(rc.inner_1, 4))] - ) - ) - } - } - } - } - } - } - - // write back output cache - for (wo_1: int32, 0, 8) { - for (hi_1: int32, 0, 8) { - for (wi_1: int32, 0, 8) { - for (ki_1: int32, 0, 32) { - output_2[((((((ho.outer*65536) + (wo_1*8192)) + (ko.outer*2048)) + (hi_1*256)) + (wi_1*32)) + ki_1)] = - (float32*)packed_input.global[((((wo_1*2048) + (hi_1*256)) + (wi_1*32)) + ki_1)] - } - } - } - } - } - } -} -``` - -# Split on Channel Out and Height - -Uses parameters `k_split` and `h_split` which creates a loop split on the outer channel out `ko` and height `ho` loops creating `outer` and `inner` loops for each split. The cache reads and writes are computed at `ho.outer` which means that cache allocation grow in proportion to `k_split` and `h_split` factors. - -## Command - -pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_conv2d.py::TestConv2dConv2dPackedFilter::test_conv2d[1-64-128-0-1-1-128-1-1-128-2-2-float32-llvm]" - -## Parameters - -| Parameter | Value | -| ------------------------ | ----- | -| Batch | 1 | -| Input Size | 64x64 | -| Input Channel | 128 | -| Conv2d #1 Pad | 0 | -| Conv2d #1 Stride | 1 | -| Conv2d #1 Kernel Size | 1 | -| Conv2d #1 Output Channel | 128 | -| Conv2d #2 Stride | 1 | -| Conv2d #2 Kernel Size | 1 | -| Conv2d #2 Output Channel | 128 | -| k_split | 2 ^ | -| h_split | 2 ^ | - -^ Changes from above - -## Constants - -| Constant | Value | -| ------------------ | ----- | -| Conv2d #2 Pad | 0 | -| Conv2d #1 Dilation | 1 | -| Conv2d #2 Dilation | 1 | - -## Shapes and Layouts - -The input is provided and padded in logical layout and then packed into its physical layout prior to compute. Logical layout / shape information is provided as a reference for physical tensors. - -| Tensor | Type | Layout | Shape | Logical Layout | Logical Shape | -| ------------ | -------- | ----------- | ---------------------- | -------------- | ---------------- | -| Input | Logical | NHWC | [1, 64, 64, 128] | | | -| Padded Input | Logical | NHWC | [1, 64, 64, 128] | | | -| Packed Input | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] | -| Filter 1 | Physical | OIHW8i32o4i | [4, 4, 1, 1, 8, 32, 4] | OIHW | [128, 128, 1, 1] | -| Temp Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] | -| Filter 2 | Physical | OIHW8i32o4i | [4, 4, 1, 1, 8, 32, 4] | OIHW | [128, 128, 1, 1] | -| Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] | - -## Schedule - -This is the conv2d compute schedule: - -``` - for (ko.outer: int32, 0, 2) { - for (ho.outer: int32, 0, 4) { - - // input cache read - for (ho.inner: int32, 0, 2) { - ... - } - - for (ko.outer_1: int32, 0, 2) { - - // filter #1 cache read - for (ko.inner: int32, 0, 2) { - ... - } - - // conv2d #1 - for (ko.inner: int32, 0, 2) { - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { - for (rc.outer: int32, 0, 4) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ki: int32, 0, 32) { - for (rc.inner: int32, 0, 32) { - } // end ko.outer_1 - - // filter #2 cache read - for (ko.inner: int32, 0, 2) { - ... - } - - // conv2d #2 - for (ko.c.inner: int32, 0, 2) { - for (ho.c.inner: int32, 0, 2) { - for (wo.c: int32, 0, 8) { - for (rc.outer_1: int32, 0, 4) { - for (hi.c: int32, 0, 8) { - for (wi.c: int32, 0, 8) { - for (ki.c: int32, 0, 32) { - for (rc.inner_1: int32, 0, 32) { - - // write back output cache - - } // end ho.outer - } // end ko.outer -``` - -The major change here versus above is the presence of `inner` loops for both channel out `ko` and height `ho` dimensions created from the `k_split` and `h_split` schedule parameters respectively, for example: - - -``` - for (ko.c.inner: int32, 0, 2) { - for (ho.c.inner: int32, 0, 2) { -``` - -The effect of this change is increased cache usage given where the caches are computed in the schedule. Specifically, the input cache is now computed over `ho.inner` and the filter caches are computed over `ko.inner` which will grow the size of the cache. Details below. - -(Same as above) Note that conv2d #1 has an independent loop over the channel out `ko.outer_1` dimension. This is because the output channels of conv2d #1 are the input channels to conv2d #2 and we compute over all input channels for each conv2d so we must compute over all output channels of conv2d #1 before we compute conv2d #2. - -``` - for (ko.outer_1: int32, 0, 2) { -``` - -## Cache Usage - -*Input Cache* - -The input cache grows by a factor of `h_split = 2` compared with above: - -``` - allocate(packed_input.global: Pointer(global float32), float32, [131072]), storage_scope = global; -``` - -*Filter Cache* - -The filter cache grows by a factor of `k_split = 2` compared with above: - -``` - allocate(packed_filter.global: Pointer(global float32), float32, [8192]), storage_scope = global; -``` - -(Same as above) Note that there is just one cache which is reused for conv2d / filter #1 and conv2d / filter #2. - -*Output Cache* - -The output cache grows by a factor of `k_split = 2` compared with above: - -``` - allocate(temp_output: Pointer(global float32), float32, [131072]), storage_scope = global; -``` - -(Same as above) Note that the input cache is reused to store the results of conv2d #2. - -## Assumptions - -* n/a - -## To Do - -* n/a - -## Annotated TIR - -``` -primfn(placeholder_3: handle, placeholder_4: handle, placeholder_5: handle, output_1: handle) -> () - attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} - buffers = {output: Buffer(output_2: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // nhw8h8w32c - placeholder_2: Buffer(placeholder_6: Pointer(float32), float32, [4, 4, 1, 1, 8, 32, 4], []), // oihw8i32o4i - placeholder_1: Buffer(placeholder_7: Pointer(float32), float32, [4, 4, 1, 1, 8, 32, 4], []), // oihw8i32o4i - placeholder: Buffer(placeholder_8: Pointer(float32), float32, [1, 64, 64, 128], [])} // nhwc - buffer_map = {placeholder_3: placeholder, placeholder_4: placeholder_1, placeholder_5: placeholder_2, output_1: output} { - allocate(packed_input.global: Pointer(global float32), float32, [131072]), storage_scope = global; - allocate(temp_output: Pointer(global float32), float32, [131072]), storage_scope = global; - allocate(packed_filter.global: Pointer(global float32), float32, [8192]), storage_scope = global; - for (ko.outer: int32, 0, 2) { - for (ho.outer: int32, 0, 4) { - - // input cache read - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { - for (co: int32, 0, 4) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ci: int32, 0, 32) { - packed_input.global[((((((ho.inner*65536) + (wo*8192)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = - (float32*)placeholder_8[(((((((ho.outer*131072) + (ho.inner*65536)) + (hi*8192)) + (wo*1024)) + (wi*128)) + (co*32)) + ci)] - } - } - } - } - } - } - - // NOTE: compute over all output channels of conv2d #1 before computing conv2d #2 - for (ko.outer_1: int32, 0, 2) { - - // filter #1 cache read - for (ko.inner: int32, 0, 2) { - for (co: int32, 0, 4) { - for (cio: int32, 0, 8) { - for (ki: int32, 0, 32) { - for (cii: int32, 0, 4) { - packed_filter.global[(((((ko.inner*4096) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] = - (float32*)placeholder_7[((((((ko.outer_1*8192) + (ko.inner*4096)) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] - } - } - } - } - } - - // conv2d #1 - for (ko.inner: int32, 0, 2) { - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { - - // init temp output to zero - for (hi.init: int32, 0, 8) { - for (wi.init: int32, 0, 8) { - for (ki.init: int32, 0, 32) { - temp_output[(((((((ho.inner*65536) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi.init*256)) + (wi.init*32)) + ki.init)] = 0f32 - } - } - } - - // compute - for (rc.outer: int32, 0, 4) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ki: int32, 0, 32) { - for (rc.inner: int32, 0, 32) { - temp_output[(((((((ho.inner*65536) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = - ( - (float32*)temp_output[(((((((ho.inner*65536) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] + - ( - (float32*)packed_input.global[((((((ho.inner*65536) + (wo*8192)) + (rc.outer*2048)) + (hi*256)) + (wi*32)) + rc.inner)] * - (float32*)packed_filter.global[(((((ko.inner*4096) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki*4)) + floormod(rc.inner, 4))] - ) - ) - } - } - } - } - } - } - } - } - } - - // filter #2 cache read - // NOTE: reusing same filter cache - for (ko.inner: int32, 0, 2) { - for (co: int32, 0, 4) { - for (cio: int32, 0, 8) { - for (ki: int32, 0, 32) { - for (cii: int32, 0, 4) { - packed_filter.global[(((((ko.inner*4096) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] = - (float32*)placeholder_6[((((((ko.outer*8192) + (ko.inner*4096)) + (co*1024)) + (cio*128)) + (ki*4)) + cii)] - } - } - } - } - } - - // conv2d #2 - for (ko.c.inner: int32, 0, 2) { - for (ho.c.inner: int32, 0, 2) { - for (wo.c: int32, 0, 8) { - - // init output cache to zero - // NOTE: reusing the input cache as the output cache - for (hi.c.init: int32, 0, 8) { - for (wi.c.init: int32, 0, 8) { - for (ki.c.init: int32, 0, 32) { - packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 - } - } - } - - // compute - for (rc.outer_1: int32, 0, 4) { - for (hi.c: int32, 0, 8) { - for (wi.c: int32, 0, 8) { - for (ki.c: int32, 0, 32) { - for (rc.inner_1: int32, 0, 32) { - packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = - ( - (float32*)packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + - ( - (float32*)temp_output[((((((ho.c.inner*65536) + (wo.c*8192)) + (rc.outer_1*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner_1)] * - (float32*)packed_filter.global[(((((ko.c.inner*4096) + (rc.outer_1*1024)) + (floordiv(rc.inner_1, 4)*128)) + (ki.c*4)) + floormod(rc.inner_1, 4))] - ) - ) - } - } - } - } - } - } - } - } - - // write back output cache - for (ko.inner_1: int32, 0, 2) { - for (ho.inner_1: int32, 0, 2) { - for (wo_1: int32, 0, 8) { - for (hi_1: int32, 0, 8) { - for (wi_1: int32, 0, 8) { - for (ki_1: int32, 0, 32) { - output_2[((((((((ho.outer*131072) + (ho.inner_1*65536)) + (wo_1*8192)) + (ko.outer*4096)) + (ko.inner_1*2048)) + (hi_1*256)) + (wi_1*32)) + ki_1)] = - (float32*)packed_input.global[((((((ho.inner_1*32768) + (wo_1*4096)) + (ko.inner_1*2048)) + (hi_1*256)) + (wi_1*32)) + ki_1)] - } - } - } - } - } - } - } - } -} -``` - -# 3x3 conv2d -> conv2d (no padding) - -Change from a 1x1 filter to a 3x3 filter. - -## Command - -pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_conv2d.py::TestConv2dConv2dPackedFilter::test_conv2d[1-64-128-0-1-3-128-1-3-128-2-2-float32-llvm]" - -## Parameters - -| Parameter | Value | -| ------------------------ | ----- | -| Batch | 1 | -| Input Size | 64x64 | -| Input Channel | 128 | -| Conv2d #1 Pad | 0 | -| Conv2d #1 Stride | 1 | -| Conv2d #1 Kernel Size | 3 ^ | -| Conv2d #1 Output Channel | 128 | -| Conv2d #2 Stride | 1 | -| Conv2d #2 Kernel Size | 3 ^ | -| Conv2d #2 Output Channel | 128 | -| k_split | 2 | -| h_split | 2 | - -^ Changes from above - -## Constants - -| Constant | Value | -| ------------------ | ----- | -| Conv2d #2 Pad | 0 | -| Conv2d #1 Dilation | 1 | -| Conv2d #2 Dilation | 1 | - -## Shapes and Layouts - -The input is provided and padded in logical layout and then packed into its physical layout prior to compute. Logical layout / shape information is provided as a reference for physical tensors. - -| Tensor | Type | Layout | Shape | Logical Layout | Logical Shape | -| ------------ | -------- | ----------- | ---------------------- | -------------- | ---------------- | -| Input | Logical | NHWC | [1, 64, 64, 128] | | | -| Padded Input | Logical | NHWC | [1, 64, 64, 128] | | | -| Packed Input | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 64, 64, 128] | -| Filter 1 | Physical | OIHW8i32o4i | [4, 4, 3, 3, 8, 32, 4] | OIHW | [128, 128, 3, 3] | -| Temp Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 62, 62, 128] | -| Filter 2 | Physical | OIHW8i32o4i | [4, 4, 3, 3, 8, 32, 4] | OIHW | [128, 128, 3, 3] | -| Output | Physical | NHWC8h8w32c | [1, 8, 8, 4, 8, 8, 32] | NHWC | [1, 60, 60, 128] | - -## Schedule - -This is the conv2d compute schedule: - -``` - for (ko.outer: int32, 0, 2) { - for (ho.outer: int32, 0, 4) { - - for (ko.outer_1: int32, 0, 2) { - for (ho.outer_1: int32, 0, 2) { - - // input cache read - for (ho.inner: int32, 0, 3) { - if ((((ho.outer_1*2) + (ho.outer*2)) + ho.inner) < 8) { - ... - } - } - - // filter #1 cache read - for (ko.inner: int32, 0, 2) { - ... - } - - // conv2d #1 - for (ko.inner: int32, 0, 2) { - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { - if (((ho.outer_1*2) + ho.inner) < 3) { - if ((((ho.outer_1*2) + (ho.outer*2)) + ho.inner) < 8) { - for (rc.outer: int32, 0, 4) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (rh: int32, 0, 3) { - for (rw: int32, 0, 3) { - for (ki: int32, 0, 32) { - for (rc.inner: int32, 0, 32) { - } // end ho.outer_1 - } // end ko.outer_1 - - // filter #2 cache read - for (ko.inner: int32, 0, 2) { - ... - } - - // conv2d #2 - for (ko.c.inner: int32, 0, 2) { - for (ho.c.inner: int32, 0, 2) { - for (wo.c: int32, 0, 8) { - for (rc.outer_1: int32, 0, 4) { - for (hi.c: int32, 0, 8) { - for (wi.c: int32, 0, 8) { - for (rh_1: int32, 0, 3) { - for (rw_1: int32, 0, 3) { - for (ki.c: int32, 0, 32) { - for (rc.inner_1: int32, 0, 32) { - - // write back output cache - - } // end ho.outer - } // end ko.outer -``` - -There are two major changes here: - -1) The first change is the farily obvious presence of the kernel height `rh` and width `rw` iterators, for example: - -``` - for (rh_1: int32, 0, 3) { - for (rw_1: int32, 0, 3) { -``` - -The effect of this change is to grow the filter cache by the size of the kernel. Details below. - -2) The second change is a bit more tricky. Remember that we want to produce `h_split` (2) "full width" and "full channel depth" slices from each conv2d. Given the 3x3 kernel size there are several changes to the schedule regarding the handling of the height dimension. - -First, notice that in order to produce `h_split` (2) "full width" and "full channel depth" slices for conv2d #1 we will need `h_split + 1` (3) "full width" and "full channel depth" slices of the input. This is because a 3x3 kernel (as opposed to a 1x1 kernel) creates a many-to-one relationship between the spatial coordinates of the input relative to the output. To illustrate, the 3x3 kernel will "fall off the bottom" of the 2nd input slice requiring values from the vertically adjacent 3rd input slice in order to produce the 2nd full output slice. Hence, we have the following input cache read over `h_split + 1` (3) input slices: - -``` - for (ho.inner: int32, 0, 3) { - if ((((ho.outer_1*2) + (ho.outer*2)) + ho.inner) < 8) { -``` - -The `if` statement above indicates NOT to prefetch the vertically adjacent slice at the "bottom" of the input since it does not exist. - -Second, notice that conv2d #1 must produce sufficient output in the height dimension before conv2d #2 can proceed. This is similar to the requirement that conv2d #1 in regard to the channel out dimension, but also different because we do not require *all* output in the height dimenson only *sufficient* output in the height dimension. How much output in the height dimension is required? The intuitive guess might be `h_split + 1` (3) slices but that is wrong and the reason is that the output spatial coordinates are "shrinking" relative to the input coordinate space due to lack of padding. Hence 2 output slices from conv2d #1 are sufficient as intput to calculate 2 output slices from conv2d #2 and we get the following independent loop over `ho.outer_1` for conv2d #1: - -``` - for (ho.outer_1: int32, 0, 2) { -``` - -There are similar `if` statements in the conv2d compute schedule to prevent computing off the "bottom" of the input and output. - -(Same as above) Note that conv2d #1 has an independent loop over the channel out `ko.outer_1` dimension. This is because the output channels of conv2d #1 are the input channels to conv2d #2 and we compute over all input channels for each conv2d so we must compute over all output channels of conv2d #1 before we compute conv2d #2. - -``` - for (ko.outer_1: int32, 0, 2) { -``` - -## Cache Usage - -*Input Cache* - -The input cache grows to hold the vertically adjacent slice: - -``` - allocate(packed_input.global: Pointer(global float32), float32, [196608]), storage_scope = global; -``` - -*Filter Cache* - -The filter cache grows to hold the 3x3 filter: - -``` - allocate(packed_filter.global: Pointer(global float32), float32, [73728]), storage_scope = global; -``` - -(Same as above) Note that there is just one cache which is reused for conv2d / filter #1 and conv2d / filter #2. - -*Output Cache* - -The output cache scales with the input cache: - -``` - allocate(temp_output: Pointer(global float32), float32, [196608]), storage_scope = global; -``` - -(Same as above) Note that the input cache is reused to store the results of conv2d #2. - -## Assumptions - -* n/a - -## To Do - -* There may be some opportunity to optimized cache reuse in this case as the vertically adjacent input slice from a previous input cache read will be reloaded as in a subsequent input cache read - -## Annotated TIR - -``` -primfn(placeholder_3: handle, placeholder_4: handle, placeholder_5: handle, output_1: handle) -> () - attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} - buffers = {output: Buffer(output_2: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // nhw8h8w32c - placeholder_2: Buffer(placeholder_6: Pointer(float32), float32, [4, 4, 3, 3, 8, 32, 4], []), // oihw8i32o4i - placeholder_1: Buffer(placeholder_7: Pointer(float32), float32, [4, 4, 3, 3, 8, 32, 4], []), // oihw8i32o4i - placeholder: Buffer(placeholder_8: Pointer(float32), float32, [1, 64, 64, 128], [])} // nhwc - buffer_map = {placeholder_3: placeholder, placeholder_4: placeholder_1, placeholder_5: placeholder_2, output_1: output} { - allocate(packed_input.global: Pointer(global float32), float32, [196608]), storage_scope = global; - allocate(temp_output: Pointer(global float32), float32, [196608]), storage_scope = global; - allocate(packed_filter.global: Pointer(global float32), float32, [73728]), storage_scope = global; - for (ko.outer: int32, 0, 2) { - for (ho.outer: int32, 0, 4) { - // NOTE: compute over all output channels of conv2d #1 before computing conv2d #2 - for (ko.outer_1: int32, 0, 2) { - // NOTE: compute enough height of conv2d #1 before computing conv2d #2 - for (ho.outer_1: int32, 0, 2) { - - // input cache read - for (ho.inner: int32, 0, 3) { - if ((((ho.outer_1*2) + (ho.outer*2)) + ho.inner) < 8) { - for (wo: int32, 0, 8) { - for (co: int32, 0, 4) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ci: int32, 0, 32) { - packed_input.global[((((((ho.inner*65536) + (wo*8192)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = - (float32*)placeholder_8[((((((((ho.outer_1*131072) + (ho.outer*131072)) + (ho.inner*65536)) + (hi*8192)) + (wo*1024)) + (wi*128)) + (co*32)) + ci)] - } - } - } - } - } - } - } - - // filter #1 cache read - for (ko.inner: int32, 0, 2) { - for (co: int32, 0, 4) { - for (rh: int32, 0, 3) { - for (rw: int32, 0, 3) { - for (cio: int32, 0, 8) { - for (ki: int32, 0, 32) { - for (cii: int32, 0, 4) { - packed_filter.global[(((((((ko.inner*36864) + (co*9216)) + (rh*3072)) + (rw*1024)) + (cio*128)) + (ki*4)) + cii)] = - (float32*)placeholder_7[((((((((ko.outer_1*73728) + (ko.inner*36864)) + (co*9216)) + (rh*3072)) + (rw*1024)) + (cio*128)) + (ki*4)) + cii)] - } - } - } - } - } - } - } - - // conv2d #1 - for (ko.inner: int32, 0, 2) { - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { - - // init temp output to zero - if (((ho.outer_1*2) + ho.inner) < 3) { - for (hi.init: int32, 0, 8) { - for (wi.init: int32, 0, 8) { - for (ki.init: int32, 0, 32) { - temp_output[((((((((ho.outer_1*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi.init*256)) + (wi.init*32)) + ki.init)] = 0f32 - } - } - } - } - - // compute - if (((ho.outer_1*2) + ho.inner) < 3) { - if ((((ho.outer_1*2) + (ho.outer*2)) + ho.inner) < 8) { - for (rc.outer: int32, 0, 4) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (rh: int32, 0, 3) { - for (rw: int32, 0, 3) { - for (ki: int32, 0, 32) { - for (rc.inner: int32, 0, 32) { - temp_output[((((((((ho.outer_1*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = - ( - (float32*)temp_output[((((((((ho.outer_1*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer_1*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] + - ( - (float32*)packed_input.global[((((((((floordiv((hi + rh), 8)*65536) + (ho.inner*65536)) + (floordiv((wi + rw), 8)*8192)) + (wo*8192)) + (rc.outer*2048)) + (floormod((hi + rh), 8)*256)) + (floormod((wi + rw), 8)*32)) + rc.inner)] * - (float32*)packed_filter.global[(((((((ko.inner*36864) + (rc.outer*9216)) + (rh*3072)) + (rw*1024)) + (floordiv(rc.inner, 4)*128)) + (ki*4)) + floormod(rc.inner, 4))] - ) - ) - } - } - } - } - } - } - } - } - } - } - } - } - } - } - - // filter #2 cache read - // NOTE: reusing same filter cache - for (ko.inner: int32, 0, 2) { - for (co: int32, 0, 4) { - for (rh: int32, 0, 3) { - for (rw: int32, 0, 3) { - for (cio: int32, 0, 8) { - for (ki: int32, 0, 32) { - for (cii: int32, 0, 4) { - packed_filter.global[(((((((ko.inner*36864) + (co*9216)) + (rh*3072)) + (rw*1024)) + (cio*128)) + (ki*4)) + cii)] = - (float32*)placeholder_6[((((((((ko.outer*73728) + (ko.inner*36864)) + (co*9216)) + (rh*3072)) + (rw*1024)) + (cio*128)) + (ki*4)) + cii)] - } - } - } - } - } - } - } - - // conv2d #2 - for (ko.c.inner: int32, 0, 2) { - for (ho.c.inner: int32, 0, 2) { - for (wo.c: int32, 0, 8) { - - // init output cache to zero - // NOTE: reusing the input cache as the output cache - for (hi.c.init: int32, 0, 8) { - for (wi.c.init: int32, 0, 8) { - for (ki.c.init: int32, 0, 32) { - packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 - } - } - } - - // compute - for (rc.outer_1: int32, 0, 4) { - for (hi.c: int32, 0, 8) { - for (wi.c: int32, 0, 8) { - for (rh_1: int32, 0, 3) { - for (rw_1: int32, 0, 3) { - for (ki.c: int32, 0, 32) { - for (rc.inner_1: int32, 0, 32) { - packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = - ( - (float32*)packed_input.global[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + - ( - (float32*)temp_output[((((((((floordiv((hi.c + rh_1), 8)*65536) + (ho.c.inner*65536)) + (floordiv((wi.c + rw_1), 8)*8192)) + (wo.c*8192)) + (rc.outer_1*2048)) + (floormod((hi.c + rh_1), 8)*256)) + (floormod((wi.c + rw_1), 8)*32)) + rc.inner_1)] * - (float32*)packed_filter.global[(((((((ko.c.inner*36864) + (rc.outer_1*9216)) + (rh_1*3072)) + (rw_1*1024)) + (floordiv(rc.inner_1, 4)*128)) + (ki.c*4)) + floormod(rc.inner_1, 4))] - ) - ) - } - } - } - } - } - } - } - } - } - } - - // write back output cache - for (ko.inner_1: int32, 0, 2) { - for (ho.inner_1: int32, 0, 2) { - for (wo_1: int32, 0, 8) { - for (hi_1: int32, 0, 8) { - for (wi_1: int32, 0, 8) { - for (ki_1: int32, 0, 32) { - output_2[((((((((ho.outer*131072) + (ho.inner_1*65536)) + (wo_1*8192)) + (ko.outer*4096)) + (ko.inner_1*2048)) + (hi_1*256)) + (wi_1*32)) + ki_1)] = - (float32*)packed_input.global[((((((ho.inner_1*32768) + (wo_1*4096)) + (ko.inner_1*2048)) + (hi_1*256)) + (wi_1*32)) + ki_1)] - } - } - } - } - } - } - } - } -} -``` diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py deleted file mode 100644 index 57332eb9f152..000000000000 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ /dev/null @@ -1,376 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=invalid-name - -"""Hexagon testing infrastructure""" - -import numpy - -import tvm -from tvm import te - - -def ceildiv(o, d): - assert o >= 0 - assert d >= 0 - return tvm.tirx.floordiv(o + d - 1, d) - - -# defines inner block shape: 8h8w32c -def get_sblock_shape(): - return 8, 8, 32 - - -# defines inner filter block shape: 8i32o41 -def get_filter_block_shape(): - return 8, 32, 4 - - -# input: locgical shape in nhwc layout -# output: physical packed shape in nhw8h8w32c layout -def get_packed_shape(logical_shape_nhwc): - assert len(logical_shape_nhwc) == 4 - physical_shape_nhwc8h8w32c = [logical_shape_nhwc[0]] - block_shape = get_sblock_shape() - off_h, off_w, off_c = block_shape - physical_shape_nhwc8h8w32c.append(ceildiv(logical_shape_nhwc[1], off_h)) - physical_shape_nhwc8h8w32c.append(ceildiv(logical_shape_nhwc[2], off_w)) - physical_shape_nhwc8h8w32c.append(ceildiv(logical_shape_nhwc[3], off_c)) - physical_shape_nhwc8h8w32c.extend(block_shape) - return physical_shape_nhwc8h8w32c - - -# input: physical packed shape in nhw8h8w32c layout -# output: logical shape in nhwc layout -def get_logical_shape(physical_shape_nhwc8h8w32c): - assert len(physical_shape_nhwc8h8w32c) == 7 - logical_shape_nhwc = [physical_shape_nhwc8h8w32c[0]] - logical_shape_nhwc.append(physical_shape_nhwc8h8w32c[1] * physical_shape_nhwc8h8w32c[4]) - logical_shape_nhwc.append(physical_shape_nhwc8h8w32c[2] * physical_shape_nhwc8h8w32c[5]) - logical_shape_nhwc.append(physical_shape_nhwc8h8w32c[3] * physical_shape_nhwc8h8w32c[6]) - return logical_shape_nhwc - - -def get_packed_filter_shape(logical_shape_oihw): - """return packed filter shape - - Parameters - ---------- - logical_shape_oihw : - logical shape in oihw layout - - Returns - ------- - physical_shape_oihw8i32o4i : - physical packed shape in oihw8i3204i layout - """ - assert len(logical_shape_oihw) == 4 - filter_block_shape = get_filter_block_shape() - filter_Cio, filter_Ki, filter_Cii = filter_block_shape - filter_Ci = filter_Cio * filter_Cii - physical_shape_oihw8i32o4i = [] - physical_shape_oihw8i32o4i.append(int(ceildiv(logical_shape_oihw[0], filter_Ki))) - physical_shape_oihw8i32o4i.append(int(ceildiv(logical_shape_oihw[1], filter_Ci))) - physical_shape_oihw8i32o4i.append(logical_shape_oihw[2]) - physical_shape_oihw8i32o4i.append(logical_shape_oihw[3]) - physical_shape_oihw8i32o4i.extend(filter_block_shape) - return physical_shape_oihw8i32o4i - - -def build_and_run(inputs, func, target: str, target_host: str, *args, **kwargs): - """build and run the function func""" - schedule, placeholders, binds = func(*args, **kwargs) - - func = tvm.compile( - schedule, placeholders, target=tvm.target.Target(target, host=target_host), binds=binds - ) - dev = tvm.device(target) - tensors = [] - for tensor in inputs: - tensors.append(tvm.runtime.tensor(tensor, dev)) - tensors.append( - tvm.runtime.tensor( - numpy.zeros([i.value for i in placeholders[-1].shape], dtype=placeholders[-1].dtype), - dev, - ) - ) - func(*tensors) - - return tensors[-1].numpy() - - -def run_module(mod, inputs): - """invokes run function of specified module with inputs provided""" - mod.set_input(**inputs) - mod.run() - output = mod.get_output(0).numpy() - return output - - -def get_conv2d_nhwc_shape(shape_nhwc, kernel_size, strides, padding, dilation, out_channels): - assert len(shape_nhwc) == 4 - kernel = [] - kernel.append((kernel_size[0] - 1) * dilation[0] + 1) - kernel.append((kernel_size[1] - 1) * dilation[1] + 1) - return ( - shape_nhwc[0], - (shape_nhwc[1] - kernel[0] + padding[0] + padding[1]) // strides[0] + 1, - (shape_nhwc[2] - kernel[1] + padding[2] + padding[3]) // strides[1] + 1, - out_channels, - ) - - -def conv2d_verify(output, ref_output, dtype): - """transpose and reshape output and compare with ref_output""" - # nhwc8h8w32c -> nhwc - logical_output_shape = get_logical_shape(output.shape) - output = output.transpose(0, 1, 4, 2, 5, 3, 6).reshape(logical_output_shape) - - # slice output to match ref_output shape - # e.g. 8x8 spatial 3x3 filter = 6x6 ref output - # but still 8x8 output given the blocked layout - output = output[ - 0 : ref_output.shape[0] : 1, - 0 : ref_output.shape[1] : 1, - 0 : ref_output.shape[2] : 1, - 0 : ref_output.shape[3] : 1, - ] - - if "int" in dtype: - tol = {"atol": 0, "rtol": 0} - elif dtype == "float32": - tol = {"rtol": 1e-4, "atol": 2e-4} - tvm.testing.assert_allclose(output, ref_output, **tol) - - -def conv2d_compute(X, filt, pad, stride, dilation): - """Define conv2d compute""" - block_shape = get_sblock_shape() - block_H, block_W, block_C = block_shape - filter_c_io, _, filter_c_ii = get_filter_block_shape() - filter_c_i = filter_c_io * filter_c_ii - - shape_filter = filt.shape - kernel_size = tuple(shape_filter[2:4]) - out_channels = shape_filter[0] * shape_filter[5] - - logical_input_shape = get_logical_shape(X.shape) - logical_output_shape = get_conv2d_nhwc_shape( - logical_input_shape, - kernel_size, - stride, - pad, - dilation, - out_channels, - ) - - output_shape = get_packed_shape(logical_output_shape) - rh = te.reduce_axis((0, kernel_size[0]), name="rh") - rw = te.reduce_axis((0, kernel_size[1]), name="rw") - rc = te.reduce_axis((0, logical_input_shape[3]), name="rc") - - def compute(n, ho, wo, ko, hi, wi, ki): - h = ho * block_H + hi - h_contig = h * stride[0] + rh - h_block_id = h_contig // block_H - h_block_offset = h_contig % block_H - - w = wo * block_W + wi - w_contig = w * stride[1] + rw - w_block_id = w_contig // block_W - w_block_offset = w_contig % block_W - - c_block_id = rc // block_C - c_block_offset = rc % block_C - - rco = rc // filter_c_i - rcio = (rc % filter_c_i) // filter_c_ii - rcii = rc % filter_c_ii - - return te.sum( - X[ - n, - h_block_id, - w_block_id, - c_block_id, - h_block_offset, - w_block_offset, - c_block_offset, - ] - * filt[ko, rco, rh, rw, rcio, ki, rcii], - axis=[rh, rw, rc], - ) - - return output_shape, compute - - -def transform_numpy(arr_np, current_layout: str, new_layout: str): - """Reshape and transpose numpy array according to the specified layout""" - if current_layout == "nhwc": - if new_layout == "nhwc": - return arr_np - if new_layout in ["nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-1d"]: - n, h, w, c = arr_np.shape - return arr_np.reshape([n, h // 8, 8, w // 4, 2, 2, c // 32, 32]).transpose( - 0, 1, 3, 6, 2, 4, 7, 5 - ) - if new_layout in ["nhwc-4h2w32c2w-2d"]: - n, h, w, c = arr_np.shape - return arr_np.reshape([n, h // 4, 4, w // 4, 2, 2, c // 32, 32]).transpose( - 0, 1, 3, 6, 2, 4, 7, 5 - ) - if new_layout in ["n11c-1024c-2d", "n11c-1024c-1d"]: - n, h, w, c = arr_np.shape - assert h == 1 and w == 1, "The size of h and w must be 1" - return arr_np.reshape([n, 1, 1, c // 1024, 1024]) - if new_layout == "nc-1024-2d": - n, c = arr_np.shape - return arr_np.reshape([n, c // 1024, 1024]) - if new_layout == "nhwc-1024c-2d": - N, H, W, C = arr_np.shape - return arr_np.reshape([N, H, W, C // 1024, 1024]) - if new_layout == "nc-2048-2d": - N, C = arr_np.shape - return arr_np.reshape([N, C // 2048, 2048]) - if new_layout == "nhwc-2048c-2d": - N, H, W, C = arr_np.shape - return arr_np.reshape([N, H, W, C // 2048, 2048]) - if new_layout == "nhwc-8h8w32c-2d": - n, h, w, c = arr_np.shape - return arr_np.reshape([n, h // 8, 8, w // 8, 8, c // 32, 32]).transpose( - 0, 1, 3, 5, 2, 4, 6 - ) - if new_layout == "n11c-2048c-2d": - n, h, w, c = arr_np.shape - assert h == 1 and w == 1, "The size of h and w must be 1" - return arr_np.reshape([n, h, w, c // 2048, 2048]) - raise RuntimeError(f"Unexpected new_layout '{new_layout}'") - - if current_layout == "nc": - n, c = arr_np.shape - if new_layout in ["nc-2048c-1d"]: - return arr_np.reshape([n, c // 2048, 2048]) - if new_layout in ["nc-2048c-2d"]: - return arr_np.reshape([n, c // 2048, 2048]) - if new_layout in ["nc-1024c-2d"]: - return arr_np.reshape([n, c // 1024, 1024]) - if new_layout in ["nc-1024c-1d"]: - return arr_np.reshape([n, c // 1024, 1024]) - if new_layout in ["nc-512c-2d"]: - return arr_np.reshape([n, c // 512, 512]) - if new_layout in ["nc-2048c-2d"]: - return arr_np.reshape([n, c // 2048, 2048]) - raise RuntimeError(f"Unexpected new_layout '{new_layout}'") - - if current_layout == "nhw": - if new_layout in ["nhw-32h16w-2d"]: - n, h, w = arr_np.shape - return arr_np.reshape([n, h // 32, 32, w // 16, 16]).transpose(0, 1, 3, 2, 4) - - raise RuntimeError(f"Unexpected new_layout '{new_layout}'") - - if current_layout == "ncw": - if new_layout == "ncw": - return arr_np - if new_layout in ["ncw-32c64w-2d"]: - n, c, w = arr_np.shape - return arr_np.reshape([n, c // 32, 32, w // 64, 64]).transpose(0, 1, 3, 2, 4) - - raise RuntimeError(f"Unexpected new_layout '{new_layout}'") - - if current_layout == "nchw": - if new_layout in ["nchw-32c8h8w-2d", "nchw-32c8h8w-1d"]: - n, c, h, w = arr_np.shape - return arr_np.reshape([n, c // 32, 32, h // 8, 8, w // 8, 8]).transpose( - 0, 1, 3, 5, 2, 4, 6 - ) - if new_layout in ["nchw-32c8h4w-2d", "nchw-32c8h4w-1d"]: - n, c, h, w = arr_np.shape - return arr_np.reshape([n, c // 32, 32, h // 8, 8, w // 4, 4]).transpose( - 0, 1, 3, 5, 2, 4, 6 - ) - raise RuntimeError(f"Unexpected new_layout '{new_layout}'") - - raise RuntimeError(f"Unexpected current_layout '{current_layout}'") - - -def quantize_np(arr_np: numpy.ndarray, dtype: str): - """ - Returns quantized array along with scale and zero-point - - Parameters - ---------- - arr_np: numpy.ndarray - Input numpy array to be quantized - dtype: str - dtype of the quantized array: "uint8", "int8", etc - - Returns - ------- - quant_np: numpy.ndarray - Quantized numpy array - scale: float - Scale - zero_point: int - Value corresponding to float 0 - - """ - if dtype == "uint8": - qmax = 255 - qmin = 0 - elif dtype == "int8": - qmax = 127 - qmin = -128 - else: - raise RuntimeError(f"Unsupported quantized data type '{dtype}'") - fmin = numpy.amin(arr_np) - fmax = numpy.amax(arr_np) - - # Include floating-point zero in the range - if fmax < 0: - fmax = 0.0 - elif fmin > 0: - fmin = 0.0 - - scale = (fmax - fmin) / (qmax - qmin) - zero_point = numpy.rint((fmax * qmin - fmin * qmax) / (fmax - fmin)).astype("int32") - quant_np = numpy.clip(((arr_np / scale).round() + zero_point), qmin, qmax).astype(dtype) - return quant_np, scale, zero_point - - -def get_hexagon_target(cpu_ver: str, **kwargs) -> tvm.target.Target: - """Creates a Hexagon target from a registered tag. - - Parameters - ---------- - cpu_ver : str - Hexagon CPU version, e.g. "v68", "v69". - **kwargs : - Optional target attribute overrides (e.g. vtcm_capacity=1024). - """ - tag = "qcom/hexagon-" + cpu_ver - if kwargs: - config = {"tag": tag} - if "vtcm_capacity" in kwargs: - config["vtcm-capacity"] = kwargs.pop("vtcm_capacity") - if "num_cores" in kwargs: - config["num-cores"] = kwargs.pop("num_cores") - config.update(kwargs) - target = tvm.target.Target(config) - else: - target = tvm.target.Target(tag) - return tvm.target.Target(target, host=target) diff --git a/tests/python/contrib/test_hexagon/pytest_util.py b/tests/python/contrib/test_hexagon/pytest_util.py deleted file mode 100644 index 1c7b92a529b1..000000000000 --- a/tests/python/contrib/test_hexagon/pytest_util.py +++ /dev/null @@ -1,176 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Hexagon pytest utility functions""" - -import collections - -import numpy as np - - -def get_test_id(*test_params, test_param_descs: list[str | None] | None = None) -> str: - """ - An opinionated alternative to pytest's default algorithm for generating a - test's ID string. Intended to make it easier for human readers to - interpret the test IDs. - - 'test_params': The sequence of pytest parameter values supplied to some unit - test. - - 'test_param_descs': An (optional) means to provide additional text for some/all of the - paramuments in 'test_params'. - - If provided, then len(test_params) must equal len(test_param_descs). - Each element test_param_descs that is a non-empty string will be used - in some sensible way in this function's returned string. - """ - - assert len(test_params) > 0 - - if test_param_descs is None: - test_param_descs = [None] * len(test_params) - else: - assert len(test_param_descs) == len(test_params) - - def get_single_param_chunk(param_val, param_desc: str | None): - if isinstance(param_val, list): - # Like str(list), but avoid the whitespace padding. - val_str = "[" + ",".join(str(x) for x in param_val) + "]" - need_prefix_separator = False - - elif isinstance(param_val, bool): - if param_val: - val_str = "T" - else: - val_str = "F" - need_prefix_separator = True - - elif isinstance(param_val, TensorContentConstant): - val_str = f"const[{param_val.elem_value}]" - need_prefix_separator = True - - elif isinstance(param_val, TensorContentDtypeMin): - val_str = "min" - need_prefix_separator = True - - elif isinstance(param_val, TensorContentDtypeMax): - val_str = "max" - need_prefix_separator = True - - elif isinstance(param_val, TensorContentRandom): - val_str = "random" - need_prefix_separator = True - - elif isinstance(param_val, TensorContentSequentialCOrder): - val_str = f"seqC[start:{param_val.start_value},inc:{param_val.increment}]" - need_prefix_separator = True - - else: - val_str = str(param_val) - need_prefix_separator = True - - if param_desc and need_prefix_separator: - return f"{param_desc}:{val_str}" - elif param_desc and not need_prefix_separator: - return f"{param_desc}{val_str}" - else: - return val_str - - chunks = [ - get_single_param_chunk(param_val, param_desc) - for param_val, param_desc in zip(test_params, test_param_descs) - ] - return "-".join(chunks) - - -def get_multitest_ids( - multitest_params_list: list[list], param_descs: list[str | None] | None -) -> list[str]: - """ - A convenience function for classes that use both 'tvm.testing.parameters' and 'get_test_id'. - - This function provides a workaround for a specific quirk in Python, where list-comprehension - can't necessarily access the value of another class-variable, discused here: - https://stackoverflow.com/q/13905741 - """ - return [ - get_test_id(*single_test_param_list, test_param_descs=param_descs) - for single_test_param_list in multitest_params_list - ] - - -def get_numpy_dtype_info(dtype) -> np.finfo | np.iinfo: - """ - Return an appropriate 'np.iinfo' or 'np.finfo' object corresponding to - the specified Numpy dtype. - - 'dtype' must be a value that 'numpy.dtype(...)' can handle. - """ - np_dtype = np.dtype(dtype) - kind = np_dtype.kind - - if kind == "f": - return np.finfo(np_dtype) - elif kind == "i": - return np.iinfo(np_dtype) - else: - raise TypeError(f"dtype ({dtype}) must indicate some floating-point or integral data type") - - -TensorContentConstant = collections.namedtuple("TensorContentConstant", ["elem_value"]) -TensorContentSequentialCOrder = collections.namedtuple( - "TensorContentSequentialCOrder", ["start_value", "increment"] -) -TensorContentRandom = collections.namedtuple("TensorContentRandom", []) -TensorContentDtypeMin = collections.namedtuple("TensorContentDtypeMin", []) -TensorContentDtypeMax = collections.namedtuple("TensorContentDtypeMax", []) - - -def create_populated_numpy_tensor( - input_shape: list | tuple, dtype: str, input_tensor_populator -) -> np.ndarray: - """ - Create a numpy tensor with the specified shape, dtype, and content. - """ - itp = input_tensor_populator # just for brevity - - if isinstance(itp, TensorContentConstant): - return np.full(tuple(input_shape), itp.elem_value, dtype=dtype) - - elif isinstance(itp, TensorContentDtypeMin): - info = get_numpy_dtype_info(dtype) - return np.full(tuple(input_shape), info.min, dtype=dtype) - - elif isinstance(itp, TensorContentDtypeMax): - info = get_numpy_dtype_info(dtype) - return np.full(tuple(input_shape), info.max, dtype=dtype) - - elif isinstance(itp, TensorContentRandom): - return np.random.random(input_shape).astype(dtype) - - elif isinstance(itp, TensorContentSequentialCOrder): - a = np.empty(tuple(input_shape), dtype) - - with np.nditer(a, op_flags=["writeonly"], order="C") as iterator: - next_elem_val = itp.start_value - for elem in iterator: - elem[...] = next_elem_val - next_elem_val += itp.increment - return a - - else: - raise ValueError(f"Unexpected input_tensor_populator type: {type(itp)}") diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py deleted file mode 100644 index fba8ccb47dc0..000000000000 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ /dev/null @@ -1,889 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: F841 - -"""Test different strategies for loading data into vtcm before running HVX workloads.""" - -import numpy as np -import pytest - -import tvm -from tvm.script import tirx as T -from tvm.testing import env - -VRMPY_SIZE_B = 128 -VRMPY_SIZE_INT32 = 32 - - -# pylint: disable=invalid-name -@T.prim_func(s_tir=True) -def conv2d_async_non_contig( - p0: T.Buffer((T.int64(1), T.int64(1), T.int64(56), T.int64(56), T.int64(4)), "uint8"), - fused_constant_1: T.Buffer( - (T.int64(1), T.int64(1), T.int64(3), T.int64(3), T.int64(1), T.int64(32), T.int64(4)), - "uint8", - ), - conv2d_NCHWc_int8: T.Buffer( - (T.int64(1), T.int64(1), T.int64(54), T.int64(54), T.int64(32)), "int32" - ), -): - """Non contiguous memory access is used in this conv2d taken from MS.""" - # pylint: disable=no-self-argument - # function attr dict - T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) - # body - # with T.sblock("root") - p0_global_vtcm = T.sblock_alloc_buffer( - [T.int64(1), T.int64(1), T.int64(56), T.int64(56), T.int64(4)], - dtype="uint8", - scope="global.vtcm", - ) - fused_constant_global_vtcm = T.sblock_alloc_buffer( - [T.int64(1), T.int64(1), T.int64(3), T.int64(3), T.int64(1), T.int64(32), T.int64(4)], - dtype="uint8", - scope="global.vtcm", - ) - for oh_0 in T.serial(T.int64(3)): - for ow_0 in T.serial( - T.int64(3), - annotations={ - "software_pipeline_async_stages": [0], - "software_pipeline_order": [0, 1, 2], - "software_pipeline_stage": [0, 0, 1], - }, - ): - for ax0_ax1_ax2_ax3_ax4_fused in T.serial(T.int64(1600)): - with T.sblock("p0_global.vtcm"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial( - T.int64(56), oh_0 * T.int64(18) + ax0_ax1_ax2_ax3_ax4_fused // T.int64(80) - ) - v3 = T.axis.spatial( - T.int64(56), - ow_0 * T.int64(18) + ax0_ax1_ax2_ax3_ax4_fused % T.int64(80) // T.int64(4), - ) - v4 = T.axis.spatial(T.int64(4), ax0_ax1_ax2_ax3_ax4_fused % T.int64(4)) - T.reads(p0[v0, v1, v2, v3, v4]) - T.writes(p0_global_vtcm[v0, v1, v2, v3, v4]) - p0_global_vtcm[v0, v1, v2, v3, v4] = p0[v0, v1, v2, v3, v4] - for ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused in T.serial(T.int64(1152)): - with T.sblock("fused_constant_global.vtcm"): - v0 = T.axis.spatial(T.int64(1), T.int64(0)) - v1 = T.axis.spatial(T.int64(1), T.int64(0)) - v2 = T.axis.spatial( - T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused // T.int64(384) - ) - v3 = T.axis.spatial( - T.int64(3), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(384) // T.int64(128) - ) - v4 = T.axis.spatial(T.int64(1), T.int64(0)) - v5 = T.axis.spatial( - T.int64(32), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(128) // T.int64(4) - ) - v6 = T.axis.spatial(T.int64(4), ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % T.int64(4)) - T.reads(fused_constant_1[v0, v1, v2, v3, v4, v5, v6]) - T.writes(fused_constant_global_vtcm[v0, v1, v2, v3, v4, v5, v6]) - fused_constant_global_vtcm[v0, v1, v2, v3, v4, v5, v6] = fused_constant_1[ - v0, v1, v2, v3, v4, v5, v6 - ] - for oh_1, ow_1 in T.grid(T.int64(3), T.int64(6)): - for oh_2_init, ow_2_init in T.grid(T.int64(6), T.int64(3)): - with T.sblock("conv2d_NCHWc_int8_o_init"): - v_n = T.axis.spatial(T.int64(1), T.int64(0)) - v_oc_chunk = T.axis.spatial(T.int64(1), T.int64(0)) - v_oh = T.axis.spatial( - T.int64(54), oh_0 * T.int64(18) + oh_1 * T.int64(6) + oh_2_init - ) - v_ow = T.axis.spatial( - T.int64(54), ow_0 * T.int64(18) + ow_1 * T.int64(3) + ow_2_init - ) - T.reads() - T.writes( - conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)] - ) - for oc_block_1 in T.vectorized(T.int64(32)): - with T.sblock("conv2d_NCHWc_int8_init"): - v_oc_block_i_init = T.axis.spatial(T.int64(32), oc_block_1) - T.reads() - T.writes( - conv2d_NCHWc_int8[ - v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init - ] - ) - conv2d_NCHWc_int8[ - v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init - ] = 0 - for kh_1, kw_1, oh_2, ow_2 in T.grid( - T.int64(3), T.int64(3), T.int64(6), T.int64(3) - ): - with T.sblock("conv2d_NCHWc_int8_o_update"): - v_n = T.axis.spatial(T.int64(1), T.int64(0)) - v_oc_chunk = T.axis.spatial(T.int64(1), T.int64(0)) - v_oh = T.axis.spatial( - T.int64(54), oh_0 * T.int64(18) + oh_1 * T.int64(6) + oh_2 - ) - v_ow = T.axis.spatial( - T.int64(54), ow_0 * T.int64(18) + ow_1 * T.int64(3) + ow_2 - ) - v_kh, v_kw = T.axis.remap("RR", [kh_1, kw_1]) - v_ic_outer = T.axis.reduce(T.int64(1), T.int64(0)) - v_ic_f_inner = T.axis.reduce(T.int64(1), T.int64(0)) - T.reads( - conv2d_NCHWc_int8[ - v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32) - ], - p0_global_vtcm[ - v_n, - v_ic_outer, - v_oh + v_kh, - v_ow + v_kw, - v_ic_f_inner * T.int64(4) : v_ic_f_inner * T.int64(4) + T.int64(4), - ], - fused_constant_global_vtcm[ - v_oc_chunk, - v_ic_outer, - v_kh, - v_kw, - v_ic_f_inner, - T.int64(0) : T.int64(32), - T.int64(0) : T.int64(4), - ], - ) - T.writes( - conv2d_NCHWc_int8[v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32)] - ) - A = T.match_buffer( - p0_global_vtcm[ - v_n, - v_ic_outer, - v_oh + v_kh, - v_ow + v_kw, - v_ic_f_inner * T.int64(4) : v_ic_f_inner * T.int64(4) + T.int64(4), - ], - [T.int64(4)], - dtype="uint8", - scope="global.vtcm", - offset_factor=1, - ) - B = T.match_buffer( - fused_constant_global_vtcm[ - v_oc_chunk, - v_ic_outer, - v_kh, - v_kw, - v_ic_f_inner, - T.int64(0) : T.int64(32), - T.int64(0) : T.int64(4), - ], - [T.int64(32), T.int64(4)], - dtype="uint8", - scope="global.vtcm", - offset_factor=1, - ) - C = T.match_buffer( - conv2d_NCHWc_int8[ - v_n, v_oc_chunk, v_oh, v_ow, T.int64(0) : T.int64(32) - ], - [T.int64(32)], - dtype="int32", - offset_factor=1, - ) - A_u8x4: T.uint8x4 = A[T.int64(0) : T.int64(4)] - A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32") - B_i8x128 = B[T.int64(0), T.int64(0) : T.int64(128)] - B_i32x32: T.int32x32 = T.reinterpret(B_i8x128, dtype="int32x32") - C[0:32] = T.call_llvm_pure_intrin( - T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"), - C[0:32], - B_i32x32, - A_i32, - dtype="int32x32", - ) - - -def conv_approximation(size_a, size_w): - """Conv approximation.""" - a_shape = (size_a, VRMPY_SIZE_B) - w_shape = (size_w, VRMPY_SIZE_B) - out_shape = (size_a, VRMPY_SIZE_INT32) - - @T.prim_func(s_tir=True) - def operator(a_input: T.handle, b_input: T.handle, c_output: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tirx.noalias": True}) - a_buffer = T.match_buffer(a_input, a_shape, dtype="uint8") - w_buffer = T.match_buffer(b_input, w_shape, dtype="uint8") - c_buffer = T.match_buffer(c_output, out_shape, dtype="int32") - for n, index_0 in T.grid(size_a, size_w): - with T.sblock("c_buffer"): - vn_index, vi_index = T.axis.remap("SR", [n, index_0]) - T.reads( - a_buffer[vn_index, 0:VRMPY_SIZE_B], - w_buffer[vi_index, 0:VRMPY_SIZE_B], - c_buffer[vn_index, 0:VRMPY_SIZE_INT32], - ) - T.writes(c_buffer[vn_index, 0:VRMPY_SIZE_INT32]) - with T.init(): - for x in T.serial(VRMPY_SIZE_INT32): - c_buffer[vn_index, x] = 0 - c_buffer[vn_index, T.ramp(0, 1, 32)] = T.call_llvm_intrin( - T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"), - c_buffer[vn_index, T.ramp(0, 1, 32)], - T.reinterpret(a_buffer[vn_index, T.ramp(0, 1, 128)], dtype="int32x32"), - T.reinterpret(w_buffer[vi_index, T.ramp(0, 1, 128)], dtype="int32x32"), - dtype="int32x32", - ) - # Currently async DMA lowering does not add any wait to the end of schedules so - # for timing purposes we are manually adding a wait to ensure that all copies - # are complete when the schedule exits. - T.evaluate( - T.tvm_call_packed( - "device_api.hexagon.dma_wait", - 0, # QueueId - 0, # Wait for 0 in flight - dtype="int32", - ) - ) - - return tvm.s_tir.Schedule(operator) - - -def evaluate( - hexagon_session, - sch, - a_data, - b_data, - c_data, - expected_output=None, - use_async_copy=0, -): - """Evaluate function.""" - target_hexagon = tvm.target.Target("qcom/hexagon-v68") - with tvm.transform.PassContext( - config={ - "tirx.use_async_copy": use_async_copy, - "tirx.experimental_dma_bypass_cache": 1, - } - ): - func_tir = tvm.compile( - sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon) - ) - module = hexagon_session.load_module(func_tir) - - a_hexagon = tvm.runtime.tensor(a_data, device=hexagon_session.device) - b_hexagon = tvm.runtime.tensor(b_data, device=hexagon_session.device) - c_hexagon = tvm.runtime.tensor(c_data, device=hexagon_session.device) - - if tvm.testing.utils.IS_IN_CI: - # Run with reduced number and repeat for CI - timer = module.time_evaluator("main", hexagon_session.device, number=1, repeat=1) - else: - timer = module.time_evaluator("main", hexagon_session.device, number=10, repeat=10) - - time = timer(a_hexagon, b_hexagon, c_hexagon) - if expected_output is not None: - tvm.testing.assert_allclose(c_hexagon.numpy(), expected_output) - return round(time.mean * 1000, 4) - - -def get_fake_conv_vtcm_schedule(size_a, size_w, blocks=2): - """Generate fake conv schedule with VTCM.""" - sch = conv_approximation(size_a, size_w) - - compute_block = sch.get_sblock("c_buffer") - sch.cache_read(compute_block, 1, "global.vtcm") - - n = sch.get_loops(compute_block)[0] - n_outer, _ = sch.split(n, [blocks, None]) - - cache_read_block_a = sch.cache_read(compute_block, 0, "global.vtcm") - sch.compute_at(cache_read_block_a, n_outer) - sch.fuse(*sch.get_loops(cache_read_block_a)[1:]) - - cache_write_block_c = sch.cache_write(compute_block, 0, "global.vtcm") - sch.reverse_compute_at(cache_write_block_c, n_outer) - sch.fuse(*sch.get_loops(cache_write_block_c)[1:]) - - return sch - - -def get_multi_input_fake_conv_vtcm_schedule(size_a, size_w, blocks=2): - """Generate multi input fake Conv using VTCM.""" - sch = conv_approximation(size_a, size_w) - - compute_block = sch.get_sblock("c_buffer") - - n = sch.get_loops(compute_block)[0] - n_outer, _ = sch.split(n, [blocks, None]) - - cache_read_block_a = sch.cache_read(compute_block, 0, "global.vtcm") - sch.compute_at(cache_read_block_a, n_outer) - sch.fuse(*sch.get_loops(cache_read_block_a)[1:]) - - cache_read_block_b = sch.cache_read(compute_block, 1, "global.vtcm") - sch.compute_at(cache_read_block_b, n_outer) - sch.fuse(*sch.get_loops(cache_read_block_b)[1:]) - - cache_write_block_c = sch.cache_write(compute_block, 0, "global.vtcm") - sch.reverse_compute_at(cache_write_block_c, n_outer) - sch.fuse(*sch.get_loops(cache_write_block_c)[1:]) - - return sch - - -def print_results(test_key, runtimes): - print(test_key) - for runtime in runtimes.items(): - print(f"-{runtime[0]} took {runtime[1]} ms") - print() - - -class TestAsyncDMAPipeline: - """Async DMA pipeline test class.""" - - # Removed most of these to speedup CI. - size_a = tvm.testing.parameter( - 1024, - 64 * 64, - # 128 * 64, # Only works on 8Gen1 HDK's - ) - - size_w = tvm.testing.parameter( - 1 * 1, - 3 * 3, - 9 * 9, - ) - - @tvm.testing.fixture - def input_a(self, size_a): - return np.random.randint(0, 8, (size_a, VRMPY_SIZE_B), dtype="uint8") - - @tvm.testing.fixture - def input_w(self, size_w): - return np.random.randint(0, 8, (size_w, VRMPY_SIZE_B), dtype="uint8") - - @tvm.testing.fixture - def expected_output(self, size_a, size_w, input_a, input_w): - """Generate expected output.""" - if tvm.testing.utils.IS_IN_CI and (size_a > 1024 or size_w > 1): - pytest.skip("Skipping test since it takes too long in CI.") - expected_result = np.zeros((size_a, VRMPY_SIZE_INT32), dtype="int32") - for n in range(size_a): - for x in range(size_w): - for index_0 in range(VRMPY_SIZE_INT32): - for r_index in range(4): - expected_result[n, index_0] += np.uint32( - input_a[n, index_0 * 4 + r_index] - ) * np.uint32(input_w[x, index_0 * 4 + r_index]) - return expected_result - - @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") - def test_loading_vtcm_for_vrmpy( - self, - hexagon_session, - size_a, - size_w, - input_a, - input_w, - expected_output, - ): - """VTCM for VRMPY test.""" - - if tvm.testing.utils.IS_IN_CI and (size_a > 1024 or size_w > 1): - pytest.skip("Skipping test since it takes too long in CI.") - - sch = conv_approximation(size_a, size_w) - base_runtime = evaluate( - hexagon_session, - sch, - input_a, - input_w, - np.zeros(expected_output.shape, "int32"), - expected_output, - ) - - sch = get_fake_conv_vtcm_schedule(size_a, size_w) - base_vtcm_runtime = evaluate( - hexagon_session, - sch, - input_a, - input_w, - np.zeros(expected_output.shape, "int32"), - expected_output, - use_async_copy=1, - ) - - sch = get_fake_conv_vtcm_schedule(size_a, size_w) - n = sch.get_loops(sch.get_sblock("c_buffer"))[0] - sch.annotate(n, "software_pipeline_stage", [0, 1, 2]) - sch.annotate(n, "software_pipeline_order", [0, 1, 2]) - sch.annotate(n, "software_pipeline_async_stages", [0]) - async_input_runtime = evaluate( - hexagon_session, - sch, - input_a, - input_w, - np.zeros(expected_output.shape, "int32"), - expected_output, - use_async_copy=1, - ) - - sch = get_fake_conv_vtcm_schedule(size_a, size_w) - n = sch.get_loops(sch.get_sblock("c_buffer"))[0] - sch.annotate(n, "software_pipeline_stage", [0, 1, 2]) - sch.annotate(n, "software_pipeline_order", [0, 1, 2]) - sch.annotate(n, "software_pipeline_async_stages", [0, 2]) - async_input_output = evaluate( - hexagon_session, - sch, - input_a, - input_w, - np.zeros(expected_output.shape, "int32"), - expected_output, - use_async_copy=1, - ) - - sch = get_fake_conv_vtcm_schedule(size_a, size_w) - n = sch.get_loops(sch.get_sblock("c_buffer"))[0] - sch.annotate(n, "software_pipeline_stage", [0, 3, 6]) - sch.annotate(n, "software_pipeline_order", [0, 1, 2]) - sch.annotate(n, "software_pipeline_async_stages", [0, 6]) - async_larger_buffers = evaluate( - hexagon_session, - sch, - input_a, - input_w, - np.zeros(expected_output.shape, "int32"), - expected_output, - use_async_copy=1, - ) - - sch = get_multi_input_fake_conv_vtcm_schedule(size_a, size_w) - n = sch.get_loops(sch.get_sblock("c_buffer"))[0] - sch.annotate(n, "software_pipeline_stage", [0, 0, 1, 2]) - sch.annotate(n, "software_pipeline_order", [0, 1, 2, 3]) - sch.annotate(n, "software_pipeline_async_stages", [0, 2]) - async_multi_input_output = evaluate( - hexagon_session, - sch, - input_a, - input_w, - np.zeros(expected_output.shape, "int32"), - expected_output, - use_async_copy=1, - ) - - sch = get_fake_conv_vtcm_schedule(size_a, size_w) - n = sch.get_loops(sch.get_sblock("c_buffer"))[0] - sch.annotate(n, "software_pipeline_stage", [0, 1, 2]) - sch.annotate(n, "software_pipeline_order", [0, 1, 2]) - sch.annotate(n, "software_pipeline_async_stages", [2]) - async_output_runtime = evaluate( - hexagon_session, - sch, - input_a, - input_w, - np.zeros(expected_output.shape, "int32"), - expected_output, - use_async_copy=1, - ) - - # Total transfer size is equal to the size of - # a_buffer + w_buffer + c_buffer which is equal to 2 * size_a * 128 + size_w * 128 - transfer_mb = round((2 * size_a * VRMPY_SIZE_B + size_w * VRMPY_SIZE_B) / 1e6, 2) - - # Total number of operations can be calculated given - # the total number of vrmpy calls (size_a * size_w) * operations - # per vrmpy accumulate (128 multiplies + 3 adds for reduction - # per lane + 1 add for accumulate per lane) - complexity = round(size_a * size_w * (VRMPY_SIZE_B * 4) / 1e9, 3) - print_results( - ( - f"Test with a_buffer.size: {size_a * VRMPY_SIZE_B}, w_buffer.size:" - f" {size_w * VRMPY_SIZE_B}, computational complexity of {complexity} GOPs" - f", and total memory transfer of {transfer_mb} MB..." - ), - { - "without_vtcm": base_runtime, - "base_vtcm": base_vtcm_runtime, - "async_dma_input": async_input_runtime, - "async_dma_output": async_output_runtime, - "async_dma_input_output": async_input_output, - "async_dma_multi_input_output": async_multi_input_output, - "async_input_output_runtime_larger_buffers": async_larger_buffers, - }, - ) - - -# from tvm.script import tirx as T -@tvm.script.ir_module -class ModulePipelined: - """Pipelined module class.""" - - # pylint: disable=no-self-argument - @T.prim_func(s_tir=True) - def main( - p0_buffer: T.Buffer((1, 1, 230, 230, 4), "uint8"), - p1_buffer: T.Buffer((2, 1, 7, 7, 1, 32, 4), "int8"), - t_cast: T.Buffer((1, 2, 112, 112, 32), "int32"), - ) -> None: - # pylint: disable=missing-function-docstring - # function attr dict - T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) - # body - # with T.sblock("root") - conv2d_nchwc_int8 = T.sblock_alloc_buffer( - [1, 2, 112, 112, 32], dtype="int32", scope="global.vtcm" - ) - p0_global_vtcm = T.sblock_alloc_buffer( - [1, 1, 230, 230, 4], dtype="uint8", scope="global.vtcm" - ) - p1_global_vtcm = T.sblock_alloc_buffer( - [2, 1, 7, 7, 1, 32, 4], dtype="int8", scope="global.vtcm" - ) - for ax0, ax1, ax2, ax3, ax4, ax5, ax6 in T.grid(2, 1, 7, 7, 1, 32, 4): - with T.sblock("p1_global.vtcm"): - v0_ind, v1_ind, v2_ind, v3_ind, v4_ind, v5_ind, v6_ind = T.axis.remap( - "SSSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5, ax6] - ) - T.reads(p1_buffer[v0_ind, v1_ind, v2_ind, v3_ind, v4_ind, v5_ind, v6_ind]) - T.writes(p1_global_vtcm[v0_ind, v1_ind, v2_ind, v3_ind, v4_ind, v5_ind, v6_ind]) - p1_global_vtcm[v0_ind, v1_ind, v2_ind, v3_ind, v4_ind, v5_ind, v6_ind] = p1_buffer[ - v0_ind, v1_ind, v2_ind, v3_ind, v4_ind, v5_ind, v6_ind - ] - for p_outer in T.serial(4): - for index_0 in T.serial(55876): - with T.sblock("p0_global.vtcm"): - v0_ind = T.axis.spatial(1, 0) - v1_ind = T.axis.spatial(1, 0) - v2_ind = T.axis.spatial(230, p_outer * 56 + index_0 // 916) - v3_ind = T.axis.spatial(230, index_0 % 916 // 4) - v4_ind = T.axis.spatial(4, index_0 % 4) - T.reads(p0_buffer[v0_ind, v1_ind, v2_ind, v3_ind, v4_ind]) - T.writes(p0_global_vtcm[v0_ind, v1_ind, v2_ind, v3_ind, v4_ind]) - p0_global_vtcm[v0_ind, v1_ind, v2_ind, v3_ind, v4_ind] = p0_buffer[ - v0_ind, v1_ind, v2_ind, v3_ind, v4_ind - ] - for index_0 in T.parallel(28): - for index_1, index_2, index_3 in T.grid(2, 14, 8): - with T.sblock("conv2d_NCHWc_int8_o_init"): - n = T.axis.spatial(1, 0) - oc_chunk = T.axis.spatial(2, index_1) - o_height = T.axis.spatial( - 112, (p_outer * 28 + index_0) // 14 * 14 + index_2 - ) - o_width = T.axis.spatial(112, (p_outer * 28 + index_0) % 14 * 8 + index_3) - oc_block_o = T.axis.spatial(1, 0) # pylint: disable=unused-variable - T.reads() - T.writes(conv2d_nchwc_int8[n, oc_chunk, o_height, o_width, 0:32]) - for i4_1 in T.vectorized(32): - with T.sblock("conv2d_NCHWc_int8_init"): - oc_block_i_init = T.axis.spatial(32, i4_1) - T.reads() - T.writes( - conv2d_nchwc_int8[ - n, oc_chunk, o_height, o_width, oc_block_i_init - ] - ) - conv2d_nchwc_int8[ - n, oc_chunk, o_height, o_width, oc_block_i_init - ] = 0 - for i1_1, i5_1, i6_1, i2_2, i3_2 in T.grid(2, 7, 7, 14, 8): - with T.sblock("conv2d_NCHWc_int8_o_update"): - n = T.axis.spatial(1, 0) - oc_chunk = T.axis.spatial(2, i1_1) - o_height = T.axis.spatial(112, (p_outer * 28 + index_0) // 14 * 14 + i2_2) - o_width = T.axis.spatial(112, (p_outer * 28 + index_0) % 14 * 8 + i3_2) - oc_block_o = T.axis.spatial(1, 0) # pylint: disable=unused-variable - k_height = T.axis.reduce(7, i5_1) - k_width = T.axis.reduce(7, i6_1) - ic_outer = T.axis.reduce(1, 0) - ic_f_inner = T.axis.reduce(1, 0) - ic_s_inner_o = T.axis.reduce(1, 0) # pylint: disable=unused-variable - T.reads( - conv2d_nchwc_int8[n, oc_chunk, o_height, o_width, 0:32], - p0_global_vtcm[ - n, - ic_outer, - o_height * 2 + k_height, - o_width * 2 + k_width, - ic_f_inner * 4 : ic_f_inner * 4 + 4, - ], - p1_global_vtcm[ - oc_chunk, ic_outer, k_height, k_width, ic_f_inner, 0:32, 0:4 - ], - ) - T.writes(conv2d_nchwc_int8[n, oc_chunk, o_height, o_width, 0:32]) - a_buffer = T.match_buffer( - p0_global_vtcm[ - n, - ic_outer, - o_height * 2 + k_height, - o_width * 2 + k_width, - ic_f_inner * 4 : ic_f_inner * 4 + 4, - ], - [4], - dtype="uint8", - offset_factor=1, - scope="global.vtcm", - ) - b_buffer = T.match_buffer( - p1_global_vtcm[ - oc_chunk, ic_outer, k_height, k_width, ic_f_inner, 0:32, 0:4 - ], - [32, 4], - dtype="int8", - offset_factor=1, - scope="global.vtcm", - ) - c_buffer = T.match_buffer( - conv2d_nchwc_int8[n, oc_chunk, o_height, o_width, 0:32], - [32], - dtype="int32", - offset_factor=1, - scope="global.vtcm", - ) - a_u8x4: T.uint8x4 = a_buffer[0:4] - a_i32: T.int32 = T.reinterpret(a_u8x4, dtype="int32") - b_i8x128 = b_buffer[0, 0:128] - b_i32x32: T.int32x32 = T.reinterpret(b_i8x128, dtype="int32x32") - c_buffer[0:32] = T.call_llvm_pure_intrin( - T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"), - c_buffer[0:32], - T.broadcast(a_i32, 32), - b_i32x32, - dtype="int32x32", - ) - for index_0 in T.serial(200704): - with T.sblock("conv2d_nchwc_int8.vtcm"): - ax0_1 = T.axis.spatial(1, 0) - ax1_1 = T.axis.spatial(2, index_0 % 7168 // 3584) - ax2_1 = T.axis.spatial( - 112, (p_outer * 28 + index_0 // 7168) // 14 * 14 + index_0 % 3584 // 256 - ) - ax3_1 = T.axis.spatial( - 112, (p_outer * 28 + index_0 // 7168) % 14 * 8 + index_0 % 256 // 32 - ) - ax4 = T.axis.spatial(32, index_0 % 32) - T.reads(conv2d_nchwc_int8[ax0_1, ax1_1, ax2_1, ax3_1, ax4]) - T.writes(t_cast[ax0_1, ax1_1, ax2_1, ax3_1, ax4]) - t_cast[ax0_1, ax1_1, ax2_1, ax3_1, ax4] = conv2d_nchwc_int8[ - ax0_1, ax1_1, ax2_1, ax3_1, ax4 - ] - - -# from tvm.script import tirx as T -@tvm.script.ir_module -class ModuleBase: - """Base module test class.""" - - # pylint: disable=no-self-argument - @T.prim_func(s_tir=True) - def main( - p0_buffer: T.Buffer((1, 1, 230, 230, 4), "uint8"), - p1_buffer: T.Buffer((2, 1, 7, 7, 1, 32, 4), "int8"), - t_cast: T.Buffer((1, 2, 112, 112, 32), "int32"), - ) -> None: - # pylint: disable=missing-function-docstring - # function attr dict - T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) - # buffer definition - # body - # with T.sblock("root") - conv2d_nchwc_int8 = T.sblock_alloc_buffer([1, 2, 112, 112, 32], dtype="int32") - for i0_0_i1_0_i2_0_i3_0_fused in T.parallel( - 112, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1} - ): - for i4_0_0 in T.serial(1): # pylint: disable=unused-variable - for i1_1_init, i2_1_init, i3_1_init, i1_2_init, i2_2_init, i3_2_init in T.grid( - 2, 1, 1, 1, 14, 8 - ): - with T.sblock("conv2d_NCHWc_int8_o_init"): - n = T.axis.spatial(1, 0) - oc_chunk = T.axis.spatial(2, i1_1_init + i1_2_init) - o_height = T.axis.spatial( - 112, i0_0_i1_0_i2_0_i3_0_fused // 14 * 14 + i2_1_init * 14 + i2_2_init - ) - o_width = T.axis.spatial( - 112, i0_0_i1_0_i2_0_i3_0_fused % 14 * 8 + i3_1_init * 8 + i3_2_init - ) - oc_block_o = T.axis.spatial(1, 0) # pylint: disable=unused-variable - T.reads() - T.writes(conv2d_nchwc_int8[n, oc_chunk, o_height, o_width, 0:32]) - for i4_1 in T.vectorized(32): - with T.sblock("conv2d_NCHWc_int8_init"): - oc_block_i_init = T.axis.spatial(32, i4_1) - T.reads() - T.writes( - conv2d_nchwc_int8[ - n, oc_chunk, o_height, o_width, oc_block_i_init - ] - ) - conv2d_nchwc_int8[ - n, oc_chunk, o_height, o_width, oc_block_i_init - ] = 0 - for i5_0, i6_0, i7_0, i8_0, i9_0_0 in T.grid( # pylint: disable=unused-variable - 1, 1, 1, 1, 1 - ): # pylint: disable=unused-variable - for ( - i0_1, # pylint: disable=unused-variable - i1_1, - i2_1, - i3_1, - i4_0_1, # pylint: disable=unused-variable - i5_1, - i6_1, - i7_1, # pylint: disable=unused-variable - i8_1, # pylint: disable=unused-variable - i9_0_1, # pylint: disable=unused-variable - i0_2, # pylint: disable=unused-variable - i1_2, - i2_2, - i3_2, - i4_0_2, # pylint: disable=unused-variable - ) in T.grid(1, 2, 1, 1, 1, 7, 7, 1, 1, 1, 1, 1, 14, 8, 1): - with T.sblock("conv2d_NCHWc_int8_o_update"): - n = T.axis.spatial(1, 0) - oc_chunk = T.axis.spatial(2, i1_1 + i1_2) - o_height = T.axis.spatial( - 112, i0_0_i1_0_i2_0_i3_0_fused // 14 * 14 + i2_1 * 14 + i2_2 - ) - o_width = T.axis.spatial( - 112, i0_0_i1_0_i2_0_i3_0_fused % 14 * 8 + i3_1 * 8 + i3_2 - ) - oc_block_o = T.axis.spatial(1, 0) # pylint: disable=unused-variable - k_height = T.axis.reduce(7, i5_0 * 7 + i5_1) - k_width = T.axis.reduce(7, i6_0 * 7 + i6_1) - ic_outer = T.axis.reduce(1, 0) - ic_f_inner = T.axis.reduce(1, 0) - ic_s_inner_o = T.axis.reduce(1, 0) # pylint: disable=unused-variable - T.reads( - conv2d_nchwc_int8[n, oc_chunk, o_height, o_width, 0:32], - p0_buffer[ - n, - ic_outer, - o_height * 2 + k_height, - o_width * 2 + k_width, - ic_f_inner * 4 : ic_f_inner * 4 + 4, - ], - p1_buffer[ - oc_chunk, ic_outer, k_height, k_width, ic_f_inner, 0:32, 0:4 - ], - ) - T.writes(conv2d_nchwc_int8[n, oc_chunk, o_height, o_width, 0:32]) - a_buffer = T.match_buffer( - p0_buffer[ - n, - ic_outer, - o_height * 2 + k_height, - o_width * 2 + k_width, - ic_f_inner * 4 : ic_f_inner * 4 + 4, - ], - [4], - dtype="uint8", - offset_factor=1, - ) - b_buffer = T.match_buffer( - p1_buffer[ - oc_chunk, ic_outer, k_height, k_width, ic_f_inner, 0:32, 0:4 - ], - [32, 4], - dtype="int8", - offset_factor=1, - ) - c_buffer = T.match_buffer( - conv2d_nchwc_int8[n, oc_chunk, o_height, o_width, 0:32], - [32], - dtype="int32", - offset_factor=1, - ) - a_u8x4: T.uint8x4 = a_buffer[0:4] - a_i32: T.int32 = T.reinterpret(a_u8x4, dtype="int32") - b_i8x128 = b_buffer[0, 0:128] - b_i32x32: T.int32x32 = T.reinterpret(b_i8x128, dtype="int32x32") - c_buffer[0:32] = T.call_llvm_pure_intrin( - T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.acc.128B"), - c_buffer[0:32], - T.broadcast(a_i32, 32), - b_i32x32, - dtype="int32x32", - ) - for ax0, ax1, ax2, ax3 in T.grid(1, 2, 14, 8): - for ax4_fused in T.vectorized(32): - with T.sblock("T_cast_2"): - ax0_1, ax1_1 = T.axis.remap("SS", [ax0, ax1]) - ax2_1 = T.axis.spatial( - 112, i0_0_i1_0_i2_0_i3_0_fused // 14 * 14 + ax2 - ) - ax3_1 = T.axis.spatial( - 112, i0_0_i1_0_i2_0_i3_0_fused % 14 * 8 + ax3 - ) - ax4 = T.axis.spatial(32, ax4_fused) - T.reads(conv2d_nchwc_int8[ax0_1, ax1_1, ax2_1, ax3_1, ax4]) - T.writes(t_cast[ax0_1, ax1_1, ax2_1, ax3_1, ax4]) - t_cast[ax0_1, ax1_1, ax2_1, ax3_1, ax4] = conv2d_nchwc_int8[ - ax0_1, ax1_1, ax2_1, ax3_1, ax4 - ] - - -@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") -def test_meta(hexagon_session): - """Test meta.""" - if tvm.testing.utils.IS_IN_CI: - pytest.skip("Skipping test since it takes too long in CI.") - - a_data = np.random.randint(1, 8, (1, 1, 230, 230, 4), dtype="uint8") - w_data = np.random.randint(1, 8, (2, 1, 7, 7, 1, 32, 4), dtype="int8") - c_data = np.zeros((1, 2, 112, 112, 32), dtype="int32") - - sch = tvm.s_tir.Schedule(ModuleBase) - base_runtime = evaluate(hexagon_session, sch, a_data, w_data, c_data) - - sch = tvm.s_tir.Schedule(ModulePipelined) - compute_block = sch.get_sblock("conv2d_NCHWc_int8_o_update") - outer = sch.get_loops(compute_block)[0] - - unscheduled_vtcm_runtime = evaluate( - hexagon_session, sch, a_data, w_data, c_data, use_async_copy=1 - ) - - sch = tvm.s_tir.Schedule(ModulePipelined) - compute_block = sch.get_sblock("conv2d_NCHWc_int8_o_update") - outer = sch.get_loops(compute_block)[0] - - sch.annotate(outer, "software_pipeline_stage", [0, 1, 2]) - sch.annotate(outer, "software_pipeline_order", [0, 1, 2]) - sch.annotate(outer, "software_pipeline_async_stages", [0, 2]) - - pipeline_runtime = evaluate(hexagon_session, sch, a_data, w_data, c_data, use_async_copy=1) - - transfer_mb = round((a_data.size + w_data.size + c_data.size) / 1e6, 2) - print_results( - ( - f"Test with a_buffer.size: {a_data.size}, w_buffer.size: {w_data.size}" - f", and total memory transfer of {transfer_mb} MB..." - ), - { - "without_vtcm": base_runtime, - "unscheduled_vtcm_runtime": unscheduled_vtcm_runtime, - "pipeline_runtime": pipeline_runtime, - }, - ) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py deleted file mode 100644 index 39b4eef55568..000000000000 --- a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py +++ /dev/null @@ -1,425 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""benchmark_elemwise_add""" - -import os -import os.path -import sys -import tempfile - -import numpy as np -import pytest - -import tvm.script -import tvm.testing -from tvm.contrib.hexagon.session import Session -from tvm.script import tirx as T -from tvm.testing import env - -from . import benchmark_util as bu -from .infrastructure import get_hexagon_target - -_SHOULD_SKIP_BENCHMARKS, _SKIP_BENCHMARKS_REASON = bu.skip_benchmarks_flag_and_reason() - -# This is a fixed detail of the v68 architecture. -HVX_VECTOR_BYTES = 128 - -# NOTE on server ports: -# These tests use different port numbers for the RPC server (7070 + ...). -# The reason is that an RPC session cannot be gracefully closed without -# triggering TIME_WAIT state on the server socket. This prevents another -# server to bind to the same port until the wait time elapses. - -_BT = bu.BenchmarksTable() - -_CSV_COLUMN_ORDER = [ - # Identifies which TE-compute / TIRScript is used as the basis for the - # benchmarked primfunc. Only needs to be meaningful to humans. - "basic_kernel", - # The tensors' element type - "dtype", - # When applicable, indicates the particular variation of schedules - # apply by the Python code. Decoding this may require looking at this - # script's source code. - "sched_type", - # The memory location of the tensors used during the execution of - # the primfunc. We currently assume just one location. - # This will likely need to be generalized as we add more sophisticated - # primfuncs. - "mem_scope", - # For primfuncs that treat tensor buffers as collections of 1D vectors, - # this is the number of vectors in each tensor. - # This will likely need to be generalized as we add more sophisticated - # primfuncs. - "num_vectors_per_tensor", - # Reserved columns defined by the BenchmarksTable class. - "row_status", - "timings_min_usecs", - "timings_max_usecs", - "timings_median_usecs", - "timings_mean_usecs", - "timings_stddev_usecs", - # For benchmarks that produce files on the host file system, this indicates - # their location. Useful for post-mortem investigation of benchmark results. - "host_files_dir_path", - # Miscellaneous comments about the benchmark. - "comments", -] - -_HOST_OUTPUT_DIR = tempfile.mkdtemp() - -_PRIMFUNC_NAME = "elemwise_add" - -print("-" * 80) -print(f"OUTPUT DIRECTORY: {_HOST_OUTPUT_DIR}") -print("-" * 80) -print() - - -def _get_irmod_elemwise_add(shape: list, dtype: str, mem_scope: str) -> tvm.ir.module.IRModule: - """ - Return an IRModule containing a single primfunc, expressed as NS-TIR. - - The primfunc implements elementwise-add. Its signature is (A,B,C), where - A and B are the input tensors, and C is the output tensor. - All three tensors have the specfied shape, dtype, and mem_scope. - - If the specified primfunc is known to be unsupported, raise an UnsupportedExcetion. - """ - assert len(shape) == 2 - - # TVMScript can reference simple Python variables, but it doesn't - # curently support more complex Python expressions... - ( - dim0_size, - dim1_size, - ) = shape - - if mem_scope == "global.vtcm": - raise bu.UnsupportedException("This benchmark kernel does not yet support VTCM buffers.") - - # This check is currently elided by the one above, but it should become relevant as soon - # as we add VTCM support to this kernel generator. - # - # Also: The VTCM budget is a very rough estimate, based only on experience. - # Assuming that it's even reasonable to use a hard-coded estimate AT ALL, this number - # may need tweaking. - - # The below code is commented is commented to avoid unreachable error - # with pylint. Please enable this once the kernel starts supporting - # VTCM buffers - - # Code starts below: - # ---- ------ ----- - # estimated_vtcm_budget_bytes = HVX_VECTOR_BYTES * 1024 - - # dtype_bits = tvm.runtime.DataType(dtype).bits - # assert dtype_bits % 8 == 0 - # dtype_bytes = dtype_bits // 8 - - # num_vtcm_tensors = 3 - # estimated_vtcm_needed_bytes = shape[0] * shape[1] * dtype_bytes * num_vtcm_tensors - - # if estimated_vtcm_needed_bytes > estimated_vtcm_budget_bytes: - # raise bu.UnsupportedException("Expect to exceed VTCM budget.") - - @tvm.script.ir_module - class BenchmarkModule: - """Elementwise STIR module for benchmarking""" - - # pylint: disable=no-self-argument,invalid-name,missing-function-docstring - @T.prim_func(s_tir=True) - def main(a: T.handle, b: T.handle, c: T.handle): - # We exchange data between function by handles, which are similar to pointer. - T.func_attr({"global_symbol": "main", "tirx.noalias": True}) - - A = T.match_buffer(a, shape, dtype=dtype) - B = T.match_buffer(b, shape, dtype=dtype) - C = T.match_buffer(c, shape, dtype=dtype) - - for i in range(dim0_size): - for j in range(dim1_size): - C[i, j] = A[i, j] + B[i, j] - - # pylint: enable=no-self-argument,invalid-name,missing-function-docstring - - return BenchmarkModule - - -def _benchmark_hexagon_elementwise_add_kernel( - hexagon_session: Session, shape: list, dtype: str, mem_scope: str -): - """ - Generate and benchmark a single elementwise-add kernel for Hexagon. - - Produce these outputs: - - Printed status updates / results to stdout and/or stderr. - - - Create a new subdirectory under _HOST_OUTPUT_DIR, and populate it with - various logs and intermediate files. - - - Add to _BT a row describing this benchmark run. - """ - # Represent the benchmark details in a form required by the benchmark table - # and for other logging... - keys_dict = { - "basic_kernel": "ewise-add", - "dtype": dtype, - "shape": shape, - "mem_scope": mem_scope, - } - - desc = bu.get_benchmark_decription(keys_dict) - - # Create the host-side directory for this benchmark run's files / logs... - host_files_dir_name = bu.get_benchmark_id(keys_dict) - host_files_dir_path = os.path.join(_HOST_OUTPUT_DIR, host_files_dir_name) - os.mkdir(host_files_dir_path) - - keys_dict["host_files_dir_path"] = host_files_dir_path - - log_file_path = os.path.join(host_files_dir_path, "out.txt") - with open(log_file_path, "w", encoding="UTF-8") as log_file: - print(f"CONFIGURATION: {desc}") - log_file.write(f"CONFIGURATION: {desc}\n") - - try: - ns_tir_module = _get_irmod_elemwise_add(shape, dtype, mem_scope) - - # Lower the primfunc's IRModule to Hexagon object code... - input1 = tvm.te.placeholder(shape, dtype=dtype) - input2 = tvm.te.placeholder(shape, dtype=dtype) - output = tvm.te.placeholder(shape, dtype=dtype) - - built_module: tvm.driver.build_module.OperatorModule = tvm.compile( - ns_tir_module, - [ - input1, - input2, - output, - ], - get_hexagon_target("v69"), - name=_PRIMFUNC_NAME, - ) - - # Create an actual Hexagon-native shared object file, initially stored on the - # host's file system... - host_dso_binary_path = os.path.join(host_files_dir_path, "test_binary.so") - built_module.write_to_file(host_dso_binary_path) - print(f"SAVED BINARY TO HOST PATH: {host_dso_binary_path}") - - # Upload the .so to the Android device's file system (or wherever is appropriate - # when using the Hexagon simulator)... - target_dso_binary_filename = "test_binary.so" - target_dso_binary_pathname = hexagon_session.upload( - host_dso_binary_path, target_dso_binary_filename - ) - - # Generate our testing / validation data... - ( - host_numpy_input1_data, - host_numpy_input2_data, - host_numpy_output_data_expected, - ) = _get_elemwise_add_reference_value_tensors(shape, dtype) - - # On the target device / simulator, make our Hexagon-native shared object - # available for use... - loaded_hexagon_module: tvm.runtime.module.Module = hexagon_session.load_module( - target_dso_binary_pathname - ) - - # Create the target-side tensors to hold the primfunc's inputs and outputs... - input1_data = tvm.runtime.empty(shape, dtype, hexagon_session.device, mem_scope) - input2_data = tvm.runtime.empty(shape, dtype, hexagon_session.device, mem_scope) - output_data = tvm.runtime.empty(shape, dtype, hexagon_session.device, mem_scope) - - # Populate the primfunc's input tensors... - input1_data.copyfrom(host_numpy_input1_data) - input2_data.copyfrom(host_numpy_input2_data) - - # Actually benchmark the primfunc... - timer = loaded_hexagon_module.time_evaluator( - "main", hexagon_session.device, number=10, repeat=1 - ) - timing_result = timer(input1_data, input2_data, output_data) - - print(f"TIMING RESULT: {timing_result}") - log_file.write(f"TIMING RESULT: {timing_result}\n") - - # Verify that the computation actually happened, and produced the correct result. - result = output_data.numpy() - - if dtype == "float16": - # These are the closest tolerance we currently expect / require for these - # kernels. They may be changed in the future. - rel_tolerance = 0.005 - abs_tolerance = 2.0 - elif dtype == "int8": - rel_tolerance = 0 - abs_tolerance = 0 - else: - raise Exception(f"Unexpected dtype: {dtype}") - - # TODO: We're assuming that *any* assertion thrown by 'assert_allclose' is because - # the numerical differences were too large. But ideally this code would - # differentiate between (a) numerical difference errors, which should simply be - # recorded as a failed benchmark run, vs. (b) more serious errors that should - # kill the overall script. - try: - tvm.testing.assert_allclose( - result, host_numpy_output_data_expected, rel_tolerance, abs_tolerance - ) - except AssertionError as err: - raise bu.NumericalAccuracyException(str(err)) - - _BT.record_success(timing_result, **keys_dict) - - except bu.NumericalAccuracyException as err: - print() - print("FAIL: Numerical accuracy error. See log file.") - - log_file.write("\n") - log_file.write(f"FAIL: {err}\n") - - _BT.record_fail(**keys_dict, comments="Numerical accuracy error. See log file.") - - except bu.UnsupportedException as err: - print() - print(f"SKIP: {err}") - - log_file.write("\n") - log_file.write(f"SKIP: {err}\n") - - _BT.record_skip(**keys_dict, comments=f"Unsupported configuration: {err}") - - -def _get_elemwise_add_reference_value_tensors(shape: list, dtype: str): - """ - Return [A:np.array, B:np.array, C:np.array] - - `A`, `B`, and `C` are reference data used to exercise and validate - an elementwise-add kernel: C = A+B. - - NOTE: These data are primarily meant for performance testing. - The values may be helpful in detecting correctness issues, but that's - a secondary consideration here. - """ - assert len(shape) == 2 - - input1 = np.ndarray(shape, dtype=dtype) - input2 = np.ndarray(shape, dtype=dtype) - - np_dtype = input1.dtype - - if np_dtype.kind in ["i", "u"]: - # We allow overflow for integer types because it tends to be well-behaved - # and well-understood... - min_value = np.iinfo(np_dtype).min - max_value = np.iinfo(np_dtype).max - - next_value = min_value - - for i in range(shape[0]): - for j in range(shape[1]): - input1[i, j] = next_value - input2[i, j] = next_value * 2 - next_value += 1 - - elif np_dtype.kind == "f": - # NOTE: For simplicity, we avoid test data that require - # well-defined behavior on floating-point overflow. - # But it may be reasonable to test that in the future. - min_value = np.finfo(np_dtype).min - max_value = np.finfo(np_dtype).max - - min_input_value = min_value / 2.0 + 1 - max_input_value = max_value / 2.0 - 2 - delta = (max_input_value - min_input_value) / (shape[0] * shape[1]) - - next_value = min_input_value - - for i in range(shape[0]): - for j in range(shape[1]): - input1[i, j] = next_value - input2[i, j] = next_value + 1 - next_value += delta - - else: - assert False, f"Unexpected data type: {np_dtype}" - - output = input1 + input2 - return [ - input1, - input2, - output, - ] - - -@pytest.mark.skipif(_SHOULD_SKIP_BENCHMARKS, reason=_SKIP_BENCHMARKS_REASON) -@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") -def test_elemwise_add(hexagon_session: Session): - """Main elementwise add test function""" - for dtype in [ - "int8", - "float16", - ]: - for mem_scope in [ - "global", - "global.vtcm", - ]: - # These numbers are fairly arbitrary, but they're meant to stress memory/caches to - # various extents. - for num_vectors_per_tensor in [ - 1, - 16, - 64, - 512, - 2048, - ]: - dtype_bits = tvm.runtime.DataType(dtype).bits - assert dtype_bits % 8 == 0 - dtype_bytes = dtype_bits // 8 - - elem_per_hvx_vector = HVX_VECTOR_BYTES // dtype_bytes - - shape = [ - num_vectors_per_tensor, - elem_per_hvx_vector, - ] - - print() - _benchmark_hexagon_elementwise_add_kernel(hexagon_session, shape, dtype, mem_scope) - - print("-" * 80) - print(f"OUTPUT DIRECTORY: {_HOST_OUTPUT_DIR}") - print("-" * 80) - print() - - tabular_output_filename = os.path.join(_HOST_OUTPUT_DIR, "benchmark-results.csv") - with open(tabular_output_filename, "w", encoding="UTF-8") as csv_file: - _BT.print_csv(csv_file, _CSV_COLUMN_ORDER) - - print(f"BENCHMARK RESULTS FILE: {tabular_output_filename}") - - _BT.print_csv(sys.stdout, _CSV_COLUMN_ORDER) - - if _BT.has_fail() > 0: - pytest.fail("At least one benchmark configuration failed", pytrace=False) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py b/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py deleted file mode 100644 index 91830e5d1a9d..000000000000 --- a/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py +++ /dev/null @@ -1,358 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: F401, RUF012 - -""" -This module serves two purposes: - (1) Demonstrates how to write Python code that exercises various - Hexagon-related algorithms / features. - - (2) Benchmark the resulting primfuncs. - -Current limitations: - - Input shapes are limited to NHWC --> NHWC_8h8w32c. - - - Testing parameters (input shapes, dtypes, etc.) currently - support only one value for each parameter. - - - height, width, channel must be integer multiples of 8, 8, and 32, - respectively. I.e., partial blocks aren't currently - supported by this script. - - - Requires that I/O tensors reside in "global.VTCM" memory, - rather than "global" memory. - This prevents benchmarking with I/O tensors that are too - large to fit into availble VTCM. - - - The script only develops one primfunc. - Future revisions to this script are expected to add more - primfuncs and demonstrate more coding strategies. -""" - -import copy -import os - -import numpy as np -import pytest - -pytest.importorskip("scipy") # tvm.topi.testing imports scipy - -import tvm.testing -from tvm import te, tirx, topi -from tvm.contrib.hexagon import allocate_hexagon_array -from tvm.contrib.hexagon.session import Session -from tvm.testing import env -from tvm.topi import testing - -from . import benchmark_util as bu -from .infrastructure import get_hexagon_target - -# Pytest seems to require that fixture names exist in the current module. -# E.g., it doesn't allow: @pytest.mark.usefixtures("bu.benchmark_group") -BENCHMARK_GROUP = bu.benchmark_group - -_SHOULD_SKIP_BENCHMARKS, _SKIP_BENCHMARKS_REASON = bu.skip_benchmarks_flag_and_reason() - - -def _ceil_div(numerator, denominator): - return (numerator + (denominator - 1)) // denominator - - -def _int8_nhwc_8h8w32c_map(n_batch, height, width, channel): - return [ - n_batch, - height // 8, - width // 8, - channel // 32, - te.AXIS_SEPARATOR, - height % 8, - width % 8, - channel % 32, - ] - - -def _int8_nhwc_8h8w32c_shape(n_batch, height, width, channel) -> list[int]: - return [ - n_batch, - _ceil_div(height, 8), - _ceil_div(width, 8), - _ceil_div(channel, 32), - 8, - 8, - 32, - ] - - -def _int8_nhwc_8h8w32c_xform_immediate(arr_in: np.ndarray) -> np.ndarray: - """ - Return a deep copy of 'arr_in', transformed from a NWHC to - NHWC-8h8wc32 shape. Any newly created array elements have value 0. - """ - stage1 = copy.copy(arr_in) - - ( - n_batch, - height, - width, - channel, - ) = stage1.shape - - ( - h_minor, - w_minor, - c_minor, - ) = [8, 8, 32] - - h_major = _ceil_div(height, h_minor) - w_major = _ceil_div(width, w_minor) - c_major = _ceil_div(channel, c_minor) - - # This handles cases where the dimensions of arr_in are not cleanly divided - # by the minor block size, i.e. [8, 8, 32]. - # - # Any additional array elements that this creates will ahve value 0. - # We shouldn't actually care what value is used for those elements, because they - # shouldn't be treated as meaningful by any of our algorithms. - if (height % h_minor) or (width % w_minor) or (channel % c_minor): - stage1.resize( - (n_batch, h_major * h_minor, w_major * w_minor, c_major * c_minor), refcheck=False - ) - - stage2 = stage1.reshape(n_batch, h_major, h_minor, w_major, w_minor, c_major, c_minor) - stage3 = stage2.transpose(0, 1, 3, 5, 2, 4, 6) - return stage3 - - -def _create_test_input(shape, dtype: str) -> np.ndarray: - np_dtype = np.dtype(dtype) - min_value = np.iinfo(np_dtype).min - max_value = np.iinfo(np_dtype).max - return np.random.randint(low=min_value, high=max_value, size=tuple(shape), dtype=np.int8) - - -@pytest.mark.usefixtures("BENCHMARK_GROUP") -class TestMaxPool2D: - """maxpool2D base test class""" - - csv_column_order = [ - # Identifies which TE-compute / TIRScript is used as the basis for the - # benchmarked primfunc. Only needs to be meaningful to humans. - "basic_kernel", - # When applicable, indicates the particular variation of schedules - # apply by the Python code. Decoding this may require looking at this - # script's source code. - "sched_type", - # Values directly based on test parameters... - "input_shape_4d", - "block_shape", - "dtype", - "kernel", - "stride", - "dilation", - "padding", - "io_tensor_mem_scope", - # Reserved columns defined by the BenchmarksTable class. - "row_status", - "timings_min_usecs", - "timings_max_usecs", - "timings_median_usecs", - "timings_mean_usecs", - "timings_stddev_usecs", - # For benchmarks that produce files on the host file system, this indicates - # their location. Useful for post-mortem investigation of benchmark results. - "host_files_dir_path", - # Miscellaneous comments about the benchmark. - "comments", - ] - - dtype = tvm.testing.parameter("int8") - - # FIXME(cconvey): The script currently fails when height, width, or channel is not an - # integer multiple of 8, 8, or 32, respectively. - n_batch = tvm.testing.parameter(1) - height = tvm.testing.parameter(*[x * 8 for x in [1, 4, 16]]) - width = tvm.testing.parameter(*[x * 8 for x in [1, 4, 16]]) - channel = tvm.testing.parameter(*[x * 32 for x in [1, 2]]) - - kernel = tvm.testing.parameter((1, 1), (3, 3)) - stride = tvm.testing.parameter((1, 1)) - dilation = tvm.testing.parameter((1, 1)) - padding = tvm.testing.parameter((0, 0, 0, 0)) - io_tensor_mem_scope = tvm.testing.parameter("global.vtcm") - - @pytest.mark.skipif(_SHOULD_SKIP_BENCHMARKS, reason=_SKIP_BENCHMARKS_REASON) - @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") - def test_maxpool2d_nhwc( - self, - n_batch, - height, - width, - channel, - dtype, - kernel, - stride, - dilation, - padding, - io_tensor_mem_scope, - hexagon_session: Session, - ): - """Test maxpool2d NHWC""" - - keys_dict = { - "basic_kernel": "max_pool2d", - "sched_type": 1, - "input_shape_4d": [n_batch, height, width, channel], - "block_shape": [8, 8, 32], - "dtype": dtype, - "kernel": kernel, - "stride": stride, - "dilation": dilation, - "padding": padding, - "io_tensor_mem_scope": io_tensor_mem_scope, - } - - desc = bu.get_benchmark_decription(keys_dict) - - # Create the host-side directory for this benchmark run's files / logs... - host_files_dir_name = bu.get_benchmark_id(keys_dict) - host_files_dir_path = os.path.join(self.working_dir, host_files_dir_name) - os.mkdir(host_files_dir_path) - - keys_dict["host_files_dir_path"] = host_files_dir_path - - log_file_path = os.path.join(host_files_dir_path, "out.txt") - with open(log_file_path, "w") as log_file: - print(f"CONFIGURATION: {desc}") - log_file.write(f"CONFIGURATION: {desc}\n") - - try: - input_tensor_shape_4d = [n_batch, height, width, channel] - input_tensor_shape_7d = _int8_nhwc_8h8w32c_shape(n_batch, height, width, channel) - - data = te.placeholder(tuple(input_tensor_shape_4d), dtype=dtype) - - output = topi.nn.pool2d( - data, kernel, stride, dilation, padding, "max", layout="NHWC" - ) - primfunc = te.create_prim_func([data, output]) - - sch = tvm.s_tir.Schedule(primfunc, debug_mask="all") - - sch.transform_layout( - block="tensor", buffer="placeholder", index_map=_int8_nhwc_8h8w32c_map - ) - - built_module = tvm.compile( - sch.mod, - target=get_hexagon_target("v69"), - ) - - # Save a local copy of the Hexagon object code (in the form of a .so file) - # to allow post-mortem inspection. - host_dso_binary_path = os.path.join(host_files_dir_path, "test_binary.so") - built_module.write_to_file(host_dso_binary_path) - print(f"SAVED BINARY TO HOST PATH: {host_dso_binary_path}") - - hexagon_mod = hexagon_session.load_module(built_module) - - # Generate the input tensor's data. - # Note that we'll eventually need it in two different layouts: - # (1) NHWC as an argument to testing.poolnd_python. - # (2) NHWC_8h8w32c for as an argument to our Hexagon primfunc. - # a_numpy_4d = np.random.randint(low=-128, high=127, - # size=input_tensor_shape_4d, dtype=np.int8) - a_numpy_4d = _create_test_input(input_tensor_shape_4d, dtype) - - ref_output_4d = testing.poolnd_python( - a_numpy_4d.astype("int32"), - kernel, - stride, - dilation, - padding[0:2], - padding[2:], - pool_type="max", - dtype="int32", - layout="NHWC", - ).astype(dtype) - - output_tensor_shape_4d = ref_output_4d.shape - - a_numpy_7d = _int8_nhwc_8h8w32c_xform_immediate(a_numpy_4d) - - a_hexagon_7d = allocate_hexagon_array( - hexagon_session.device, - tensor_shape=input_tensor_shape_7d, - axis_separators=[4], - dtype=dtype, - mem_scope=io_tensor_mem_scope, - ) - - c_hexagon_4d = allocate_hexagon_array( - hexagon_session.device, - tensor_shape=output_tensor_shape_4d, - axis_separators=[], - dtype=dtype, - mem_scope=io_tensor_mem_scope, - ) - - a_hexagon_7d.copyfrom(a_numpy_7d) - - if dtype == "int8": - rel_tolerance = 0 - abs_tolerance = 0 - else: - assert False, f"TODO: decide acceptable tolerances for dtype {dtype}" - - timer = hexagon_mod.time_evaluator( - "main", hexagon_session.device, number=10, repeat=1 - ) - timing_result = timer(a_hexagon_7d, c_hexagon_4d) - - try: - tvm.testing.assert_allclose( - ref_output_4d, c_hexagon_4d.numpy(), rtol=rel_tolerance, atol=abs_tolerance - ) - except AssertionError as exception: - raise bu.NumericalAccuracyException(str(exception)) - - except bu.NumericalAccuracyException as exception: - print() - print("FAIL: Numerical accuracy error. See log file.") - - log_file.write("\n") - log_file.write(f"FAIL: {exception}\n") - - self.benchmark_table.record_fail( - **keys_dict, comments="Numerical accuracy error. See log file." - ) - - except bu.UnsupportedException as exception: - print() - print(f"SKIP: {exception}") - - log_file.write("\n") - log_file.write(f"SKIP: {exception}\n") - - self.benchmark_table.record_skip( - **keys_dict, comments=f"Unsupported configuration: {exception}" - ) - - self.benchmark_table.record_success(timing_result, **keys_dict) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/test_dma_builtin.py b/tests/python/contrib/test_hexagon/test_dma_builtin.py deleted file mode 100644 index 961d3bb5d602..000000000000 --- a/tests/python/contrib/test_hexagon/test_dma_builtin.py +++ /dev/null @@ -1,190 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Test relax vm builtin to enable DMA copy and wait operations. -""" - -import numpy as np -import pytest - -import tvm -import tvm.contrib.hexagon -import tvm.script -import tvm.testing -from tvm import relax -from tvm.script.parser import ir as I -from tvm.script.parser import relax as R -from tvm.script.parser import tirx as T -from tvm.testing import env - -# pylint: disable=invalid-name, missing-class-docstring, missing-function-docstring, no-self-argument - -data_type = "int32" - - -@I.ir_module(s_tir=True) -class Module_1D: - @T.prim_func(s_tir=True) - def compute_add_in_vtcm(a: T.handle, b: T.handle, c: T.handle) -> None: - m = T.int32() - A = T.match_buffer(a, (m,), data_type, scope="global.vtcm") - B = T.match_buffer(b, (m,), data_type, scope="global.vtcm") - C = T.match_buffer(c, (m,), data_type, scope="global.vtcm") - for ax0 in T.grid(m): - with T.sblock("T_add"): - v_ax0 = T.axis.remap("S", [ax0]) - T.reads(A[v_ax0], B[v_ax0]) - T.writes(C[v_ax0]) - C[v_ax0] = A[v_ax0] + B[v_ax0] - - @R.function(pure=False) - def main( - x: R.Tensor((12800,), data_type), - y: R.Tensor((12800,), data_type), - ) -> R.Tensor((12800,), data_type): - cls = Module_1D - vtcm_obj: R.Object = R.vm.alloc_storage( - R.shape( - [ - 3 * 12800, # 3 = 2 inputs + 1 output - ] - ), - runtime_device_index=0, - dtype=data_type, - storage_scope="global.vtcm", - ) - a: R.Tensor( - [ - 12800, - ], - dtype=data_type, - ) = R.vm.alloc_tensor( - vtcm_obj, - offset=0, - shape=R.shape( - [ - 12800, - ] - ), - dtype=data_type, - ) - __: R.Tuple = R.call_builtin_with_ctx( - "vm.builtin.hexagon.dma_copy", - [x, a, 0, True], - sinfo_args=[], - ) - b: R.Tensor( - [ - 12800, - ], - dtype=data_type, - ) = R.vm.alloc_tensor( - vtcm_obj, - offset=12800 * 4, - shape=R.shape( - [ - 12800, - ] - ), - dtype=data_type, - ) - __: R.Tuple = R.call_builtin_with_ctx( - "vm.builtin.hexagon.dma_copy", - [y, b, 1, True], - sinfo_args=[], - ) - c: R.Tensor( - [ - 12800, - ], - dtype=data_type, - ) = R.vm.alloc_tensor( - vtcm_obj, - offset=2 * 12800 * 4, - shape=R.shape( - [ - 12800, - ] - ), - dtype=data_type, - ) - __: R.Tuple = R.call_builtin_with_ctx( - "vm.builtin.hexagon.dma_wait", - [0, 2, True, x, a], - sinfo_args=[], - ) - __: R.Tuple = R.call_builtin_with_ctx( - "vm.builtin.hexagon.dma_wait", - [1, 1, True, y, b], - sinfo_args=[], - ) - ___: R.Tuple = cls.compute_add_in_vtcm(a, b, c) - ret_val: R.Tensor((12800,), dtype=data_type) = R.builtin.alloc_tensor( - R.shape( - [ - 12800, - ] - ), - R.dtype(data_type), - R.prim_value(0), - ) - __: R.Tuple = R.call_builtin_with_ctx( - "vm.builtin.hexagon.dma_copy", - [c, ret_val, 0, True], - sinfo_args=[], - ) - __: R.Tuple = R.call_builtin_with_ctx( - "vm.builtin.hexagon.dma_wait", - [0, 1, True, c, ret_val], - sinfo_args=[], - ) - _t3: R.Tuple = R.vm.kill_object(vtcm_obj) - _t6: R.Tuple = R.vm.kill_object(a) - _t7: R.Tuple = R.vm.kill_object(b) - _t8: R.Tuple = R.vm.kill_object(c) - lv: R.Tensor((12800,), dtype=data_type) = ret_val - return lv - - -class TestDMACopyWait: - """Tests for Copy and wait""" - - mode = tvm.testing.parameter("bytecode", "compiled") - module = tvm.testing.parameter(Module_1D) - - @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") - def test_vtcm_alloc_compute(self, hexagon_launcher, mode, module): - target_hexagon = tvm.target.Target("qcom/hexagon-v69") - target = tvm.target.Target(target_hexagon, host=target_hexagon) - with tvm.transform.PassContext(opt_level=3, config=[]): - ex = tvm.compile(mod=module, target=target, exec_mode=mode) - with hexagon_launcher.create_session() as session: - dev = session.device - input_arg0_data = np.random.randint(0, 9, size=(12800,), dtype=data_type) - input_arg1_data = np.random.randint(0, 9, size=(12800,), dtype=data_type) - output_data = np.add(input_arg0_data, input_arg1_data) - vm_mod = session.get_executor_from_factory(ex) - vm_rt = relax.VirtualMachine( - vm_mod, dev, "naive" - ) # Use naive allocator to exercise VTCM allocation in relax - data0 = tvm.runtime.tensor(input_arg0_data, dev) - data1 = tvm.runtime.tensor(input_arg1_data, dev) - vm_rt.set_input("main", data0, data1) - vm_rt.invoke_stateful("main") - hexagon_output = vm_rt.get_outputs("main").numpy() - tvm.testing.assert_allclose(output_data, hexagon_output) diff --git a/tests/python/contrib/test_hexagon/test_memory_alloc.py b/tests/python/contrib/test_hexagon/test_memory_alloc.py deleted file mode 100644 index 3030f9a6cbc4..000000000000 --- a/tests/python/contrib/test_hexagon/test_memory_alloc.py +++ /dev/null @@ -1,84 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test memory allocation.""" - -import numpy as np - -import tvm -from tvm.contrib.hexagon import allocate_hexagon_array -from tvm.script import tirx as T - -from .infrastructure import get_hexagon_target - - -def generated_func(shape: tuple, dtype: str, axis_separators: list): - """Generate element wise function.""" - dim0, dim1 = shape - - @T.prim_func(s_tir=True) - def elwise(a: T.handle, b: T.handle): - a_buffer = T.match_buffer(a, shape, dtype=dtype, axis_separators=axis_separators) - b_buffer = T.match_buffer(b, shape, dtype=dtype, axis_separators=axis_separators) - - for i, j in T.grid(dim0, dim1): - with T.sblock("compute"): - b_buffer[i, j] = a_buffer[i, j] * T.cast(2, dtype=dtype) - - return elwise - - -class TestMemoryAlloc: - """Memory allocation test.""" - - dtype = tvm.testing.parameter("int8") - shape = tvm.testing.parameter((128, 128)) - - ( - scope, - axis_separators, - ) = tvm.testing.parameters( - ("global", []), - ("global.vtcm", []), - ("global.vtcm", [1]), - ("global.ddr", []), - ("global.ddr", [1]), - ) - - def test_global_axis_separator(self, hexagon_session, shape, dtype, scope, axis_separators): - """Test with global axis separator.""" - mod1 = tvm.compile( - generated_func(shape, dtype, axis_separators), - target=get_hexagon_target("v69"), - ) - mod2 = hexagon_session.load_module(mod1) - - a_np = np.ones(shape=shape, dtype=dtype) - a = allocate_hexagon_array( - hexagon_session.device, data=a_np, mem_scope=scope, axis_separators=axis_separators - ) - - b_np = np.zeros(shape=shape, dtype=dtype) - b = allocate_hexagon_array( - hexagon_session.device, data=b_np, mem_scope=scope, axis_separators=axis_separators - ) - - mod2(a, b) - tvm.testing.assert_allclose(a.numpy() * 2, b.numpy(), atol=1e-4, rtol=1e-4) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/test_meta_schedule.py b/tests/python/contrib/test_hexagon/test_meta_schedule.py deleted file mode 100644 index 4c9f7d2ee358..000000000000 --- a/tests/python/contrib/test_hexagon/test_meta_schedule.py +++ /dev/null @@ -1,370 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Test rpc based launcher for hexagon""" - -import tempfile - -import numpy as np -import pytest - -pytest.importorskip("scipy") # tvm.topi.testing imports scipy - -import tvm.testing -import tvm.topi.testing -from tvm import te -from tvm.contrib.hexagon.meta_schedule import ( - get_hexagon_local_builder, - get_hexagon_rpc_runner, -) -from tvm.s_tir import meta_schedule as ms -from tvm.s_tir.meta_schedule import postproc, schedule_rule -from tvm.s_tir.meta_schedule.arg_info import TensorInfo -from tvm.s_tir.meta_schedule.builder import BuilderInput -from tvm.s_tir.meta_schedule.runner import RunnerInput -from tvm.s_tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN -from tvm.script import tirx as T -from tvm.testing import env -from tvm.tirx import FloatImm - -from .infrastructure import get_hexagon_target - -MATMUL_N = 16 -MATMUL_M = 32 - - -@tvm.script.ir_module -class MatmulModule: - """Matmultest class""" - - # pylint: disable=no-self-argument - @T.prim_func(s_tir=True) - def main(a: T.handle, b: T.handle, c: T.handle) -> None: # type: ignore - # pylint: disable=missing-function-docstring - T.func_attr({"global_symbol": "main", "tirx.noalias": True}) - a_buffer = T.match_buffer(a, (16, 16), "float32") - b_buffer = T.match_buffer(b, (16, 16), "float32") - c_buffer = T.match_buffer(c, (16, 16), "float32") - for i, j, k in T.grid(16, 16, 16): - with T.sblock("matmul"): - vi_axis, vj_axis, vk_axis = T.axis.remap("SSR", [i, j, k]) - with T.init(): - c_buffer[vi_axis, vj_axis] = 0.0 # type: ignore - c_buffer[vi_axis, vj_axis] = ( - c_buffer[vi_axis, vj_axis] - + a_buffer[vi_axis, vk_axis] * b_buffer[vk_axis, vj_axis] - ) - - -@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") -def test_builder_runner(hexagon_launcher): - """Test builder and runner.""" - if hexagon_launcher.is_simulator(): - pytest.skip("Tuning on simulator not supported.") - - mod = MatmulModule - - max_workers = 4 - builder = get_hexagon_local_builder(max_workers=max_workers) - runner = get_hexagon_rpc_runner( - hexagon_launcher, number=1, repeat=1, min_repeat_ms=0, max_workers=max_workers - ) - - (builder_result,) = builder.build([BuilderInput(mod, get_hexagon_target("v68"))]) - assert builder_result.artifact_path is not None - assert builder_result.error_msg is None - - runner_input = RunnerInput( - builder_result.artifact_path, - "llvm", - [ - TensorInfo("float32", (MATMUL_N, MATMUL_N)), - TensorInfo("float32", (MATMUL_N, MATMUL_N)), - TensorInfo("float32", (MATMUL_N, MATMUL_N)), - ], - ) - - # Run the module - (runner_future,) = runner.run([runner_input]) - runner_result = runner_future.result() - - assert runner_result.error_msg is None - for result in runner_result.run_secs: - if isinstance(result, FloatImm): - result = result.value - assert isinstance(result, float) - assert result >= 0.0 - - -def dense_compute(m, n, k): - """dense compute""" - X = te.placeholder((m, k), name="X", dtype="uint8") - packed_width = te.placeholder((n // 32, k // 4, 32, 4), name="packed_width", dtype="uint8") - - axis_k = te.reduce_axis((0, k), name="k") - out = te.compute( - (m, n), - lambda i, j: te.sum( - X[i, axis_k].astype("int32") - * packed_width[ - tvm.tirx.indexdiv(j, 32), tvm.tirx.indexdiv(axis_k, 4), j % 32, axis_k % 4 - ].astype("int32"), - axis=axis_k, - ), - name="compute", - ) - return [X, packed_width, out] - - -def schedule_dense(sch, block, m_size, do_tune): - """dense schedule""" - a_y, a_x, _ = sch.get_loops(block)[-3:] - - if do_tune: - y_factors = sch.sample_perfect_tile(a_y, n=2, max_innermost_factor=128) - a_yo, a_yi = sch.split(a_y, factors=y_factors) - else: - a_yo, a_yi = sch.split(a_y, factors=[None, min(m_size, 32)]) - - a_xo, a_xi = sch.split(a_x, factors=[None, 32]) - sch.reorder(a_yo, a_xo, a_yi, a_xi) - - a_xi, a_k = sch.get_loops(block)[-2:] - a_ko, a_ki = sch.split(a_k, factors=[None, 4]) - sch.reorder(a_ko, a_xi, a_ki) - - fused = sch.fuse(a_yo, a_xo) - - sch.parallel(fused) - - dec = sch.decompose_reduction(block, a_ko) - - init_loop = sch.get_loops(dec)[-1] - sch.vectorize(init_loop) - - sch.tensorize(a_xi, VRMPY_u8u8i32_INTRIN) - - -def verify_dense(sch, target, m_size, n_size, k_size, hexagon_session): - """Verify dense operator.""" - f = tvm.compile(sch.mod["main"], target=target) - mod = hexagon_session.load_module(f) - dev = hexagon_session.device - - a_np = np.random.uniform(1, 10, size=(m_size, k_size)).astype("uint8") - b_np = np.random.uniform(1, 10, size=(n_size, k_size)).astype("uint8") - c_np = np.dot(a_np.astype("int32"), b_np.transpose().astype("int32")) - - pack_width = np.random.uniform(1, 10, size=(n_size // 32, (k_size // 4), 32, 4)).astype("uint8") - - for r_idx in range(n_size // 32): - for k_output in range(k_size // 4): - for s_idx in range(32): - for t_idx in range(4): - pack_width[r_idx][k_output][s_idx][t_idx] = b_np[r_idx * 32 + s_idx][ - k_output * 4 + t_idx - ] - - a = tvm.runtime.tensor(a_np, dev) - b = tvm.runtime.tensor(pack_width, dev) - c = tvm.runtime.tensor(np.zeros((m_size, n_size), dtype="int32"), dev) - - mod(a, b, c) - np.testing.assert_equal(c.numpy(), c_np) - - evaluator = mod.time_evaluator(mod.entry_name, dev, number=10) - gflops = (n_size * m_size * k_size) * 2 / 1e9 - time_ms = evaluator(a, b, c).mean * 1e3 - print(f"{time_ms:f} ms, {gflops / (time_ms / 1e3):f} GOPS") - - -@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") -def test_vrmpy_dense(hexagon_launcher): - """Test vector reduce muliply dense.""" - if hexagon_launcher.is_simulator(): - pytest.skip("Tuning on simulator not supported.") - - do_tune = True - - m_size, n_size, k_size = 128, 768, 768 - workload = te.create_prim_func(dense_compute(m_size, n_size, k_size)) - - if not do_tune: - ir_module = tvm.IRModule({"main": workload}) - sch = tvm.s_tir.Schedule(ir_module) - block = sch.get_sblock("compute") - schedule_dense(sch, block, m_size, do_tune) - else: - with tempfile.TemporaryDirectory() as work_dir: - - def schedule_dense_for_tune(sch): - block = sch.get_sblock("compute") - return schedule_dense(sch, block, None, True) - - target = get_hexagon_target("v69") - database = ms.tir_integration.tune_tir( - mod=workload, - target=target, - work_dir=work_dir, - max_trials_global=8, - space=ms.space_generator.ScheduleFn( - schedule_dense_for_tune, - sch_rules=[], - postprocs=[], - mutator_probs={}, - ), - strategy="replay-trace", - builder=get_hexagon_local_builder(), - runner=get_hexagon_rpc_runner(hexagon_launcher, number=10), - ) - sch = ms.tir_integration.compile_tir(database, workload, target) - - with hexagon_launcher.create_session() as session: - verify_dense(sch, get_hexagon_target("v68"), m_size, n_size, k_size, session) - - -# This is an example of a schedule found by vrmpy auto tensorization. -# It gets 440 GFLOPS on SD888. -@tvm.script.ir_module -class ModuleVRMPYAutoTensorize: - """Vector Reduce Multimply auto tensorize test class.""" - - # pylint: disable=no-self-argument - @T.prim_func(s_tir=True) - def main( # type: ignore - X: T.Buffer((128, 768), "uint8"), # type: ignore - packed_width: T.Buffer((24, 192, 32, 4), "uint8"), # type: ignore - compute: T.Buffer((128, 768), "int32"), # type: ignore - ) -> None: - # pylint: disable=missing-function-docstring - T.func_attr({"global_symbol": "main", "tirx.noalias": True}) - for i0_0_i1_0_0_fused in T.parallel( - 512, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1} - ): - for i0_1_init, i1_0_1_init, i0_2_init, i1_0_2_init in T.grid(2, 3, 1, 1): - with T.sblock("compute_o_init"): - i = T.axis.spatial(128, i0_0_i1_0_0_fused // 8 * 2 + i0_1_init + i0_2_init) - j_o = T.axis.spatial(24, i1_0_2_init + i0_0_i1_0_0_fused % 8 * 3 + i1_0_1_init) - T.reads() - T.writes(compute[i, j_o * 32 : j_o * 32 + 32]) # type: ignore - for i1_1 in T.vectorized(32): - with T.sblock("compute_init"): - j_i_init = T.axis.spatial(32, i1_1) - T.reads() - T.writes(compute[i, j_o * 32 + j_i_init]) - compute[i, j_o * 32 + j_i_init] = 0 # type: ignore - for i2_0_0, i0_1, i1_0_1, i2_0_1, i0_2, i1_0_2 in T.grid(32, 2, 3, 6, 1, 1): - with T.sblock("compute_o_update"): - i = T.axis.spatial(128, i0_0_i1_0_0_fused // 8 * 2 + i0_1 + i0_2) - j_o = T.axis.spatial(24, i1_0_2 + i0_0_i1_0_0_fused % 8 * 3 + i1_0_1) - k_o = T.axis.reduce(192, i2_0_0 * 6 + i2_0_1) - T.reads( - compute[i, j_o * 32 : j_o * 32 + 32], # type: ignore - X[i, k_o * 4 : k_o * 4 + 4], # type: ignore - packed_width[j_o, k_o, 0:32, 0:4], # type: ignore - ) - T.writes(compute[i, j_o * 32 : j_o * 32 + 32]) # type: ignore - a_buffer = T.match_buffer( - X[i, k_o * 4 : k_o * 4 + 4], - [4], - dtype="uint8", - offset_factor=1, # type: ignore - ) - b_buffer = T.match_buffer( - packed_width[j_o, k_o, 0:32, 0:4], [32, 4], dtype="uint8", offset_factor=1 - ) - c_buffer = T.match_buffer( - compute[i, j_o * 32 : j_o * 32 + 32], - [32], - dtype="int32", - offset_factor=1, # type: ignore - ) - a_u8x4: T.uint8x4 = a_buffer[0:4] # type: ignore - a_i32: T.int32 = T.reinterpret(a_u8x4, dtype="int32") # type: ignore - b_i32x32: T.int32x32 = T.reinterpret(b_buffer[0, 0:128], dtype="int32x32") # type: ignore - c_buffer[0:32] = T.call_llvm_pure_intrin( # type: ignore - 4390, c_buffer[0:32], b_i32x32, a_i32, dtype="int32x32" - ) - - -@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") -def test_vrmpy_dense_auto_tensorize(hexagon_launcher): - """Test VRMPY dense operator.""" - if hexagon_launcher.is_simulator(): - pytest.skip("Tuning on simulator not supported.") - - m_size, n_size, k_size = 128, 768, 768 - workload = te.create_prim_func(dense_compute(m_size, n_size, k_size)) - - sch_rules = [ - schedule_rule.MultiLevelTilingWithIntrin( - VRMPY_u8u8i32_INTRIN, - structure="SRSRS", - tile_binds=None, - max_innermost_factor=64, - vector_load_lens=None, - reuse_read=None, - reuse_write=schedule_rule.ReuseType( - req="may", - levels=[1, 2], - scope="global", - ), - ), - schedule_rule.ParallelizeVectorizeUnroll( - max_jobs_per_core=16, - max_vectorize_extent=128, - unroll_max_steps=[0, 16, 64, 512], - unroll_explicit=True, - ), - ] - - postprocs = [ - postproc.RewriteParallelVectorizeUnroll(), - postproc.RewriteReductionBlock(), - postproc.RewriteTensorize(vectorize_init_loop=True), - ] - - # Make this to False to compile and run the best tuned schedule - run_tuning = True - if run_tuning: - with tempfile.TemporaryDirectory() as work_dir: - target = get_hexagon_target("v68") - database = ms.tir_integration.tune_tir( - mod=workload, - target=target, - max_trials_global=8, - num_trials_per_iter=8, - work_dir=work_dir, - space=ms.space_generator.PostOrderApply( - f_block_filter=None, - sch_rules=sch_rules, - postprocs=postprocs, - mutator_probs={}, - ), - builder=get_hexagon_local_builder(), - runner=get_hexagon_rpc_runner(hexagon_launcher, number=10), - ) - sch = ms.tir_integration.compile_tir(database, workload, target) - else: - sch = tvm.s_tir.Schedule(ModuleVRMPYAutoTensorize, debug_mask="all") - - with hexagon_launcher.create_session() as session: - verify_dense(sch, get_hexagon_target("v68"), m_size, n_size, k_size, session) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx.py b/tests/python/contrib/test_hexagon/test_parallel_hvx.py deleted file mode 100644 index 4f8747c4e034..000000000000 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx.py +++ /dev/null @@ -1,240 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Test parallelizing HVX workloads and compare them to single thread examples. -""" - -import numpy as np -import pytest - -import tvm -from tvm.script import tirx as T -from tvm.testing import env - -from .infrastructure import get_hexagon_target - -TEST_OUTPUT_TEMPLATE = ( - "Test {} with {} operations... \n" - " -Single Thread: {} ms \n" - " -Parallel: {} ms\n -Speedup: {}x\n" -) - - -def get_vrmpy_shape_dtypes(operations): - return ((operations, 128), "uint8", (operations, 128), "uint8", (operations, 32), "int32") - - -def get_vmpy_vadd_shape_dtype(operations): - return ((operations, 128), "uint8", (operations, 128), "uint8", (operations, 128), "int16") - - -def vmpy_expected_producer(shape, a, b): - expected = np.zeros(shape, dtype="int16") - for n in range(shape[0]): - for i in range(0, 128, 2): - expected[n, i // 2] = np.int16(a[n, i]) * np.int16(b[n, i]) - for i in range(1, 128, 2): - expected[n, i // 2 + 64] = np.int16(a[n, i]) * np.int16(b[n, i]) - return expected - - -def vadd_expected_producer(shape, a, b): - expected = np.zeros(shape, dtype="int16") - for n in range(shape[0]): - for i in range(0, 128, 2): - expected[n, i // 2] = np.int16(a[n, i]) + np.int16(b[n, i]) - for i in range(1, 128, 2): - expected[n, i // 2 + 64] = np.int16(a[n, i]) + np.int16(b[n, i]) - return expected - - -def vrmpy_expected_producer(shape, a, b): - expected = np.zeros(shape, dtype="int32") - for n in range(shape[0]): - for i in range(32): - for r_ind in range(4): - expected[n, i] = expected[n, i] + np.uint32(a[n, i * 4 + r_ind]) * np.uint32( - b[n, i * 4 + r_ind] - ) - return expected - - -def get_vmpy_operator(operations): - """Generate vector multiply operator""" - - @T.prim_func(s_tir=True) - def operator(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tirx.noalias": True}) - a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8") - b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8") - c_buffer = T.match_buffer(c, [operations, 128], dtype="int16") - for n in T.grid(operations): - with T.sblock("c_buffer"): - vn_ind = T.axis.remap("S", [n]) - c_buffer[vn_ind, T.ramp(0, 1, 128)] = T.call_llvm_intrin( - T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vmpybusv.128B"), - T.reinterpret(a_buffer[vn_ind, T.ramp(0, 1, 128)], dtype="int32x32"), - T.reinterpret(b_buffer[vn_ind, T.ramp(0, 1, 128)], dtype="int32x32"), - dtype="int16x128", - ) - - return operator - - -def get_vadd_operator(operations): - """Generate vadd operator.""" - - @T.prim_func(s_tir=True) - def operator(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tirx.noalias": True}) - a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8") - b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8") - c_buffer = T.match_buffer(c, [operations, 128], dtype="int16") - for n in T.grid(operations): - with T.sblock("c_buffer"): - vn_ind = T.axis.remap("S", [n]) - c_buffer[vn_ind, T.ramp(0, 1, 128)] = T.call_llvm_intrin( - T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vaddubh.128B"), - T.reinterpret(a_buffer[vn_ind, T.ramp(0, 1, 128)], dtype="int32x32"), - T.reinterpret(b_buffer[vn_ind, T.ramp(0, 1, 128)], dtype="int32x32"), - dtype="int16x128", - ) - - return operator - - -def get_vrmpy_operator(operations): - """Generate vrmpy operator.""" - - @T.prim_func(s_tir=True) - def operator(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tirx.noalias": True}) - a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8") - b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8") - c_buffer = T.match_buffer(c, [operations, 32], dtype="int32") - for n in T.grid(operations): - with T.sblock("c_buffer"): - vn_ind = T.axis.remap("S", [n]) - c_buffer[vn_ind, T.ramp(0, 1, 32)] = T.call_llvm_intrin( - T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"), - T.reinterpret(a_buffer[vn_ind, T.ramp(0, 1, 128)], dtype="int32x32"), - T.reinterpret(b_buffer[vn_ind, T.ramp(0, 1, 128)], dtype="int32x32"), - dtype="int32x32", - ) - - return operator - - -def evaluate(hexagon_session, shape_dtypes, expected_output_producer, sch): - """Evaluate schedule.""" - a_shape, a_dtype, b_shape, b_dtype, c_shape, c_dtype = shape_dtypes - - func_tir = tvm.compile(sch.mod["main"], target=get_hexagon_target("v68")) - module = hexagon_session.load_module(func_tir) - - a = np.random.randint(0, 16, a_shape, dtype=a_dtype) - b = np.random.randint(0, 16, b_shape, dtype=b_dtype) - c = np.zeros(c_shape, dtype=c_dtype) - - a_hexagon = tvm.runtime.tensor(a, device=hexagon_session.device) - b_hexagon = tvm.runtime.tensor(b, device=hexagon_session.device) - c_hexagon = tvm.runtime.tensor(c, device=hexagon_session.device) - - # These are reduced for CI but number=100 and repeat=10 does a good job of removing noise. - number = 1 - repeat = 1 - - timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) - runtime = timer(a_hexagon, b_hexagon, c_hexagon) - tvm.testing.assert_allclose(c_hexagon.numpy(), expected_output_producer(c_shape, a, b)) - - return round(runtime.mean * 1000, 6) - - -class TestMatMulVec: - """MatMul test class.""" - - ( - operation_name, - operator_producer, - shape_dtypes_producer, - expected_output_producer, - ) = tvm.testing.parameters( - ("vrmpy", get_vrmpy_operator, get_vrmpy_shape_dtypes, vrmpy_expected_producer), - ("vmpy", get_vmpy_operator, get_vmpy_vadd_shape_dtype, vmpy_expected_producer), - ("vadd", get_vadd_operator, get_vmpy_vadd_shape_dtype, vadd_expected_producer), - ) - - # Experimentally best split factor but all multiples of 4 perform pretty well. - # This is because there are 4 HVX untis available on the device and pipelining - # works best with parallels of the number of available HVX. - split_factor = tvm.testing.parameter(4) - - # Removed most of these to speedup CI. - operation_count = tvm.testing.parameter( - 128, - # 256, - # 512, - # Single thread runs faster since L2 cache can handle the entire request quickly - # 1024, - # 2048, - # Significant performance degredation once the inputs and outputs cannot all fit in L2 - # 4096, - # 8192, - # 16384, - ) - - @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") - def test( - self, - hexagon_session, - operation_count, - operation_name, - operator_producer, - shape_dtypes_producer, - expected_output_producer, - split_factor, - ): - """Test function handler.""" - - sch = tvm.s_tir.Schedule(operator_producer(operation_count)) - single_thread_runtime = evaluate( - hexagon_session, shape_dtypes_producer(operation_count), expected_output_producer, sch - ) - - sch = tvm.s_tir.Schedule(operator_producer(operation_count)) - block = sch.get_sblock("c_buffer") - b = sch.get_loops(block) - b_output, _ = sch.split(b[0], factors=[split_factor, None]) - sch.parallel(b_output) - - parallel_runtime = evaluate( - hexagon_session, shape_dtypes_producer(operation_count), expected_output_producer, sch - ) - - speedup = round(single_thread_runtime / parallel_runtime, 2) - - print( - TEST_OUTPUT_TEMPLATE.format( - operation_name, operation_count, single_thread_runtime, parallel_runtime, speedup - ) - ) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py deleted file mode 100644 index 39755e28111f..000000000000 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py +++ /dev/null @@ -1,562 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Test different strategies for loading data into vtcm before running HVX workloads.""" - -import numpy as np -import pytest - -import tvm -from tvm.script import tirx as T -from tvm.testing import env - -from .infrastructure import get_hexagon_target - -TEST_OUTPUT_TEMPLATE = ( - "Test with {} MB of data to load... \n" - " -No VTCM: {} Gops \n -Basic VTCM: {} Gops \n" - " -Vectorized: {} Gops\n -Vectorized and" - " Parallelized: {} Gops\n -Preallocated and Vectorized: {} Gops\n" - " -Preallocated, Vectorized, and Parallelized: {} Gops\n" - " -Single DMA: {} Gops\n -Preloaded: {} Gops\n" -) - - -def apply_parallel_unroll_vectorize(sch, blocks, outer_split, unroll_split, vector_split): - """Apply parallel unroll vectorized.""" - for block in blocks: - vb_index, vi_index = sch.get_loops(block) - v = sch.fuse(vb_index, vi_index) - vbo, vbi, vio, vii = sch.split( # pylint: disable=unused-variable - v, factors=[outer_split, None, unroll_split, vector_split] - ) # pylint: disable=unused-variable - sch.vectorize(vii) - sch.unroll(vio) - sch.parallel(vbo) - return sch - - -def apply_unroll_vectorize(sch, blocks, unroll_split, vector_split): - for block in blocks: - vb_index, vi_index = sch.get_loops(block) - v = sch.fuse(vb_index, vi_index) - _, vio, vii = sch.split(v, factors=[None, unroll_split, vector_split]) - sch.vectorize(vii) - sch.unroll(vio) - return sch - - -def apply_vrmpy_parallelization(sch): - block = sch.get_sblock("c_buffer") - b = sch.get_loops(block) - b_outer, _ = sch.split(b[0], factors=[4, None]) - sch.parallel(b_outer) - return sch - - -def apply_vtcm_cache_read_write(sch): - block = sch.get_sblock("c_buffer") - sch.cache_read(block, 0, "global.vtcm") - sch.cache_read(block, 1, "global.vtcm") - sch.cache_write(block, 0, "global.vtcm") - return sch - - -def vrmpy(operations): - """Generate VRMPY operator""" - - @T.prim_func(s_tir=True) - def operator(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tirx.noalias": True}) - a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8", align=128) - b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", align=128) - c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128) - for n in T.grid(operations): - with T.sblock("c_buffer"): - vn_ind = T.axis.remap("S", [n]) - c_buffer[vn_ind, T.ramp(0, 1, 32)] = T.call_llvm_intrin( - T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"), - T.reinterpret(a_buffer[vn_ind, T.ramp(0, 1, 128)], dtype="int32x32"), - T.reinterpret(b_buffer[vn_ind, T.ramp(0, 1, 128)], dtype="int32x32"), - dtype="int32x32", - ) - - return operator - - -def preloaded_vrmpy(operations): - """Generate preloaded VRMPY operator.""" - - @T.prim_func(s_tir=True) - def operator(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tirx.noalias": True}) - a_buffer = T.match_buffer( - a, - [T.cast(operations, "int32") * 128], - dtype="uint8", - align=128, - scope="global.vtcm", - ) - b_buffer = T.match_buffer( - b, - [T.cast(operations, "int32") * 128], - dtype="uint8", - align=128, - scope="global.vtcm", - ) - c_buffer = T.match_buffer( - c, [T.cast(operations, "int32") * 32], dtype="int32", align=128, scope="global.vtcm" - ) - for n in T.grid(operations): - with T.sblock("c_buffer"): - vn_ind = T.axis.remap("S", [n]) - c_buffer[T.ramp(T.cast(vn_ind, "int32") * 32, 1, 32)] = T.call_llvm_intrin( - T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"), - T.reinterpret( - a_buffer[T.ramp(T.cast(vn_ind, "int32") * 128, 1, 128)], dtype="int32x32" - ), - T.reinterpret( - b_buffer[T.ramp(T.cast(vn_ind, "int32") * 128, 1, 128)], dtype="int32x32" - ), - dtype="int32x32", - ) - - return operator - - -def preallocated_vrmpy(operations): - """Generate preallocated VRMPY operator.""" - size = operations * 128 - out_size = operations * 32 - - @T.prim_func(s_tir=True) - def operator( - a: T.handle, b: T.handle, c: T.handle, a_v: T.handle, b_v: T.handle, c_v: T.handle - ) -> None: - T.func_attr({"global_symbol": "main", "tirx.noalias": True}) - a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8", align=128, scope="global") - b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", align=128, scope="global") - c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, scope="global") - a_global_vtcm = T.match_buffer(a_v, [size], dtype="uint8", align=128, scope="global.vtcm") - b_global_vtcm = T.match_buffer(b_v, [size], dtype="uint8", align=128, scope="global.vtcm") - c_global_vtcm = T.match_buffer( - c_v, [out_size], dtype="int32", align=128, scope="global.vtcm" - ) - for n, i in T.grid(operations, 128): - with T.sblock("a_buffer_global.vtcm"): - vn_ind, vi_index = T.axis.remap("SS", [n, i]) - a_global_vtcm[vn_ind * 128 + vi_index] = a_buffer[vn_ind, vi_index] - for n, i in T.grid(operations, 128): - with T.sblock("b_buffer_global.vtcm"): - vn_ind, vi_index = T.axis.remap("SS", [n, i]) - b_global_vtcm[vn_ind * 128 + vi_index] = b_buffer[vn_ind, vi_index] - for n in T.grid(operations): - with T.sblock("c_buffer"): - vn_ind = T.axis.remap("S", [n]) - c_global_vtcm[T.ramp(T.cast(vn_ind, "int32") * 32, 1, 32)] = T.call_llvm_intrin( - T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"), - T.reinterpret( - a_global_vtcm[T.ramp(T.cast(vn_ind, "int32") * 128, 1, 128)], - dtype="int32x32", - ), - T.reinterpret( - b_global_vtcm[T.ramp(T.cast(vn_ind, "int32") * 128, 1, 128)], - dtype="int32x32", - ), - dtype="int32x32", - ) - for n, i in T.grid(operations, 32): - with T.sblock("c_buffer_global.vtcm"): - vn_ind, vi_index = T.axis.remap("SS", [n, i]) - c_buffer[vn_ind, vi_index] = c_global_vtcm[vn_ind * 32 + vi_index] - - return operator - - -def preallocated_single_dma_vrmpy(operations): - """Generate preallocated single DMA VRMPY operator.""" - size = operations * 128 - out_size = operations * 32 - - @T.prim_func(s_tir=True) - def operator( - a: T.handle, - b: T.handle, - c: T.handle, - a_v: T.handle, - b_v: T.handle, - c_v: T.handle, - ) -> None: - T.func_attr({"global_symbol": "main", "tirx.noalias": True}) - a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8", align=128, scope="global") - b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", align=128, scope="global") - c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, scope="global") - a_global_vtcm = T.match_buffer(a_v, [size], dtype="uint8", align=128, scope="global.vtcm") - b_global_vtcm = T.match_buffer(b_v, [size], dtype="uint8", align=128, scope="global.vtcm") - c_global_vtcm = T.match_buffer( - c_v, [out_size], dtype="int32", align=128, scope="global.vtcm" - ) - T.evaluate( - T.tvm_call_packed( - "device_api.hexagon.dma_copy_dltensor", - T.tvm_stack_make_array( - a_global_vtcm.data, - T.tvm_stack_make_shape(size, dtype="handle"), - 0, - 1, - a_global_vtcm.dtype, - 0, - dtype="handle", - ), - T.tvm_stack_make_array( - a_buffer.data, - T.tvm_stack_make_shape(size, dtype="handle"), - 0, - 1, - a_buffer.dtype, - 0, - dtype="handle", - ), - T.cast(size, dtype="int"), - True, # bypass cache - dtype="int32", - ) - ) - T.evaluate( - T.tvm_call_packed( - "device_api.hexagon.dma_copy_dltensor", - T.tvm_stack_make_array( - b_global_vtcm.data, - T.tvm_stack_make_shape(size, dtype="handle"), - 0, - 1, - b_global_vtcm.dtype, - 0, - dtype="handle", - ), - T.tvm_stack_make_array( - b_buffer.data, - T.tvm_stack_make_shape(size, dtype="handle"), - 0, - 1, - b_buffer.dtype, - 0, - dtype="handle", - ), - T.cast(size, dtype="int"), - True, # bypass cache - dtype="int32", - ) - ) - for n in T.grid(operations): - with T.sblock("c_buffer"): - vn_ind = T.axis.remap("S", [n]) - c_global_vtcm[T.ramp(T.cast(vn_ind, "int32") * 32, 1, 32)] = T.call_llvm_intrin( - T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vrmpyubv.128B"), - T.reinterpret( - a_global_vtcm[T.ramp(T.cast(vn_ind, "int32") * 128, 1, 128)], - dtype="int32x32", - ), - T.reinterpret( - b_global_vtcm[T.ramp(T.cast(vn_ind, "int32") * 128, 1, 128)], - dtype="int32x32", - ), - dtype="int32x32", - ) - T.evaluate( - T.tvm_call_packed( - "device_api.hexagon.dma_copy_dltensor", - T.tvm_stack_make_array( - c_buffer.data, - T.tvm_stack_make_shape(size, dtype="handle"), - 0, - 1, - c_buffer.dtype, - 0, - dtype="handle", - ), - T.tvm_stack_make_array( - c_global_vtcm.data, - T.tvm_stack_make_shape(size, dtype="handle"), - 0, - 1, - c_global_vtcm.dtype, - 0, - dtype="handle", - ), - T.cast(size, dtype="int"), - True, # bypass cache - dtype="int32", - ) - ) - - return operator - - -def evaluate_result(operations, tag, time, result, expected_output): - transfer_mb = round(3 * operations * 128 / 1e6, 2) - gops = round(operations * 128 * 3 / time.mean / 1e9, 3) - mean_ms = round(time.mean * 1000, 6) - - print(f"\ntest_{transfer_mb}MB_{tag} took {mean_ms} ms @ GOPS: {gops}") - tvm.testing.assert_allclose(result, expected_output) - - -def setup_and_run(hexagon_session, sch, a, b, c, operations, mem_scope="global"): - """Setup and run operator.""" - func_tir = tvm.compile(sch.mod["main"], target=get_hexagon_target("v69")) - module = hexagon_session.load_module(func_tir) - - a_hexagon = tvm.runtime.tensor(a, device=hexagon_session.device, mem_scope=mem_scope) - b_hexagon = tvm.runtime.tensor(b, device=hexagon_session.device, mem_scope=mem_scope) - c_hexagon = tvm.runtime.tensor(c, device=hexagon_session.device, mem_scope=mem_scope) - - # These are reduced for CI but number=100 and repeat=10 does a good job of removing noise. - number = 1 - repeat = 1 - - timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) - time = timer(a_hexagon, b_hexagon, c_hexagon) - gops = round(operations * 128 * 3 / time.mean / 1e9, 4) - return gops, c_hexagon.numpy() - - -def setup_and_run_preallocated(hexagon_session, sch, a, b, c, operations): - """Setup and run for preallocated.""" - func_tir = tvm.compile(sch.mod["main"], target=get_hexagon_target("v69")) - module = hexagon_session.load_module(func_tir) - - a_vtcm = np.zeros((a.size), dtype="uint8") - b_vtcm = np.zeros((b.size), dtype="uint8") - c_vtcm = np.zeros((c.size), dtype="int32") - - a_hexagon = tvm.runtime.tensor(a, device=hexagon_session.device, mem_scope="global") - b_hexagon = tvm.runtime.tensor(b, device=hexagon_session.device, mem_scope="global") - c_hexagon = tvm.runtime.tensor(c, device=hexagon_session.device, mem_scope="global") - a_vtcm_hexagon = tvm.runtime.tensor( - a_vtcm, device=hexagon_session.device, mem_scope="global.vtcm" - ) - b_vtcm_hexagon = tvm.runtime.tensor( - b_vtcm, device=hexagon_session.device, mem_scope="global.vtcm" - ) - c_vtcm_hexagon = tvm.runtime.tensor( - c_vtcm, device=hexagon_session.device, mem_scope="global.vtcm" - ) - - # These are reduced for CI but number=100 and repeat=10 does a good job of removing noise. - number = 1 - repeat = 1 - - timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) - time = timer(a_hexagon, b_hexagon, c_hexagon, a_vtcm_hexagon, b_vtcm_hexagon, c_vtcm_hexagon) - gops = round(operations * 128 * 3 / time.mean / 1e9, 4) - return gops, c_hexagon.numpy() - - -class TestMatMulVec: - """MatMul test class.""" - - # Removed most of these to speedup CI. - operations = tvm.testing.parameter( - 1024, - # 2048, - # 4096, - # 5 * 2048, # 3.93MB of total transfer - # 16384, #Only works on 8Gen1 HDK's - # 5 * 4096, # 7.86MB of total transfer. Only works on 8Gen1 HDK's - ) - - # Experimentally best configurations for the memcopy - outer_split = tvm.testing.parameter(4) - unroll_split = tvm.testing.parameter(8) - vector_split = tvm.testing.parameter(64) - c_vector_split = tvm.testing.parameter(16) - c_vector_split_unallocated = tvm.testing.parameter(8) - - @tvm.testing.fixture - def input_a(self, operations): - return np.random.randint(0, 16, (operations, 128), dtype="uint8") - - @tvm.testing.fixture - def input_b(self, operations): - return np.random.randint(0, 16, (operations, 128), dtype="uint8") - - @tvm.testing.fixture - def input_c(self, operations): - return np.zeros((operations, 32), dtype="int32") - - @tvm.testing.fixture - def expected_output(self, operations, input_a, input_b, input_c): - expected_output = np.zeros(input_c.shape, dtype="int32") - for n in range(operations): - for i in range(32): - for r_ind in range(4): # pylint: disable=unused-variable - expected_output[n, i] = expected_output[n, i] + np.uint32( - input_a[n, i * 4 + r_ind] - ) * np.uint32(input_b[n, i * 4 + r_ind]) - return expected_output - - @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") - def test_loading_vtcm_for_vrmpy( - self, - hexagon_session, - operations, - input_a, - input_b, - input_c, - expected_output, - outer_split, - unroll_split, - vector_split, - c_vector_split, - c_vector_split_unallocated, - ): - """Load VTCM for VRMPY operator test.""" - # Run parallel vrmpy without loading to VTCM. - sch = tvm.s_tir.Schedule(vrmpy(operations)) - sch = apply_vrmpy_parallelization(sch) - base_runtime, result = setup_and_run( - hexagon_session, sch, input_a, input_b, input_c, operations - ) - tvm.testing.assert_allclose(result, expected_output) - - # Run parallel vrmpy with basic memory loads to VTCM. - sch = tvm.s_tir.Schedule(vrmpy(operations)) - sch = apply_vtcm_cache_read_write(sch) - sch = apply_vrmpy_parallelization(sch) - basic_load_runtime, result = setup_and_run( - hexagon_session, sch, input_a, input_b, input_c, operations - ) - tvm.testing.assert_allclose(result, expected_output) - - # Run parallel vrmpy with vectorized memory loads to VTCM. - sch = tvm.s_tir.Schedule(vrmpy(operations)) - sch = apply_vtcm_cache_read_write(sch) - sch = apply_vrmpy_parallelization(sch) - sch = apply_unroll_vectorize( - sch, - [sch.get_sblock("a_buffer_global.vtcm"), sch.get_sblock("b_buffer_global.vtcm")], - unroll_split, - vector_split, - ) - sch = apply_unroll_vectorize( - sch, [sch.get_sblock("c_buffer_global.vtcm")], unroll_split, c_vector_split_unallocated - ) - vectorized_runtime, result = setup_and_run( - hexagon_session, sch, input_a, input_b, input_c, operations - ) - tvm.testing.assert_allclose(result, expected_output) - - # Run parallel vrmpy with vectorized and parallelized memory loads to VTCM. - sch = tvm.s_tir.Schedule(vrmpy(operations)) - sch = apply_vtcm_cache_read_write(sch) - sch = apply_vrmpy_parallelization(sch) - sch = apply_parallel_unroll_vectorize( - sch, - [sch.get_sblock("a_buffer_global.vtcm"), sch.get_sblock("b_buffer_global.vtcm")], - outer_split, - unroll_split, - vector_split, - ) - sch = apply_parallel_unroll_vectorize( - sch, - [sch.get_sblock("c_buffer_global.vtcm")], - outer_split, - unroll_split, - c_vector_split_unallocated, - ) - vectorized_parallelized_runtime, result = setup_and_run( - hexagon_session, sch, input_a, input_b, input_c, operations - ) - tvm.testing.assert_allclose(result, expected_output) - - # Run parallel vrmpy with preallocated and vectorized memory loads to VTCM. - sch = tvm.s_tir.Schedule(preallocated_vrmpy(operations)) - sch = apply_vrmpy_parallelization(sch) - sch = apply_unroll_vectorize( - sch, - [sch.get_sblock("a_buffer_global.vtcm"), sch.get_sblock("b_buffer_global.vtcm")], - unroll_split, - vector_split, - ) - sch = apply_unroll_vectorize( - sch, [sch.get_sblock("c_buffer_global.vtcm")], unroll_split, c_vector_split - ) - preallocated_vectorized_runtime, result = setup_and_run_preallocated( - hexagon_session, sch, input_a, input_b, input_c, operations - ) - result = result.reshape((operations, 32)) - tvm.testing.assert_allclose(result, expected_output) - - # Run parallel vrmpy with preallocated, vectorized, and parallelized memory loads to VTCM. - sch = tvm.s_tir.Schedule(preallocated_vrmpy(operations)) - sch = apply_vrmpy_parallelization(sch) - sch = apply_parallel_unroll_vectorize( - sch, - [sch.get_sblock("a_buffer_global.vtcm"), sch.get_sblock("b_buffer_global.vtcm")], - outer_split, - unroll_split, - vector_split, - ) - sch = apply_parallel_unroll_vectorize( - sch, [sch.get_sblock("c_buffer_global.vtcm")], outer_split, unroll_split, c_vector_split - ) - prealloc_vector_parallelized, result = setup_and_run_preallocated( - hexagon_session, sch, input_a, input_b, input_c, operations - ) - result = result.reshape((operations, 32)) - tvm.testing.assert_allclose(result, expected_output) - - # Run parallel vrmpy with preallocated single dma memory load to VTCM. - sch = tvm.s_tir.Schedule(preallocated_single_dma_vrmpy(operations)) - sch = apply_vrmpy_parallelization(sch) - single_dma_runtime, result = setup_and_run_preallocated( - hexagon_session, sch, input_a, input_b, input_c, operations - ) - result = result.reshape((operations, 32)) - tvm.testing.assert_allclose(result, expected_output) - - # Run parallel vrmpy with data preloaded in VTCM. - sch = tvm.s_tir.Schedule(preloaded_vrmpy(operations)) - sch = apply_vrmpy_parallelization(sch) - input_a = input_a.reshape(operations * 128) - input_b = input_b.reshape(operations * 128) - input_c = input_c.reshape(operations * 32) - preloaded_runtime, result = setup_and_run( - hexagon_session, sch, input_a, input_b, input_c, operations, "global.vtcm" - ) - result = result.reshape((operations, 32)) - tvm.testing.assert_allclose(result, expected_output) - - transfer_mb = round(3 * operations * 128 / 1e6, 2) - print( - TEST_OUTPUT_TEMPLATE.format( - transfer_mb, - base_runtime, - basic_load_runtime, - vectorized_runtime, - vectorized_parallelized_runtime, - preallocated_vectorized_runtime, - prealloc_vector_parallelized, - single_dma_runtime, - preloaded_runtime, - ) - ) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/test_parallel_scalar.py b/tests/python/contrib/test_hexagon/test_parallel_scalar.py deleted file mode 100644 index e14e6911f05d..000000000000 --- a/tests/python/contrib/test_hexagon/test_parallel_scalar.py +++ /dev/null @@ -1,177 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Test parallelism for multiple different scalar workloads.""" - -import numpy as np -import pytest - -import tvm -from tvm.script import tirx as T -from tvm.testing import env - -from .infrastructure import get_hexagon_target - -TEST_OUTPUT_TEMPLATE = ( - "Test {} with {} operations... \n" - " -Single Thread: {} ms \n" - " -Parallel: {} ms\n -Speedup: {}x\n" -) - - -def get_add_operator(operations): - """Generate add operator.""" - - @T.prim_func(s_tir=True) - def operator(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tirx.noalias": True}) - a_buffer = T.match_buffer(a, [operations], dtype="float64") - b_buffer = T.match_buffer(b, [operations], dtype="float64") - c_buffer = T.match_buffer(c, [operations], dtype="float64") - for n in T.grid(operations): - with T.sblock("c_buffer"): - vn_ind = T.axis.remap("S", [n]) - c_buffer[vn_ind] = a_buffer[vn_ind] + b_buffer[vn_ind] - - return operator - - -def get_multiply_operator(operations): - """Generate multiply operator.""" - - @T.prim_func(s_tir=True) - def operator(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tirx.noalias": True}) - a_buffer = T.match_buffer(a, [operations], dtype="float64") - b_buffer = T.match_buffer(b, [operations], dtype="float64") - c_buffer = T.match_buffer(c, [operations], dtype="float64") - for n in T.grid(operations): - with T.sblock("c_buffer"): - vn_ind = T.axis.remap("S", [n]) - c_buffer[vn_ind] = a_buffer[vn_ind] * b_buffer[vn_ind] - - return operator - - -def get_sub_operator(operations): - """Generate subtract operator.""" - - @T.prim_func(s_tir=True) - def operator(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tirx.noalias": True}) - a_buffer = T.match_buffer(a, [operations], dtype="float64") - b_buffer = T.match_buffer(b, [operations], dtype="float64") - c_buffer = T.match_buffer(c, [operations], dtype="float64") - for n in T.grid(operations): - with T.sblock("c_buffer"): - vn_ind = T.axis.remap("S", [n]) - c_buffer[vn_ind] = a_buffer[vn_ind] - b_buffer[vn_ind] - - return operator - - -def evaluate(hexagon_session, operations, expected, sch): - """Evalute schedule.""" - shape = operations - dtype = "float64" - - func_tir = tvm.compile(sch.mod["main"], target=get_hexagon_target("v68")) - module = hexagon_session.load_module(func_tir) - - # np.random.random returns float64 by default, but make the cast explicit - # to make it easier to switch when necessary. - a = np.random.random(shape).astype(dtype) - b = np.random.random(shape).astype(dtype) - c = np.zeros(shape, dtype=dtype) - - a_hexagon = tvm.runtime.tensor(a, device=hexagon_session.device) - b_hexagon = tvm.runtime.tensor(b, device=hexagon_session.device) - c_hexagon = tvm.runtime.tensor(c, device=hexagon_session.device) - - # These are reduced for CI but number=100 and repeat=10 does a good job of removing noise. - number = 1 - repeat = 1 - - timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) - runtime = timer(a_hexagon, b_hexagon, c_hexagon) - - tvm.testing.assert_allclose(c_hexagon.numpy(), expected(a, b)) - - return round(runtime.mean * 1000, 6) - - -class TestMatMulVec: - """MatMul test class.""" - - ( - operation_name, - operator_producer, - expected_output_producer, - ) = tvm.testing.parameters( - ("add", get_add_operator, (lambda a, b: a + b)), - ("mul", get_multiply_operator, (lambda a, b: a * b)), - ("sub", get_sub_operator, (lambda a, b: a - b)), - ) - - # Removed most of these to speedup CI. - operations = tvm.testing.parameter( - 128, - # 256, - # 512, - # Single thread runs faster since L2 cache can handle the entire request quickly - # 1024, - # 2048, - # Significant performance degredation once the inputs and outputs cannot all fit in L2 - # 4096, - # 8192, - # 16384, - ) - - split_factor = tvm.testing.parameter(4) - - @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") - def test_add( - self, - hexagon_session, - operation_name, - operator_producer, - expected_output_producer, - operations, - split_factor, - ): - """Test Add operator.""" - - sch = tvm.s_tir.Schedule(operator_producer(operations)) - single_thread_runtime = evaluate(hexagon_session, operations, expected_output_producer, sch) - - sch = tvm.s_tir.Schedule(operator_producer(operations)) - block = sch.get_sblock("c_buffer") - b = sch.get_loops(block) - b_output, _ = sch.split(b[0], factors=[split_factor, None]) - sch.parallel(b_output) - parallel_runtime = evaluate(hexagon_session, operations, expected_output_producer, sch) - - speedup = round(single_thread_runtime / parallel_runtime, 2) - print( - TEST_OUTPUT_TEMPLATE.format( - operation_name, operations, single_thread_runtime, parallel_runtime, speedup - ) - ) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py deleted file mode 100644 index ab69c9fa0d97..000000000000 --- a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py +++ /dev/null @@ -1,93 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Relax hexagon 2d VTCM allocation test.""" - -import numpy as np -import pytest - -import tvm -import tvm.contrib.hexagon -import tvm.testing -from tvm import relax -from tvm.script import ir as I -from tvm.script import relax as R -from tvm.script import tirx as T - - -# pylint: disable=missing-docstring,no-self-argument,invalid-name -@I.ir_module(s_tir=True) -class Module: - @T.prim_func(s_tir=True) - def add( - arg0: T.Buffer((2, 2), "float32"), - arg1: T.Buffer((2, 2), "float32"), - output: T.Buffer((2, 2), "float32"), - ): - T.func_attr({"operator_name": "relax.add"}) - for ax0 in range(2): - for ax1 in range(2): - with T.sblock("T_add"): - v_ax0 = T.axis.spatial(2, ax0) - v_ax1 = T.axis.spatial(2, ax1) - T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1]) - T.writes(output[v_ax0, v_ax1]) - output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1] - - @R.function(pure=False) - def main(x: R.Tensor((2, 2), dtype="float32")): - cls = Module - # Try allocating 2d storage (2,2) in global.vtcm scope with nd allocator - storage = R.vm.alloc_storage( - R.shape([2, 2]), runtime_device_index=0, dtype="float32", storage_scope="global.vtcm" - ) - alloc = R.vm.alloc_tensor(storage, offset=0, shape=R.shape([2, 2]), dtype="float32") - _: R.Tuple = cls.add(x, x, alloc) - out: R.Tensor((2, 2), dtype="float32") = alloc - storage2 = R.vm.alloc_storage(R.shape([4 * 2 * 2]), runtime_device_index=0, dtype="uint8") - alloc2 = R.vm.alloc_tensor(storage2, offset=0, shape=R.shape([2, 2]), dtype="float32") - _1: R.Tuple = cls.add(out, x, alloc2) - out2: R.Tensor((2, 2), dtype="float32") = alloc2 - return out2 - - -# pylint: enable=missing-docstring,no-self-argument,invalid-name -@pytest.mark.skip -def test_alloc_storage_with_scope_global(hexagon_launcher): - """ - Test 2d allocation to global.vtcm memory scope in a Relax Function - """ - arg0 = np.random.uniform(size=(2, 2)).astype(np.float32) - - output_ref = arg0 + arg0 + arg0 - - mod = Module - - target_hexagon = tvm.target.Target({"tag": "qcom/hexagon-v69", "vtcm-capacity": 4 * 2**20}) - target = tvm.target.Target(target_hexagon, host=target_hexagon) - with tvm.transform.PassContext(opt_level=3): - lib = tvm.compile(mod, target, exec_mode="compiled") - - with hexagon_launcher.create_session() as session: - dev = session.device - vm_mod = session.get_executor_from_factory(lib) - # This is the important line which tests nd allocator - vm_rt = relax.VirtualMachine(vm_mod, dev, memory_cfg="naive") - x = tvm.runtime.tensor(arg0, dev) - vm_rt.set_input("main", x) - vm_rt.invoke_stateful("main") - hexagon_output = vm_rt.get_outputs("main").numpy() - tvm.testing.assert_allclose(output_ref, hexagon_output) diff --git a/tests/python/contrib/test_hexagon/test_relax_integration.py b/tests/python/contrib/test_hexagon/test_relax_integration.py deleted file mode 100644 index b8947b114c74..000000000000 --- a/tests/python/contrib/test_hexagon/test_relax_integration.py +++ /dev/null @@ -1,115 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: F401, F821 -"""Relax hexagon test.""" - -import numpy as np -import pytest - -pytest.importorskip("onnx") # tvm.relax.frontend.onnx imports onnx - -import tvm.testing -from tvm import relax, runtime -from tvm.contrib.hexagon.session import Session -from tvm.relax.frontend import onnx -from tvm.relax.testing import relay_translator -from tvm.testing import env - - -def get_onnx_mobilenet(): - """Download and import mobilenet model with ONNX""" - import onnx # pylint: disable=import-outside-toplevel - - # pylint: disable=line-too-long - model_url = "https://github.com/onnx/models/raw/131c99da401c757207a40189385410e238ed0934/vision/classification/mobilenet/model/mobilenetv2-7.onnx" - model_path = tvm.contrib.download.download_testdata( - model_url, "mobilenetv2-7.onnx", module="onnx" - ) - return onnx.load(model_path) - - -@pytest.mark.skip("takes too long (~20min)") -@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") -def test_mobilenet_onnx(hexagon_session: Session): - """Test MobileNetV2 ONNX model""" - onnx_model = get_onnx_mobilenet() - data_np = np.random.rand(1, 3, 224, 224).astype("float32") - shape_dict = {"input": data_np.shape} - relay_mod, _ = relay.frontend.from_onnx(onnx_model, shape_dict, freeze_params=True) - - target_hexagon = tvm.target.Target("qcom/hexagon-v68") - target = tvm.target.Target(target_hexagon, host=target_hexagon) - relax_mod = onnx.from_onnx(onnx_model, shape_dict, freeze_params=True) - relax_mod = relay_translator.from_relay(relay_mod["main"], target_hexagon) - - # Compile and run on Hexagon. - exe = tvm.compile(relax_mod, target) - dev = hexagon_session.device - - vm_mod = hexagon_session.get_executor_from_factory(exe) - vm_rt = relax.VirtualMachine(vm_mod, dev) - data = tvm.runtime.tensor(data_np, dev) - vm_rt.set_input("main", data) - vm_rt.invoke_stateful("main") - hexagon_res = vm_rt.get_outputs("main") - - # Compile and run on LLVM for comparison. - relax_mod = relay_translator.from_relay(relay_mod["main"], "llvm") - exe = tvm.compile(relax_mod, "llvm") - dev = tvm.cpu() - vm_rt = relax.VirtualMachine(exe, dev) - data = tvm.runtime.tensor(data_np, dev) - llvm_res = vm_rt["main"](data) - tvm.testing.assert_allclose(hexagon_res.numpy(), llvm_res.numpy(), rtol=1e-3) - - -@pytest.mark.skip("takes too long (~20min)") -@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") -def test_mobilenet(hexagon_session: Session): - """Test MobileNet workload""" - relay_mod, params = testing.mobilenet.get_workload(batch_size=1, dtype="float32") - data_np = np.random.rand(1, 3, 224, 224).astype("float32") - - target_hexagon = tvm.target.Target("qcom/hexagon-v68") - target = tvm.target.Target(target_hexagon, host=target_hexagon) - - # translate the relay mobilenet and bind params - relax_mod = relay_translator.from_relay(relay_mod["main"], target, params) - - # Compile and run on Hexagon. - exe = tvm.compile(relax_mod, target) - dev = hexagon_session.device - - vm_mod = hexagon_session.get_executor_from_factory(exe) - vm_rt = relax.VirtualMachine(vm_mod, dev) - data = tvm.runtime.tensor(data_np, dev) - vm_rt.set_input("main", data) - vm_rt.invoke_stateful("main") - hexagon_res = vm_rt.get_outputs("main") - - # Compile and run on LLVM for comparison. - relax_mod = relay_translator.from_relay(relay_mod["main"], "llvm", params) - exe = tvm.compile(relax_mod, "llvm") - dev = tvm.cpu() - vm_rt = relax.VirtualMachine(exe, dev) - data = tvm.runtime.tensor(data_np, dev) - llvm_res = vm_rt["main"](data) - tvm.testing.assert_allclose(hexagon_res.numpy(), llvm_res.numpy(), rtol=1e-3) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/test_run_unit_tests.py b/tests/python/contrib/test_hexagon/test_run_unit_tests.py deleted file mode 100644 index f1cec118e4c3..000000000000 --- a/tests/python/contrib/test_hexagon/test_run_unit_tests.py +++ /dev/null @@ -1,183 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=redefined-outer-name - -"""capture gtest output and return over FFI""" - -import pytest - -import tvm -import tvm.testing -from tvm.contrib.hexagon.session import Session -from tvm.testing import env - -unit_test_name = tvm.testing.parameter( - "HexagonUserDMATest.wait", - "HexagonUserDMATest.poll", - "HexagonUserDMATest.bad_copy", - "HexagonUserDMATest.sync_dma", - "HexagonUserDMATest.async_dma_wait", - "HexagonUserDMATest.async_dma_poll", - "HexagonUserDMATest.pipeline", - "HexagonUserDMATest.pipeline_write_queue", - "HexagonUserDMATest.overflow_ring_buffer", - "HexagonUserDMATest.sync_dma_bypass", - "HexagonUserDMATest.sync_dma_bypass_vtcm_to_vtcm", - "HexagonUserDMATest.sync_dma_bypass_", - "HexagonBuffer.default_scope", - "HexagonBuffer.ddr_scope", - "HexagonBuffer.vtcm_scope", - "HexagonBuffer.invalid_scope", - "HexagonBuffer.micro_copies_corresponding_regions", - "HexagonBuffer.micro_copies_src_bigger", - "HexagonBuffer.micro_copies_dest_bigger", - "HexagonBuffer.micro_copies_src_overlaps_dest_region", - "HexagonBuffer.micro_copies_dest_overlaps_src_region", - "HexagonBuffer.micro_copies_discontiguous_regions", - "HexagonBuffer.micro_copies_invalid_size", - "HexagonBuffer.macro_copies_adjacent_corresponding_regions_merged", - "HexagonBuffer.macro_copies_discontiguous_regions_not_merged", - "HexagonBuffer.macro_copies_overlapping_regions_merged", - "HexagonBuffer.copy_from", - "HexagonBuffer.copy_from_invalid_size", - "HexagonBuffer.copy_from_smaller_size", - "HexagonBuffer.nd", - "HexagonBuffer.nd_copy_from", - "HexagonBuffer.1d_copy_from_1d", - "HexagonBuffer.2d_copy_from_1d", - "HexagonBuffer.1d_copy_from_2d", - "HexagonBuffer.nd_copy_from_nd_invalid_size", - "HexagonBuffer.nd_copy_from_nd_smaller_size", - "HexagonBuffer.md_copy_from_nd", - "HexagonBuffer.copy_to", - "HexagonBuffer.nd_copy_to", - "RingBufferTest.zero_size_ring_buffer", - "RingBufferTest.in_flight", - "RingBufferTest.next", - "RingBufferTest.full", - "RingBufferTest.wrap", - "RingBufferTest.wrap_corner", - "RingBufferTest.half_in_flight", - "RingBufferTest.half_in_flight_blocked", - "QueuedRingBufferTest.invalid_queue", - "QueuedRingBufferTest.two_queues", - "QueuedRingBufferTest.group_end_before_group_start", - "QueuedRingBufferTest.group_restart", - "QueuedRingBufferTest.zero_size_group", - "QueuedRingBufferTest.in_flight_before_group_end", - "QueuedRingBufferTest.group_of_one", - "QueuedRingBufferTest.group_of_two", - "QueuedRingBufferTest.group_of_three", - "QueuedRingBufferTest.two_groups_of_two", - "QueuedRingBufferTest.two_queues_two_groups_of_two", - "HexagonVtcmPoolTest.basic", - "HexagonVtcmPoolTest.small_allocations", - "HexagonVtcmPoolTest.no_free_vtcm", - "HexagonVtcmPoolTest.not_enough_free_vtcm", - "HexagonVtcmPoolTest.free_with_wrong_size", - "HexagonVtcmPoolTest.free_alloc_combinations", - "HexagonVtcmPoolTest.find_allocation", - "HexagonVtcmPoolTest.find_smallest_allocation_combinations", - "HexagonVtcmPoolTest.vtcm_alignment", - "HexagonThreadManagerTest.ctor_edge_cases", - "HexagonThreadManagerTest.init", - "HexagonThreadManagerTest.dispatch", - "HexagonThreadManagerTest.dispatch_wait", - "HexagonThreadManagerTest.wait_signal", - "HexagonThreadManagerTest.re_signal", - "HexagonThreadManagerTest.re_wait", - "HexagonThreadManagerTest.wait_signal_x2", - "HexagonThreadManagerTest.signal_wait", - "HexagonThreadManagerTest.sync_from_to", - "HexagonThreadManagerTest.sync_from_to_self", - "HexagonThreadManagerTest.sync_from_to_x2", - "HexagonThreadManagerTest.sync_from_to_all", - "HexagonThreadManagerTest.pipe_fill", - "HexagonThreadManagerTest.pipe_overflow", - "HexagonThreadManagerTest.producer_consumer", - "HexagonThreadManagerTest.producer_consumer_signal_wait", - "HexagonThreadManagerTest.thread_order", - "HexagonThreadManagerTest.thread_order_signal_wait", - "HexagonThreadManagerTest.dispatch_writes", - "HexagonThreadManagerTest.threads_for_resource_types", - "HexagonUtilsActivationsBlockizeTest.prepare_nhwc", - "HexagonUtilsActivationsBlockizeTest.blockize_hwc_16b", - "HexagonUtilsActivationsBlockizeTest.deblockize_hwc_16b", - "HexagonUtilsWeightsChunkifyTest.calculate_num_weight_chunks", - "HexagonUtilsWeightsChunkifyTest.prepare_hwio", - "HexagonUtilsWeightsChunkifyTest.chunkify_hwio_16b", - "HexagonUtilsQuantActivationsBlockizeTest.prepare_nhwc", - "HexagonUtilsQuantActivationsBlockizeTest.blockize_hwc_8b", - "HexagonUtilsQuantActivationsBlockizeTest.deblockize_hwc_8b", - "HexagonUtilsQuantWeightsChunkifyTest.calculate_num_weight_chunks", - "HexagonUtilsQuantWeightsChunkifyTest.prepare_hwio", - "HexagonUtilsQuantWeightsChunkifyTest.chunkify_hwio_8b", - "HexagonDeviceAPITest.global", - "HexagonDeviceAPITest.alloc_free_cpu", - "HexagonDeviceAPITest.alloc_free_hex", - "HexagonDeviceAPITest.alloc_errors", - "HexagonDeviceAPITest.free_errors", - "HexagonDeviceAPITest.allocnd_free_cpu", - "HexagonDeviceAPITest.allocnd_free_hex", - "HexagonDeviceAPITest.allocnd_free_hex_vtcm", - "HexagonDeviceAPITest.allocnd_erros", - "HexagonDeviceAPITest.alloc_scalar", - "HexagonDeviceAPITest.DISABLED_alloc_free_diff_dev", - "HexagonDeviceAPITest.runtime_buffer_manager", - "HexagonDeviceAPITest.thread_manager", - "HexagonDeviceAPITest.user_dma", - "HexagonDeviceAPITest.vtcm_pool", -) - - -# use pytest -sv to observe gtest output -# use --gtest_args to pass arguments to gtest -# for example to run all "foo" tests twice and observe gtest output run -# pytest -sv --gtests_args="--gtest_filter=*foo* --gtest_repeat=2" -@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") -def test_run_unit_tests(hexagon_session: Session, gtest_args, unit_test_name): - """Try running gtest unit tests and capture output and error code""" - try: - func = hexagon_session._rpc.get_function("hexagon.run_unit_tests") - except Exception: - print( - "This test requires TVM Runtime to be built with a Hexagon gtest" - "version using Hexagon API cmake flag" - "-DUSE_HEXAGON_GTEST=/path/to/hexagon/sdk/utils/googletest/gtest" - ) - raise - - # Prepend the unit test name, so command-line arguments still take - # precedence, but CI runs each gtest as a separate pytest case. - if gtest_args: - gtest_args = f"--gtest_filter={unit_test_name} {gtest_args}" - else: - gtest_args = f"--gtest_filter={unit_test_name}" - - gtest_error_code_and_output = func(gtest_args) - gtest_error_code = int(gtest_error_code_and_output.splitlines()[0]) - gtest_output = gtest_error_code_and_output.split("\n", 1)[-1] - print(gtest_output) - if gtest_error_code != 0: - raise RuntimeError( - f"Hexagon gtest retruned non-zero error code = {gtest_error_code}:\n{gtest_output}" - ) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/test_sigmoid.py b/tests/python/contrib/test_hexagon/test_sigmoid.py deleted file mode 100644 index f9ffe5522097..000000000000 --- a/tests/python/contrib/test_hexagon/test_sigmoid.py +++ /dev/null @@ -1,120 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: F401, F841 -"""Sigmoid operator tests.""" - -import numpy as np -import pytest - -import tvm -import tvm.testing -from tvm import te, tirx, topi -from tvm.contrib.hexagon import allocate_hexagon_array -from tvm.testing import env - -from .infrastructure import get_hexagon_target - - -def sigmoid_compute(sigmoid_input): - return topi.sigmoid(sigmoid_input) - - -def sigmoid_stir_schedule(sigmoid_input, sigmoid_output): - sigmoid_func = te.create_prim_func([sigmoid_input, sigmoid_output]) - sch = tvm.s_tir.Schedule(sigmoid_func, debug_mask="all") - block = sch.get_sblock("compute") - - (n,) = sch.get_loops(block) - sch.vectorize(n) - return sch - - -class BaseSigmoid: - ( - in_shape, - dtype, - min_val, - max_val, - ) = tvm.testing.parameters( - ((64,), "float16", -8.0, 8.0), - ((64,), "float16", -6.0, 7.0), - ((64,), "float16", -10.0, 15.0), - ((64,), "float16", -10.0, 0.0), - ((64,), "float16", 0.0, 10.0), - ) - - -class TestSigmoid(BaseSigmoid): - """Sigmoid test class.""" - - @tvm.testing.fixture - def input_np(self, in_shape, dtype, min_val, max_val): - return np.random.uniform(low=min_val, high=max_val, size=in_shape).astype(dtype) - - @tvm.testing.fixture - def ref_output_np(self, input_np): - output_np = 1 / (1 + np.exp(-input_np)) - return output_np - - @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") - def test_sigmoid( - self, - in_shape, - dtype, - input_np, - ref_output_np, - hexagon_session, - ): - """Sigmoid test.""" - input_tensor = te.placeholder(in_shape, name="input_tensor", dtype=dtype) - - output_tensor = sigmoid_compute(input_tensor) - - tir_s = sigmoid_stir_schedule(input_tensor, output_tensor) - - input_data = allocate_hexagon_array( - hexagon_session.device, - data=input_np, - ) - output_data = allocate_hexagon_array( - hexagon_session.device, - tensor_shape=ref_output_np.shape, - dtype=ref_output_np.dtype, - ) - - func_name = "sigmoid" - with tvm.transform.PassContext(opt_level=3): - runtime_module = tvm.compile(tir_s.mod, target=get_hexagon_target("v69")) - - assert "hvx_sigmoid" in runtime_module.inspect_source("asm") - assert "vmin" in runtime_module.inspect_source("asm") - assert "vmax" in runtime_module.inspect_source("asm") - mod = hexagon_session.load_module(runtime_module) - - mod(input_data, output_data) - output_np = output_data.numpy() - - tvm.testing.assert_allclose( - output_np, - ref_output_np, - 1e-3, - 1e-3, - ) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py deleted file mode 100644 index e06a07fef1ac..000000000000 --- a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py +++ /dev/null @@ -1,206 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: F401 -"""Async software pipeline tests.""" - -import numpy as np -import pytest - -import tvm -from tvm import tirx -from tvm.script import tirx as T -from tvm.testing import env - -from .infrastructure import get_hexagon_target - - -def compute(comp_type, outer, inner, dtype): - """Generate compute function.""" - if comp_type == "single_input": - - @T.prim_func(s_tir=True) - def a_plus_1_primfunc( - a_buffer: T.Buffer((outer, inner), dtype), out: T.Buffer((outer, inner), dtype) - ): - for i in T.serial(outer): - for j in T.serial(inner): - with T.sblock("compute"): - with T.sblock(): - out[i, j] = a_buffer[i, j] + T.cast(1, dtype) - - return a_plus_1_primfunc - else: - - @T.prim_func(s_tir=True) - def a_plus_b_plus_1_primfunc( - a_buffer: T.Buffer((outer, inner), dtype), - b_buffer: T.Buffer((outer, inner), dtype), - out: T.Buffer((outer, inner), dtype), - ): - for i in T.serial(outer): - for j in T.serial(inner): - with T.sblock("compute"): - with T.sblock(): - out[i, j] = a_buffer[i, j] + b_buffer[i, j] + T.cast(1, dtype) - - return a_plus_b_plus_1_primfunc - - -class TestAsyncSoftwarePipeline: - """Async software pipeline test class.""" - - outer = tvm.testing.parameter(8, 16) - inner = tvm.testing.parameter(64, 128) - dtype = tvm.testing.parameter("uint8", "float16") - scope = tvm.testing.parameter("global", "global.vtcm") - # TODO(Joseph) Turn on "multi_input_diffQ" compute type once we have upstreamed - # changes in the InjectSoftwarePipeline pass to alleviate this restriction: - # 'a_buffer dependency on multiple async stages is not supported' - comp_type = tvm.testing.parameter("single_input", "multi_input_sameQ") - # TODO(Straw) Add back "cache_write" schedule type once we have upstreamed - # buffer dependency analysis in InjectSoftwarePipeline pass - # to insert approprite TIR "wait" attributes for this schedule - sched_type = tvm.testing.parameter("cache_read", "cache_read_write") - - @tvm.testing.fixture - def data(self, comp_type, outer, inner, dtype): - out_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype) - a_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype) - if comp_type == "single_input": - return out_np, a_np - else: - b_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype) - return out_np, a_np, b_np - - @tvm.testing.fixture - def verify(self, dtype): - def check(out, ref): - if "int" in dtype: - np.testing.assert_equal(out.numpy(), ref) - else: - tvm.testing.assert_allclose(out.numpy(), ref, rtol=1e-3, atol=1e-3) - - return check - - @tvm.testing.fixture - def reference(self, comp_type): - """Returns reference data.""" - if comp_type == "single_input": - - def a_plus_1_ref(a): - return a + 1 - - return a_plus_1_ref - else: - - def a_plus_b_plus_1_ref(a, b): - return a + b + 1 - - return a_plus_b_plus_1_ref - - @tvm.testing.fixture - def schedule(self, comp_type, sched_type, outer, inner, dtype, scope): - """Generate schedule.""" - sch = tvm.s_tir.Schedule(compute(comp_type, outer, inner, dtype)) - - compute_block = sch.get_sblock("compute") - i, _ = sch.get_loops(compute_block) - - if "read" in sched_type: - cache_read_a = sch.cache_read(compute_block, 0, scope) - sch.compute_at(cache_read_a, i) - - if "multi_input" in comp_type: - cache_read_b = sch.cache_read(compute_block, 1, scope) - sch.compute_at(cache_read_b, i) - - if "write" in sched_type: - cache_write_out = sch.cache_write(compute_block, 0, scope) - sch.reverse_compute_at(cache_write_out, i) - - if "read" in sched_type and "write" in sched_type: - if comp_type == "single_input": - sch.annotate(i, "software_pipeline_stage", [0, 1, 2]) - sch.annotate(i, "software_pipeline_order", [0, 1, 2]) - sch.annotate(i, "software_pipeline_async_stages", [0, 2]) - elif comp_type == "multi_input_sameQ": - sch.annotate(i, "software_pipeline_stage", [0, 0, 1, 2]) - sch.annotate(i, "software_pipeline_order", [0, 1, 2, 3]) - sch.annotate(i, "software_pipeline_async_stages", [0, 2]) - elif comp_type == "multi_input_diffQ": - sch.annotate(i, "software_pipeline_stage", [0, 1, 2, 3]) - sch.annotate(i, "software_pipeline_order", [0, 1, 2, 3]) - sch.annotate(i, "software_pipeline_async_stages", [0, 1, 2]) - - elif "read" in sched_type: - if comp_type == "single_input": - sch.annotate(i, "software_pipeline_stage", [0, 1]) - sch.annotate(i, "software_pipeline_order", [0, 1]) - sch.annotate(i, "software_pipeline_async_stages", [0]) - elif comp_type == "multi_input_sameQ": - sch.annotate(i, "software_pipeline_stage", [0, 0, 1]) - sch.annotate(i, "software_pipeline_order", [0, 1, 2]) - sch.annotate(i, "software_pipeline_async_stages", [0]) - elif comp_type == "multi_input_diffQ": - sch.annotate(i, "software_pipeline_stage", [0, 1, 2]) - sch.annotate(i, "software_pipeline_order", [0, 1, 2]) - sch.annotate(i, "software_pipeline_async_stages", [0, 1]) - - elif "write" in sched_type: - sch.annotate(i, "software_pipeline_stage", [0, 1]) - sch.annotate(i, "software_pipeline_order", [0, 1]) - sch.annotate(i, "software_pipeline_async_stages", [1]) - - return sch - - @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") - def test_async_software_pipeline( - self, hexagon_launcher, comp_type, data, reference, schedule, verify - ): - """Async software pipeline test.""" - out_np = data[0] - a_np = data[1] - if comp_type == "single_input": - ref = reference(a_np) - else: - b_np = data[2] - ref = reference(a_np, b_np) - - with tvm.transform.PassContext( - config={ - "tirx.use_async_copy": 1, - "tirx.experimental_dma_bypass_cache": 1, - } - ): - func = tvm.compile(schedule.mod["main"], target=get_hexagon_target("v68")) - - with hexagon_launcher.create_session() as hexagon_session: - dev = hexagon_session.device - mod = hexagon_session.load_module(func) - out = tvm.runtime.tensor(out_np, device=dev) - a = tvm.runtime.tensor(a_np, device=dev) - if comp_type == "single_input": - mod(a, out) - else: - b = tvm.runtime.tensor(b_np, device=dev) - mod(a, b, out) - - verify(out, ref) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/test_take.py b/tests/python/contrib/test_hexagon/test_take.py deleted file mode 100644 index 9e46da1c1718..000000000000 --- a/tests/python/contrib/test_hexagon/test_take.py +++ /dev/null @@ -1,397 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring, invalid-name, unused-argument, not-callable -import numpy as np -import pytest -import tvm_ffi - -pytest.importorskip("scipy") - -from scipy import special - -import tvm -import tvm.testing -from tvm import relax -from tvm.contrib.hexagon import generate_take_op, hexagon_unary_ops -from tvm.script import relax as R -from tvm.script import tirx as T - -from .infrastructure import quantize_np - -# Testing the structural and value correctness on replacing unary op with take op. - - -@tvm.script.ir_module -class Module_tanh: - @R.function - def main( - input_tanh: R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): - out = R.call_tir( - Module_tanh.tanh, - ( - input_tanh, - R.const(0.003186821002586215, "float32"), - R.const(0, "int32"), - R.const(0.002631544131858676, "float32"), - R.const(0, "int32"), - ), - out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) - return out - - @T.prim_func(s_tir=True) - def tanh( - rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - rxplaceholder_1: T.Buffer((), "float32"), - rxplaceholder_2: T.Buffer((), "int32"), - rxplaceholder_3: T.Buffer((), "float32"), - rxplaceholder_4: T.Buffer((), "int32"), - compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - ): - T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.tanh"}}) - - -@tvm.script.ir_module -class Module_sqrt: - @R.function - def main( - input_sqrt: R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): - out = R.call_tir( - Module_sqrt.sqrt, - ( - input_sqrt, - R.const(0.003186821002586215, "float32"), - R.const(0, "int32"), - R.const(0.003535157327728918, "float32"), - R.const(0, "int32"), - ), - out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) - return out - - @T.prim_func(s_tir=True) - def sqrt( - rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - rxplaceholder_1: T.Buffer((), "float32"), - rxplaceholder_2: T.Buffer((), "int32"), - rxplaceholder_3: T.Buffer((), "float32"), - rxplaceholder_4: T.Buffer((), "int32"), - compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - ): - T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.sqrt"}}) - - -@tvm.script.ir_module -class Module_rsqrt: - @R.function - def main( - input_rsqrt: R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): - out = R.call_tir( - Module_rsqrt.rsqrt, - ( - input_rsqrt, - R.const(0.003186821002586215, "float32"), - R.const(0, "int32"), - R.const(0.008154160766635542, "float32"), - R.const(0, "int32"), - ), - out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) - return out - - @T.prim_func(s_tir=True) - def rsqrt( - rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - rxplaceholder_1: T.Buffer((), "float32"), - rxplaceholder_2: T.Buffer((), "int32"), - rxplaceholder_3: T.Buffer((), "float32"), - rxplaceholder_4: T.Buffer((), "int32"), - compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - ): - T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.rsqrt"}}) - - -@tvm.script.ir_module -class Module_exp: - @R.function - def main( - input_exp: R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): - out = R.call_tir( - Module_exp.exp, - ( - input_exp, - R.const(0.003186821002586215, "float32"), - R.const(0, "int32"), - R.const(0.008838622987079832, "float32"), - R.const(0, "int32"), - ), - out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) - return out - - @T.prim_func(s_tir=True) - def exp( - rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - rxplaceholder_1: T.Buffer((), "float32"), - rxplaceholder_2: T.Buffer((), "int32"), - rxplaceholder_3: T.Buffer((), "float32"), - rxplaceholder_4: T.Buffer((), "int32"), - compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - ): - T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.exp"}}) - - -@tvm.script.ir_module -class Module_erf: - @R.function - def main( - input_erf: R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): - out = R.call_tir( - Module_erf.erf, - ( - input_erf, - R.const(0.003186821002586215, "float32"), - R.const(0, "int32"), - R.const(0.002939393251118067, "float32"), - R.const(0, "int32"), - ), - out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) - return out - - @T.prim_func(s_tir=True) - def erf( - rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - rxplaceholder_1: T.Buffer((), "float32"), - rxplaceholder_2: T.Buffer((), "int32"), - rxplaceholder_3: T.Buffer((), "float32"), - rxplaceholder_4: T.Buffer((), "int32"), - compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - ): - T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.erf"}}) - - -@tvm.script.ir_module -class Module_sigmoid: - @R.function - def main( - input_sigmoid: R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): - out = R.call_tir( - Module_sigmoid.sigmoid, - ( - input_sigmoid, - R.const(0.003186821002586215, "float32"), - R.const(0, "int32"), - R.const(0.002631544131858676, "float32"), - R.const(0, "int32"), - ), - out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) - return out - - @T.prim_func(s_tir=True) - def sigmoid( - rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - rxplaceholder_1: T.Buffer((), "float32"), - rxplaceholder_2: T.Buffer((), "int32"), - rxplaceholder_3: T.Buffer((), "float32"), - rxplaceholder_4: T.Buffer((), "int32"), - compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - ): - T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.sigmoid"}}) - - -@tvm.script.ir_module -class Module_hardswish: - @R.function - def main( - input_hardswish: R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): - out = R.call_tir( - Module_hardswish.hardswish, - ( - input_hardswish, - R.const(0.003186821002586215, "float32"), - R.const(0, "int32"), - R.const(0.0020250332087720325, "float32"), - R.const(0, "int32"), - ), - out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) - return out - - @T.prim_func(s_tir=True) - def hardswish( - rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - rxplaceholder_1: T.Buffer((), "float32"), - rxplaceholder_2: T.Buffer((), "int32"), - rxplaceholder_3: T.Buffer((), "float32"), - rxplaceholder_4: T.Buffer((), "int32"), - compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - ): - T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.hardswish"}}) - - -@tvm.script.ir_module -class Module_log: - @R.function - def main( - input_log: R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): - out = R.call_tir( - Module_log.log, - ( - input_log, - R.const(0.003186821002586215, "float32"), - R.const(0, "int32"), - R.const(0.0057414634248614226, "float32"), - R.const(255, "int32"), - ), - out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) - return out - - @T.prim_func(s_tir=True) - def log( - rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - rxplaceholder_1: T.Buffer((), "float32"), - rxplaceholder_2: T.Buffer((), "int32"), - rxplaceholder_3: T.Buffer((), "float32"), - rxplaceholder_4: T.Buffer((), "int32"), - compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - ): - T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.log"}}) - - -@tvm.script.ir_module -class Module_abs: - @R.function - def main( - input_abs: R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"): - out = R.call_tir( - Module_abs.abs, - ( - input_abs, - R.const(0.003186821002586215, "float32"), - R.const(0, "int32"), - R.const(0.0031868210196078434, "float32"), - R.const(0, "int32"), - ), - out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"), - ) - return out - - @T.prim_func(s_tir=True) - def abs( - rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - rxplaceholder_1: T.Buffer((), "float32"), - rxplaceholder_2: T.Buffer((), "int32"), - rxplaceholder_3: T.Buffer((), "float32"), - rxplaceholder_4: T.Buffer((), "int32"), - compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), - ): - T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.abs"}}) - - -# data = np.random.random([1, 2, 2, 2]).astype("float32") : Need to hadcode the data -# so that we can get the quantization parameters and use them as input to the main func -data = [ - [ - [[0.3034368, 0.60848576], [0.29697746, 0.67340654]], - [[0.656068, 0.23129226], [0.42117321, 0.81263936]], - ] -] -dtype = "uint8" - -# Quantizing input : scale is returned as float64 and zp is returned as int32 -inp_quant, inp_scale, inp_zero_point = quantize_np(data, dtype) -inp_quant = tvm.runtime.tensor(inp_quant.astype(np.uint8)) - - -# Test the implementations value output with numpy data. First the IR is runn through pass -# to replace unary op with take op. Followed by value testing. -def test_value(): - ops = ["tanh", "sqrt", "rsqrt", "exp", "erf", "sigmoid", "hardswish", "log", "abs"] - - atol_val = 2 - for op_name in ops: - if op_name == "tanh": - op_val = np.tanh(data) - before = Module_tanh - elif op_name == "sqrt": - op_val = np.sqrt(data) - before = Module_sqrt - elif op_name == "rsqrt": - op_val = 1 / np.sqrt(data) - before = Module_rsqrt - elif op_name == "exp": - op_val = np.exp(data) - before = Module_exp - elif op_name == "erf": - op_val = special.erf(data) - before = Module_erf - elif op_name == "sigmoid": - op_val = 1 / (1 + np.exp(np.negative(data))) - atol_val = 15 - before = Module_sigmoid - elif op_name == "hardswish": - op_val = hexagon_unary_ops.hardswish_func(data) - before = Module_hardswish - elif op_name == "log": - op_val = np.log(data) - before = Module_log - elif op_name == "abs": - op_val = np.abs(data) - before = Module_abs - - # Quantizing output : scale is returned as float64 and zp is returned as int32 - out_quant, _, _ = quantize_np(op_val, dtype) - - after = generate_take_op.PassReplaceWithTakeOpPrimFuncs()(before) - target = tvm.target.Target("llvm", host="llvm") - ex = tvm.compile(after, target) - vm = relax.VirtualMachine(ex, tvm.cpu()) - res = vm["main"](inp_quant) - - tvm.testing.assert_allclose(res.numpy(), out_quant, atol=atol_val) - print("Passed Value : ", op_name) - - -# Testing the structural implementation, if the unary op is replaced with take op. -def test_structural(): - Modules = [ - Module_tanh, - Module_sqrt, - Module_rsqrt, - Module_exp, - Module_erf, - Module_sigmoid, - Module_hardswish, - Module_log, - Module_abs, - ] - for mod in Modules: - after = generate_take_op.PassReplaceWithTakeOpPrimFuncs()(mod) - assert not tvm_ffi.structural_equal(after["main"], mod["main"]) - print("Passed Structural") diff --git a/tests/python/contrib/test_hexagon/test_thread_pool.py b/tests/python/contrib/test_hexagon/test_thread_pool.py deleted file mode 100644 index ae37427bc7e6..000000000000 --- a/tests/python/contrib/test_hexagon/test_thread_pool.py +++ /dev/null @@ -1,108 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Add hexagon thread pool test""" - -import numpy as np -import pytest - -import tvm -import tvm.contrib.hexagon -import tvm.script -import tvm.testing -from tvm.contrib.hexagon.session import Session -from tvm.script import tirx as T -from tvm.testing import env - -from .infrastructure import get_hexagon_target - - -@tvm.script.ir_module -class ElemwiseSumIRModule: - """IRModule definition for elementwise sum""" - - # pylint: disable=no-self-argument,invalid-name,missing-function-docstring - @T.prim_func(s_tir=True) - def elemwise_sum_serial(a: T.handle, b: T.handle, c: T.handle, n: T.int32): - T.func_attr({"global_symbol": "elemwise_sum_serial", "tirx.noalias": True}) - A = T.match_buffer(a, (n,), dtype="float32") - B = T.match_buffer(b, (n,), dtype="float32") - C = T.match_buffer(c, (n,), dtype="float32") - for i in T.serial(n): - with T.sblock("C"): - vi = T.axis.spatial(n, i) - C[vi] = A[vi] + B[vi] - - @T.prim_func(s_tir=True) - def elemwise_sum_parallel(a: T.handle, b: T.handle, c: T.handle, n: T.int32): - T.func_attr({"global_symbol": "elemwise_sum_parallel", "tirx.noalias": True}) - A = T.match_buffer(a, (n,), dtype="float32") - B = T.match_buffer(b, (n,), dtype="float32") - C = T.match_buffer(c, (n,), dtype="float32") - for i in T.parallel(n): - with T.sblock("C"): - vi = T.axis.spatial(n, i) - C[vi] = A[vi] + B[vi] - - # pylint: enable=no-self-argument,invalid-name,missing-function-docstring - - -def generate_add_test_data(hexagon_session: Session, n=128 * 1024): - a = tvm.runtime.tensor(np.random.uniform(size=n).astype("float32"), hexagon_session.device) - b = tvm.runtime.tensor(np.random.uniform(size=n).astype("float32"), hexagon_session.device) - c = tvm.runtime.tensor(np.zeros(n, dtype="float32"), hexagon_session.device) - return (a, b, c, n) - - -def benchmark_func(mod, name, args, hexagon_session): - (a, b, c, n) = args - evaluator = mod.time_evaluator(name, hexagon_session.device, number=100) - return evaluator(a, b, c, n).mean - - -@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") -def test_speedup(hexagon_session: Session, capsys): - """Test speedup""" - func = tvm.compile( - ElemwiseSumIRModule, - target=get_hexagon_target("v68"), - ) - mod = hexagon_session.load_module(func) - args = generate_add_test_data(hexagon_session) - parallel_mean = benchmark_func(mod, "elemwise_sum_parallel", args, hexagon_session) - serial_mean = benchmark_func(mod, "elemwise_sum_serial", args, hexagon_session) - - with capsys.disabled(): - print(f"... speedup of {serial_mean / parallel_mean:.2f}", end=" ") - - -@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") -def test_elemwise_sum_parallel(hexagon_session: Session): - """Test parallel elementwise sum""" - func = tvm.compile( - ElemwiseSumIRModule, - target=get_hexagon_target("v68"), - ) - mod = hexagon_session.load_module(func) - - (a, b, c, n) = generate_add_test_data(hexagon_session) - mod["elemwise_sum_parallel"](a, b, c, n) - tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy()) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/test_vtcm.py b/tests/python/contrib/test_hexagon/test_vtcm.py deleted file mode 100644 index f52fef0b9b73..000000000000 --- a/tests/python/contrib/test_hexagon/test_vtcm.py +++ /dev/null @@ -1,94 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: F401 -"""VTCM Tests""" - -import pytest - -import tvm.testing -from tvm import tirx -from tvm.script import tirx as T -from tvm.testing import env - -from .infrastructure import get_hexagon_target - - -@T.prim_func(s_tir=True) -def scale_by_two(buffer_a: T.Buffer((8192,), "int8"), buffer_c: T.Buffer((8192,), "int8")): - for i in T.serial( - 0, - 8192, - ): - with T.sblock("C"): - buffer_c[i] = buffer_a[i] * T.int8(2) - - -def get_scale_by_two_schedule(): - mod = tvm.IRModule.from_expr(scale_by_two.with_attr("global_symbol", "main")) - sch = tvm.s_tir.Schedule(mod, debug_mask="all") - block_c = sch.get_sblock("C") - (flat,) = sch.get_loops(block_c) - outer, _, _, _ = sch.split(flat, factors=[8, 4, 2, 128]) - cache_block = sch.cache_read(block_c, 0, storage_scope="global.vtcm") - sch.compute_at(cache_block, outer) - return sch - - -@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") -def test_vtcm_building(): - """Test building with vtcm mem scope""" - sch = get_scale_by_two_schedule() - target = get_hexagon_target("v68") - built = tvm.compile(sch.mod, target=target) - assert "global.vtcm" in built.inspect_source("asm") - - -@pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") -@pytest.mark.parametrize("vtcm_capacity,limited", [(8192, False), (1024, False), (128, True)]) -def test_vtcm_limit(vtcm_capacity, limited): - """Test building with vtcm mem scope limit""" - sch = get_scale_by_two_schedule() - - def _raises_exception(f): - try: - f() - except RuntimeError: - return True - return False - - target = get_hexagon_target("v68", vtcm_capacity=vtcm_capacity) - - assert _raises_exception(lambda: tvm.compile(sch.mod, target=target)) == limited, ( - "Case 1 - arg. VTCM memory allocation limiter does not work correctly " - ) - - with target: - assert _raises_exception(lambda: tvm.compile(sch.mod)) == limited, ( - "Case 2 - with.VTCM memory allocation limiter does not work correctly " - ) - - with tvm.transform.PassContext(config={"tirx.vtcm_capacity": vtcm_capacity}): - assert ( - _raises_exception( - lambda: tvm.compile(sch.mod, target=get_hexagon_target("v68", vtcm_capacity=0)) - ) - == limited - ), "Case 3 - context. VTCM memory allocation limiter does not work correctly " - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py deleted file mode 100644 index 9f42a9bbdb9f..000000000000 --- a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py +++ /dev/null @@ -1,194 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Test theoretical bandwith for data transfers to VTCM for different strategies.""" - -import numpy as np -import pytest - -import tvm -from tvm.s_tir.tensor_intrin.hexagon import DMA_READ_128_i8 -from tvm.script import tirx as T -from tvm.testing import env - -from .infrastructure import get_hexagon_target - -MB = 1024**2 -KB = 1024 -TEST_OUTPUT_TEMPLATE = ( - "Test bandwidth with buffer size {}MB... \n" - " -Base: {} GBps \n -Vectorized: {} GBps\n" - " -Vectorized and Parallelized: {} GBps\n" - " -Sync DMA: {} GBps\n" - " -Single DMA Copy: {} GBps\n" -) - - -def memcopy_operator(size): - """Generate memory copy operator.""" - - @T.prim_func(s_tir=True) - def operator(a: T.handle, a_v: T.handle) -> None: - a_buffer = T.match_buffer(a, size, dtype="int8", align=128, scope="global") - a_global_vtcm = T.match_buffer(a_v, size, dtype="int8", align=128, scope="global.vtcm") - for ax0 in T.serial(size): - with T.sblock("A_global.vtcm"): - v0_ind = T.axis.spatial(size, ax0) - T.reads(a_buffer[v0_ind]) - T.writes(a_global_vtcm[v0_ind]) - a_global_vtcm[v0_ind] = a_buffer[v0_ind] - - return operator - - -def single_dma_operator(size): - """Generate single dma operator.""" - - @T.prim_func(s_tir=True) - def operator(a: T.handle, a_v: T.handle) -> None: - a_buffer = T.match_buffer(a, size, dtype="int8", align=128, scope="global") - a_global_vtcm = T.match_buffer(a_v, size, dtype="int8", align=128, scope="global.vtcm") - T.evaluate( - T.tvm_call_packed( - "device_api.hexagon.dma_copy_dltensor", - T.tvm_stack_make_array( - a_global_vtcm.data, - T.tvm_stack_make_shape(size, dtype="handle"), - 0, - 1, - a_global_vtcm.dtype, - 0, - dtype="handle", - ), - T.tvm_stack_make_array( - a_buffer.data, - T.tvm_stack_make_shape(size, dtype="handle"), - 0, - 1, - a_buffer.dtype, - 0, - dtype="handle", - ), - T.cast(size, dtype="int"), - True, # bypass cache - dtype="int32", - ) - ) - - return operator - - -def evaluate(hexagon_session, sch, size): - """Evaluate schedule.""" - a_shape = size - - func_tir = tvm.compile(sch.mod["main"], target=get_hexagon_target("v69")) - module = hexagon_session.load_module(func_tir) - - a = np.random.randint(-128, 127, a_shape, dtype="int8") - a_vtcm = np.zeros(a_shape, dtype="int8") - - a_hexagon = tvm.runtime.tensor(a, device=hexagon_session.device, mem_scope="global") - a_vtcm_hexagon = tvm.runtime.tensor( - a_vtcm, device=hexagon_session.device, mem_scope="global.vtcm" - ) - - if tvm.testing.utils.IS_IN_CI: - # Run with reduced number and repeat for CI - timer = module.time_evaluator("main", hexagon_session.device, number=1, repeat=1) - else: - timer = module.time_evaluator("main", hexagon_session.device, number=10, repeat=10) - - runtime = timer(a_hexagon, a_vtcm_hexagon) - - gbps = round((size / 2**30) / runtime.mean, 4) - tvm.testing.assert_allclose(a_vtcm_hexagon.numpy(), a) - - return gbps - - -class TestMatMulVec: - """MatMul test class.""" - - # Removed most of these to speedup CI. - size = tvm.testing.parameter( - 128, - KB, - 10 * KB, - 100 * KB, - MB, - ) - - outer_split = tvm.testing.parameter(4) - unroll_split = tvm.testing.parameter(2) - vector_split = tvm.testing.parameter(128) - - @pytest.mark.skipif(not env.has_hexagon(), reason="need hexagon") - def test_bandwidth(self, hexagon_session, size, outer_split, unroll_split, vector_split): - """Test bandwidth.""" - - if tvm.testing.utils.IS_IN_CI and (size > 128): - pytest.skip("Skipping test since it takes too long in CI.") - - # Run the base memcopy operator. - sch = tvm.s_tir.Schedule(memcopy_operator(size)) - base_gpbs = evaluate(hexagon_session, sch, size) - - # Run with some basic unroll and vectorize scheduling. - sch = tvm.s_tir.Schedule(memcopy_operator(size)) - vtcm_block_a = sch.get_sblock("A_global.vtcm") - v_block = sch.get_loops(vtcm_block_a) - _, vio_a, vii_a = sch.split(v_block[0], factors=[None, unroll_split, vector_split]) - sch.unroll(vio_a) - sch.vectorize(vii_a) - vectorize_gbps = evaluate(hexagon_session, sch, size) - - # Run with some basic unroll and vectorize scheduling and parallelization. - sch = tvm.s_tir.Schedule(memcopy_operator(size)) - vtcm_block_a = sch.get_sblock("A_global.vtcm") - v_block = sch.get_loops(vtcm_block_a) - vbo_a, _, vio_a, vii_a = sch.split( - v_block[0], factors=[outer_split, None, unroll_split, vector_split] - ) - sch.unroll(vio_a) - sch.vectorize(vii_a) - sch.parallel(vbo_a) - parallel_gbps = evaluate(hexagon_session, sch, size) - - # Run with some basic unroll and vectorize scheduling and parallelization. - sch = tvm.s_tir.Schedule(memcopy_operator(size)) - block = sch.get_sblock("A_global.vtcm") - loops = sch.get_loops(block) - _, inner = sch.split(loops[0], [None, 128]) - sch.tensorize(inner, DMA_READ_128_i8) - # print(sch.mod.script()) - sync_dma_gbps = evaluate(hexagon_session, sch, size) - - # Run using a single dma copy to transfer the data. - sch = tvm.s_tir.Schedule(single_dma_operator(size)) - single_dma_gbps = evaluate(hexagon_session, sch, size) - - mbs = round(size / MB, 2) - print( - TEST_OUTPUT_TEMPLATE.format( - mbs, base_gpbs, vectorize_gbps, parallel_gbps, sync_dma_gbps, single_dma_gbps - ) - ) - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/testing/test_env.py b/tests/python/testing/test_env.py index bba6d983c7a1..c6d381c9652e 100644 --- a/tests/python/testing/test_env.py +++ b/tests/python/testing/test_env.py @@ -175,7 +175,7 @@ def test_llvm_min_version_is_monotone(): assert env.has_llvm_min_version(1) -def test_hexagon_run_implies_toolchain(): +def test_runtime_hexagon_run_implies_toolchain(): """Full Hexagon support implies the compile-time toolchain is present.""" if env.has_hexagon(): assert env.has_hexagon_toolchain() From b6c73cdaebd8d0243876b9e1974bce67b3bfda2d Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Wed, 17 Jun 2026 00:29:10 +0800 Subject: [PATCH 17/23] [CI] Pin GitHub Actions to SHA for ASF INFRA compliance (#19793) ## Why ASF INFRA enforces that external GitHub Actions must be pinned to a commit SHA on the approved allowlist, failing the workflow with "not allowed in apache/tvm". See the [policy](https://infra.apache.org/github-actions-policy.html) and the [approved allowlist](https://github.com/apache/infrastructure-actions/blob/main/approved_patterns.yml). ## How - Pin `pre-commit/action` to `2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd` (v3.0.1) - Pin `pypa/cibuildwheel` to `294735312765b09d24a2fbec22660ce817587d55` (v4.1.0) - Pin `pypa/gh-action-pypi-publish` to `ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e` (v1.13.0) - Leave GitHub-owned `actions/*` and the allowlisted `conda-incubator/setup-miniconda@*` pattern untouched --------- Signed-off-by: Guan-Ming (Wesley) Chiu <105915352+guan404ming@users.noreply.github.com> --- .../build-wheel-for-publish/action.yml | 2 +- .github/workflows/lint.yml | 2 +- .github/workflows/publish_wheel.yml | 2 +- apps/cpp_rpc/rpc_env.cc | 14 +-- apps/cpp_rpc/rpc_server.cc | 5 +- python/tvm/backend/cuda/op.py | 1 - python/tvm/backend/cuda/script.py | 4 +- .../contrib/tensorrt/tensorrt_builder.cc | 5 +- .../extra/contrib/tensorrt/tensorrt_ops.cc | 6 +- .../task_scheduler/task_scheduler.cc | 8 +- src/target/llvm/codegen_llvm.cc | 7 +- src/target/llvm/codegen_params.cc | 4 +- src/tirx/transform/vectorize_loop.cc | 10 +- .../codegen/test_target_codegen_riscv.py | 1 + tests/python/relax/test_frontend_onnx.py | 2 +- .../cuda/elementwise/test_unary.py | 104 ++++++++++++++---- 16 files changed, 115 insertions(+), 62 deletions(-) diff --git a/.github/actions/build-wheel-for-publish/action.yml b/.github/actions/build-wheel-for-publish/action.yml index e71844237998..d4a3ca14c263 100644 --- a/.github/actions/build-wheel-for-publish/action.yml +++ b/.github/actions/build-wheel-for-publish/action.yml @@ -108,7 +108,7 @@ runs: # ---- Build and test wheels ---- - name: Build and test wheels - uses: pypa/cibuildwheel@v4.1.0 + uses: pypa/cibuildwheel@294735312765b09d24a2fbec22660ce817587d55 # v4.1.0 with: package-dir: . output-dir: wheelhouse diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 6c17e0f149a4..16fa502aaa65 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -35,4 +35,4 @@ jobs: with: fetch-depth: 0 fetch-tags: true - - uses: pre-commit/action@v3.0.1 + - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 diff --git a/.github/workflows/publish_wheel.yml b/.github/workflows/publish_wheel.yml index 63375e606325..a2fedda9e054 100644 --- a/.github/workflows/publish_wheel.yml +++ b/.github/workflows/publish_wheel.yml @@ -213,7 +213,7 @@ jobs: - name: Publish package distributions to PyPI if: ${{ inputs.publish_repository == 'pypi' }} - uses: pypa/gh-action-pypi-publish@v1.13.0 + uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0 with: attestations: true verbose: true diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index b0b1fe4064bc..4df5f87024b0 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -158,8 +158,7 @@ std::string RPCEnv::GetPath(const std::string& file_name) const { */ void RPCEnv::CleanUp() const { CleanDir(base_); - if (!CheckPath(base_)) - return; + if (!CheckPath(base_)) return; const int ret = rmdir(base_.c_str()); if (ret != 0) { LOG(WARNING) << "Remove directory " << base_ << " failed"; @@ -325,11 +324,11 @@ std::string BuildSharedLibrary(std::string file) { */ bool CheckPath(const std::string& pathname) { #if defined(_WIN32) - DWORD attribs = GetFileAttributesA(pathname.c_str()); - return (attribs != INVALID_FILE_ATTRIBUTES); + DWORD attribs = GetFileAttributesA(pathname.c_str()); + return (attribs != INVALID_FILE_ATTRIBUTES); #else - struct stat info; - return (stat(pathname.c_str(), &info) == 0); + struct stat info; + return (stat(pathname.c_str(), &info) == 0); #endif } @@ -338,8 +337,7 @@ bool CheckPath(const std::string& pathname) { * \param dirname The name of the directory */ void CleanDir(const std::string& dirname) { - if (!CheckPath(dirname)) - return; + if (!CheckPath(dirname)) return; auto files = ListDir(dirname); for (const auto& filename : files) { std::string file_path = dirname + "/"; diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index b601478f4374..88971cc34c6c 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -210,7 +210,7 @@ class RPCServer { << ", status = " << status_second; } else if (finished_first == worker_pid) { LOG(INFO) << "Child pid=" << worker_pid << " finished" - << ", status = "<< status_first; + << ", status = " << status_first; } } else { auto pid = fork(); @@ -334,8 +334,7 @@ class RPCServer { RPCServerLoop(int(sock.sockfd)); const auto e_time = std::chrono::high_resolution_clock::now(); std::chrono::duration elapsed = e_time - s_time; - LOG(INFO) << "Finished serving " << addr.AsString() - << " after " << elapsed.count() << " sec"; + LOG(INFO) << "Finished serving " << addr.AsString() << " after " << elapsed.count() << " sec"; env.CleanUp(); } diff --git a/python/tvm/backend/cuda/op.py b/python/tvm/backend/cuda/op.py index 9570e266623c..bb3c59599e56 100644 --- a/python/tvm/backend/cuda/op.py +++ b/python/tvm/backend/cuda/op.py @@ -694,7 +694,6 @@ def ptx_mbarrier_arrive_cluster_count(bar, cta_id, count): return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, True, count) - def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None): """TVM intrinsic to call mbarrier.arrive_expect_tx.shared::cta.b64 diff --git a/python/tvm/backend/cuda/script.py b/python/tvm/backend/cuda/script.py index a46aa7e7e472..76ba87344bc3 100644 --- a/python/tvm/backend/cuda/script.py +++ b/python/tvm/backend/cuda/script.py @@ -278,9 +278,7 @@ def __init__(self): self.init = _op_wrapper(_cuda_op.ptx_mbarrier_init) self.try_wait = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait) self.try_wait_once = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait_once) - self.try_wait_acquire_cluster = _op_wrapper( - _cuda_op.ptx_mbarrier_try_wait_acquire_cluster - ) + self.try_wait_acquire_cluster = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait_acquire_cluster) self.arrive = MbarrierArriveNamespace() diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc b/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc index f0c2a26b2e66..281d64cfbc33 100644 --- a/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc +++ b/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc @@ -201,9 +201,8 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() { delete runtime; TVM_FFI_THROW(InternalError) << "Failed to deserialize the TensorRT engine."; } - TVM_FFI_ICHECK_EQ( - engine->getNbIOTensors(), - static_cast(network_input_names_.size() + network_output_names_.size())); + TVM_FFI_ICHECK_EQ(engine->getNbIOTensors(), static_cast(network_input_names_.size() + + network_output_names_.size())); nvinfer1::IExecutionContext* context = engine->createExecutionContext(); CleanUp(); diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc index d3e68778fde9..00ca3cea967c 100644 --- a/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc @@ -422,9 +422,9 @@ class DenseOpConverter : public TensorRTOpConverter { ->addConstant(VectorToTrtDims(params->inputs.at(1).weight_shape), params->inputs.at(1).weight) ->getOutput(0); - auto* matmul_layer = params->network->addMatrixMultiply( - *input_tensor, nvinfer1::MatrixOperation::kNONE, *weight_tensor, - nvinfer1::MatrixOperation::kTRANSPOSE); + auto* matmul_layer = + params->network->addMatrixMultiply(*input_tensor, nvinfer1::MatrixOperation::kNONE, + *weight_tensor, nvinfer1::MatrixOperation::kTRANSPOSE); TVM_FFI_ICHECK(matmul_layer != nullptr); params->outputs.push_back(matmul_layer->getOutput(0)); } diff --git a/src/s_tir/meta_schedule/task_scheduler/task_scheduler.cc b/src/s_tir/meta_schedule/task_scheduler/task_scheduler.cc index 76b407b5cf7c..3d7fadd40a0f 100644 --- a/src/s_tir/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/s_tir/meta_schedule/task_scheduler/task_scheduler.cc @@ -208,15 +208,13 @@ void TaskSchedulerNode::Tune(ffi::Array ctxs, ffi::Array int n_build_errs = 0; const ffi::Array& builder_results = task->builder_results.value(); for (int i = 0; i < num_candidates; i++) { - if (builder_results[i]->error_msg.has_value()) - ++n_build_errs; + if (builder_results[i]->error_msg.has_value()) ++n_build_errs; } if (n_build_errs > 0) { TVM_PY_LOG(INFO, this->logger) << "Build errors: " << n_build_errs << " sample(s)"; } - TVM_PY_LOG(INFO, this->logger) << "Sending " - << num_candidates - n_build_errs - << " valid sample(s) to runner"; + TVM_PY_LOG(INFO, this->logger) + << "Sending " << num_candidates - n_build_errs << " valid sample(s) to runner"; SendToRunner(task, runner); } else { TerminateTask(task_id); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index f32dcdde11fd..912f8ec8c086 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -269,10 +269,9 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, cons user_symbol = user_symbol.substr(std::char_traits::length(kFFISymbolPrefix)); } TVM_FFI_THROW(InternalError) << "Duplicate PrimFunc global_symbol '" << user_symbol - << "' in LLVM codegen: IRModule keys '" << it->second - << "' and '" << gvar->name_hint - << "' both lower to the same exported symbol '" << symbol_name - << "'. " + << "' in LLVM codegen: IRModule keys '" << it->second << "' and '" + << gvar->name_hint << "' both lower to the same exported symbol '" + << symbol_name << "'. " << "Each exposed PrimFunc in one IRModule must have a unique " "global_symbol."; } diff --git a/src/target/llvm/codegen_params.cc b/src/target/llvm/codegen_params.cc index 6d8684a87eda..0633c4fcb3b6 100644 --- a/src/target/llvm/codegen_params.cc +++ b/src/target/llvm/codegen_params.cc @@ -61,8 +61,8 @@ struct LLVMConstantGetter::value>> static llvm::Constant* getElement(llvm::Type* ty, T t) { return llvm::ConstantFP::get(ty, t); } }; -template ::value && std::is_trivial::value>> +template ::value && + std::is_trivial::value>> void BuildLLVMVector(llvm::Type* element_type, void* tensor_data, size_t num_elements, std::vector* elements) { elements->resize(num_elements, nullptr); diff --git a/src/tirx/transform/vectorize_loop.cc b/src/tirx/transform/vectorize_loop.cc index fe6734863bb8..e746c6ac9507 100644 --- a/src/tirx/transform/vectorize_loop.cc +++ b/src/tirx/transform/vectorize_loop.cc @@ -71,9 +71,8 @@ bool TargetHasVLA(Target target) { } bool ContainsCallNode(const Stmt& stmt) { - return CheckContains::StmtContains(stmt, [](const PrimExpr& expr) { - return expr.as() != nullptr; - }); + return CheckContains::StmtContains( + stmt, [](const PrimExpr& expr) { return expr.as() != nullptr; }); } } // namespace @@ -1067,9 +1066,8 @@ class LoopVectorizer : public StmtMutator { PrimExpr index = outer * scalable_lanes_index + inner_index; Stmt body = Substitute(op->body, {{op->loop_var, index}}); Stmt guarded_body = IfThenElse(index < fixed_extent, body, std::nullopt, op->span); - Stmt vector_loop = - For(inner, make_const(lane_dtype, 0), scalable_lanes, ForKind::kVectorized, guarded_body, - std::nullopt, op->annotations, std::nullopt, op->span); + Stmt vector_loop = For(inner, make_const(lane_dtype, 0), scalable_lanes, ForKind::kVectorized, + guarded_body, std::nullopt, op->annotations, std::nullopt, op->span); Stmt loop = For(outer, zero, num_chunks, ForKind::kSerial, vector_loop, std::nullopt, {}, std::nullopt, op->span); diff --git a/tests/python/codegen/test_target_codegen_riscv.py b/tests/python/codegen/test_target_codegen_riscv.py index 3ac75dc33745..5447fddcdebe 100644 --- a/tests/python/codegen/test_target_codegen_riscv.py +++ b/tests/python/codegen/test_target_codegen_riscv.py @@ -17,6 +17,7 @@ # ruff: noqa: E501, F841 import re + import pytest import tvm diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index db8b977efcbb..414c3d5bbfcb 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -837,7 +837,7 @@ def test_reduce_min_max_nan_preserve(op_name, x): ref_out = (np.max if op_name == "ReduceMax" else np.min)(x) tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18) - out_np = (tvm_out[0] if isinstance(tvm_out, (list, tuple)) else tvm_out).numpy() + out_np = (tvm_out[0] if isinstance(tvm_out, list | tuple) else tvm_out).numpy() np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ref_out)) if not np.isnan(ref_out): diff --git a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py index fb70b3754123..97a1be256e0a 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py @@ -1326,14 +1326,28 @@ def test_cast_wg_rejects_thread_local_view(): @T.prim_func def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: - A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) - B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + A = T.match_buffer( + A_ptr, (_SL_ROWS, _SL_COLS), "float32", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]) + ) + B = T.match_buffer( + B_ptr, (_SL_ROWS, _SL_COLS), "float16", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]) + ) T.device_entry() _bx = T.cta_id([1]) _wg = T.warpgroup_id([1]) tid = T.thread_id_in_wg([_SL_ROWS]) - src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)])) - dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)])) + src = T.alloc_buffer( + (_SL_ROWS, _SL_COLS), + "float32", + scope="local", + layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]), + ) + dst = T.alloc_buffer( + (_SL_ROWS, _SL_COLS), + "float16", + scope="local", + layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]), + ) src_row = src.local(_SL_COLS) for i in T.serial(_SL_COLS): src_row[i] = A[tid, i] @@ -1351,13 +1365,27 @@ def test_cast_cta_rejects_thread_local_view(): @T.prim_func def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: - A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) - B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + A = T.match_buffer( + A_ptr, (_SL_ROWS, _SL_COLS), "float32", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]) + ) + B = T.match_buffer( + B_ptr, (_SL_ROWS, _SL_COLS), "float16", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]) + ) T.device_entry() _bx = T.cta_id([1]) tx_var = T.thread_id([_SL_ROWS]) - src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)])) - dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)])) + src = T.alloc_buffer( + (_SL_ROWS, _SL_COLS), + "float32", + scope="local", + layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)]), + ) + dst = T.alloc_buffer( + (_SL_ROWS, _SL_COLS), + "float16", + scope="local", + layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)]), + ) src_row = src.local(_SL_COLS) for i in T.serial(_SL_COLS): src_row[i] = A[tx_var, i] @@ -1376,14 +1404,28 @@ def test_cast_wg_rejects_partial_thread_coverage(): @T.prim_func def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: - A = T.match_buffer(A_ptr, (half, _SL_COLS), "float32", layout=TileLayout(S[(half, _SL_COLS)])) - B = T.match_buffer(B_ptr, (half, _SL_COLS), "float16", layout=TileLayout(S[(half, _SL_COLS)])) + A = T.match_buffer( + A_ptr, (half, _SL_COLS), "float32", layout=TileLayout(S[(half, _SL_COLS)]) + ) + B = T.match_buffer( + B_ptr, (half, _SL_COLS), "float16", layout=TileLayout(S[(half, _SL_COLS)]) + ) T.device_entry() _bx = T.cta_id([1]) _wg = T.warpgroup_id([1]) tid = T.thread_id_in_wg([_SL_ROWS]) - src = T.alloc_buffer((half, _SL_COLS), "float32", scope="local", layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)])) - dst = T.alloc_buffer((half, _SL_COLS), "float16", scope="local", layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)])) + src = T.alloc_buffer( + (half, _SL_COLS), + "float32", + scope="local", + layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)]), + ) + dst = T.alloc_buffer( + (half, _SL_COLS), + "float16", + scope="local", + layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)]), + ) src_row = src.local(_SL_COLS) for i in T.serial(_SL_COLS): src_row[i] = A[tid, i] @@ -1401,14 +1443,28 @@ def test_cast_wg_accepts_wg_level_layout(): @T.prim_func def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: - A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) - B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + A = T.match_buffer( + A_ptr, (_SL_ROWS, _SL_COLS), "float32", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]) + ) + B = T.match_buffer( + B_ptr, (_SL_ROWS, _SL_COLS), "float16", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]) + ) T.device_entry() _bx = T.cta_id([1]) _wg = T.warpgroup_id([1]) tid = T.thread_id_in_wg([_SL_ROWS]) - src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)])) - dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local", layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)])) + src = T.alloc_buffer( + (_SL_ROWS, _SL_COLS), + "float32", + scope="local", + layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]), + ) + dst = T.alloc_buffer( + (_SL_ROWS, _SL_COLS), + "float16", + scope="local", + layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]), + ) src_row = src.local(_SL_COLS) for i in T.serial(_SL_COLS): src_row[i] = A[tid, i] @@ -1425,13 +1481,21 @@ def test_cast_thread_accepts_local_view(): @T.prim_func def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: - A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) - B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])) + A = T.match_buffer( + A_ptr, (_SL_ROWS, _SL_COLS), "float32", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]) + ) + B = T.match_buffer( + B_ptr, (_SL_ROWS, _SL_COLS), "float16", layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]) + ) T.device_entry() _bx = T.cta_id([1]) tx_var = T.thread_id([_SL_ROWS]) - src = T.alloc_buffer((_SL_COLS,), "float32", scope="local", layout=TileLayout(S[(_SL_COLS,)])) - dst = T.alloc_buffer((_SL_COLS,), "float16", scope="local", layout=TileLayout(S[(_SL_COLS,)])) + src = T.alloc_buffer( + (_SL_COLS,), "float32", scope="local", layout=TileLayout(S[(_SL_COLS,)]) + ) + dst = T.alloc_buffer( + (_SL_COLS,), "float16", scope="local", layout=TileLayout(S[(_SL_COLS,)]) + ) for i in T.serial(_SL_COLS): src[i] = A[tx_var, i] Tx.cast(dst, src) From a335a14e8fd2d424f8a655dfc56bee9fa9df585f Mon Sep 17 00:00:00 2001 From: Thomas Steiner Date: Tue, 16 Jun 2026 18:30:45 +0200 Subject: [PATCH 18/23] [Web] use singular requestFileHandle() instead of requestFileHandles() (#19780) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Switches from the deprecated plural `requestFileHandles([hash])` to the new singular `requestFileHandle(hash)` API, which returns a handle directly rather than a single-element array. Also updates the interface definition and removes the now-redundant `handles[0]` indexing. The rename was adopted in the COS spec after a survey of all known real-world implementations found that every call site passed a single-element array and immediately indexed `[0]` — no implementation ever used the plural form as a batch. FYI guan404ming CharlieFRuan — as reviewers of the original COS PR ([#18893](https://github.com/apache/tvm/pull/18893)). See https://github.com/WICG/cross-origin-storage/issues/61 for details. --- web/src/artifact_cache.ts | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/web/src/artifact_cache.ts b/web/src/artifact_cache.ts index 2d66fd0c0a02..35cc918f0225 100644 --- a/web/src/artifact_cache.ts +++ b/web/src/artifact_cache.ts @@ -119,10 +119,10 @@ interface CrossOriginStorageWritable { } interface CrossOriginStorageAPI { - requestFileHandles( - descriptors: CrossOriginHashDescriptor[], + requestFileHandle( + descriptor: CrossOriginHashDescriptor, options?: CrossOriginStorageRequestFileHandleOptions, - ): Promise; + ): Promise; } declare global { @@ -169,8 +169,7 @@ class CrossOriginStorage { if (!api) { return undefined; } - const handles = await api.requestFileHandles([hash]); - const handle = handles[0]; + const handle = await api.requestFileHandle(hash); if (!handle) { return undefined; } @@ -189,10 +188,9 @@ class CrossOriginStorage { if (!api) { throw new Error("Cross-origin storage API unavailable."); } - const handles = await api.requestFileHandles([hash], { create: true }); - const handle = handles[0]; + const handle = await api.requestFileHandle(hash, { create: true }); if (!handle) { - throw new Error("Cross-origin storage API returned no handles."); + throw new Error("Cross-origin storage API returned no handle."); } const writableStream = await handle.createWritable(); await writableStream.write(blob); From b1b95b1e1ed8aa71129b02c996dcc1a7f05bbeb9 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 16 Jun 2026 13:03:34 -0400 Subject: [PATCH 19/23] [REFACTOR][IR] Simplify CallingConv attribute access (#19799) CallingConv already participates in TVM FFI integral enum conversion, so keeping call sites on manual integer casts adds noise without changing behavior. This was not possible before the TVM FFI Any support but now we natively support enum class int value conversion with Any, so we can simplify the codepath Main changes: - Read `tvm::attr::kCallingConv` as `CallingConv` directly - Compare optional/defaulted values against `CallingConv` enum values - Store CallingConv enum values directly where the cleanup touches attr writes --- src/backend/cuda/codegen/codegen_cuda.cc | 21 +++++++++----------- src/backend/metal/codegen/codegen_metal.cc | 10 ++++++---- src/backend/opencl/codegen/codegen_opencl.cc | 10 ++++++---- src/backend/vulkan/codegen/spirv_utils.cc | 10 ++++++---- src/backend/webgpu/codegen/codegen_webgpu.cc | 10 ++++++---- src/tirx/analysis/verify_memory.cc | 4 ++-- src/tirx/transform/make_packed_api.cc | 13 ++++++------ src/tirx/transform/split_host_device.cc | 8 ++++---- 8 files changed, 45 insertions(+), 41 deletions(-) diff --git a/src/backend/cuda/codegen/codegen_cuda.cc b/src/backend/cuda/codegen/codegen_cuda.cc index aa2ef63b149a..e04541a73da4 100644 --- a/src/backend/cuda/codegen/codegen_cuda.cc +++ b/src/backend/cuda/codegen/codegen_cuda.cc @@ -160,16 +160,15 @@ void CodeGenCUDA::Init(bool output_ssa) { void CodeGenCUDA::PrintFunctionSignature(const ffi::String& function_name, const PrimFunc& func, std::ostream& os) { - int64_t calling_conv = func->GetAttr(tvm::attr::kCallingConv, - static_cast(tvm::CallingConv::kDefault)) - .value(); - if (calling_conv == static_cast(CallingConv::kDeviceKernelLaunch)) { + CallingConv calling_conv = + func->GetAttr(tvm::attr::kCallingConv, CallingConv::kDefault).value(); + if (calling_conv == CallingConv::kDeviceKernelLaunch) { os << "extern \"C\" __global__ "; - } else if (calling_conv == static_cast(CallingConv::kDefault)) { + } else if (calling_conv == CallingConv::kDefault) { os << "extern \"C\" __device__ "; } else { TVM_FFI_THROW(InternalError) << "Unsupported calling convention for cuda codegen: " - << calling_conv; + << static_cast(calling_conv); } CodeGenC::PrintFunctionSignature(function_name, func, os); } @@ -2107,12 +2106,10 @@ ffi::Module BuildCUDA(IRModule mod, Target target) { for (auto [gvar, base_func] : mod->functions) { TVM_FFI_ICHECK(base_func->IsInstance()) << "CodeGenCUDA: Can only take PrimFunc"; auto prim_func = Downcast(base_func); - int64_t calling_conv = prim_func - ->GetAttr(tvm::attr::kCallingConv, - static_cast(tvm::CallingConv::kDefault)) - .value(); - TVM_FFI_ICHECK(calling_conv == static_cast(CallingConv::kDeviceKernelLaunch) || - calling_conv == static_cast(CallingConv::kDefault)) + CallingConv calling_conv = + prim_func->GetAttr(tvm::attr::kCallingConv, CallingConv::kDefault).value(); + TVM_FFI_ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch || + calling_conv == CallingConv::kDefault) << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch or " "CallingConv::kDefault"; functions.Set(gvar, prim_func); diff --git a/src/backend/metal/codegen/codegen_metal.cc b/src/backend/metal/codegen/codegen_metal.cc index b68840f32752..17668a4867b3 100644 --- a/src/backend/metal/codegen/codegen_metal.cc +++ b/src/backend/metal/codegen/codegen_metal.cc @@ -474,10 +474,12 @@ ffi::Module BuildMetal(IRModule mod, Target target) { CodeGenMetal cg(target); cg.Init(output_ssa); auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - TVM_FFI_ICHECK(calling_conv.has_value() && - calling_conv.value() == static_cast(CallingConv::kDeviceKernelLaunch)) - << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + TVM_FFI_ICHECK(calling_conv.has_value()) + << "CodeGenMetal: expected kCallingConv attribute to be set."; + TVM_FFI_ICHECK(calling_conv.value() == CallingConv::kDeviceKernelLaunch) + << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch, but got " + << static_cast(calling_conv.value()); cg.AddFunction(kv.first, f); diff --git a/src/backend/opencl/codegen/codegen_opencl.cc b/src/backend/opencl/codegen/codegen_opencl.cc index 5bad02e55824..a5a94c41da89 100644 --- a/src/backend/opencl/codegen/codegen_opencl.cc +++ b/src/backend/opencl/codegen/codegen_opencl.cc @@ -689,10 +689,12 @@ ffi::Module BuildOpenCL(IRModule mod, Target target) { TVM_FFI_ICHECK(base_func->IsInstance()) << "CodeGenOpenCL: Can only take PrimFunc"; auto prim_func = Downcast(base_func); - auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); - TVM_FFI_ICHECK(calling_conv.has_value() && - calling_conv.value() == static_cast(CallingConv::kDeviceKernelLaunch)) - << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + auto calling_conv = prim_func->GetAttr(tvm::attr::kCallingConv); + TVM_FFI_ICHECK(calling_conv.has_value()) + << "CodeGenOpenCL: expected kCallingConv attribute to be set."; + TVM_FFI_ICHECK(calling_conv.value() == CallingConv::kDeviceKernelLaunch) + << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch, but got " + << static_cast(calling_conv.value()); functions.Set(gvar, prim_func); } diff --git a/src/backend/vulkan/codegen/spirv_utils.cc b/src/backend/vulkan/codegen/spirv_utils.cc index 11aecf1c43d3..6ee872a33afd 100644 --- a/src/backend/vulkan/codegen/spirv_utils.cc +++ b/src/backend/vulkan/codegen/spirv_utils.cc @@ -124,10 +124,12 @@ std::pair, std::string> Lo for (auto kv : mod->functions) { TVM_FFI_ICHECK(kv.second->IsInstance()) << "CodeGenSPIRV: Can only take PrimFunc"; auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - TVM_FFI_ICHECK(calling_conv.has_value() && - calling_conv.value() == static_cast(CallingConv::kDeviceKernelLaunch)) - << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + TVM_FFI_ICHECK(calling_conv.has_value()) + << "CodeGenSPIRV: expected kCallingConv attribute to be set."; + TVM_FFI_ICHECK(calling_conv.value() == CallingConv::kDeviceKernelLaunch) + << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch, but got " + << static_cast(calling_conv.value()); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); TVM_FFI_ICHECK(global_symbol.has_value()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/backend/webgpu/codegen/codegen_webgpu.cc b/src/backend/webgpu/codegen/codegen_webgpu.cc index 08c75ed8404b..9e7d2f5e84b5 100644 --- a/src/backend/webgpu/codegen/codegen_webgpu.cc +++ b/src/backend/webgpu/codegen/codegen_webgpu.cc @@ -760,10 +760,12 @@ ffi::Module BuildWebGPU(IRModule mod, Target target) { TVM_FFI_ICHECK(kv.second->IsInstance()) << "CodeGenWebGPU: Can only take PrimFunc"; auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - TVM_FFI_ICHECK(calling_conv.has_value() && - calling_conv.value() == static_cast(CallingConv::kDeviceKernelLaunch)) - << "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + TVM_FFI_ICHECK(calling_conv.has_value()) + << "CodeGenWebGPU: expected kCallingConv attribute to be set."; + TVM_FFI_ICHECK(calling_conv.value() == CallingConv::kDeviceKernelLaunch) + << "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch, but got " + << static_cast(calling_conv.value()); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); TVM_FFI_ICHECK(global_symbol.has_value()) << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/tirx/analysis/verify_memory.cc b/src/tirx/analysis/verify_memory.cc index aa1a19cf0ec5..2c3396480cd1 100644 --- a/src/tirx/analysis/verify_memory.cc +++ b/src/tirx/analysis/verify_memory.cc @@ -177,8 +177,8 @@ std::vector VerifyMemory_(const PrimFunc& func) { << "' for primitive:" << std::endl << func; - if (func->GetAttr(tvm::attr::kCallingConv, static_cast(CallingConv::kDefault)) - .value() == static_cast(CallingConv::kDefault)) { + if (func->GetAttr(tvm::attr::kCallingConv, CallingConv::kDefault).value() == + CallingConv::kDefault) { MemoryAccessVerifier v(func, target.value()->GetTargetDeviceType()); v.Run(); return v.Errors(); diff --git a/src/tirx/transform/make_packed_api.cc b/src/tirx/transform/make_packed_api.cc index 4f8229080f9c..2d4eb80f03e7 100644 --- a/src/tirx/transform/make_packed_api.cc +++ b/src/tirx/transform/make_packed_api.cc @@ -178,8 +178,8 @@ class SubroutineCallRewriter : public StmtExprMutator { ffi::Optional RequiresPackedAPI(const PrimFunc& func) { // A function with an explicit calling convention has already been // lowered, and should not be modified. - if (auto opt = func->GetAttr(tvm::attr::kCallingConv)) { - if (CallingConv(opt.value()) != CallingConv::kDefault) { + if (auto opt = func->GetAttr(tvm::attr::kCallingConv)) { + if (opt.value() != CallingConv::kDefault) { return std::nullopt; } } @@ -244,11 +244,10 @@ PrimFunc MakePackedAPI(PrimFunc func) { ffi::Array args{v_self_handle, v_packed_args, v_num_packed_args, v_result}; // reset global symbol to attach prefix - func = WithAttrs( - std::move(func), - {{tvm::attr::kCallingConv, static_cast(CallingConv::kCPackedFunc)}, - {tvm::attr::kTarget, target_host}, - {tvm::attr::kGlobalSymbol, ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.value()}}); + func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, CallingConv::kCPackedFunc}, + {tvm::attr::kTarget, target_host}, + {tvm::attr::kGlobalSymbol, + ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.value()}}); Stmt body = ReturnRewriter(v_result)(func_ptr->body); body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope, diff --git a/src/tirx/transform/split_host_device.cc b/src/tirx/transform/split_host_device.cc index acc5e473afb8..079309db3f95 100644 --- a/src/tirx/transform/split_host_device.cc +++ b/src/tirx/transform/split_host_device.cc @@ -494,10 +494,10 @@ class DeviceKernelMutator : public StmtExprMutator { write_ptr->body = ReturnRemover::Apply(write_ptr->body); } - func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, - static_cast(tvm::CallingConv::kDeviceKernelLaunch)}, - {tvm::tirx::attr::kKernelLaunchParams, info.launch_params}, - {tvm::attr::kGlobalSymbol, info.global_symbol}}); + func = WithAttrs(std::move(func), + {{tvm::attr::kCallingConv, tvm::CallingConv::kDeviceKernelLaunch}, + {tvm::tirx::attr::kKernelLaunchParams, info.launch_params}, + {tvm::attr::kGlobalSymbol, info.global_symbol}}); } else if (is_call_extern && !func->GetAttr(tvm::attr::kGlobalSymbol)) { func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint); From 9112a17d84deb389f0d0d593ebf0926620ea56a0 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 16 Jun 2026 15:33:06 -0400 Subject: [PATCH 20/23] [CI] Remove Jenkins PR linter step (#19798) The Jenkins PR title/body linter is comparatively heavy and can report false positives before the normal CI signal is available. This removes the check_pr step from the Jenkins prepare flow and drops the now-unused script-level test coverage. --- ci/jenkins/generated/arm_jenkinsfile.groovy | 21 +-- ci/jenkins/generated/cpu_jenkinsfile.groovy | 21 +-- .../generated/docker_jenkinsfile.groovy | 21 +-- ci/jenkins/generated/gpu_jenkinsfile.groovy | 21 +-- ci/jenkins/generated/wasm_jenkinsfile.groovy | 21 +-- ci/jenkins/templates/utils/Prepare.groovy.j2 | 19 --- ci/scripts/jenkins/check_pr.py | 143 ------------------ .../cuda/operator/intrinsics/header.py | 6 +- .../cuda/codegen/literal/cuda_half_t.h | 6 +- tests/python/ci/test_ci.py | 43 ------ 10 files changed, 11 insertions(+), 311 deletions(-) delete mode 100755 ci/scripts/jenkins/check_pr.py diff --git a/ci/jenkins/generated/arm_jenkinsfile.groovy b/ci/jenkins/generated/arm_jenkinsfile.groovy index 0cbed4cb2805..99a13e85fc97 100644 --- a/ci/jenkins/generated/arm_jenkinsfile.groovy +++ b/ci/jenkins/generated/arm_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2026-06-09T19:52:01.246622 +// Generated at 2026-06-16T13:42:33.948016 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -339,31 +339,12 @@ def should_skip_ci(pr_number) { return git_skip_ci_code == 0 } -def check_pr(pr_number) { - if (env.BRANCH_NAME == null || !env.BRANCH_NAME.startsWith('PR-')) { - // never skip CI on build sourced from a branch - return false - } - withCredentials([string( - credentialsId: 'tvm-bot-jenkins-reader', - variable: 'GITHUB_TOKEN', - )]) { - sh ( - script: "python3 ${jenkins_scripts_root}/check_pr.py --pr ${pr_number}", - label: 'Check PR title and body', - ) - } - -} - def prepare(node_type) { stage('Prepare') { node(node_type) { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/prepare") { init_git() - check_pr(env.CHANGE_ID) - if (env.DETERMINE_DOCKER_IMAGES == 'yes') { sh( script: "./${jenkins_scripts_root}/determine_docker_images.py ci_arm ci_cpu ci_gpu ci_lint ci_wasm ", diff --git a/ci/jenkins/generated/cpu_jenkinsfile.groovy b/ci/jenkins/generated/cpu_jenkinsfile.groovy index 584bb8db92f3..a65edd3f7da8 100644 --- a/ci/jenkins/generated/cpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/cpu_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2026-06-09T19:52:01.232631 +// Generated at 2026-06-16T13:42:33.935741 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -339,31 +339,12 @@ def should_skip_ci(pr_number) { return git_skip_ci_code == 0 } -def check_pr(pr_number) { - if (env.BRANCH_NAME == null || !env.BRANCH_NAME.startsWith('PR-')) { - // never skip CI on build sourced from a branch - return false - } - withCredentials([string( - credentialsId: 'tvm-bot-jenkins-reader', - variable: 'GITHUB_TOKEN', - )]) { - sh ( - script: "python3 ${jenkins_scripts_root}/check_pr.py --pr ${pr_number}", - label: 'Check PR title and body', - ) - } - -} - def prepare(node_type) { stage('Prepare') { node(node_type) { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/prepare") { init_git() - check_pr(env.CHANGE_ID) - if (env.DETERMINE_DOCKER_IMAGES == 'yes') { sh( script: "./${jenkins_scripts_root}/determine_docker_images.py ci_arm ci_cpu ci_gpu ci_lint ci_wasm ", diff --git a/ci/jenkins/generated/docker_jenkinsfile.groovy b/ci/jenkins/generated/docker_jenkinsfile.groovy index b64fd4bb018c..19b785fd2aa1 100644 --- a/ci/jenkins/generated/docker_jenkinsfile.groovy +++ b/ci/jenkins/generated/docker_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2026-06-09T19:52:01.257919 +// Generated at 2026-06-16T13:42:33.958277 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -339,31 +339,12 @@ def should_skip_ci(pr_number) { return git_skip_ci_code == 0 } -def check_pr(pr_number) { - if (env.BRANCH_NAME == null || !env.BRANCH_NAME.startsWith('PR-')) { - // never skip CI on build sourced from a branch - return false - } - withCredentials([string( - credentialsId: 'tvm-bot-jenkins-reader', - variable: 'GITHUB_TOKEN', - )]) { - sh ( - script: "python3 ${jenkins_scripts_root}/check_pr.py --pr ${pr_number}", - label: 'Check PR title and body', - ) - } - -} - def prepare(node_type) { stage('Prepare') { node(node_type) { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/prepare") { init_git() - check_pr(env.CHANGE_ID) - if (env.DETERMINE_DOCKER_IMAGES == 'yes') { sh( script: "./${jenkins_scripts_root}/determine_docker_images.py ci_arm ci_cpu ci_gpu ci_lint ci_wasm ", diff --git a/ci/jenkins/generated/gpu_jenkinsfile.groovy b/ci/jenkins/generated/gpu_jenkinsfile.groovy index 772639ee1ef1..8c04f811db36 100644 --- a/ci/jenkins/generated/gpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/gpu_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2026-06-09T19:52:01.271485 +// Generated at 2026-06-16T13:42:33.970141 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -339,31 +339,12 @@ def should_skip_ci(pr_number) { return git_skip_ci_code == 0 } -def check_pr(pr_number) { - if (env.BRANCH_NAME == null || !env.BRANCH_NAME.startsWith('PR-')) { - // never skip CI on build sourced from a branch - return false - } - withCredentials([string( - credentialsId: 'tvm-bot-jenkins-reader', - variable: 'GITHUB_TOKEN', - )]) { - sh ( - script: "python3 ${jenkins_scripts_root}/check_pr.py --pr ${pr_number}", - label: 'Check PR title and body', - ) - } - -} - def prepare(node_type) { stage('Prepare') { node(node_type) { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/prepare") { init_git() - check_pr(env.CHANGE_ID) - if (env.DETERMINE_DOCKER_IMAGES == 'yes') { sh( script: "./${jenkins_scripts_root}/determine_docker_images.py ci_arm ci_cpu ci_gpu ci_lint ci_wasm ", diff --git a/ci/jenkins/generated/wasm_jenkinsfile.groovy b/ci/jenkins/generated/wasm_jenkinsfile.groovy index 28e4462ae9cf..c7ab9a2cc974 100644 --- a/ci/jenkins/generated/wasm_jenkinsfile.groovy +++ b/ci/jenkins/generated/wasm_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2026-06-09T19:52:01.285310 +// Generated at 2026-06-16T13:42:33.982453 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -339,31 +339,12 @@ def should_skip_ci(pr_number) { return git_skip_ci_code == 0 } -def check_pr(pr_number) { - if (env.BRANCH_NAME == null || !env.BRANCH_NAME.startsWith('PR-')) { - // never skip CI on build sourced from a branch - return false - } - withCredentials([string( - credentialsId: 'tvm-bot-jenkins-reader', - variable: 'GITHUB_TOKEN', - )]) { - sh ( - script: "python3 ${jenkins_scripts_root}/check_pr.py --pr ${pr_number}", - label: 'Check PR title and body', - ) - } - -} - def prepare(node_type) { stage('Prepare') { node(node_type) { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/prepare") { init_git() - check_pr(env.CHANGE_ID) - if (env.DETERMINE_DOCKER_IMAGES == 'yes') { sh( script: "./${jenkins_scripts_root}/determine_docker_images.py ci_arm ci_cpu ci_gpu ci_lint ci_wasm ", diff --git a/ci/jenkins/templates/utils/Prepare.groovy.j2 b/ci/jenkins/templates/utils/Prepare.groovy.j2 index 6770fab24850..b74e7f4ba514 100644 --- a/ci/jenkins/templates/utils/Prepare.groovy.j2 +++ b/ci/jenkins/templates/utils/Prepare.groovy.j2 @@ -221,31 +221,12 @@ def should_skip_ci(pr_number) { return git_skip_ci_code == 0 } -def check_pr(pr_number) { - if (env.BRANCH_NAME == null || !env.BRANCH_NAME.startsWith('PR-')) { - // never skip CI on build sourced from a branch - return false - } - withCredentials([string( - credentialsId: 'tvm-bot-jenkins-reader', - variable: 'GITHUB_TOKEN', - )]) { - sh ( - script: "python3 ${jenkins_scripts_root}/check_pr.py --pr ${pr_number}", - label: 'Check PR title and body', - ) - } - -} - def prepare(node_type) { stage('Prepare') { node(node_type) { ws("workspace/exec_${env.EXECUTOR_NUMBER}/tvm/prepare") { init_git() - check_pr(env.CHANGE_ID) - if (env.DETERMINE_DOCKER_IMAGES == 'yes') { sh( script: "./${jenkins_scripts_root}/determine_docker_images.py {% for image in images %}{{ image.name }} {% endfor %}", diff --git a/ci/scripts/jenkins/check_pr.py b/ci/scripts/jenkins/check_pr.py deleted file mode 100755 index 683c3fdd6ddb..000000000000 --- a/ci/scripts/jenkins/check_pr.py +++ /dev/null @@ -1,143 +0,0 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E501 -import argparse -import json -import os -import re -import textwrap -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any - -from cmd_utils import init_log, tags_from_title -from git_utils import GitHubRepo, git, parse_remote - -GITHUB_USERNAME_REGEX = re.compile(r"(@[a-zA-Z0-9-]+)", flags=re.MULTILINE) -OK = object() -FAIL = object() - - -@dataclass -class Check: - # check to run, returning OK means it passed, anything else means it failed - check: Callable[[str], Any] - - # function to call to generate the error message - error_fn: Callable[[Any], str] - - -def non_empty(s: str): - if len(s) == 0: - return FAIL - return OK - - -def usernames(s: str): - m = GITHUB_USERNAME_REGEX.findall(s) - return m if m else OK - - -def tags(s: str): - items = tags_from_title(s) - if len(items) == 0: - return FAIL - return OK - - -def trailing_period(s: str): - if s.endswith("."): - return FAIL - return OK - - -title_checks = [ - Check(check=non_empty, error_fn=lambda d: "PR must have a title but title was empty"), - Check(check=trailing_period, error_fn=lambda d: "PR must not end in a tailing '.'"), - Check( - check=usernames, - error_fn=lambda d: f"PR title must not tag anyone but found these usernames: {d}", - ), -] -body_checks = [ - Check(check=non_empty, error_fn=lambda d: "PR must have a body but body was empty"), - Check( - check=usernames, - error_fn=lambda d: f"PR body must not tag anyone but found these usernames: {d}", - ), -] - - -def run_checks(checks: list[Check], s: str, name: str) -> bool: - print(f"Running checks for {name}") - print(textwrap.indent(s, prefix=" ")) - passed = True - print(" Checks:") - for i, check in enumerate(checks): - result = check.check(s) - if result == OK: - print(f" [{i + 1}] {check.check.__name__}: PASSED") - else: - passed = False - msg = check.error_fn(result) - print(f" [{i + 1}] {check.check.__name__}: FAILED: {msg}") - - return passed - - -if __name__ == "__main__": - init_log() - help = "Check a PR's title and body for conformance to guidelines" - parser = argparse.ArgumentParser(description=help) - parser.add_argument("--pr", required=True) - parser.add_argument("--remote", default="origin", help="ssh remote to parse") - parser.add_argument( - "--pr-data", help="(testing) PR data to use instead of fetching from GitHub" - ) - args = parser.parse_args() - - try: - pr = int(args.pr) - except ValueError: - print(f"PR was not a number: {args.pr}") - exit(0) - - if args.pr_data: - pr = json.loads(args.pr_data) - else: - remote = git(["config", "--get", f"remote.{args.remote}.url"]) - user, repo = parse_remote(remote) - - github = GitHubRepo(token=os.environ["GITHUB_TOKEN"], user=user, repo=repo) - pr = github.get(f"pulls/{args.pr}") - - body = "" if pr["body"] is None else pr["body"].strip() - title = "" if pr["title"] is None else pr["title"].strip() - - title_passed = run_checks(checks=title_checks, s=title, name="PR title") - print("") - body_passed = run_checks(checks=body_checks, s=body, name="PR body") - - if title_passed and body_passed: - print("All checks passed!") - exit(0) - else: - print( - "Some checks failed, please review the logs above and edit your PR on GitHub accordingly" - ) - exit(1) diff --git a/python/tvm/backend/cuda/operator/intrinsics/header.py b/python/tvm/backend/cuda/operator/intrinsics/header.py index 848c3bd0ecf5..2330f75a4f30 100644 --- a/python/tvm/backend/cuda/operator/intrinsics/header.py +++ b/python/tvm/backend/cuda/operator/intrinsics/header.py @@ -343,7 +343,7 @@ def header_generator(tags): TVec2 hi_half2 = *reinterpret_cast(&z); __nv_fp8x2_e4m3 lo_part(lo_half2), hi_part(hi_half2); result.__x = - (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); + (static_cast(lo_part.__x) | (static_cast(hi_part.__x) << 16)); return result; } __host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e5m2& fp8x4) { @@ -363,7 +363,7 @@ def header_generator(tags): TVec2 hi_half2 = *reinterpret_cast(&z); __nv_fp8x2_e5m2 lo_part(lo_half2), hi_part(hi_half2); result.__x = - (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); + (static_cast(lo_part.__x) | (static_cast(hi_part.__x) << 16)); return result; } __host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e8m0& fp8x4) { @@ -383,7 +383,7 @@ def header_generator(tags): TVec2 hi_half2 = *reinterpret_cast(&z); __nv_fp8x2_e8m0 lo_part(lo_half2), hi_part(hi_half2); result.__x = - (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); + (static_cast(lo_part.__x) | (static_cast(hi_part.__x) << 16)); return result; } """ diff --git a/src/backend/cuda/codegen/literal/cuda_half_t.h b/src/backend/cuda/codegen/literal/cuda_half_t.h index 78ee0298be24..d55453cba468 100644 --- a/src/backend/cuda/codegen/literal/cuda_half_t.h +++ b/src/backend/cuda/codegen/literal/cuda_half_t.h @@ -435,7 +435,7 @@ struct __align__(8) half4_bfloat164 { TVec2 hi_half2 = *reinterpret_cast(&z); __nv_fp8x2_e4m3 lo_part(lo_half2), hi_part(hi_half2); result.__x = - (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); + (static_cast(lo_part.__x) | (static_cast(hi_part.__x) << 16)); return result; } __host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e5m2& fp8x4) { @@ -455,7 +455,7 @@ struct __align__(8) half4_bfloat164 { TVec2 hi_half2 = *reinterpret_cast(&z); __nv_fp8x2_e5m2 lo_part(lo_half2), hi_part(hi_half2); result.__x = - (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); + (static_cast(lo_part.__x) | (static_cast(hi_part.__x) << 16)); return result; } __host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e8m0& fp8x4) { @@ -475,7 +475,7 @@ struct __align__(8) half4_bfloat164 { TVec2 hi_half2 = *reinterpret_cast(&z); __nv_fp8x2_e8m0 lo_part(lo_half2), hi_part(hi_half2); result.__x = - (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); + (static_cast(lo_part.__x) | (static_cast(hi_part.__x) << 16)); return result; } )"; diff --git a/tests/python/ci/test_ci.py b/tests/python/ci/test_ci.py index 251143c4306f..4ff8be49d2bd 100644 --- a/tests/python/ci/test_ci.py +++ b/tests/python/ci/test_ci.py @@ -1403,48 +1403,5 @@ def test_should_rebuild_docker(tmpdir_factory, changed_files, name, check, expec assert proc.returncode == expected_code -@parameterize_named( - passing=dict( - title="[something] a change", - body="something", - expected="All checks passed", - expected_code=0, - ), - period=dict( - title="[something] a change.", - body="something", - expected="trailing_period: FAILED", - expected_code=1, - ), - empty_body=dict( - title="[something] a change", - body=None, - expected="non_empty: FAILED", - expected_code=1, - ), -) -def test_pr_linter(title, body, expected, expected_code): - """ - Test the PR linter - """ - tag_script = JENKINS_SCRIPT_ROOT / "check_pr.py" - pr_data = { - "title": title, - "body": body, - } - proc = run_script( - [ - tag_script, - "--pr", - 1234, - "--pr-data", - json.dumps(pr_data), - ], - check=False, - ) - assert proc.returncode == expected_code - assert_in(expected, proc.stdout) - - if __name__ == "__main__": tvm.testing.main() From d4dcb7075571564e3181b64864937a7e7565f9ce Mon Sep 17 00:00:00 2001 From: Felix Hirwa Nshuti Date: Tue, 16 Jun 2026 22:21:10 +0200 Subject: [PATCH 21/23] [Relax][Frontend][TFLite] Add support for FFT/complex operators: REAL, IMAG, COMPLEX_ABS (#19763) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Part of https://github.com/apache/tvm/issues/19519 This PR adds support for the FFT and complex operator family in the Relax TFLite frontend. **Key implementations:** - Registered `REAL`, `IMAG`, `COMPLEX_ABS`to the TFLite op map. - Implemented `convert_real` and `convert_imag` which extract the real and imaginary parts of a complex tensor via `strided_slice` + `squeeze` along the last axis. - Implemented `convert_complex_abs` which computes `sqrt(re^2 + im^2)` using elementwise Relax ops. - All three ops adopt a unified representation convention: TFLite `complex64` tensors (which have no native Relax dtype equivalent) are represented as `float32[..., 2]`, where the last axis holds `(real, imaginary)` interleaved.. **Out of scope:** - `RFFT2D` is not registered in this PR. An O(N²) matmul decomposition is feasible using existing Relax ops and will be contributed separately with benchmarks showing the performance gap versus a native FFT op. A native `relax.op.signal.rfft2d` is tracked in https://github.com/apache/tvm/issues/19764 **Testing:** - Added structural equality tests for `REAL`, `IMAG`, and `COMPLEX_ABS` in `test_frontend_tflite.py` following the `verify(TestClass, Expected)` pattern. ```bash python3 -m pytest tests/python/relax/test_frontend_tflite.py -k "test_real or test_imag or test_complex_abs" ``` --- .../relax/frontend/tflite/tflite_frontend.py | 65 ++++++++++++- tests/python/relax/test_frontend_tflite.py | 94 +++++++++++++++++++ 2 files changed, 158 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index d14643d75c60..3bd87d0af414 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -207,6 +207,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "BROADCAST_ARGS": self.convert_broadcast_args, "CALL": self.convert_call, "CALL_ONCE": self.convert_call_once, + "COMPLEX_ABS": self.convert_complex_abs, "CAST": self.convert_cast, "CEIL": functools.partial(self._convert_unary_elemwise, relax_op=_op.ceil), "CONCATENATION": self.convert_concatenation, @@ -252,6 +253,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "HASHTABLE_LOOKUP": self.convert_hashtable_lookup, "HASHTABLE_SIZE": self.convert_hashtable_size, "IF": self.convert_if, + "IMAG": self.convert_imag, "L2_NORMALIZATION": self.convert_l2_normalization, "L2_POOL_2D": functools.partial(self.convert_pool2d, pool_type="l2"), "LEAKY_RELU": self.convert_leaky_relu, @@ -295,6 +297,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "RANDOM_STANDARD_NORMAL": self.convert_random_standard_normal, "RANDOM_UNIFORM": self.convert_random_uniform, "READ_VARIABLE": self.convert_read_variable, + "REAL": self.convert_real, "REDUCE_ALL": functools.partial(self._convert_reduce_bool, relax_op=_op.min), "REDUCE_ANY": functools.partial(self._convert_reduce_bool, relax_op=_op.max), "REDUCE_MAX": functools.partial(self._convert_reduce, relax_op=_op.max), @@ -1006,6 +1009,7 @@ def get_tensor_type_as_numpy(self, tensor_wrapper): TensorType.UINT32: np.uint32, TensorType.UINT64: np.uint64, TensorType.BOOL: np.bool_, + TensorType.COMPLEX64: np.complex64, }[tensor_wrapper.tensor.Type()] # pylint: disable=no-else-return @@ -1051,6 +1055,8 @@ def get_tensor_type_str(self, tensor_type): return "uint64" if tensor_type == TensorType.BOOL: return "bool" + if tensor_type == TensorType.COMPLEX64: + return "complex64" raise NotImplementedError(f"Tensor type {tensor_type!s} is currently not supported") def _get_shape_expr_from_tensor(self, shape_tensor, prefix): @@ -7580,6 +7586,52 @@ def convert_fake_quant(self, op): rounded = relax.op.floor(_op.add(_op.multiply(clamped_shifted, inv_scale), half)) return relax.op.add(_op.multiply(rounded, scale_expr), nudged_min_expr) + def convert_real(self, op): + """Convert TFLite REAL op. + + TFLite complex64 tensors are represented as float32[..., 2] in Relax, + where index 0 = real part, index 1 = imaginary part along the last axis + """ + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = self.get_expr(input_tensors[0].tensor_idx) + # slice last axis at index 0, and squeeze to remove the last axis + real = _op.strided_slice(input_tensor, begin=[0], end=[1], strides=[1], axes=[-1]) + return _op.squeeze(real, axis=[-1]) + + def convert_imag(self, op): + """Convert TFLite IMAG op. + + See convert_real for representation of complex64 tensors in Relax. + """ + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = self.get_expr(input_tensors[0].tensor_idx) + # slice last axis at index 1, and squeeze to remove the last axis + imag = _op.strided_slice(input_tensor, begin=[1], end=[2], strides=[1], axes=[-1]) + return _op.squeeze(imag, axis=[-1]) + + def convert_complex_abs(self, op): + """Convert TFLite COMPLEX_ABS op: sqrt(real^2 + imag^2) + + See convert_real for the float32[..., 2] complex representation convention. + """ + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = self.get_expr(input_tensors[0].tensor_idx) + real = self.bb.emit( + _op.strided_slice(input_tensor, begin=[0], end=[1], strides=[1], axes=[-1]) + ) + real = self.bb.emit(_op.squeeze(real, axis=[-1])) + imag = self.bb.emit( + _op.strided_slice(input_tensor, begin=[1], end=[2], strides=[1], axes=[-1]) + ) + imag = self.bb.emit(_op.squeeze(imag, axis=[-1])) + real_sq = self.bb.emit(_op.multiply(real, real)) + imag_sq = self.bb.emit(_op.multiply(imag, imag)) + sum_expr = self.bb.emit(_op.add(real_sq, imag_sq)) + return _op.sqrt(sum_expr) + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) @@ -7609,6 +7661,12 @@ def get_tensor_expr(self, tensor, is_sparse=False): type_str = self.get_tensor_type_str(tensor.tensor.Type()) value = self.get_tensor_value_or_prefetched(tensor, is_sparse) + # complex64 constants have no native Relax dtype. Reinterpret the + # interleaved float32 storage as float32[..., 2] to match the + # convention used for input tensors. + if type_str == "complex64": + value = value.view(np.float32).reshape(value.shape + (2,)) + type_str = "float32" return self.exp_tab.new_const(value, dtype=type_str, source_name=tensor.tensor.Name()) def get_tensor_shape(self, tensor_wrapper): @@ -8044,8 +8102,9 @@ def _input_type(model): input_shape = tuple(tensor.ShapeAsNumpy()) tensor_type = tensor.Type() input_name = get_tensor_name(subgraph, input_) + input_dtype = _decode_type(tensor_type) shape_dict[input_name] = input_shape - dtype_dict[input_name] = _decode_type(tensor_type) + dtype_dict[input_name] = input_dtype return shape_dict, dtype_dict @@ -8183,6 +8242,10 @@ def func(self, data): dtype = ( _dtype_dict[model_input_name] if model_input_name in _dtype_dict else "float32" ) + if dtype == "complex64": + dtype = "float32" + if shape is not None: + shape = tuple(shape) + (2,) input_var = relax.Var( name_hint=model_input_name, struct_info=relax.TensorStructInfo(shape=shape, dtype=dtype), diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index e4483b9d41cc..e9a842b8dfe8 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -13020,5 +13020,99 @@ def test_unidirectional_sequence_rnn_time_major(): assert tuple(int(d) for d in out_shape) == (batch, time, num_units) +def test_real(): + class Real(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.complex64)]) + def func(self, x): + return tf.math.real(x) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( + x, + (R.prim_value(-1),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[-1]) + R.output(gv) + return gv + + verify(Real, Expected) + + +def test_imag(): + class Imag(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.complex64)]) + def func(self, x): + return tf.math.imag(x) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( + x, + (R.prim_value(-1),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[-1]) + R.output(gv) + return gv + + verify(Imag, Expected) + + +def test_complex_abs(): + class ComplexAbs(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.complex64)]) + def func(self, x): + return tf.math.abs(x) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( + x, + (R.prim_value(-1),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv1: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[-1]) + lv2: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( + x, + (R.prim_value(-1),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv3: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv2, axis=[-1]) + lv4: R.Tensor((2, 4), dtype="float32") = R.multiply(lv1, lv1) + lv5: R.Tensor((2, 4), dtype="float32") = R.multiply(lv3, lv3) + lv6: R.Tensor((2, 4), dtype="float32") = R.add(lv4, lv5) + gv: R.Tensor((2, 4), dtype="float32") = R.sqrt(lv6) + R.output(gv) + return gv + + verify(ComplexAbs, Expected) + + if __name__ == "__main__": pytest.main(["-s", __file__]) From 297f94476417cdbd6eae864f81149c0a7a1f12ba Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 16 Jun 2026 13:49:53 +0000 Subject: [PATCH 22/23] [REFACTOR][TIRX] Add IntImm common scalar ctor and streamline MakeConst (#19797) Common bool, int32, and int64 scalar constants show up throughout TIRX and related lowering code, and named constructors make these call sites easier to read than repeated DataType spelling. This PR establishes the scalar-constant construction policy and renames make_const to MakeConst to match the public helper naming style. - Prefer IntImm::Bool, IntImm::Int32, and IntImm::Int64 for common known scalar bool, int32, and int64 constants. - Prefer direct IntImm or FloatImm construction when dtype is known to be scalar integer or floating point. This makes the compiled code more compact and efficient. Keep MakeConst for generic overload cases where dtype can be integer, floating point, or vector-valued and the caller needs its scalar/vector dispatch. - Phase out make_zero in favor of explicit scalar constructors, or ConstHandle(0) for null handles. --- include/tvm/ir/expr.h | 29 ++++- include/tvm/ir/node_functor.h | 2 +- include/tvm/script/printer/doc.h | 4 +- include/tvm/tirx/buffer.h | 2 +- include/tvm/tirx/op.h | 108 ++++++++---------- include/tvm/topi/detail/broadcast.h | 2 +- include/tvm/topi/detail/extern.h | 13 +-- include/tvm/topi/detail/strided_slice.h | 4 +- include/tvm/topi/elemwise.h | 58 +++++----- include/tvm/topi/nn.h | 28 ++--- include/tvm/topi/nn/bnn.h | 2 +- include/tvm/topi/nn/dilate.h | 2 +- include/tvm/topi/nn/group_norm.h | 4 +- include/tvm/topi/nn/instance_norm.h | 4 +- include/tvm/topi/nn/layer_norm.h | 4 +- include/tvm/topi/nn/local_response_norm.h | 6 +- include/tvm/topi/nn/pooling.h | 27 +++-- include/tvm/topi/nn/rms_norm.h | 4 +- include/tvm/topi/nn/softmax.h | 2 +- include/tvm/topi/reduction.h | 12 +- include/tvm/topi/transform.h | 30 ++--- src/arith/analyzer.cc | 4 +- src/arith/canonical_simplify.cc | 62 +++++----- src/arith/conjunctive_normal_form.cc | 13 +-- src/arith/const_fold.h | 32 +++--- src/arith/detect_linear_equation.cc | 10 +- src/arith/int_constraints.cc | 2 +- src/arith/int_set.cc | 33 +++--- src/arith/ir_mutator_with_analyzer.cc | 10 +- src/arith/iter_affine_map.cc | 68 +++++------ src/arith/modular_set.cc | 2 +- src/arith/pattern_match.h | 2 +- src/arith/presburger_set.cc | 30 ++--- src/arith/product_normal_form.h | 4 +- src/arith/rewrite_simplify.cc | 45 ++++---- src/arith/solve_linear_equation.cc | 26 ++--- src/arith/solve_linear_inequality.cc | 24 ++-- src/backend/cuda/codegen/codegen_cuda.cc | 12 +- .../codegen/llvm/intrin_rule_hexagon.cc | 15 +-- src/backend/opencl/codegen/codegen_opencl.cc | 4 +- .../rocm/codegen/llvm/intrin_rule_rocm.cc | 6 +- .../trn/transform/lower_trainium_layout.cc | 2 +- src/backend/vulkan/codegen/codegen_spirv.cc | 2 +- src/ir/expr.cc | 2 +- src/relax/analysis/struct_info_analysis.cc | 66 +++++------ src/relax/analysis/tir_op_pattern_kind.cc | 6 +- src/relax/backend/contrib/clml/codegen.cc | 7 +- src/relax/backend/contrib/utils.cc | 2 +- src/relax/backend/vm/codegen_vm_tir.cc | 7 +- src/relax/ir/dataflow_matcher.cc | 8 +- src/relax/ir/dataflow_matcher.h | 2 +- src/relax/ir/emit_te.cc | 2 +- src/relax/ir/expr.cc | 4 +- src/relax/ir/expr_functor.cc | 2 +- src/relax/op/distributed/statistical.cc | 2 +- src/relax/op/image/resize.cc | 8 +- src/relax/op/memory/view.cc | 6 +- src/relax/op/nn/convolution.cc | 96 +++++++--------- src/relax/op/nn/nn.cc | 6 +- src/relax/op/nn/pooling.cc | 102 ++++++++--------- src/relax/op/tensor/binary.cc | 2 +- src/relax/op/tensor/index.cc | 4 +- src/relax/op/tensor/inspect.cc | 24 ++-- src/relax/op/tensor/manipulate.cc | 23 ++-- src/relax/op/tensor/set.cc | 14 +-- src/relax/op/tensor/statistical.cc | 14 +-- src/relax/op/vision/multibox_transform_loc.cc | 2 +- src/relax/op/vision/nms.cc | 6 +- src/relax/op/vision/roi_align.cc | 9 +- src/relax/op/vision/roi_pool.cc | 4 +- src/relax/transform/adjust_matmul_order.cc | 4 +- src/relax/transform/allocate_workspace.cc | 6 +- src/relax/transform/alter_op_impl.cc | 2 +- .../transform/combine_parallel_matmul.cc | 2 +- src/relax/transform/dataflow_inplace.cc | 4 +- src/relax/transform/fold_constant.cc | 2 +- src/relax/transform/fuse_tir.cc | 4 +- src/relax/transform/infer_amp_utils.cc | 4 +- src/relax/transform/lazy_transform_params.cc | 19 ++- src/relax/transform/lower_alloc_tensor.cc | 4 +- src/relax/transform/rewrite_cuda_graph.cc | 10 +- .../transform/split_layout_rewrite_preproc.cc | 4 +- .../transform/static_plan_block_memory.cc | 16 +-- src/s_tir/analysis/identify_memcpy.cc | 4 +- .../analysis/sblock_access_region_detector.cc | 2 +- .../backend/adreno/inject_texture_alloc.cc | 4 +- src/s_tir/backend/adreno/texture_flatten.cc | 2 +- .../meta_schedule/database/json_database.cc | 2 +- .../feature_extractor/per_store_feature.cc | 4 +- .../meta_schedule/mutator/mutate_parallel.cc | 4 +- .../postproc/rewrite_cooperative_fetch.cc | 25 ++-- .../rewrite_parallel_vectorize_unroll.cc | 7 +- .../postproc/rewrite_unbound_block.cc | 2 +- .../meta_schedule/postproc/verify_gpu_code.cc | 6 +- .../schedule/cuda/thread_bind.cc | 9 +- .../meta_schedule/schedule/cuda/winograd.cc | 59 +++++----- .../schedule_rule/add_rfactor.cc | 2 +- .../schedule_rule/multi_level_tiling.cc | 4 +- .../multi_level_tiling_tensor_core.cc | 17 ++- .../parallel_vectorize_unroll.cc | 4 +- src/s_tir/schedule/analysis/layout.cc | 12 +- src/s_tir/schedule/concrete_schedule.cc | 4 +- src/s_tir/schedule/concrete_schedule.h | 4 +- .../schedule/primitive/blockize_tensorize.cc | 10 +- src/s_tir/schedule/primitive/cache_index.cc | 10 +- .../schedule/primitive/cache_read_write.cc | 42 +++---- src/s_tir/schedule/primitive/compute_at.cc | 8 +- .../schedule/primitive/compute_inline.cc | 6 +- .../schedule/primitive/decompose_padding.cc | 4 +- .../primitive/layout_transformation.cc | 18 +-- .../schedule/primitive/loop_transformation.cc | 13 +-- src/s_tir/schedule/primitive/pad_einsum.cc | 10 +- src/s_tir/schedule/primitive/read_write_at.cc | 5 +- src/s_tir/schedule/primitive/reduction.cc | 30 ++--- src/s_tir/schedule/state.cc | 6 +- src/s_tir/schedule/trace.cc | 2 +- src/s_tir/schedule/traced_schedule.cc | 87 +++++++------- src/s_tir/schedule/transform.cc | 6 +- src/s_tir/support/nd_int_set.h | 2 +- src/s_tir/transform/bound_checker.cc | 4 +- src/s_tir/transform/canonicalize_loop.cc | 4 +- src/s_tir/transform/compact_buffer_region.cc | 24 ++-- src/s_tir/transform/decorate_device_scope.cc | 2 +- src/s_tir/transform/default_gpu_schedule.cc | 9 +- src/s_tir/transform/inject_double_buffer.cc | 14 +-- src/s_tir/transform/inject_ptx_ldg32.cc | 20 ++-- .../transform/inject_software_pipeline.cc | 23 ++-- src/s_tir/transform/inject_virtual_thread.cc | 6 +- src/s_tir/transform/loop_partition.cc | 16 +-- .../transform/lower_cross_thread_reduction.cc | 26 ++--- src/s_tir/transform/lower_match_buffer.cc | 4 +- src/s_tir/transform/lower_opaque_block.cc | 4 +- src/s_tir/transform/lower_thread_allreduce.cc | 18 +-- src/s_tir/transform/lower_vtcm_alloc.cc | 2 +- src/s_tir/transform/memhammer_coalesce.cc | 4 +- .../transform/memhammer_intermediate_stage.cc | 8 +- .../transform/memhammer_lower_auto_copy.cc | 7 +- .../transform/memhammer_tensorcore_rewrite.cc | 20 ++-- .../plan_update_buffer_allocation_location.cc | 2 +- src/s_tir/transform/thread_storage_sync.cc | 2 +- .../transform/transform_mma_buffer_layout.cc | 12 +- .../using_assume_to_reduce_branches.cc | 4 +- src/script/printer/ir/distributed.cc | 2 +- src/target/intrin_rule.cc | 30 ++--- src/target/llvm/codegen_cpu.cc | 4 +- src/target/llvm/codegen_llvm.cc | 8 +- src/target/llvm/codegen_x86_64.cc | 2 +- src/target/llvm/intrin_rule_llvm.cc | 11 +- src/te/operation/create_primfunc.cc | 12 +- src/te/tensor.cc | 2 +- src/tirx/analysis/exec_context.cc | 2 +- src/tirx/ir/buffer.cc | 18 +-- src/tirx/ir/exec_scope.cc | 9 +- src/tirx/ir/expr.cc | 8 +- src/tirx/ir/index_map.cc | 4 +- src/tirx/ir/layout/tile_core.cc | 2 +- src/tirx/ir/layout/tile_slice.cc | 4 +- src/tirx/ir/layout/utils.cc | 2 +- src/tirx/ir/script/script_complete.cc | 2 +- src/tirx/ir/stmt.cc | 6 +- src/tirx/op/op.cc | 88 +++++++------- src/tirx/script/builder/frame.cc | 9 +- src/tirx/script/builder/ir.cc | 10 +- src/tirx/transform/dtype_conversion.cc | 11 +- .../transform/force_narrow_index_to_i32.cc | 2 +- src/tirx/transform/ir_utils.h | 15 ++- src/tirx/transform/lower_intrin.cc | 18 +-- src/tirx/transform/lower_tirx_cleanup.cc | 2 +- src/tirx/transform/lower_tirx_opaque.cc | 2 +- src/tirx/transform/lower_tvm_builtin.cc | 49 ++++---- src/tirx/transform/lower_warp_memory.cc | 13 +-- src/tirx/transform/make_packed_api.cc | 31 +++-- src/tirx/transform/split_host_device.cc | 2 +- src/tirx/transform/storage_rewrite.cc | 22 ++-- src/tirx/transform/tile_primitive_dispatch.cc | 12 +- src/tirx/transform/tvm_ffi_binder.cc | 46 ++++---- src/tirx/transform/unroll_loop.cc | 2 +- src/tirx/transform/vectorize_loop.cc | 10 +- src/topi/einsum.cc | 6 +- tests/cpp/arith_simplify_test.cc | 14 +-- tests/cpp/ir_functor_test.cc | 14 +-- tests/cpp/nested_msg_test.cc | 34 +++--- tests/cpp/tir_scalable_datatype.cc | 11 +- 183 files changed, 1219 insertions(+), 1269 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index e614d7539487..eb924be63d94 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -515,6 +515,33 @@ class IntImm : public PrimExpr { */ TVM_DLL IntImm(DataType dtype, int64_t value, Span span = Span()); + /*! + * \brief Construct a scalar boolean constant. + * \param value The boolean value. + * \param span The location of this object in the source code. + */ + static IntImm Bool(bool value, Span span = Span()) { + return IntImm(DataType::Bool(), value, span); + } + + /*! + * \brief Construct a scalar int32 constant. + * \param value The integer value. + * \param span The location of this object in the source code. + */ + static IntImm Int32(int64_t value, Span span = Span()) { + return IntImm(DataType::Int(32), value, span); + } + + /*! + * \brief Construct a scalar int64 constant. + * \param value The integer value. + * \param span The location of this object in the source code. + */ + static IntImm Int64(int64_t value, Span span = Span()) { + return IntImm(DataType::Int(64), value, span); + } + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(IntImm, PrimExpr, IntImmNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(IntImmNode); }; @@ -636,7 +663,7 @@ struct TypeTraits : public ObjectRefWithFallbackTraitsBase::ConvertFallbackValue(StrictBool value) { - return IntImm(DataType::Bool(), value, Span()); + return IntImm::Bool(value); } TVM_FFI_INLINE PrimExpr TypeTraits::ConvertFallbackValue(int64_t value) { diff --git a/include/tvm/ir/node_functor.h b/include/tvm/ir/node_functor.h index c7be2188d314..fc6c925da387 100644 --- a/include/tvm/ir/node_functor.h +++ b/include/tvm/ir/node_functor.h @@ -47,7 +47,7 @@ namespace tvm { * return prefix + "IntImm" * }); * - * Expr x = make_const(1); + * Expr x = MakeConst(1); * Expr y = x + x; * // dispatch to IntImm, outputs "MyIntImm" * LOG(INFO) << tostr(x, "My"); diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index d63942ac71df..2389c1b50d15 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -277,7 +277,7 @@ class LiteralDoc : public ExprDoc { * \param p The object path */ static LiteralDoc Int(int64_t v, const ffi::Optional& p) { - return LiteralDoc(IntImm(DataType::Int(64), v), p); + return LiteralDoc(IntImm::Int64(v), p); } /*! * \brief Create a LiteralDoc to represent boolean. @@ -285,7 +285,7 @@ class LiteralDoc : public ExprDoc { * \param p The object path */ static LiteralDoc Boolean(bool v, const ffi::Optional& p) { - return LiteralDoc(IntImm(DataType::Bool(), v), p); + return LiteralDoc(IntImm::Bool(v), p); } /*! * \brief Create a LiteralDoc to represent float. diff --git a/include/tvm/tirx/buffer.h b/include/tvm/tirx/buffer.h index a5146600f4fa..1456787d688b 100644 --- a/include/tvm/tirx/buffer.h +++ b/include/tvm/tirx/buffer.h @@ -206,7 +206,7 @@ class Buffer : public ffi::ObjectRef { * \param input_extent The extent of ptr. */ TVM_DLL PrimExpr access_ptr(int access_mask, DataType ptr_type = DataType::Handle(), - int content_lanes = 1, PrimExpr offset = IntImm(DataType::Int(32), 0), + int content_lanes = 1, PrimExpr offset = IntImm::Int32(0), ffi::Optional input_extent = std::nullopt) const; /*! * \brief Create an Expr that does a vector load at begin index. diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h index a92ca7dc52bc..bae96f2d132a 100644 --- a/include/tvm/tirx/op.h +++ b/include/tvm/tirx/op.h @@ -52,7 +52,7 @@ namespace tvm { // Most common operators can be overloaded by argument type(PrimExpr). // So we put them under the root namespace. // -// We put more developer oriented APIs -- make_const and is_const under tirx +// We put more developer oriented APIs -- MakeConst and is_const under tirx // as they are more specific to the tirx namespace. /*! @@ -816,7 +816,14 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) { /*! * \brief Make a const value with certain data type. - * \param t The target type. + * + * Prefer direct IntImm or FloatImm construction when dtype is known to be + * scalar integer or floating point. This makes the compiled code more compact + * and efficient. Keep MakeConst for generic overload cases where dtype can be + * integer, floating point, or vector-valued and the caller needs its + * scalar/vector dispatch. + * + * \param dtype The target type. * \param value The input value * \return the result expression. * \tparam ValueType The constant value type @@ -825,32 +832,14 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) { template ::value && std::is_trivial::value>::type> -inline PrimExpr make_const(DataType t, ValueType value, Span span = Span()); -/*! - * \brief Make a const zero expr. - * \param t The target type. - * \param span The location of this operation in the source. - * \return the result expression. - */ -inline PrimExpr make_zero(DataType t, Span span = Span()); -/*! - * \brief Make a constant true expression. - * \param lanes The number of lanes in the bool - * \param span The location of this operation in the source. - * \return The result expression. - */ -inline PrimExpr const_true(int lanes = 1, Span span = Span()) { - return make_const(DataType::Bool(lanes), 1); -} +inline PrimExpr MakeConst(DataType dtype, ValueType value, Span span = Span()); /*! - * \brief Make a constant false expression. - * \param lanes The number of lanes in the bool + * \brief Make a constant handle value. + * \param value The integer payload to reinterpret as a handle. * \param span The location of this operation in the source. * \return The result expression. */ -inline PrimExpr const_false(int lanes = 1, Span span = Span()) { - return make_const(DataType::Bool(lanes), 0); -} +inline PrimExpr ConstHandle(int64_t value, Span span = Span()); /*! * \brief Get x as constant int expression. * \param x The expression @@ -981,53 +970,52 @@ inline bool is_no_op(const tirx::Stmt& stmt) { } template -inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) { - if (t.is_int() || t.is_bool()) return IntImm(t, static_cast(value), span); - if (t.is_uint()) { +inline PrimExpr MakeConstScalar(DataType dtype, ValueType value, Span span = Span()) { + if (dtype.is_int() || dtype.is_bool()) return IntImm(dtype, static_cast(value), span); + if (dtype.is_uint()) { // Use IntImm if it is a small integer uint64_t uval = static_cast(value); if (value < static_cast(0)) { TVM_FFI_THROW(InternalError) << "cannot make uint from negative value " << value; } else if (uval <= static_cast(std::numeric_limits::max())) { - return IntImm(t, static_cast(value), span); + return IntImm(dtype, static_cast(value), span); } else { uint64_t mask = (static_cast(1) << 32U) - 1U; uint64_t low = uval & mask; uint64_t high = uval >> 32U; - return LargeUIntImm(t, static_cast(low), static_cast(high), span); + return LargeUIntImm(dtype, static_cast(low), static_cast(high), span); } } - if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float6() || t.is_float4()) - return FloatImm(t, static_cast(value), span); - TVM_FFI_THROW(InternalError) << "cannot make const for type " << t; + if (dtype.is_float() || dtype.is_bfloat16() || dtype.is_float8() || dtype.is_float6() || + dtype.is_float4()) { + return FloatImm(dtype, static_cast(value), span); + } + TVM_FFI_THROW(InternalError) << "cannot make const for type " << dtype; throw; } template <> -inline PrimExpr MakeConstScalar(DataType t, bool value, Span span) { - return MakeConstScalar(t, static_cast(value), span); +inline PrimExpr MakeConstScalar(DataType dtype, bool value, Span span) { + return MakeConstScalar(dtype, static_cast(value), span); } template -inline PrimExpr make_const(DataType t, ValueType value, Span span) { - if (t.is_scalar()) { - return MakeConstScalar(t, value, span); +inline PrimExpr MakeConst(DataType dtype, ValueType value, Span span) { + if (dtype.is_scalar()) { + return MakeConstScalar(dtype, value, span); } else { - if (t.is_fixed_length_vector()) { - return tirx::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span); + if (dtype.is_fixed_length_vector()) { + return tirx::Broadcast(MakeConstScalar(dtype.element_of(), value, span), dtype.lanes(), span); } else { - PrimExpr lanes = - tirx::Mul(tirx::Call(DataType::Int(32), tirx::builtin::vscale(), {}), t.vscale_factor()); - return tirx::Broadcast(MakeConstScalar(t.element_of(), value, span), lanes, span); + PrimExpr lanes = tirx::Mul(tirx::Call(DataType::Int(32), tirx::builtin::vscale(), {}), + dtype.vscale_factor()); + return tirx::Broadcast(MakeConstScalar(dtype.element_of(), value, span), lanes, span); } } } -inline PrimExpr make_zero(DataType t, Span span) { - if (t.is_handle()) { - return reinterpret(t, make_const(DataType::UInt(64), 0, span)); - } - return make_const(t, 0, span); +inline PrimExpr ConstHandle(int64_t value, Span span) { + return reinterpret(DataType::Handle(), IntImm(DataType::UInt(64), value, span)); } } // namespace tirx @@ -1043,13 +1031,13 @@ inline PrimExpr make_zero(DataType t, Span span) { inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \ inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \ inline PrimExpr Name(int a, const PrimExpr& b) { \ - return Name(tirx::make_const(b.dtype(), a), b); \ + return Name(tirx::MakeConst(b.dtype(), a), b); \ } \ inline PrimExpr Name(const PrimExpr& a, int b) { \ - return Name(a, tirx::make_const(a.dtype(), b)); \ + return Name(a, tirx::MakeConst(a.dtype(), b)); \ } \ inline PrimExpr Name(const PrimExpr& a, double b) { \ - return Name(a, tirx::make_const(DataType::Float(64), b)); \ + return Name(a, FloatImm(DataType::Float(64), b)); \ } #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name) \ @@ -1060,13 +1048,13 @@ inline PrimExpr make_zero(DataType t, Span span) { return Name(PrimExpr(a), b, span); \ } \ inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \ - return Name(tirx::make_const(b.dtype(), a), b, span); \ + return Name(tirx::MakeConst(b.dtype(), a), b, span); \ } \ inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \ - return Name(a, tirx::make_const(a.dtype(), b), span); \ + return Name(a, tirx::MakeConst(a.dtype(), b), span); \ } \ inline PrimExpr Name(const PrimExpr& a, double b, Span span = Span()) { \ - return Name(a, tirx::make_const(DataType::Float(64), b), span); \ + return Name(a, FloatImm(DataType::Float(64), b), span); \ } #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ @@ -1081,18 +1069,18 @@ inline PrimExpr make_zero(DataType t, Span span) { return Name(PrimExpr(a), b, span); \ } -#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, int b) { \ - return Name(a, tirx::make_const(a.dtype(), b)); \ - } \ - inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::make_const(b.dtype(), a), b); } +#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, int b) { \ + return Name(a, tirx::MakeConst(a.dtype(), b)); \ + } \ + inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::MakeConst(b.dtype(), a), b); } #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \ inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \ - return Name(a, tirx::make_const(a.dtype(), b), span); \ + return Name(a, tirx::MakeConst(a.dtype(), b), span); \ } \ inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \ - return Name(tirx::make_const(b.dtype(), a), b, span); \ + return Name(tirx::MakeConst(b.dtype(), a), b, span); \ } TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+); diff --git a/include/tvm/topi/detail/broadcast.h b/include/tvm/topi/detail/broadcast.h index 9a4c7f9339ab..c9dce9eb7489 100644 --- a/include/tvm/topi/detail/broadcast.h +++ b/include/tvm/topi/detail/broadcast.h @@ -130,7 +130,7 @@ inline tvm::ffi::Array InputIndexFromBroadcast( // Only inject 0 here if we have not yet reached the dimension of I // (i.e. this must be a 1) if (!found && (ovars.size() - i) <= expected_dims) { - ivars.push_back(tvm::tirx::make_zero(ovars[i].dtype())); + ivars.push_back(tvm::IntImm(ovars[i].dtype(), 0)); } } TVM_FFI_ICHECK(expected_dims == ivars.size()); diff --git a/include/tvm/topi/detail/extern.h b/include/tvm/topi/detail/extern.h index 14eb54c3ed65..161d5291c38e 100644 --- a/include/tvm/topi/detail/extern.h +++ b/include/tvm/topi/detail/extern.h @@ -108,13 +108,12 @@ inline PrimExpr pack_buffer(Buffer buf) { } else { strides = 0; } - ffi::Array pack_args{ - buf->data, - shape, - strides, - make_const(DataType::Int(32), static_cast(buf->shape.size())), - make_const(buf->dtype, 0), - buf->elem_offset}; + ffi::Array pack_args{buf->data, + shape, + strides, + IntImm::Int32(static_cast(buf->shape.size())), + MakeConst(buf->dtype, 0), + buf->elem_offset}; return tvm::tirx::Call(DataType::Handle(), tvm::tirx::builtin::tvm_stack_make_array(), pack_args); } diff --git a/include/tvm/topi/detail/strided_slice.h b/include/tvm/topi/detail/strided_slice.h index 2e5df30808be..19ee79a2086f 100644 --- a/include/tvm/topi/detail/strided_slice.h +++ b/include/tvm/topi/detail/strided_slice.h @@ -99,10 +99,10 @@ inline ffi::Array StridedSliceCanonicalizeBegin(const ffi::ArrayIsInstance()) { int64_t dim_i = GetConstInt(ishape[ax]); int64_t begin_i = CanonicalizeIndex(begin[i], dim_i, strides[i]); - begin_expr.push_back(make_const(dtype, begin_i)); + begin_expr.push_back(MakeConst(dtype, begin_i)); } else { auto idim = ishape[ax]; - auto b_expr = make_const(dtype, begin[i]); + auto b_expr = MakeConst(dtype, begin[i]); PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr; auto s = strides[i]; if (s < 0) { diff --git a/include/tvm/topi/elemwise.h b/include/tvm/topi/elemwise.h index 940b1f149ead..57225af9b493 100644 --- a/include/tvm/topi/elemwise.h +++ b/include/tvm/topi/elemwise.h @@ -82,22 +82,22 @@ TOPI_DECLARE_UNARY_OP(isinf); inline Tensor fast_tanh_float(const Tensor& in, std::string name, std::string tag) { // Clamp the inputs to the range [-9, 9] since anything outside // this range is +/-1.0f in single-precision. - auto x = maximum(make_const(in->dtype, -9.0), minimum(make_const(in->dtype, 9.0), in)); + auto x = maximum(MakeConst(in->dtype, -9.0), minimum(MakeConst(in->dtype, 9.0), in)); // The monomial coefficients of the numerator polynomial (odd). - auto alpha_1 = make_const(in->dtype, 4.89352455891786e-03); - auto alpha_3 = make_const(in->dtype, 6.37261928875436e-04); - auto alpha_5 = make_const(in->dtype, 1.48572235717979e-05); - auto alpha_7 = make_const(in->dtype, 5.12229709037114e-08); - auto alpha_9 = make_const(in->dtype, -8.60467152213735e-11); - auto alpha_11 = make_const(in->dtype, 2.00018790482477e-13); - auto alpha_13 = make_const(in->dtype, -2.76076847742355e-16); + auto alpha_1 = MakeConst(in->dtype, 4.89352455891786e-03); + auto alpha_3 = MakeConst(in->dtype, 6.37261928875436e-04); + auto alpha_5 = MakeConst(in->dtype, 1.48572235717979e-05); + auto alpha_7 = MakeConst(in->dtype, 5.12229709037114e-08); + auto alpha_9 = MakeConst(in->dtype, -8.60467152213735e-11); + auto alpha_11 = MakeConst(in->dtype, 2.00018790482477e-13); + auto alpha_13 = MakeConst(in->dtype, -2.76076847742355e-16); // The monomial coefficients of the denominator polynomial (even). - auto beta_0 = make_const(in->dtype, 4.89352518554385e-03); - auto beta_2 = make_const(in->dtype, 2.26843463243900e-03); - auto beta_4 = make_const(in->dtype, 1.18534705686654e-04); - auto beta_6 = make_const(in->dtype, 1.19825839466702e-06); + auto beta_0 = MakeConst(in->dtype, 4.89352518554385e-03); + auto beta_2 = MakeConst(in->dtype, 2.26843463243900e-03); + auto beta_4 = MakeConst(in->dtype, 1.18534705686654e-04); + auto beta_6 = MakeConst(in->dtype, 1.19825839466702e-06); return compute( x->shape, @@ -209,9 +209,9 @@ inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag return compute( x->shape, [&](const ffi::Array& i) { - PrimExpr zero = make_zero(x->dtype); - PrimExpr one = make_const(x->dtype, 1); - PrimExpr minus_one = make_const(x->dtype, -1); + PrimExpr zero = MakeConst(x->dtype, 0); + PrimExpr one = MakeConst(x->dtype, 1); + PrimExpr minus_one = MakeConst(x->dtype, -1); auto s1 = tvm::tirx::Select((x(i) < zero), minus_one, zero); auto s2 = tvm::tirx::Select((x(i) > zero), one, s1); return s2; @@ -232,7 +232,7 @@ inline Tensor rsqrt(const Tensor& x, std::string name = "tensor", std::string ta return compute( x->shape, [&](const ffi::Array& i) { - PrimExpr one = make_const(x->dtype, 1); + PrimExpr one = MakeConst(x->dtype, 1); return one / tvm::sqrt(x(i)); }, name, tag); @@ -392,19 +392,19 @@ inline Tensor full_like(const Tensor& x, const PrimExpr fill_value, * y = exp(f) = 1 + 2 * P(x**2)/(Q(x**2) - P(x**2)) */ inline Tensor fast_exp_float32(const Tensor& _x, std::string name, std::string tag) { - auto x_hi = make_const(DataType::Float(32), 88.3762626647950f); - auto x_lo = make_const(DataType::Float(32), -88.3762626647949f); - auto log2e = make_const(DataType::Float(32), 1.44269504088896341f); - auto ln2 = make_const(DataType::Float(32), 0.6931471805599453f); - PrimExpr p[6] = {make_const(DataType::Float(32), 1.9875691500E-4f), - make_const(DataType::Float(32), 1.3981999507E-3f), - make_const(DataType::Float(32), 8.3334519073E-3f), - make_const(DataType::Float(32), 4.1665795894E-2f), - make_const(DataType::Float(32), 1.6666665459E-1f), - make_const(DataType::Float(32), 5.0000001201E-1f)}; - auto one = make_const(DataType::Float(32), 1.0f); - auto one_half = make_const(DataType::Float(32), 0.5f); - auto b = make_const(DataType::Float(32), 127.0f); + auto x_hi = FloatImm(DataType::Float(32), 88.3762626647950f); + auto x_lo = FloatImm(DataType::Float(32), -88.3762626647949f); + auto log2e = FloatImm(DataType::Float(32), 1.44269504088896341f); + auto ln2 = FloatImm(DataType::Float(32), 0.6931471805599453f); + PrimExpr p[6] = {FloatImm(DataType::Float(32), 1.9875691500E-4f), + FloatImm(DataType::Float(32), 1.3981999507E-3f), + FloatImm(DataType::Float(32), 8.3334519073E-3f), + FloatImm(DataType::Float(32), 4.1665795894E-2f), + FloatImm(DataType::Float(32), 1.6666665459E-1f), + FloatImm(DataType::Float(32), 5.0000001201E-1f)}; + auto one = FloatImm(DataType::Float(32), 1.0f); + auto one_half = FloatImm(DataType::Float(32), 0.5f); + auto b = FloatImm(DataType::Float(32), 127.0f); return compute( _x->shape, diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 7df01fe8c1b4..0a448620dae3 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -57,7 +57,7 @@ inline tvm::te::Tensor relu(const tvm::te::Tensor& t, T threshold = static_cast< return tvm::te::compute( t->shape, [&](const tvm::ffi::Array& i) { - auto threshold_const = tvm::tirx::make_const(t->dtype, threshold); + auto threshold_const = tvm::tirx::MakeConst(t->dtype, threshold); return tvm::max(t(i), threshold_const); }, name, tag); @@ -80,7 +80,7 @@ inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1, t->shape, [&](const tvm::ffi::Array& i) { auto value = t(i); - auto calpha = tvm::tirx::make_const(value.dtype(), alpha); + auto calpha = tvm::tirx::MakeConst(value.dtype(), alpha); return tvm::tirx::Select(value > 0, value, value * calpha); }, name, tag); @@ -194,7 +194,7 @@ inline tvm::te::Tensor pad( } if (!pad_value.defined()) { - pad_value = tvm::tirx::make_const(t->dtype, 0); + pad_value = tvm::tirx::MakeConst(t->dtype, 0); } auto l = [&](tvm::ffi::Array ovars) { @@ -232,12 +232,12 @@ inline tvm::te::Tensor pad( if (pad_mode == "constant") { return tvm::if_then_else( foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); }, - const_true(1), sel), + IntImm::Bool(true), sel), t(indices), pad_value); } else if (pad_mode == "edge" || pad_mode == "reflect") { return tvm::if_then_else( foldl([](PrimExpr a, PrimExpr b, Span span) { return tvm::logical_and(a, b, span); }, - const_true(1), sel), + IntImm::Bool(true), sel), t(indices), t(pad_idx)); } } @@ -507,7 +507,7 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, // pad the input with paddings provided if (!pad_value.defined()) { - pad_value = tvm::tirx::make_const(data->dtype, 0); + pad_value = tvm::tirx::MakeConst(data->dtype, 0); } padded_t = pad(data, pad_before_int32, pad_after_int32, pad_value); @@ -534,7 +534,7 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, << padded_input << ")" << " must be divisible by its block size (" << block_size << ")"; - PrimExpr bs = IntImm(DataType::Int(64), block_shape[i - 1]); + PrimExpr bs = IntImm::Int64(block_shape[i - 1]); r_shape.push_back(div(padded_shape[i], bs)); r_shape.push_back(bs); block_shape_prod *= bs; @@ -549,7 +549,7 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, } o_shape.push_back(tvm::PrimExpr(batch) * block_shape_prod); for (size_t i = 1; i <= num_block_dims; i++) { - PrimExpr bs = IntImm(DataType::Int(64), block_shape[i - 1]); + PrimExpr bs = IntImm::Int64(block_shape[i - 1]); o_shape.push_back(div(padded_shape[i], bs)); } // append remaining shape @@ -595,7 +595,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, int batch = static_cast(GetConstInt(in_shape[0])); for (size_t i = 0; i < num_block_dims; i++) { - PrimExpr bs = IntImm(DataType::Int(64), block_shape[i]); + PrimExpr bs = IntImm::Int64(block_shape[i]); r_shape.push_back(bs); block_shape_prod *= bs; } @@ -614,7 +614,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, ffi::Array r_p_shape; r_p_shape.push_back(batch / block_shape_prod); for (size_t i = 1; i <= num_block_dims; i++) { - PrimExpr bs = IntImm(DataType::Int(64), block_shape[i - 1]); + PrimExpr bs = IntImm::Int64(block_shape[i - 1]); r_p_shape.push_back(in_shape[i] * bs); } for (size_t i = num_block_dims + 1; i < num_input_dims; i++) { @@ -677,7 +677,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T [&](const tvm::ffi::Array& target_indices) { auto c = targets(); return tvm::tirx::Select(c != ignore_index, -predictions(c) * weights(c), - tvm::tirx::make_const(predictions->dtype, 0)); + tvm::tirx::MakeConst(predictions->dtype, 0)); }, name, tag); if (reduction == "mean") { @@ -686,7 +686,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T [&](const tvm::ffi::Array& target_indices) { auto c = targets(); return tvm::tirx::Select(c != ignore_index, weights(c), - tvm::tirx::make_const(predictions->dtype, 0)); + tvm::tirx::MakeConst(predictions->dtype, 0)); }, name, tag); return topi::divide(T, W); @@ -705,7 +705,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T pred_indices.push_back(target_indices[i]); // indices for multidimensional loss } return tvm::tirx::Select(c != ignore_index, -predictions(pred_indices) * weights(c), - tvm::tirx::make_const(predictions->dtype, 0)); + tvm::tirx::MakeConst(predictions->dtype, 0)); }, name, tag); TVM_FFI_ICHECK(T->shape.size() != 0); @@ -715,7 +715,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T [&](const tvm::ffi::Array& target_indices) { auto c = targets(target_indices); return tvm::tirx::Select(c != ignore_index, weights(c), - tvm::tirx::make_const(predictions->dtype, 0)); + tvm::tirx::MakeConst(predictions->dtype, 0)); }, name, tag); return topi::divide(topi::sum(T, tvm::ffi::Array(nullptr)), diff --git a/include/tvm/topi/nn/bnn.h b/include/tvm/topi/nn/bnn.h index 5a3ba871d56b..5faed879c005 100644 --- a/include/tvm/topi/nn/bnn.h +++ b/include/tvm/topi/nn/bnn.h @@ -71,7 +71,7 @@ inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, int axis, start_idx.push_back(i == static_cast(axis) ? indices[i] * 32 : static_cast(indices[i])); } - auto packed = make_const(DataType::UInt(32), 0); + PrimExpr packed = IntImm(DataType::UInt(32), 0); for (size_t j = 0; j < 32; ++j) { ffi::Array idx; for (size_t i = 0; i < n; ++i) { diff --git a/include/tvm/topi/nn/dilate.h b/include/tvm/topi/nn/dilate.h index e6f280c4bcba..0c8ea395c701 100644 --- a/include/tvm/topi/nn/dilate.h +++ b/include/tvm/topi/nn/dilate.h @@ -95,7 +95,7 @@ inline Tensor dilate(const Tensor& x, ffi::Array strides, double dilat if (not_zero.size() > 0) { auto all_not_zero = all(not_zero); return tvm::if_then_else(all_not_zero, x(index_tuple), - make_const(x->dtype, dilation_value)); + MakeConst(x->dtype, dilation_value)); } return x(index_tuple); }, diff --git a/include/tvm/topi/nn/group_norm.h b/include/tvm/topi/nn/group_norm.h index 1f1ac91867af..4962587a9396 100644 --- a/include/tvm/topi/nn/group_norm.h +++ b/include/tvm/topi/nn/group_norm.h @@ -126,7 +126,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& auto temp_x = temp_x_x2[0]; auto temp_x2 = temp_x_x2[1]; - auto reduce_extent = make_const(DataType::Float(32), 1); + PrimExpr reduce_extent = FloatImm(DataType::Float(32), 1); for (auto axis : new_axes) { reduce_extent *= data_reshaped->shape[axis]; } @@ -143,7 +143,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const Tensor& auto mean = temp_x(non_reduce_indices) / reduce_extent; auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean; PrimExpr group_norm = - (data_reshaped(indices) - mean) * tvm::rsqrt(var + make_const(data->dtype, epsilon)); + (data_reshaped(indices) - mean) * tvm::rsqrt(var + MakeConst(data->dtype, epsilon)); if (is_float16) { group_norm = Cast(DataType::Float(16), group_norm); } diff --git a/include/tvm/topi/nn/instance_norm.h b/include/tvm/topi/nn/instance_norm.h index 48fcf23904d5..60361e8bc681 100644 --- a/include/tvm/topi/nn/instance_norm.h +++ b/include/tvm/topi/nn/instance_norm.h @@ -106,7 +106,7 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso auto temp_x = temp_x_x2[0]; auto temp_x2 = temp_x_x2[1]; - auto reduce_extent = make_const(data->dtype, 1); + auto reduce_extent = MakeConst(data->dtype, 1); for (int i : real_axis) { reduce_extent *= data->shape[i]; } @@ -124,7 +124,7 @@ inline Tensor instance_norm(const Tensor& data, const Tensor& gamma, const Tenso channel = indices[channel_axis]; auto mean = temp_x(non_reduce_indices) / reduce_extent; auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean; - auto instance_norm = (data(indices) - mean) * tvm::rsqrt(var + make_const(var->dtype, epsilon)); + auto instance_norm = (data(indices) - mean) * tvm::rsqrt(var + MakeConst(var->dtype, epsilon)); if (is_float16) { instance_norm = Cast(DataType::Float(16), instance_norm); } diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h index d74bbce23f65..fb8155ef654a 100644 --- a/include/tvm/topi/nn/layer_norm.h +++ b/include/tvm/topi/nn/layer_norm.h @@ -102,7 +102,7 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& data->op->name + "_sum", kCommReduce); DataType reduce_dtype = is_float16 ? DataType::Float(32) : data->dtype; - PrimExpr reduce_extent = make_const(reduce_dtype, 1); + PrimExpr reduce_extent = MakeConst(reduce_dtype, 1); for (int i : real_axis) { reduce_extent *= data->shape[i]; } @@ -138,7 +138,7 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const Tensor& } auto mean = temp_mean(non_reduce_indices); auto var = temp_var_sum(non_reduce_indices) / reduce_extent; - auto layer_norm = (data(indices) - mean) * rsqrt(var + make_const(var->dtype, epsilon)); + auto layer_norm = (data(indices) - mean) * rsqrt(var + MakeConst(var->dtype, epsilon)); if (is_float16) { layer_norm = Cast(DataType::Float(16), layer_norm); } diff --git a/include/tvm/topi/nn/local_response_norm.h b/include/tvm/topi/nn/local_response_norm.h index 0c045a1631bc..7407448f88c5 100644 --- a/include/tvm/topi/nn/local_response_norm.h +++ b/include/tvm/topi/nn/local_response_norm.h @@ -79,9 +79,9 @@ inline Tensor lrn(const Tensor& data, int size, int axis = 1, float alpha = 0.00 }, "tensor", "sqr_sum"); } - PrimExpr alpha_imm = tvm::te::make_const(data->dtype, alpha); - PrimExpr beta_imm = tvm::te::make_const(data->dtype, beta); - PrimExpr bias_imm = tvm::te::make_const(data->dtype, bias); + PrimExpr alpha_imm = tvm::te::MakeConst(data->dtype, alpha); + PrimExpr beta_imm = tvm::te::MakeConst(data->dtype, beta); + PrimExpr bias_imm = tvm::te::MakeConst(data->dtype, bias); auto sqrt_sum_up = tvm::te::compute( input_shape, [&](Var i, Var j, Var k, Var l) { diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h index 3cdb5b03c58a..e8410d8add22 100644 --- a/include/tvm/topi/nn/pooling.h +++ b/include/tvm/topi/nn/pooling.h @@ -145,17 +145,17 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww); PrimExpr out_idx_lower_h = tirx::Select( - pad_inds[height_axis] < kernel_height, make_const(pad_inds[height_axis].dtype(), 0), + pad_inds[height_axis] < kernel_height, IntImm(pad_inds[height_axis].dtype(), 0), (pad_inds[height_axis] - kernel_height) / stride_height + 1); PrimExpr out_idx_lower_w = tirx::Select( - pad_inds[width_axis] < kernel_width, make_const(pad_inds[width_axis].dtype(), 0), + pad_inds[width_axis] < kernel_width, IntImm(pad_inds[width_axis].dtype(), 0), (pad_inds[width_axis] - kernel_width) / stride_width + 1); return tvm::sum( tvm::if_then_else(tirx::And(tirx::And(out_idx[height_axis] >= out_idx_lower_h, out_idx[width_axis] >= out_idx_lower_w), mp_inds(out_idx) == idx), - out_grad(out_idx), make_const(x->dtype, 0)), + out_grad(out_idx), MakeConst(x->dtype, 0)), {windowh, windoww}); }, "T_pool_grad", "pool_grad_max"); @@ -176,10 +176,10 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww)); PrimExpr out_idx_lower_h = - tirx::Select(pad_h_idx < kernel_height, make_const(pad_h_idx.dtype(), 0), + tirx::Select(pad_h_idx < kernel_height, IntImm(pad_h_idx.dtype(), 0), (pad_h_idx - kernel_height) / stride_height + 1); PrimExpr out_idx_lower_w = - tirx::Select(pad_w_idx < kernel_width, make_const(pad_w_idx.dtype(), 0), + tirx::Select(pad_w_idx < kernel_width, IntImm(pad_w_idx.dtype(), 0), (pad_w_idx - kernel_width) / stride_width + 1); PrimExpr divide_factor; // number of pooled elements @@ -191,17 +191,16 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, PrimExpr h_end = min(h_start + kernel_height, height); PrimExpr w_end = min(w_start + kernel_width, width); - h_start = max(h_start, make_const(h_start.dtype(), 0)); - w_start = max(w_start, make_const(w_start.dtype(), 0)); - divide_factor = - max((h_end - h_start) * (w_end - w_start), make_const(h_end.dtype(), 1)); + h_start = max(h_start, IntImm(h_start.dtype(), 0)); + w_start = max(w_start, IntImm(w_start.dtype(), 0)); + divide_factor = max((h_end - h_start) * (w_end - w_start), MakeConst(h_end.dtype(), 1)); } return tvm::sum( tvm::if_then_else(tirx::And(tirx::And(out_idx[height_axis] >= out_idx_lower_h, out_idx[height_axis] < out_height), tirx::And(out_idx[width_axis] >= out_idx_lower_w, out_idx[width_axis] < out_width)), - out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)), + out_grad(out_idx) / divide_factor, MakeConst(out_grad->dtype, 0)), {windowh, windoww}); }, "T_pool_grad", "pool_grad_avg"); @@ -627,7 +626,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const ffi::Array& kernel_s if (count_include_pad) { std::vector start(k_size); std::vector end(k_size); - auto num_el = make_const(DataType::Int(32), 1); + auto num_el = IntImm::Int32(1); for (int i = 0; i < k_size; i++) { int ii = axis[i]; start[i] = output[ii] * stride[i] - pad_head[i]; @@ -643,7 +642,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const ffi::Array& kernel_s } else { std::vector start(k_size); std::vector end(k_size); - auto num_el = make_const(DataType::Int(32), 1); + auto num_el = IntImm::Int32(1); for (int i = 0; i < k_size; i++) { int ii = axis[i]; @@ -658,13 +657,13 @@ inline Tensor pool_impl_nd(const Tensor& x, const ffi::Array& kernel_s // number that represents the number of steps along the dilated kernel to reach a // non-padded value. Otherwise this should be 0. PrimExpr jumps_to_non_pad = (dilation[i] - 1 - start[i]) / dilation[i]; - jumps_to_non_pad = max(jumps_to_non_pad, make_const(jumps_to_non_pad.dtype(), 0)); + jumps_to_non_pad = max(jumps_to_non_pad, IntImm(jumps_to_non_pad.dtype(), 0)); end[i] = min(end[i], data_shape[ii] - 1); num_el *= (end[i] - (start[i] + dilation[i] * jumps_to_non_pad)) / dilation[i] + 1; } - PrimExpr divide_factor = max(num_el, make_const(DataType::Int(32), 1)); + PrimExpr divide_factor = max(num_el, IntImm::Int32(1)); return div(pool_sum(indices), divide_factor); } }, diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h index ac36e5badd41..294d82054e3e 100644 --- a/include/tvm/topi/nn/rms_norm.h +++ b/include/tvm/topi/nn/rms_norm.h @@ -63,7 +63,7 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const ffi::Arra auto ndim = data_fp32->shape.size(); TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor"; auto real_axis = GetRealAxis(static_cast(ndim), axis); - auto reduce_extent = make_const(data_fp32->dtype, 1); + auto reduce_extent = MakeConst(data_fp32->dtype, 1); for (int i : real_axis) { reduce_extent *= data_fp32->shape[i]; } @@ -75,7 +75,7 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const ffi::Arra } } auto output = - tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon)); + tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + MakeConst(data_type, epsilon)); return output; }; auto rsqrt_shape = ffi::Array(); diff --git a/include/tvm/topi/nn/softmax.h b/include/tvm/topi/nn/softmax.h index 9786099f9edb..0479e3431908 100644 --- a/include/tvm/topi/nn/softmax.h +++ b/include/tvm/topi/nn/softmax.h @@ -61,7 +61,7 @@ inline Tensor softmax(const Tensor& x, int axis = -1, std::string name = "tensor auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false); tvm::ffi::Map attrs; - attrs.Set("axis", IntImm(DataType::Int(32), axis)); + attrs.Set("axis", IntImm::Int32(axis)); auto insert_reduce_index = [axis, ndim](const ffi::Array& indices, const IterVar& reduce_index) { diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index d8889000f2dd..e6b4c5af1dea 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -286,7 +286,7 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, auto result = fcombine(lhs, rhs); auto id_elem = fidentity(dtypes); - auto cond = condition != nullptr ? *condition : tirx::const_true(); + auto cond = condition != nullptr ? *condition : IntImm::Bool(true); auto combiner = tvm::tirx::CommReducer(lhs, rhs, result, id_elem); ffi::Array outputs; @@ -479,8 +479,8 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { }; auto fidentity = [&](std::vector types) { ffi::Array result; - result.push_back(tvm::tirx::make_const(types[0], -1)); // idx - result.push_back(tvm::max_value(types[1])); // val + result.push_back(tvm::tirx::MakeConst(types[0], -1)); // idx + result.push_back(tvm::max_value(types[1])); // val return result; }; return MakeCommReducer(fcombine, fidentity, "argmin"); @@ -541,8 +541,8 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { }; auto fidentity = [&](std::vector types) { ffi::Array result; - result.push_back(tvm::tirx::make_const(types[0], -1)); // idx - result.push_back(tvm::min_value(types[1])); // val + result.push_back(tvm::tirx::MakeConst(types[0], -1)); // idx + result.push_back(tvm::min_value(types[1])); // val return result; }; return MakeCommReducer(fcombine, fidentity, "argmax"); @@ -604,7 +604,7 @@ inline FCommReduce MakeTupleSumReducer() { auto fidentity = [](std::vector types) { ffi::Array result; for (size_t i = 0; i < types.size(); ++i) { - result.push_back(tvm::tirx::make_const(types[i], 0)); + result.push_back(tvm::tirx::MakeConst(types[i], 0)); } return result; }; diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index a46c1c05b344..1b26d03cc183 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -98,16 +98,16 @@ inline Tensor sliding_window(const Tensor& x, int axis, ffi::Array wind // Length of the shape along this dimension. auto dim_len = x->shape[_axis + i]; // Length of the window along this dimension. - PrimExpr window_len = IntImm(DataType::Int(64), window_shape[i]); + PrimExpr window_len = IntImm::Int64(window_shape[i]); // Strides along this dimension. - PrimExpr stride = IntImm(DataType::Int(64), strides[i]); + PrimExpr stride = IntImm::Int64(strides[i]); new_shape.push_back(floordiv(dim_len - (window_len - 1) + stride - 1, stride)); } // Dimensions comprising the window. for (size_t i = 0; i < window_shape.size(); ++i) { - new_shape.push_back(IntImm(DataType::Int(64), window_shape[i])); + new_shape.push_back(IntImm::Int64(window_shape[i])); } TVM_FFI_ICHECK(new_shape.size() == _axis + 2 * window_shape.size()); @@ -129,7 +129,7 @@ inline Tensor sliding_window(const Tensor& x, int axis, ffi::Array wind // Which index within the window we are indexing. auto idx_within_window = indices[_axis + window_shape.size() + i]; // Stride value for this dimension. - PrimExpr stride = IntImm(DataType::Int(64), strides[i]); + PrimExpr stride = IntImm::Int64(strides[i]); idx.push_back(window_idx * stride + idx_within_window); } @@ -842,7 +842,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b ffi::Array begin_expr, end_expr, strides_expr; for (int64_t i = 0; i < num_dynamic_axes; ++i) { - auto ind = make_const(index_dtype, i); + auto ind = MakeConst(index_dtype, i); begin_expr.push_back(begin(ind)); end_expr.push_back(end(ind)); strides_expr.push_back(strides(ind)); @@ -938,7 +938,7 @@ inline Tensor strided_slice_with_axes( for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); for (size_t i = 0; i < normalized_axes.size(); ++i) { int64_t ax = normalized_axes[i]; - auto stride = make_const(strides[i]->dtype, strides_vec[i]); + auto stride = MakeConst(strides[i]->dtype, strides_vec[i]); PrimExpr ind = indices[ax] * stride + begin_expr[i]; real_indices.Set(ax, ind); } @@ -1118,7 +1118,7 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub len_index.push_back(bid); PrimExpr ret = tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index), - tvm::tirx::make_const(data->dtype, mask_value), data(out_index)); + tvm::tirx::MakeConst(data->dtype, mask_value), data(out_index)); return ret; }, name, tag); @@ -1293,7 +1293,7 @@ inline Tensor take(const Tensor& a, ffi::Variant indices, int PrimExpr in_bounds = idx >= 0 && idx < axis_dim; return tvm::if_then_else( in_bounds, a(real_indices), - tvm::tirx::make_const(a->dtype, std::numeric_limits::quiet_NaN())); + tvm::tirx::MakeConst(a->dtype, std::numeric_limits::quiet_NaN())); }, name, tag); } else { // mode == "wrap" @@ -1428,16 +1428,16 @@ inline Tensor tile(const Tensor& x, ffi::Array reps, std::string name = if (ndim == rdim) { for (size_t i = 0; i < ndim; ++i) { data_shape.push_back(x->shape[i]); - reps_shape.push_back(IntImm(DataType::Int(64), reps[i])); + reps_shape.push_back(IntImm::Int64(reps[i])); } } else if (ndim > rdim) { for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]); for (size_t i = 0; i < (ndim - rdim); ++i) reps_shape.push_back(1); - for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(IntImm(DataType::Int(64), reps[i])); + for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(IntImm::Int64(reps[i])); } else { for (size_t i = 0; i < (rdim - ndim); ++i) data_shape.push_back(1); for (size_t i = 0; i < ndim; ++i) data_shape.push_back(x->shape[i]); - for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(IntImm(DataType::Int(64), reps[i])); + for (size_t i = 0; i < rdim; ++i) reps_shape.push_back(IntImm::Int64(reps[i])); } for (size_t i = 0; i < tdim; ++i) new_shape.push_back(data_shape[i] * reps_shape[i]); @@ -1592,7 +1592,7 @@ inline Tensor gather_nd(const Tensor& data, const Tensor& indices, int batch_dim real_indices.push_back(out_index[i]); } for (size_t i = 0; i < indices_dim0; ++i) { - indices_position.Set(0, make_const(DataType::Int(32), i)); + indices_position.Set(0, IntImm::Int32(i)); if (indices->dtype.is_int() || indices->dtype.is_uint()) { real_indices.push_back(indices(indices_position)); } else { @@ -1960,7 +1960,7 @@ inline Tensor meta_schedule_layout_transform( ffi::Array iter_domain; iter_domain.reserve(src->shape.size()); for (const PrimExpr& e : src->shape) { - iter_domain.push_back(Range::FromMinExtent(make_zero(e->dtype), e)); + iter_domain.push_back(Range::FromMinExtent(IntImm(e->dtype, 0), e)); } ffi::Array post_transform_shape = index_map->MapShape(src->shape, analyzer); return compute( @@ -2046,7 +2046,7 @@ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const Prim int indices_index = 0; for (int i = 0; i < ndim; i++) { if (i == true_axis) { - oshape.push_back(IntImm(DataType::Int(32), depth)); + oshape.push_back(IntImm::Int32(depth)); } else { oshape.push_back(indices->shape[indices_index++]); } @@ -2268,7 +2268,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b [&](const ffi::Array& indices) { ffi::Array real_indices; for (size_t i = 0; i < num_dynamic_axes; ++i) { - auto ind = make_const(DataType::Int(64), i); + auto ind = IntImm::Int64(i); real_indices.push_back(indices[i] * strides(ind) + tvm::min(begin(ind), x->shape[i] - 1)); } return x(real_indices); diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index cc3c73bb6207..69dbe97f5e8a 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -70,7 +70,7 @@ void AnalyzerObj::Bind(const Var& var, const Range& range, bool allow_override) void AnalyzerObj::MarkGlobalNonNegValue(const PrimExpr& value) { // decompose value as symbol * scale + offset int64_t offset = 0; - PrimExpr symbol_scale = tirx::make_const(value.dtype(), 0); + PrimExpr symbol_scale = tirx::MakeConst(value.dtype(), 0); auto fcollect_sum = [&](PrimExpr val, int sign) { if (const auto* intimm = val.as()) { @@ -87,7 +87,7 @@ void AnalyzerObj::MarkGlobalNonNegValue(const PrimExpr& value) { // split out the symbol and non-symbolic part int64_t cscale = 1; - PrimExpr symbol = tirx::make_const(value.dtype(), 1); + PrimExpr symbol = tirx::MakeConst(value.dtype(), 1); auto fcollect_prod = [&](PrimExpr val) { if (const auto* intimm = val.as()) { cscale *= intimm->value; diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index a6093a25ba6a..f1dd1a63c5e7 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -130,18 +130,18 @@ class SplitExprNode : public CanonicalExprNode { PrimExpr res = this->index; DataType dtype = this->dtype; if (this->scale == 0) { - return make_const(dtype, 0); + return IntImm(dtype, 0); } if (this->upper_factor != SplitExprNode::kPosInf) { - res = ModImpl(res, make_const(dtype, this->upper_factor), div_mode); + res = ModImpl(res, MakeConst(dtype, this->upper_factor), div_mode); } if (this->lower_factor != 1) { - res = DivImpl(res, make_const(dtype, this->lower_factor), div_mode); + res = DivImpl(res, MakeConst(dtype, this->lower_factor), div_mode); } sscale *= this->scale; if (sscale != 1) { TVM_FFI_ICHECK(!dtype.is_uint() || sscale > 0); - res = res * make_const(dtype, sscale); + res = res * MakeConst(dtype, sscale); } return res; } @@ -172,20 +172,20 @@ class SplitExprNode : public CanonicalExprNode { return false; } if (this->upper_factor != SplitExprNode::kPosInf) { - res = ModImpl(res, make_const(this->dtype, this->upper_factor), div_mode); + res = ModImpl(res, MakeConst(this->dtype, this->upper_factor), div_mode); if (!CastIsSafe(dtype, res, analyzer)) { return false; } } if (this->lower_factor != 1) { - res = DivImpl(res, make_const(this->dtype, this->lower_factor), div_mode); + res = DivImpl(res, MakeConst(this->dtype, this->lower_factor), div_mode); if (!CastIsSafe(dtype, res, analyzer)) { return false; } } if (this->scale != 1) { TVM_FFI_ICHECK(!this->dtype.is_uint() || this->scale > 0); - res = res * make_const(this->dtype, this->scale); + res = res * MakeConst(this->dtype, this->scale); if (!CastIsSafe(dtype, res, analyzer)) { return false; } @@ -252,7 +252,7 @@ class SumExprNode : public CanonicalExprNode { PrimExpr Normalize() const final { // quick path 1. if (this->args.size() == 0) { - return make_const(this->dtype, this->base); + return MakeConst(this->dtype, this->base); } return Normalize_(this->dtype, SimplifySplitExprs(args), base); } @@ -344,7 +344,7 @@ class SumExprNode : public CanonicalExprNode { if (dtype.bits() >= this->dtype.bits()) { return true; // upcast is safe } - PrimExpr res = make_const(dtype, 0); + PrimExpr res = IntImm(dtype, 0); for (size_t i = 0; i < args.size(); ++i) { if (args[i]->scale > 0) { res = res + args[i]->Normalize(); @@ -354,7 +354,7 @@ class SumExprNode : public CanonicalExprNode { } } if (base > 0 || is_min_value) { - res = res + make_const(dtype, base); + res = res + MakeConst(dtype, base); if (!CastIsSafe(dtype, res, analyzer)) { return false; } @@ -369,7 +369,7 @@ class SumExprNode : public CanonicalExprNode { } } if (base < 0 && !is_min_value) { - res = res - make_const(dtype, -base); + res = res - MakeConst(dtype, -base); if (!CastIsSafe(dtype, res, analyzer)) { return false; } @@ -500,14 +500,14 @@ class SumExprNode : public CanonicalExprNode { bool is_min_value = dtype.bits() == 64 ? base == std::numeric_limits::lowest() : base == -(1LL << (dtype.bits() - 1)); // Positive scales first - PrimExpr res = make_const(dtype, 0); + PrimExpr res = IntImm(dtype, 0); for (size_t i = 0; i < args.size(); ++i) { if (args[i]->scale > 0) { res = res + args[i]->Normalize(); } } if (base > 0 || is_min_value) { - res = res + make_const(dtype, base); + res = res + MakeConst(dtype, base); } // negative scales follows using sub. for (size_t i = 0; i < args.size(); ++i) { @@ -516,7 +516,7 @@ class SumExprNode : public CanonicalExprNode { } } if (base < 0 && !is_min_value) { - res = res - make_const(dtype, -base); + res = res - MakeConst(dtype, -base); } return res; } @@ -834,11 +834,11 @@ SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, return lhs; } else if (lhs->upper_factor <= (lhs->lower_factor * scaled_cval)) { // (x % c1) / c2 => 0 when c2 >= c1 - return ToSplitExpr(make_zero(lhs.dtype())); + return ToSplitExpr(IntImm(lhs.dtype(), 0)); } else { // move the upper_factor modular into index. lhs.CopyOnWrite()->index = - ModImpl(lhs->index, make_const(lhs.dtype(), lhs->upper_factor), div_mode); + ModImpl(lhs->index, MakeConst(lhs.dtype(), lhs->upper_factor), div_mode); lhs.CopyOnWrite()->upper_factor = SplitExprNode::kPosInf; lhs.CopyOnWrite()->scale = 1; lhs.CopyOnWrite()->lower_factor *= scaled_cval; @@ -862,8 +862,8 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, if (prhs->as()) return false; // collect lhs products and try to eliminate by matching them to prod in rhs ffi::Array> lhs_prods; - PrimExpr new_rhs = make_const(prhs->dtype(), 1); - PrimExpr new_common_scale = make_const(prhs->dtype(), 1); + PrimExpr new_rhs = MakeConst(prhs->dtype(), 1); + PrimExpr new_common_scale = MakeConst(prhs->dtype(), 1); int64_t lhs_cscale = 1, rhs_cscale = 1; int num_elimination = 0; @@ -905,13 +905,13 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, if (num_elimination == 0 && cscale_gcd == 1) return false; // construct prod via canonical form - PrimExpr new_lhs = make_const(plhs->dtype(), 1); + PrimExpr new_lhs = MakeConst(plhs->dtype(), 1); for (ffi::Optional val : lhs_prods) { if (val.defined()) new_lhs = new_lhs * val.value(); } - *plhs = new_lhs * make_const(plhs->dtype(), lhs_cscale); - *prhs = new_rhs * make_const(prhs->dtype(), rhs_cscale); - *common_scale = new_common_scale * make_const(prhs->dtype(), cscale_gcd); + *plhs = new_lhs * MakeConst(plhs->dtype(), lhs_cscale); + *prhs = new_rhs * MakeConst(prhs->dtype(), rhs_cscale); + *common_scale = new_common_scale * MakeConst(prhs->dtype(), cscale_gcd); return true; } @@ -958,7 +958,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { // if a >= 0 && a < cval, then result == 0 auto cbound = analyzer_->const_int_bound(Normalize(a)); if (cbound->min_value >= 0 && cbound->max_value < cval) { - return make_zero(a.dtype()); + return IntImm(a.dtype(), 0); } } return SplitDivConst(ToSplitExpr(std::move(a)), cval, kTruncDiv); @@ -1019,7 +1019,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { // if a >= 0 && a < cval, then result == 0 auto cbound = analyzer_->const_int_bound(Normalize(a)); if (cbound->min_value >= 0 && cbound->max_value < cval) { - return make_zero(a.dtype()); + return IntImm(a.dtype(), 0); } } // Identity: floordiv(floormod(index, m*n), n) = floormod(floordiv(index, n), m) @@ -1049,7 +1049,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { } // Apply floormod(floordiv_result, m) to complete the identity PrimExpr div_result = Normalize(lhs); - return this->VisitExpr(floormod(div_result, make_const(a.dtype(), new_mod))); + return this->VisitExpr(floormod(div_result, MakeConst(a.dtype(), new_mod))); } } } @@ -1096,7 +1096,7 @@ SplitExpr CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval, // Do a recursive call to simplify the mod with the new factor. if (new_upper_factor < lhs->upper_factor && lhs->upper_factor != SplitExprNode::kPosInf) { auto updated = ToSplitExpr(this->VisitExpr( - ModImpl(lhs->index, make_const(lhs.dtype(), new_upper_factor), div_mode))); + ModImpl(lhs->index, MakeConst(lhs.dtype(), new_upper_factor), div_mode))); // re-apply the lower_factor if (lhs->lower_factor != 1) { auto ret = SplitDivConst(updated, lhs->lower_factor, div_mode); @@ -1144,7 +1144,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { SumExpr lhs, extra; SeparateDivisibleParts(psum, cval, &lhs, &extra); if (extra->IsZero()) { - return make_zero(a.dtype()); + return IntImm(a.dtype(), 0); } // both lhs and extra are non-negative if (analyzer_->CanProveGreaterEqual(lhs->Normalize(), 0) && @@ -1414,11 +1414,11 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) { DataType dtype = divisible->dtype; TVM_FFI_ICHECK(extra->dtype == dtype); PrimExpr normal_extra = extra->Normalize(); - if (this->analyzer_->CanProve(normal_extra < make_const(dtype, gcd)) && - this->analyzer_->CanProve(normal_extra >= make_const(dtype, 0))) { + if (this->analyzer_->CanProve(normal_extra < MakeConst(dtype, gcd)) && + this->analyzer_->CanProve(normal_extra >= IntImm(dtype, 0))) { // Case 1. 0 <= xn < d divisible.CopyOnWrite()->DivideBy(gcd); - return Rewriter::VisitExpr(divisible->Normalize() < make_zero(dtype)); + return Rewriter::VisitExpr(divisible->Normalize() < IntImm(dtype, 0)); } else if (extra->args.size() == 1 && extra->args[0]->scale == 1 && extra->args[0]->upper_factor != ConstIntBoundNode::kPosInf && extra->args[0]->upper_factor % (gcd * extra->args[0]->lower_factor) == 0) { @@ -1435,7 +1435,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) { int64_t lower_factor = gcd * extra->args[0]->lower_factor; PrimExpr extra_expr = floormod(floordiv(split_expr->index, lower_factor), floordiv(split_expr->upper_factor, lower_factor)); - return Rewriter::VisitExpr(divisible->Normalize() + extra_expr < make_zero(dtype)); + return Rewriter::VisitExpr(divisible->Normalize() + extra_expr < IntImm(dtype, 0)); } } diff --git a/src/arith/conjunctive_normal_form.cc b/src/arith/conjunctive_normal_form.cc index d88d9fd34df4..a3bb95347e9e 100644 --- a/src/arith/conjunctive_normal_form.cc +++ b/src/arith/conjunctive_normal_form.cc @@ -139,16 +139,15 @@ class AndOfOrs { /*! \brief Mapping from PrimExpr to internal Key */ std::unordered_map expr_to_key_; - /*! \brief Cached key representing tirx::IntImm(DataType::Bool(), 1) */ + /*! \brief Cached key representing IntImm::Bool(true) */ Key key_true_; - /*! \brief Cached key representing tirx::IntImm(DataType::Bool(), 0) */ + /*! \brief Cached key representing IntImm::Bool(false) */ Key key_false_; }; AndOfOrs::AndOfOrs(const PrimExpr& expr) - : key_true_(GetKey(IntImm(DataType::Bool(), 1))), - key_false_(GetKey(IntImm(DataType::Bool(), 0))) { + : key_true_(GetKey(IntImm::Bool(true))), key_false_(GetKey(IntImm::Bool(false))) { VisitAndExpressions(expr, [&](const PrimExpr& outer_expr) { std::vector or_components; VisitOrExpressions(outer_expr, [&](const PrimExpr& inner_expr) { @@ -235,9 +234,9 @@ PrimExpr AndOfOrs::GetExpr(AndOfOrs::Key key) const { } PrimExpr AndOfOrs::AsPrimExpr() const { - PrimExpr expr = IntImm(DataType::Bool(), 1); + PrimExpr expr = IntImm::Bool(true); for (const auto& chunk : chunks_) { - PrimExpr chunk_expr = IntImm(DataType::Bool(), 0); + PrimExpr chunk_expr = IntImm::Bool(false); for (Key j : chunk) { chunk_expr = chunk_expr || GetExpr(j); } @@ -368,7 +367,7 @@ void AndOfOrs::SimplifyAcrossChunks(AnalyzerObj* analyzer) { // When attempting to simplify (B and C), the analyzer may // assume that A is false. PrimExpr known = [&]() { - PrimExpr known = IntImm(DataType::Bool(), 1); + PrimExpr known = IntImm::Bool(true); for (const auto& key : i_chunk) { if (&key != &key_i) { known = known && analyzer->Simplify(!GetExpr(key)); diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 91db540f2e82..fb1055660e3b 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -263,7 +263,8 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { if (pa->value == 0) return a; } if (pb) { - if (pb->value == 1) return tirx::make_zero(rtype); + // MakeConst can handle both vector and scalar types. + if (pb->value == 1) return tirx::MakeConst(rtype, 0); TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; } }); @@ -318,7 +319,8 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr if (pa->value == 0) return a; } if (pb) { - if (pb->value == 1) return tirx::make_zero(rtype); + // MakeConst can handle both vector and scalar types. + if (pb->value == 1) return tirx::MakeConst(rtype, 0); TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; } }); @@ -350,8 +352,8 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::Bool(), pa->value > pb->value); - if (fa && fb) return IntImm(DataType::Bool(), fa->value > fb->value); + if (pa && pb) return IntImm::Bool(pa->value > pb->value); + if (fa && fb) return IntImm::Bool(fa->value > fb->value); }); return std::nullopt; } @@ -359,8 +361,8 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::Bool(), pa->value >= pb->value); - if (fa && fb) return IntImm(DataType::Bool(), fa->value >= fb->value); + if (pa && pb) return IntImm::Bool(pa->value >= pb->value); + if (fa && fb) return IntImm::Bool(fa->value >= fb->value); }); return std::nullopt; } @@ -368,8 +370,8 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::Bool(), pa->value < pb->value); - if (fa && fb) return IntImm(DataType::Bool(), fa->value < fb->value); + if (pa && pb) return IntImm::Bool(pa->value < pb->value); + if (fa && fb) return IntImm::Bool(fa->value < fb->value); }); return std::nullopt; } @@ -377,8 +379,8 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::Bool(), pa->value <= pb->value); - if (fa && fb) return IntImm(DataType::Bool(), fa->value <= fb->value); + if (pa && pb) return IntImm::Bool(pa->value <= pb->value); + if (fa && fb) return IntImm::Bool(fa->value <= fb->value); }); return std::nullopt; } @@ -386,8 +388,8 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::Bool(), pa->value == pb->value); - if (fa && fb) return IntImm(DataType::Bool(), fa->value == fb->value); + if (pa && pb) return IntImm::Bool(pa->value == pb->value); + if (fa && fb) return IntImm::Bool(fa->value == fb->value); }); return std::nullopt; } @@ -395,8 +397,8 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { template <> inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ - if (pa && pb) return IntImm(DataType::Bool(), pa->value != pb->value); - if (fa && fb) return IntImm(DataType::Bool(), fa->value != fb->value); + if (pa && pb) return IntImm::Bool(pa->value != pb->value); + if (fa && fb) return IntImm::Bool(fa->value != fb->value); }); return std::nullopt; } @@ -427,7 +429,7 @@ template <> inline ffi::Optional TryConstFold(PrimExpr a) { const IntImmNode* pa = a.as(); if (pa) { - return IntImm(DataType::Bool(), !(pa->value)); + return IntImm::Bool(!(pa->value)); } return std::nullopt; } diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index d7a4874de0b3..b629b005f759 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -54,10 +54,10 @@ class LinearEqDetector : public ExprFunctorbase.defined()) { - ret->base = make_zero(var_.dtype()); + ret->base = IntImm(var_.dtype(), 0); } if (!ret->coeff.defined()) { - ret->coeff = make_zero(var_.dtype()); + ret->coeff = IntImm(var_.dtype(), 0); } return true; } @@ -102,7 +102,7 @@ class LinearEqDetector : public ExprFunctordtype; - ret.coeff = make_const(DataType::Int(dtype.bits(), dtype.lanes()), 1); + ret.coeff = MakeConst(DataType::Int(dtype.bits(), dtype.lanes()), 1); } else { ret.base = e; } @@ -195,13 +195,13 @@ bool DetectClipBound(const PrimExpr& cond, PrimExpr canonical; if (const LTNode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; - canonical = op->b - op->a - make_const(op->a.dtype(), 1); + canonical = op->b - op->a - MakeConst(op->a.dtype(), 1); } else if (const LENode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; canonical = op->b - op->a; } else if (const GTNode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; - canonical = op->a - op->b - make_const(op->a.dtype(), 1); + canonical = op->a - op->b - MakeConst(op->a.dtype(), 1); } else if (const GENode* op = cond.as()) { if (!op->a.dtype().is_int()) return false; canonical = op->a - op->b; diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 8a24d262e4fc..55db4fc774b6 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -86,7 +86,7 @@ IntGroupBounds::IntGroupBounds(PrimExpr coef, ffi::Array lower, IntGroupBounds IntGroupBounds::FromRange(const Range& r) { Analyzer analyzer; - PrimExpr coef = tirx::make_const(r->min.dtype(), 1); + PrimExpr coef = tirx::MakeConst(r->min.dtype(), 1); ffi::Array equal; ffi::Array lower; ffi::Array upper; diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 8659807cc7ea..a1e01d3e86a0 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -46,8 +46,7 @@ namespace arith { using tirx::is_one; using tirx::is_zero; -using tirx::make_const; -using tirx::make_zero; +using tirx::MakeConst; TVM_FFI_STATIC_INIT_BLOCK() { IntervalSetNode::RegisterReflection(); } @@ -133,7 +132,7 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, IntervalSet b, return IntervalSet::SinglePoint(expr); } if (is_logical_op::value) { - return IntervalSet(make_const(dtype, 0), make_const(dtype, 1)); + return IntervalSet(IntImm(dtype, 0), IntImm(dtype, 1)); } if (a->IsEmpty()) return a; if (b->IsEmpty()) return b; @@ -196,7 +195,7 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, Inte return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { using tirx::Select; - PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); + PrimExpr sign = b->min_value >= IntImm(b->min_value.dtype().element_of(), 0); PrimExpr e1 = a->min_value * b->min_value; PrimExpr e2 = a->max_value * b->min_value; return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); @@ -230,7 +229,7 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, Inte return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { using tirx::Select; - PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); + PrimExpr sign = b->min_value >= IntImm(b->min_value.dtype().element_of(), 0); PrimExpr e1 = a->min_value / b->min_value; PrimExpr e2 = a->max_value / b->min_value; return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); @@ -259,7 +258,7 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, Inte // is the case of our application. // TODO(tqchen): add bound constraints for a. if (analyzer->CanProveGreaterEqual(divisor, 0)) { - return IntervalSet(make_zero(divisor.dtype()), divisor - 1); + return IntervalSet(IntImm(divisor.dtype(), 0), divisor - 1); } else { PrimExpr bound = abs(divisor) - 1; return IntervalSet(-bound, bound); @@ -293,7 +292,7 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { using tirx::Select; - PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); + PrimExpr sign = b->min_value >= IntImm(b->min_value.dtype().element_of(), 0); PrimExpr e1 = floordiv(a->min_value, b->min_value); PrimExpr e2 = floordiv(a->max_value, b->min_value); return IntervalSet(Select(sign, e1, e2), Select(sign, e2, e1)); @@ -349,12 +348,12 @@ inline IntervalSet Combine(AnalyzerObj* analyzer, IntervalSet a, int64_t max_mod_result = max_quotient * gcd + (dividend_mod->base % gcd); if (max_mod_result >= 0 && max_mod_result < div_val) { - return IntervalSet(make_zero(op->dtype), make_const(op->dtype, max_mod_result)); + return IntervalSet(IntImm(op->dtype, 0), IntImm(op->dtype, max_mod_result)); } } } } - return IntervalSet(make_zero(divisor.dtype()), divisor - 1); + return IntervalSet(IntImm(divisor.dtype(), 0), divisor - 1); } else { PrimExpr bound = abs(divisor) - 1; return IntervalSet(-bound, bound); @@ -528,25 +527,25 @@ class IntervalSetEvaluator : public ExprFunctor { if (op->lanes->IsInstance()) { int lanes = static_cast(Downcast(op->lanes)->value); if (vstride > 0) { - PrimExpr stride_expr = make_const(t, vstride * (lanes - 1)); + PrimExpr stride_expr = MakeConst(t, vstride * (lanes - 1)); auto add_op = tirx::Add(op->base, stride_expr); auto add_node = add_op.as(); - return Combine(analyzer_, base, IntervalSet(make_zero(t), stride_expr), add_node); + return Combine(analyzer_, base, IntervalSet(IntImm(t, 0), stride_expr), add_node); } else { - PrimExpr stride_expr = make_const(t, vstride * (lanes - 1)); + PrimExpr stride_expr = MakeConst(t, vstride * (lanes - 1)); auto add_op = tirx::Add(op->base, stride_expr); auto add_node = add_op.as(); - return Combine(analyzer_, base, IntervalSet(stride_expr, make_zero(t)), add_node); + return Combine(analyzer_, base, IntervalSet(stride_expr, IntImm(t, 0)), add_node); } } else { /* Scalable vector */ if (vstride > 0) { - auto add_op = tirx::Add(op->base, make_zero(t)); + auto add_op = tirx::Add(op->base, IntImm(t, 0)); auto add_node = add_op.as(); - return Combine(analyzer_, base, IntervalSet(make_zero(t), pos_inf()), add_node); + return Combine(analyzer_, base, IntervalSet(IntImm(t, 0), pos_inf()), add_node); } else { - auto add_op = tirx::Add(op->base, make_zero(t)); + auto add_op = tirx::Add(op->base, IntImm(t, 0)); auto add_node = add_op.as(); - return Combine(analyzer_, base, IntervalSet(neg_inf(), make_zero(t)), add_node); + return Combine(analyzer_, base, IntervalSet(neg_inf(), IntImm(t, 0)), add_node); } } } diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 8aa821fa453a..1e78ab2ff218 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -55,10 +55,10 @@ void AppendFloorDivConstraints(const FloorDivNode* div, int64_t value, CompareKi if (!TryGetIntImm(div->b, &divisor_value) || divisor_value <= 0) return; DataType dtype = div->a.dtype(); - PrimExpr divisor = make_const(dtype, divisor_value); - PrimExpr k = make_const(dtype, value); + PrimExpr divisor = MakeConst(dtype, divisor_value); + PrimExpr k = MakeConst(dtype, value); PrimExpr lo = k * divisor; - PrimExpr hi = (k + make_const(dtype, 1)) * divisor; + PrimExpr hi = (k + MakeConst(dtype, 1)) * divisor; switch (kind) { case CompareKind::kEQ: @@ -160,7 +160,7 @@ void IRMutatorWithAnalyzer::MarkBufferMapShapes(const tirx::PrimFunc& func) { ffi::Array IRMutatorWithAnalyzer::IterMapSimplifyWithContext( const ffi::Array& indices, bool non_trivial_only) { - PrimExpr pred = const_true(); + PrimExpr pred = IntImm::Bool(true); for (PrimExpr val : iter_predicates_) { pred = pred && val; } @@ -260,7 +260,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { if (op->attr_key == tirx::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { IterVar iv = Downcast(op->node); TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); - Range dom = Range::FromMinExtent(make_zero(op->value.dtype()), op->value); + Range dom = Range::FromMinExtent(IntImm(op->value.dtype(), 0), op->value); analyzer_->Bind(iv->var, dom); iter_vars_.Set(iv->var, dom); } diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 1930feb42877..36923b9dfe2a 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -66,7 +66,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { IterSplitExpr::IterSplitExpr(IterMark source) { auto n = ffi::make_object(); - auto one = make_const(source->source->dtype, 1); + auto one = MakeConst(source->source->dtype, 1); n->dtype = source->source->dtype; n->source = std::move(source); n->extent = n->source->extent; @@ -77,7 +77,7 @@ IterSplitExpr::IterSplitExpr(IterMark source) { IterSplitExpr::IterSplitExpr(IterMark source, PrimExpr scale) { auto n = ffi::make_object(); - auto one = make_const(source->source->dtype, 1); + auto one = MakeConst(source->source->dtype, 1); n->dtype = source->source->dtype; n->source = std::move(source); n->extent = n->source->extent; @@ -180,7 +180,7 @@ class IterMapRewriter : public ExprMutator { : analyzer_(analyzer), check_level_(check_level), errors_(*errors), - padding_predicate_(const_false()) { + padding_predicate_(IntImm::Bool(false)) { for (auto kv : input_iters) { const Var& var = kv.first; const Range& vrng = kv.second; @@ -563,7 +563,7 @@ class IterMapRewriter : public ExprMutator { IterMapLevel check_level) { std::vector used(splits.size(), false); std::vector iters; - PrimExpr expected_lower_factor = make_const(mark->source->dtype, 1); + PrimExpr expected_lower_factor = MakeConst(mark->source->dtype, 1); for (size_t i = 0; i < splits.size(); ++i) { size_t j = 0; @@ -694,7 +694,7 @@ class IterMapRewriter : public ExprMutator { PrimExpr iter_min = mark_offset; PrimExpr iter_max = iter_min + mark->extent; // the delta of iter_min when it is updated when the lower bound predicate is present - PrimExpr iter_min_delta = make_const(iter_min.dtype(), 0); + PrimExpr iter_min_delta = IntImm(iter_min.dtype(), 0); if (predicate_induced_min.defined()) { iter_min_delta = max(predicate_induced_min.value(), iter_min) - iter_min; iter_min = max(predicate_induced_min.value(), iter_min); @@ -788,7 +788,7 @@ class IterMapRewriter : public ExprMutator { for (IterSplitExpr split : expr->args) { int64_t symbol_prod_count = 0; int64_t cscale = 1; - PrimExpr res = tirx::make_const(split.dtype(), 1); + PrimExpr res = tirx::MakeConst(split.dtype(), 1); auto fcollect = [&](PrimExpr val) { if (const auto* intimm = val.as()) { cscale *= intimm->value; @@ -799,7 +799,7 @@ class IterMapRewriter : public ExprMutator { }; UnpackReduction(split->scale, fcollect); if (cscale != 1) { - res = res * tirx::make_const(res.dtype(), cscale); + res = res * tirx::MakeConst(res.dtype(), cscale); } split.CopyOnWrite()->scale = res; items.emplace_back(Item{cscale, symbol_prod_count, split}); @@ -830,7 +830,7 @@ class IterMapRewriter : public ExprMutator { if (auto op = expr.as()) { return op.value(); } else if (auto op = expr.as()) { - return IterSumExpr({op.value()}, make_zero(expr->dtype)); + return IterSumExpr({op.value()}, IntImm(expr->dtype, 0)); } else { TVM_FFI_ICHECK(!expr->IsInstance()); return IterSumExpr({}, expr); @@ -1103,8 +1103,8 @@ class IterMapRewriter : public ExprMutator { std::vector flattened_iters, grouped_iters; // check if it can be remapped into a fused pattern. - PrimExpr expected_extra_base = make_const(expr.dtype(), 0); - PrimExpr tail_extent = make_const(expr.dtype(), 0); + PrimExpr expected_extra_base = IntImm(expr.dtype(), 0); + PrimExpr tail_extent = IntImm(expr.dtype(), 0); PrimExpr expected_scale = base_scale; int first_possible_unit_extent_pos = FindFirstPossibleUnitExtentIndex(expr); @@ -1200,10 +1200,10 @@ class IterMapRewriter : public ExprMutator { IterSumExpr structured_form = expr, flattened_form = expr; flattened_form.CopyOnWrite()->args = ffi::Array(flattened_iters.rbegin(), flattened_iters.rend()); - flattened_form.CopyOnWrite()->base = make_const(expr.dtype(), 0); + flattened_form.CopyOnWrite()->base = IntImm(expr.dtype(), 0); structured_form.CopyOnWrite()->args = ffi::Array(grouped_iters.rbegin(), grouped_iters.rend()); - structured_form.CopyOnWrite()->base = make_const(expr.dtype(), 0); + structured_form.CopyOnWrite()->base = IntImm(expr.dtype(), 0); auto it = sum_fuse_map_.find(flattened_form); if (it != sum_fuse_map_.end()) { // old iter @@ -1245,7 +1245,7 @@ class IterMapRewriter : public ExprMutator { if (sign > 0) { lhs->args.push_back(rhs); } else { - rhs.CopyOnWrite()->scale = make_zero(rhs->scale.dtype()) - rhs->scale; + rhs.CopyOnWrite()->scale = IntImm(rhs->scale.dtype(), 0) - rhs->scale; lhs->args.push_back(rhs); } } @@ -1677,7 +1677,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr original_dividend) { if (dividend->IsInstance()) { auto split = Downcast(dividend); - return IterSumExpr({split}, make_zero(split.dtype())); + return IterSumExpr({split}, IntImm(split.dtype(), 0)); } else if (dividend->IsInstance()) { auto sum = Downcast(dividend); if (sum->args.empty()) { @@ -1715,7 +1715,7 @@ PrimExpr ApproxLeastCommonMultiple(const PrimExpr& a, const PrimExpr& b, Analyze }; auto p1 = fsplit(a); auto p2 = fsplit(b); - auto const_lcm = IntImm(DataType::Int(32), LeastCommonMultiple(p1.second, p2.second)); + auto const_lcm = IntImm::Int32(LeastCommonMultiple(p1.second, p2.second)); if (analyzer->CanProveEqual(p1.first, p2.first)) { return p1.first * const_lcm; } else if (analyzer->CanProveEqual(floormod(p1.first, p2.first), 0)) { @@ -1880,12 +1880,12 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P } else if (CanProveDivisible(rhs, lhs->scale) && is_zero(base)) { // floordiv(x*c1, c1*c2) = floordiv(x, c2), c2=rhs/scale rhs = floordiv(rhs, lhs->scale); - lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1); + lhs.CopyOnWrite()->scale = MakeConst(rhs->dtype, 1); } else if (CanProveDivisible(rhs, lhs->scale) && CanProveDivisible(base, lhs->scale)) { // floordiv(x*c1 + y*c1, c1*c2) = floordiv(x+y, c2), c2=rhs/scale base = floordiv(base, lhs->scale); rhs = floordiv(rhs, lhs->scale); - lhs.CopyOnWrite()->scale = make_const(rhs->dtype, 1); + lhs.CopyOnWrite()->scale = MakeConst(rhs->dtype, 1); } else { // mark as unresolved. ErrorLogger(this) << "Cannot represent as IterMap: the numerator's scaling factor, " @@ -1931,7 +1931,7 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P new_split = IterSplitExpr(IterMark(padded, padded->extent), /* lower_factor = */ rhs, /* extent = */ analyzer_->Simplify(ceildiv(padded->extent, rhs)), - /* scale = */ make_const(rhs->dtype, 1)); + /* scale = */ MakeConst(rhs->dtype, 1)); } auto new_base = analyzer_->Simplify(floordiv(base - left_pad, rhs), 6); @@ -1987,13 +1987,13 @@ PrimExpr IterMapRewriter::SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, P if (is_one(rhs)) { // floormod(x, 1) = 0 - return make_zero(lhs->dtype); + return IntImm(lhs->dtype, 0); } if (!is_one(lhs->scale)) { if (CanProveDivisible(lhs->scale, rhs) && CanProveDivisible(base, rhs)) { // floormod(x*c1*c2, c1) = 0 - return make_zero(lhs->dtype); + return IntImm(lhs->dtype, 0); } else if (CanProveDivisible(rhs, lhs->scale) && is_zero(base)) { // floormod(x*c1, c1*c2) = (floormod(x, c2)) * c1, where c2 = rhs/scale rhs = floordiv(rhs, lhs->scale); @@ -2113,7 +2113,7 @@ class IterMapToExprNormalizer : public ExprMutator { // simplify trivial iters like `vi \in [0, 1)`, which can be useful for subsequent analysis // like tensorization. if (is_one(expr->extent) && !is_one(expr->source->extent)) { - return make_const(expr->extent->dtype, 0); + return IntImm(expr->extent->dtype, 0); } return floordiv(source, expr->lower_factor) * expr->scale; } else { @@ -2168,7 +2168,7 @@ ffi::Array IterMapSimplify(const ffi::Array& indices, // The input predicate may cause detect iter map to fail // but we can still detect the iter map without the input predicate // in which case the resulting iter map is valid and can be used for simplification. - rewrite = DetectIterMap(indices, input_iters, const_true(), check_level, ana, + rewrite = DetectIterMap(indices, input_iters, IntImm::Bool(true), check_level, ana, /*simplify_trivial_iterators=*/simplify_trivial_iterators) ->indices; } @@ -2256,14 +2256,14 @@ class SubspaceDivider { static DivisionResult Inner(const IterMapExpr& iter, const PrimExpr& extent) { auto dtype = iter.dtype(); - return DivisionResult(IterSumExpr({}, make_const(dtype, 0)), make_const(dtype, 1), iter, - extent, Kind::kInner); + return DivisionResult(IterSumExpr({}, IntImm(dtype, 0)), IntImm(dtype, 1), iter, extent, + Kind::kInner); } static DivisionResult Outer(const IterMapExpr& iter, const PrimExpr& extent) { auto dtype = iter.dtype(); - return DivisionResult(iter, extent, IterSumExpr({}, make_const(dtype, 0)), - make_const(dtype, 1), Kind::kOuter); + return DivisionResult(iter, extent, IterSumExpr({}, IntImm(dtype, 0)), IntImm(dtype, 1), + Kind::kOuter); } // Special value to indicate the division is not possible @@ -2288,8 +2288,8 @@ class SubspaceDivider { auto dtype = expr.dtype(); if (expr->args.empty()) { // base - return DivisionResult(IterSumExpr({}, make_const(dtype, 0)), make_const(dtype, 1), - IterSumExpr({}, expr->base), make_const(dtype, 1)); + return DivisionResult(IterSumExpr({}, IntImm(dtype, 0)), IntImm(dtype, 1), + IterSumExpr({}, expr->base), IntImm(dtype, 1)); } else if (expr->args.size() == 1) { // arg + base, if arg=Y*E(X)+X, then arg+base = Y*E(X)+(X+base) if (!is_one(expr->args[0]->scale)) { @@ -2303,7 +2303,7 @@ class SubspaceDivider { // arg1 + arg2 + ... + argn + base // then we can write it as Y*E(X)+X // if it starts with contiguous outer splits, followed by contiguous inner splits - PrimExpr extent = make_const(dtype, 1); + PrimExpr extent = IntImm(dtype, 1); std::vector outer_args, inner_args; bool inner = true, scale_is_one = false; // we check in inverse order so we can visit from inner to outer @@ -2335,7 +2335,7 @@ class SubspaceDivider { return DivisionResult::Failure(); } bool need_predicate = !analyzer_->CanProveEqual(extent, mark_extent); - const IterMark& outer_mark = MarkFromArgsAndBase(outer_args, make_const(dtype, 0)); + const IterMark& outer_mark = MarkFromArgsAndBase(outer_args, IntImm(dtype, 0)); const IterMark& inner_mark = MarkFromArgsAndBase(inner_args, expr->base); IterSumExpr outer_source = Downcast(outer_mark->source); IterSumExpr inner_source = Downcast(inner_mark->source); @@ -2377,7 +2377,7 @@ class SubspaceDivider { // args are sorted from inner to outer static IterMark MarkFromArgsAndBase(const std::vector& args, PrimExpr base) { std::vector res; - PrimExpr extent = make_const(base.dtype(), 1); + PrimExpr extent = MakeConst(base.dtype(), 1); for (const IterSplitExpr& it : args) { IterSplitExpr arg = it; arg.CopyOnWrite()->scale = extent; @@ -2429,7 +2429,7 @@ class SubspaceDivider { bool encountered_boundary = mark_division.IsOuter(); std::vector used(splits.size(), false); std::vector inner_iters, outer_iters; - PrimExpr expected_lower_factor = make_const(expr->source->source->dtype, 1); + PrimExpr expected_lower_factor = MakeConst(expr->source->source->dtype, 1); // find the boundary of outer and inner, like case 1 above for (size_t i = 0; i < splits.size(); ++i) { size_t j = 0; @@ -2485,7 +2485,7 @@ class SubspaceDivider { std::unordered_map split_map_; // predicate of outer space and inner space; - PrimExpr outer_preds_{const_true()}, inner_preds_{const_true()}; + PrimExpr outer_preds_{IntImm::Bool(true)}, inner_preds_{IntImm::Bool(true)}; }; ffi::Array> SubspaceDivide(const ffi::Array& bindings, @@ -2547,7 +2547,7 @@ class InverseAffineIterMapTransformer { // initialize back propagation accumulator for (const IterMapExprNode* node : post_dfs_order) { - backprop_.Set(ffi::GetRef(node), IntImm(DataType::Int(32), 0)); + backprop_.Set(ffi::GetRef(node), IntImm::Int32(0)); } for (size_t i = 0; i < iter_map.size(); i++) { backprop_.Set(iter_map[i], outputs[i]); diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index d10a7bad2932..5f66356e1ae9 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -303,7 +303,7 @@ class ModularSetAnalyzer::Impl : public ExprFunctorargs[1]); if (b.is_const()) { int shift; - if (is_const_power_of_two_integer(IntImm(DataType::Int(32), b.base + 1), &shift)) { + if (is_const_power_of_two_integer(IntImm::Int32(b.base + 1), &shift)) { return ModByConst(op->args[0], static_cast(1) << shift, true); } } diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 795bebd0ae0e..bb1ebd54cca7 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -377,7 +377,7 @@ class PConstWithTypeLike : public Pattern> { } } - PrimExpr Eval() const { return tirx::make_const(ref_.Eval().dtype(), value_); } + PrimExpr Eval() const { return tirx::MakeConst(ref_.Eval().dtype(), value_); } private: typename TA::Nested ref_; diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc index bbe330147cf9..ba6b5564d967 100644 --- a/src/arith/presburger_set.cc +++ b/src/arith/presburger_set.cc @@ -126,11 +126,11 @@ void PresburgerSetNode::UpdateConstraint(const PrimExpr& constraint, const ffi:: } PrimExpr PresburgerSetNode::GenerateConstraint() const { - PrimExpr constraint = const_false(); + PrimExpr constraint = IntImm::Bool(false); for (const IntegerRelation& disjunct : disjuncts) { - PrimExpr union_entry = const_true(); + PrimExpr union_entry = IntImm::Bool(true); for (unsigned i = 0, e = disjunct.getNumEqualities(); i < e; ++i) { - PrimExpr linear_eq = IntImm(DataType::Int(64), 0); + PrimExpr linear_eq = IntImm::Int64(0); if (disjunct.getNumCols() > 1) { for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) { #if TVM_MLIR_VERSION >= 160 @@ -139,9 +139,9 @@ PrimExpr PresburgerSetNode::GenerateConstraint() const { auto coeff = disjunct.atEq(i, j); #endif if (coeff >= 0 || is_zero(linear_eq)) { - linear_eq = linear_eq + IntImm(DataType::Int(64), coeff) * vars[j]; + linear_eq = linear_eq + IntImm::Int64(coeff) * vars[j]; } else { - linear_eq = linear_eq - IntImm(DataType::Int(64), -coeff) * vars[j]; + linear_eq = linear_eq - IntImm::Int64(-coeff) * vars[j]; } } } @@ -150,11 +150,11 @@ PrimExpr PresburgerSetNode::GenerateConstraint() const { #else auto c0 = disjunct.atEq(i, disjunct.getNumCols() - 1); #endif - linear_eq = linear_eq + IntImm(DataType::Int(64), c0); + linear_eq = linear_eq + IntImm::Int64(c0); union_entry = (union_entry && (linear_eq == 0)); } for (unsigned i = 0, e = disjunct.getNumInequalities(); i < e; ++i) { - PrimExpr linear_eq = IntImm(DataType::Int(64), 0); + PrimExpr linear_eq = IntImm::Int64(0); if (disjunct.getNumCols() > 1) { for (unsigned j = 0, f = disjunct.getNumCols() - 1; j < f; ++j) { #if TVM_MLIR_VERSION >= 160 @@ -163,9 +163,9 @@ PrimExpr PresburgerSetNode::GenerateConstraint() const { auto coeff = disjunct.atIneq(i, j); #endif if (coeff >= 0 || is_zero(linear_eq)) { - linear_eq = linear_eq + IntImm(DataType::Int(64), coeff) * vars[j]; + linear_eq = linear_eq + IntImm::Int64(coeff) * vars[j]; } else { - linear_eq = linear_eq - IntImm(DataType::Int(64), -coeff) * vars[j]; + linear_eq = linear_eq - IntImm::Int64(-coeff) * vars[j]; } } } @@ -175,9 +175,9 @@ PrimExpr PresburgerSetNode::GenerateConstraint() const { auto c0 = disjunct.atIneq(i, disjunct.getNumCols() - 1); #endif if (c0 >= 0) { - linear_eq = linear_eq + IntImm(DataType::Int(64), c0); + linear_eq = linear_eq + IntImm::Int64(c0); } else { - linear_eq = linear_eq - IntImm(DataType::Int(64), -c0); + linear_eq = linear_eq - IntImm::Int64(-c0); } union_entry = (union_entry && (linear_eq >= 0)); } @@ -245,15 +245,15 @@ IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) { auto maxRoundedDown(simplex.computeOptimum(Simplex::Direction::Up, coeffs)); auto opt = range.first.getOptimumIfBounded(); #if TVM_MLIR_VERSION >= 160 - auto min = opt.has_value() ? IntImm(DataType::Int(64), int64_t(opt.value())) : neg_inf(); + auto min = opt.has_value() ? IntImm::Int64(int64_t(opt.value())) : neg_inf(); #else - auto min = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : neg_inf(); + auto min = opt.hasValue() ? IntImm::Int64(opt.getValue()) : neg_inf(); #endif opt = range.second.getOptimumIfBounded(); #if TVM_MLIR_VERSION >= 160 - auto max = opt.has_value() ? IntImm(DataType::Int(64), int64_t(opt.value())) : pos_inf(); + auto max = opt.has_value() ? IntImm::Int64(int64_t(opt.value())) : pos_inf(); #else - auto max = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : pos_inf(); + auto max = opt.hasValue() ? IntImm::Int64(opt.getValue()) : pos_inf(); #endif auto interval = IntervalSet(min, max); result = Union({result, interval}); diff --git a/src/arith/product_normal_form.h b/src/arith/product_normal_form.h index d3308c07bb2a..40d02c1952b7 100644 --- a/src/arith/product_normal_form.h +++ b/src/arith/product_normal_form.h @@ -79,7 +79,7 @@ inline void UnpackSum(const PrimExpr& value, FLeaf fleaf, int sign = 1) { */ inline PrimExpr MulAndNormalize(const PrimExpr& lhs, const PrimExpr& rhs) { int64_t cscale = 1; - PrimExpr res = tirx::make_const(lhs.dtype(), 1); + PrimExpr res = tirx::MakeConst(lhs.dtype(), 1); auto fcollect = [&](PrimExpr val) { if (const auto* intimm = val.as()) { cscale *= intimm->value; @@ -90,7 +90,7 @@ inline PrimExpr MulAndNormalize(const PrimExpr& lhs, const PrimExpr& rhs) { UnpackReduction(lhs, fcollect); UnpackReduction(rhs, fcollect); if (cscale != 1) { - res = res * tirx::make_const(res.dtype(), cscale); + res = res * tirx::MakeConst(res.dtype(), cscale); } return res; } diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 5a86cdd15abb..b5b0cc604e22 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -237,7 +237,7 @@ CompareResult RewriteSimplifier::Impl::TryComparisonOfProductAndSum(const PrimEx (B * A) + (A + B) * C, } .Match(diff)) { - return std::tuple{A.Eval(), B.Eval(), C.Eval(), IntImm(DataType::Int(32), -1)}; + return std::tuple{A.Eval(), B.Eval(), C.Eval(), IntImm::Int32(-1)}; } else { return std::nullopt; } @@ -543,7 +543,7 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c // applied. negation = NormalizeBooleanOperators(Not(subconstraint)); } else { - negation = subconstraint == make_zero(subconstraint.dtype()); + negation = subconstraint == IntImm(subconstraint.dtype(), 0); } literal_constraints_.push_back(Not(negation)); } @@ -839,7 +839,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode* op) { if (truncdiv(c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - return make_const(op->dtype, truncdiv(c1val, c2val)); + return MakeConst(op->dtype, truncdiv(c1val, c2val)); } // while it is always true for trunc div @@ -1019,7 +1019,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const ModNode* op) { // canonicalization: x % c == x % (-c) for truncated division // NOTE: trunc div required TVM_TRY_RECURSIVE_REWRITE_IF( - truncmod(x, c1), truncmod(x, PConst(make_const(op->dtype, -c1.Eval()->value))), + truncmod(x, c1), truncmod(x, PConst(MakeConst(op->dtype, -c1.Eval()->value))), c1.Eval()->value < 0); // try modular analysis @@ -1089,7 +1089,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { floordiv(y + x * c1, c2).Match(ret)) { int64_t c1val = c1.Eval()->value; int64_t c2val = c2.Eval()->value; - PrimExpr yval = y.EvalOr(IntImm(DataType::Int(32), 0)); + PrimExpr yval = y.EvalOr(IntImm::Int32(0)); if (c2val == 0) return ret; // try eliminate residue part @@ -1098,8 +1098,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr y_div = CanProveEqual(floordiv(yval, c2val), 0) ? 0 : floordiv(yval, c2val); auto bound = analyzer_->const_int_bound(residue); if (bound.defined() && bound->max_value == bound->min_value) { - return x.Eval() * floordiv(c1val, c2.Eval()) + - (y_div + IntImm(DataType::Int(32), bound->max_value)); + return x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + IntImm::Int32(bound->max_value)); } // try simplify divisor @@ -1687,10 +1686,10 @@ ffi::Optional RewriteSimplifier::Impl::TryMatchLiteralConstraint( ExprDeepEqual expr_equal; for (const auto& constraint : literal_constraints_) { if (expr_equal(constraint, expr)) { - return make_const(expr->dtype, true); + return MakeConst(expr->dtype, true); } if (expr_equal(constraint, negation)) { - return make_const(expr->dtype, false); + return MakeConst(expr->dtype, false); } } return std::nullopt; @@ -1716,7 +1715,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { // Pattern var match IntImm PVar c1, c2; PVar lanes; - PConst ctrue(make_const(ret->dtype, true)); + PConst ctrue(MakeConst(ret->dtype, true)); // vector rule if (ret->dtype.is_scalable_or_fixed_length_vector()) { @@ -1726,10 +1725,10 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { if (IsIndexType(ret->a.dtype())) { CompareResult result = TryCompare(ret->a, ret->b); if (result == CompareResult::kEQ) { - return make_const(ret->dtype, true); + return MakeConst(ret->dtype, true); } else if (result == CompareResult::kNE || result == CompareResult::kGT || result == CompareResult::kLT) { - return make_const(ret->dtype, false); + return MakeConst(ret->dtype, false); } TVM_TRY_REWRITE(c1 == x, x == c1); @@ -1763,9 +1762,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const NENode* op) { CompareResult result = TryCompare(op->a, op->b); if (result == CompareResult::kNE || result == CompareResult::kGT || result == CompareResult::kLT) { - return make_const(op->dtype, true); + return MakeConst(op->dtype, true); } else if (result == CompareResult::kEQ) { - return make_const(op->dtype, false); + return MakeConst(op->dtype, false); } else if (result == CompareResult::kGE) { // Known: a >= b // @@ -1807,9 +1806,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LENode* op) { CompareResult result = TryCompare(op->a, op->b); if (result == CompareResult::kLE || result == CompareResult::kLT || result == CompareResult::kEQ) { - return make_const(op->dtype, true); + return MakeConst(op->dtype, true); } else if (result == CompareResult::kGT) { - return make_const(op->dtype, false); + return MakeConst(op->dtype, false); } else if (result == CompareResult::kNE) { // Known: a != b // @@ -1866,11 +1865,11 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { if (IsIndexType(ret->a.dtype())) { CompareResult result = TryCompare(ret->a, ret->b); if (result == CompareResult::kLT) { - return make_const(ret->dtype, true); + return MakeConst(ret->dtype, true); } if (result == CompareResult::kEQ || result == CompareResult::kGT || result == CompareResult::kGE) { - return make_const(ret->dtype, false); + return MakeConst(ret->dtype, false); } // clang-format off @@ -1988,9 +1987,9 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) { } else if (diff == 1) { return lhs <= rhs; } else if (diff < 0 && rhs_offset != 0) { - return lhs + make_const(lhs.dtype(), -diff) < rhs; + return lhs + MakeConst(lhs.dtype(), -diff) < rhs; } else if (diff > 0 && lhs_offset != 0) { - return lhs < rhs + make_const(rhs.dtype(), diff); + return lhs < rhs + MakeConst(rhs.dtype(), diff); } return std::nullopt; @@ -2105,7 +2104,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AndNode* op) { TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x && y, lanes)); } - auto cfalse = PConst(make_const(op->dtype, false)); + auto cfalse = PConst(MakeConst(op->dtype, false)); TVM_TRY_REWRITE(x == y && x != y, cfalse); TVM_TRY_REWRITE(x != y && x == y, cfalse); TVM_TRY_REWRITE(x && !x, cfalse); @@ -2253,7 +2252,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const OrNode* op) { TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x || y, lanes)); } - auto ctrue = PConst(make_const(op->dtype, true)); + auto ctrue = PConst(MakeConst(op->dtype, true)); TVM_TRY_REWRITE(x == y || x != y, ctrue); TVM_TRY_REWRITE(x != y || x == y, ctrue); @@ -2342,7 +2341,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { } else if (op->op.same_as(clz_op)) { if (const auto* arg_int = op->args[0].as()) { int bits = arg_int->dtype.bits(); - if (arg_int->value == 0) return make_const(op->dtype, bits); + if (arg_int->value == 0) return MakeConst(op->dtype, bits); for (int i = bits - 1; i >= 0; --i) { if ((int64_t(1) << i) & arg_int->value) { return IntImm(op->dtype, bits - i - 1); diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 623a906ee75c..27144c674b9f 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -133,10 +133,10 @@ void SmithNormalFormDiag(std::vector>* S, std::vector>* S, std::vector= 0) { - PrimExpr a = tirx::make_const(Uy[j].dtype(), S[j][j]); + PrimExpr a = tirx::MakeConst(Uy[j].dtype(), S[j][j]); solution_for_V_inv_x.push_back(analyzer_problem->Simplify(floordiv(Uy[j], a))); } else { // This is required because some simplifiers // have problems with dividing by negative numbers - PrimExpr a = tirx::make_const(Uy[j].dtype(), -S[j][j]); + PrimExpr a = tirx::MakeConst(Uy[j].dtype(), -S[j][j]); solution_for_V_inv_x.push_back(analyzer_problem->Simplify(floordiv(-Uy[j], a))); } } @@ -416,9 +416,9 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol // V V^{-1} x = x for (size_t i = 0; i < num_vars; ++i) { - PrimExpr e = tirx::make_zero(system_to_solve->variables[i].dtype()); + PrimExpr e = IntImm(system_to_solve->variables[i].dtype(), 0); for (size_t j = 0; j < num_vars; ++j) { - e = e + tirx::make_const(e.dtype(), V[i][j]) * solution_for_V_inv_x[j]; + e = e + tirx::MakeConst(e.dtype(), V[i][j]) * solution_for_V_inv_x[j]; } e = analyzer_problem->Simplify(e); old_to_new_map.Set(system_to_solve->variables[i], e); diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index aa66dcf5a655..80d064f71157 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -92,9 +92,9 @@ class NormalizeComparisons : public ExprMutator { PrimExpr Make(const PrimExpr& a, const PrimExpr& b) { // rewrite LT to LE for ints if (std::is_same::value && (a.dtype().is_int() || a.dtype().is_uint())) { - return LE(analyzer_->Simplify(a - b + 1), make_zero(a.dtype())); + return LE(analyzer_->Simplify(a - b + 1), IntImm(a.dtype(), 0)); } - return T(analyzer_->Simplify(a - b), make_zero(a.dtype())); + return T(analyzer_->Simplify(a - b), IntImm(a.dtype(), 0)); } arith::Analyzer analyzer_; }; @@ -248,11 +248,11 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t for (const auto& pos : coef_pos) { for (const auto& neg : coef_neg) { auto first_gcd = ExtendedEuclidean(pos.first, -neg.first, &gcd_x, &gcd_y); - PrimExpr c_pos = make_const(v.dtype(), neg.first / first_gcd); - PrimExpr c_neg = make_const(v.dtype(), pos.first / first_gcd); + PrimExpr c_pos = MakeConst(v.dtype(), neg.first / first_gcd); + PrimExpr c_neg = MakeConst(v.dtype(), pos.first / first_gcd); // eliminate the current variable PrimExpr new_lhs = c_neg * neg.second - c_pos * pos.second; - PrimExpr new_ineq = LE(new_lhs, make_zero(pos.second.dtype())); + PrimExpr new_ineq = LE(new_lhs, IntImm(pos.second.dtype(), 0)); // we need rewrite_simplify -> canonical_simplify -> rewrite_simplify // to help simplify things like (((y + 10) - (-1*(y - 20))) <= 0) => y - 5 <= 0 // with steps = 2 it's (y*2) - 10 <= 0 @@ -281,7 +281,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t lower_bounds.reserve(coef_neg.size()); for (const auto& pos : coef_pos) { - PrimExpr bound = make_const(v.dtype(), -coef_lcm / pos.first) * pos.second; + PrimExpr bound = MakeConst(v.dtype(), -coef_lcm / pos.first) * pos.second; bound = analyzer->Simplify(bound, kSimplifyRewriteCanonicalRewrite); // Don't add if any of the existing bounds is better if (std::any_of(upper_bounds.begin(), upper_bounds.end(), @@ -302,7 +302,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t upper_bounds.push_back(bound); } for (const auto& neg : coef_neg) { - PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * neg.second; + PrimExpr bound = MakeConst(v.dtype(), -coef_lcm / neg.first) * neg.second; bound = analyzer->Simplify(bound, kSimplifyRewriteCanonicalRewrite); // Don't add if any of the existing bounds is better if (std::any_of(lower_bounds.begin(), lower_bounds.end(), @@ -330,7 +330,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t std::sort(equal_list.begin(), equal_list.end(), ExprLess()); // Write it to the result. - IntGroupBounds bnds(make_const(v.dtype(), coef_lcm), + IntGroupBounds bnds(MakeConst(v.dtype(), coef_lcm), ffi::Array(lower_bounds.begin(), lower_bounds.end()), ffi::Array(equal_list.begin(), equal_list.end()), ffi::Array(upper_bounds.begin(), upper_bounds.end())); @@ -345,7 +345,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t PrimExpr e_simp = analyzer->Simplify(e, kSimplifyRewriteCanonicalRewrite); if (is_const_int(e_simp, 0)) { // contradiction detected - other_conditions = {const_false()}; + other_conditions = {IntImm::Bool(false)}; break; } else if (is_const_int(e_simp, 1)) { continue; @@ -413,7 +413,7 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { if (analyzer->CanProveGreaterEqual(-best_range->extent, 0)) { // range.extent <= 0 implies the input inequality system is unsolvable return IntConstraints(/*variables=*/{}, /*ranges=*/{}, - /*relations=*/{tirx::make_zero(DataType::Bool())}); + /*relations=*/{IntImm::Bool(false)}); } res_ranges.Set(var, best_range); vranges.Set(var, best_range); @@ -498,7 +498,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ IntConstraints( /*variables=*/{}, /*ranges=*/{}, - /*relations=*/{tirx::make_zero(DataType::Bool())}), + /*relations=*/{IntImm::Bool(false)}), {}, {}); } else { // created new_var starts from 0 @@ -509,7 +509,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ analyzer->Simplify(var - Substitute(best_range->min, res_dst_to_src))); // Add the new var to the resulting axis - auto range = Range(make_zero(new_var.dtype()), best_range->extent); + auto range = Range(IntImm(new_var.dtype(), 0), best_range->extent); res_variables.push_back(new_var); res_ranges.Set(new_var, range); diff --git a/src/backend/cuda/codegen/codegen_cuda.cc b/src/backend/cuda/codegen/codegen_cuda.cc index e04541a73da4..357f2c95857c 100644 --- a/src/backend/cuda/codegen/codegen_cuda.cc +++ b/src/backend/cuda/codegen/codegen_cuda.cc @@ -201,12 +201,12 @@ class ThreadIdxExtractor : public tirx::StmtVisitor { } public: - PrimExpr threadIdx_x_ext = IntImm(DataType::Int(32), 1); - PrimExpr threadIdx_y_ext = IntImm(DataType::Int(32), 1); - PrimExpr threadIdx_z_ext = IntImm(DataType::Int(32), 1); - PrimExpr clusterCtaIdx_x_ext = IntImm(DataType::Int(32), 1); - PrimExpr clusterCtaIdx_y_ext = IntImm(DataType::Int(32), 1); - PrimExpr clusterCtaIdx_z_ext = IntImm(DataType::Int(32), 1); + PrimExpr threadIdx_x_ext = IntImm::Int32(1); + PrimExpr threadIdx_y_ext = IntImm::Int32(1); + PrimExpr threadIdx_z_ext = IntImm::Int32(1); + PrimExpr clusterCtaIdx_x_ext = IntImm::Int32(1); + PrimExpr clusterCtaIdx_y_ext = IntImm::Int32(1); + PrimExpr clusterCtaIdx_z_ext = IntImm::Int32(1); }; void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) { diff --git a/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc b/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc index 0a4ca893b631..3e46e322a881 100644 --- a/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc +++ b/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc @@ -135,16 +135,17 @@ TVM_REGISTER_OP("tirx.tanh") return TVMExternCall(call, tvm_wrapper); } #endif - PrimExpr one = tirx::make_const(x.dtype(), 1); - PrimExpr two = tirx::make_const(x.dtype(), 2); - PrimExpr neg_two = tirx::make_const(x.dtype(), -2); + PrimExpr one = tirx::MakeConst(x.dtype(), 1); + PrimExpr two = tirx::MakeConst(x.dtype(), 2); + PrimExpr neg_two = tirx::MakeConst(x.dtype(), -2); PrimExpr exp_neg2x = exp(neg_two * x); PrimExpr exp_pos2x = exp(two * x); PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); - PrimExpr tanh_x = tirx::Select(x >= tirx::make_zero(x.dtype()), tanh_pos, tanh_neg); + // MakeConst can handle both vector and scalar types. + PrimExpr tanh_x = tirx::Select(x >= tirx::MakeConst(x.dtype(), 0), tanh_pos, tanh_neg); return tanh_x; }); @@ -194,8 +195,8 @@ TVM_REGISTER_OP("tirx.sigmoid") useqhl = tstring.find("+hvx-qfloat") != std::string::npos; } - PrimExpr MinBound = tirx::make_const(x.dtype(), -8); - PrimExpr MaxBound = tirx::make_const(x.dtype(), 8); + PrimExpr MinBound = tirx::MakeConst(x.dtype(), -8); + PrimExpr MaxBound = tirx::MakeConst(x.dtype(), 8); const PrimExpr v1 = tirx::Max(x, MinBound); const PrimExpr v2 = tirx::Min(v1, MaxBound); @@ -208,7 +209,7 @@ TVM_REGISTER_OP("tirx.sigmoid") return TVMExternCall(new_call.get(), tvm_wrapper); } #endif - PrimExpr one = tirx::make_const(x.dtype(), 1); + PrimExpr one = tirx::MakeConst(x.dtype(), 1); return one / (one + exp(-x)); }); diff --git a/src/backend/opencl/codegen/codegen_opencl.cc b/src/backend/opencl/codegen/codegen_opencl.cc index a5a94c41da89..9265fcc55547 100644 --- a/src/backend/opencl/codegen/codegen_opencl.cc +++ b/src/backend/opencl/codegen/codegen_opencl.cc @@ -457,7 +457,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { os << ", "; this->PrintExpr(op->args[3], os); os << ", "; - this->PrintExpr(make_const(DataType::Int(32), 0), os); + this->PrintExpr(IntImm::Int32(0), os); os << "), "; os << "as_"; this->PrintType(channel_type, os); @@ -490,7 +490,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { ss << ", "; this->PrintExpr(op->args[3], ss); ss << ", "; - this->PrintExpr(make_const(DataType::Int(32), 0), ss); + this->PrintExpr(IntImm::Int32(0), ss); ss << "))))"; std::string rhs = SSAGetID(ss.str(), op->dtype.with_lanes(data_lanes)); diff --git a/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc b/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc index 8bd0497a0d59..4859fd5f4a24 100644 --- a/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc +++ b/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc @@ -69,8 +69,8 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { TVM_FFI_ICHECK_EQ(var.dtype().bits(), 32); // get own lane in self (__lane_id) - PrimExpr minus_one = tirx::make_const(DataType::Int(32), -1); - PrimExpr zero = tirx::make_zero(DataType::Int(32)); + PrimExpr minus_one = IntImm::Int32(-1); + PrimExpr zero = IntImm::Int32(0); PrimExpr lo = Call(DataType::Int(32), builtin::call_pure_extern(), {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero}); PrimExpr self = Call(DataType::Int(32), builtin::call_pure_extern(), @@ -111,7 +111,7 @@ void RegisterROCMIntrinRules() { // dummy because we don't have the activemask TVM_REGISTER_OP("tirx.tvm_warp_activemask") .set_attr("rocm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { - PrimExpr zero = tirx::make_zero(DataType::Int(32)); + PrimExpr zero = IntImm::Int32(0); return zero; }); diff --git a/src/backend/trn/transform/lower_trainium_layout.cc b/src/backend/trn/transform/lower_trainium_layout.cc index b6b2cdcb3209..b0fba77ebab4 100644 --- a/src/backend/trn/transform/lower_trainium_layout.cc +++ b/src/backend/trn/transform/lower_trainium_layout.cc @@ -147,7 +147,7 @@ class TrainiumLayoutApplier : public arith::IRMutatorWithAnalyzer { if (auto tile_layout = buf->layout.as(); tile_layout && tile_layout->HasThreadAxis()) { arith::Analyzer ana; - PrimExpr mem_span = make_const(DataType::Int(32), 1); + PrimExpr mem_span = IntImm::Int32(1); for (const auto& iter : tile_layout->shard) { if (iter->axis->IsMemoryAxis()) { mem_span = mem_span + (iter->extent - 1) * iter->stride; diff --git a/src/backend/vulkan/codegen/codegen_spirv.cc b/src/backend/vulkan/codegen/codegen_spirv.cc index 4828dd2d5eb3..90b251cc8d6b 100644 --- a/src/backend/vulkan/codegen/codegen_spirv.cc +++ b/src/backend/vulkan/codegen/codegen_spirv.cc @@ -536,7 +536,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) { for (int i = 0; i < lanes; ++i) { spirv::Value v = base; if (i != 0) { - spirv::Value offset = MakeValue(make_const(op->stride.dtype(), i) * op->stride); + spirv::Value offset = MakeValue(MakeConst(op->stride.dtype(), i) * op->stride); v = builder_->Add(v, offset); } values.push_back(v); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index f3f55878f849..ef6ea0ed6dca 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -46,7 +46,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { RangeNode::RegisterReflection(); } -PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {} +PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm::Int32(value)) {} PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 00fc9ceb4ef3..932e7efeedfa 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -648,97 +648,97 @@ class StructInfoBasePreconditionCollector PrimExpr VisitStructInfo(const StructInfo& lhs, const StructInfo& other) override { if (lhs.same_as(other)) { // Early bail-out if the StructInfo has reference equality. - return tirx::const_true(); + return IntImm::Bool(true); } else { return StructInfoFunctor::VisitStructInfo(lhs, other); } } PrimExpr VisitStructInfo_(const ObjectStructInfoNode* lhs, const StructInfo& other) final { - return IntImm(DataType::Bool(), 1); + return IntImm::Bool(true); } PrimExpr VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } if (lhs->dtype != rhs->dtype) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } if (lhs->value.defined() && rhs->value.defined()) { return lhs->value.value() == rhs->value.value(); } else if (lhs->value.defined() && !rhs->value.defined()) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } else { - return IntImm(DataType::Bool(), 1); + return IntImm::Bool(true); } } PrimExpr VisitStructInfo_(const ShapeStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } // lhs have unknown ndim if (lhs->IsUnknownNdim()) { - return IntImm(DataType::Bool(), 1); + return IntImm::Bool(true); } // ndim must match if (lhs->ndim != rhs->ndim) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } if (lhs->values.defined() && rhs->values.defined()) { return ArrayCheck(lhs->values.value(), rhs->values.value()); } else if (lhs->values.defined() && !rhs->values.defined()) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } else { - return IntImm(DataType::Bool(), 1); + return IntImm::Bool(true); } } PrimExpr VisitStructInfo_(const TensorStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } // dtype mismatch if (!lhs->IsUnknownDtype() && lhs->dtype != rhs->dtype) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } // ndim mismatch if (!lhs->IsUnknownNdim() && lhs->ndim != rhs->ndim) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } // vdevice mismatch if (lhs->vdevice.defined() && !rhs->vdevice.defined()) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } if (lhs->vdevice.defined() && rhs->vdevice.defined()) { VDevice lhs_vdevice = lhs->vdevice.value(); VDevice rhs_vdevice = rhs->vdevice.value(); if (lhs_vdevice->target.defined() && !rhs_vdevice->target.defined()) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } // mismatch in either the target, vdevice_id, or memory_scope if ((lhs_vdevice->target.defined() && rhs_vdevice->target.defined()) && (lhs_vdevice->target != rhs_vdevice->target || lhs_vdevice->vdevice_id != rhs_vdevice->vdevice_id || lhs_vdevice->memory_scope != rhs_vdevice->memory_scope)) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } } if (lhs->shape.same_as(rhs->shape)) { - return IntImm(DataType::Bool(), 1); + return IntImm::Bool(true); } else if (lhs->shape.defined() && !rhs->shape.defined()) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } auto* lhs_shape = lhs->shape.as(); @@ -746,23 +746,23 @@ class StructInfoBasePreconditionCollector if (lhs_shape && rhs_shape) { return ArrayCheck(lhs_shape->values, rhs_shape->values); } else if (lhs_shape && !rhs_shape) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } - return IntImm(DataType::Bool(), 1); + return IntImm::Bool(true); } PrimExpr VisitStructInfo_(const distributed::DTensorStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } ffi::StructuralEqual struct_equal; if (!struct_equal(lhs->device_mesh, rhs->device_mesh) || !struct_equal(lhs->placement, rhs->placement)) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } return this->VisitStructInfo(lhs->tensor_sinfo, rhs->tensor_sinfo); @@ -771,7 +771,7 @@ class StructInfoBasePreconditionCollector PrimExpr VisitStructInfo_(const TupleStructInfoNode* lhs, const StructInfo& other) final { auto* rhs = other.as(); if (rhs == nullptr) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } return ArrayCheck(lhs->fields, rhs->fields); } @@ -779,19 +779,19 @@ class StructInfoBasePreconditionCollector PrimExpr VisitStructInfo_(const FuncStructInfoNode* lhs, const StructInfo& other) override { auto* rhs = other.as(); if (rhs == nullptr) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } // Check purity: Pure functions are a subtype of impure functions if (lhs->purity && !rhs->purity) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } if (lhs->derive_func.defined() && !lhs->derive_func.same_as(rhs->derive_func)) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } if (lhs->params.defined() && !rhs->params.defined()) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } PrimExpr all_match = VisitStructInfo(lhs->ret, rhs->ret); @@ -800,7 +800,7 @@ class StructInfoBasePreconditionCollector if (lhs->params.defined()) { param_check = ArrayCheck(lhs->params.value(), rhs->params.value()); } else { - param_check = IntImm(DataType::Bool(), 1); + param_check = IntImm::Bool(true); } PrimExpr ret_check = VisitStructInfo(lhs->ret, rhs->ret); @@ -811,10 +811,10 @@ class StructInfoBasePreconditionCollector private: PrimExpr ArrayCheck(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } - PrimExpr all_equal = IntImm(DataType::Bool(), 1); + PrimExpr all_equal = IntImm::Bool(true); for (size_t i = 0; i < lhs.size(); i++) { all_equal = all_equal && (lhs[i] == rhs[i]); } @@ -823,10 +823,10 @@ class StructInfoBasePreconditionCollector PrimExpr ArrayCheck(const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) { - return IntImm(DataType::Bool(), 0); + return IntImm::Bool(false); } - PrimExpr all_pass = IntImm(DataType::Bool(), 1); + PrimExpr all_pass = IntImm::Bool(true); for (size_t i = 0; i < lhs.size(); ++i) { all_pass = all_pass && VisitStructInfo(lhs[i], rhs[i]); diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index 6fb6e8549bbb..369f5793d9b5 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -445,7 +445,7 @@ bool HasReshapePattern(const PrimFunc& func) { return arith::IterMapSimplify( /*indices=*/{idx}, /*input_iters=*/var_range, - /*input_pred=*/const_true(), + /*input_pred=*/IntImm::Bool(true), /*check_level=*/arith::IterMapLevel::Surjective, /*analyzer=*/ana_, /*simplify_trivial_iterators=*/true)[0]; @@ -459,7 +459,7 @@ bool HasReshapePattern(const PrimFunc& func) { for (int i = 0; i < static_cast(block->iter_vars.size()); ++i) { if (!(indices[i].same_as(block->iter_vars[i]->var) && this->ana_->CanProveEqual(block->iter_vars[i]->dom->min, - IntImm(DataType::Int(64), /*value=*/0)) && + IntImm::Int64(/*value=*/0)) && this->ana_->CanProveEqual(buffer->shape[i], block->iter_vars[i]->dom->extent))) { return false; } @@ -495,7 +495,7 @@ bool HasReshapePattern(const PrimFunc& func) { ffi::Array simplify_res = arith::IterMapSimplify( /*indices=*/{flattened_idx}, /*input_iters=*/{{fused_var, Range(IntImm(dtype, /*value=*/0), stride)}}, - /*input_pred=*/const_true(), + /*input_pred=*/IntImm::Bool(true), /*check_level=*/arith::IterMapLevel::Surjective, /*analyzer=*/this->ana_, /*simplify_trivial_iterators=*/true); diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index 75073de17da4..fb2cebb4e099 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -48,8 +48,7 @@ struct OpenCLMLCompilerConfigNode : public ffi::Object { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro( "clml_version", &OpenCLMLCompilerConfigNode::clml_version, - "OpenCLML version as (major, minor, patch).", - refl::DefaultValue(IntImm(DataType::Int(32), 3))); + "OpenCLML version as (major, minor, patch).", refl::DefaultValue(IntImm::Int32(3))); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.ext.attrs.OpenCLMLCompilerConfig", OpenCLMLCompilerConfigNode, ffi::Object); @@ -335,9 +334,9 @@ inline constexpr bool IsOpenCLMLRuntimeEnabled() { */ IntImm GetOpenCLMLVersion() { #if TVM_GRAPH_EXECUTOR_CLML - return IntImm(DataType::Int(32), TVM_CLML_VERSION); + return IntImm::Int32(TVM_CLML_VERSION); #else - return IntImm(DataType::Int(32), 3); + return IntImm::Int32(3); #endif // TVM_GRAPH_EXECUTOR_CLML } diff --git a/src/relax/backend/contrib/utils.cc b/src/relax/backend/contrib/utils.cc index a66c070b3030..a8987a4092db 100644 --- a/src/relax/backend/contrib/utils.cc +++ b/src/relax/backend/contrib/utils.cc @@ -57,7 +57,7 @@ ffi::Map ExtractArgIdx(ffi::String pattern_name, Function f auto exp = matched_expr.value()[pat]; if (auto arg_var = exp.as()) { if (auto idx = find_index(f->params, ffi::GetRef(arg_var))) { - arg_idx.Set(name, IntImm(DataType::Int(64), *idx)); + arg_idx.Set(name, IntImm::Int64(*idx)); } } } diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index a1089eafb3dd..e93c2ee199db 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -82,9 +82,9 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { private: int64_t NewRegister() { return registers_num_++; } - static IntImm ConstInt64(int64_t value) { return IntImm(DataType::Int(64), value); } + static IntImm ConstInt64(int64_t value) { return IntImm::Int64(value); } - static IntImm ConstInt32(int64_t value) { return IntImm(DataType::Int(32), value); } + static IntImm ConstInt32(int64_t value) { return IntImm::Int32(value); } PrimExpr RegListGet(int64_t slot) const { // use 128 bits to represent any @@ -231,8 +231,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { Call call = ffi::GetRef(call_node); if (call_node->op == null_value_op_) { - return tirx::Call(DataType::Handle(), tirx::builtin::reinterpret(), - {IntImm(DataType::Int(64), 0)}); + return tirx::Call(DataType::Handle(), tirx::builtin::reinterpret(), {IntImm::Int64(0)}); } int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister(); if (call->op.as()) { diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index ad653087a088..d44e6ae42ce1 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -472,7 +472,7 @@ PrimExpr DFPatternMatcher::SimplifyCondition(PrimExpr condition) { constraints.begin(), constraints.end(), [&sort_key](const PrimExpr& a, const PrimExpr& b) { return sort_key(a) < sort_key(b); }); - PrimExpr sorted_condition = tirx::const_true(); + PrimExpr sorted_condition = IntImm::Bool(true); for (const PrimExpr& constraint : constraints) { sorted_condition = sorted_condition && constraint; } @@ -505,7 +505,7 @@ std::tuple SameShapeConstraintNode::AsPrimExpr( bool all_shapes_defined = true; // The expression that must be true in order - PrimExpr all_dimensions_equal = IntImm(DataType::Bool(), 1); + PrimExpr all_dimensions_equal = IntImm::Bool(true); for (const auto& arg : args) { if (auto opt_var = match_state(arg.get())) { @@ -524,7 +524,7 @@ std::tuple SameShapeConstraintNode::AsPrimExpr( if (!opt_var_shape.defined()) { // The pattern has matched to something without a shape. // Therefore, it cannot have the same shape as something else. - return {PrimExpr(IntImm(DataType::Bool(), 0)), true}; + return {PrimExpr(IntImm::Bool(false)), true}; } auto var_shape = opt_var_shape.value(); @@ -541,7 +541,7 @@ std::tuple SameShapeConstraintNode::AsPrimExpr( // The shapes have different dimensionality. No need to // perform potentially-expensive simplifications, because // the dimensions do not match. - return {PrimExpr(IntImm(DataType::Bool(), 0)), true}; + return {PrimExpr(IntImm::Bool(false)), true}; } } else { diff --git a/src/relax/ir/dataflow_matcher.h b/src/relax/ir/dataflow_matcher.h index e4006e2bc4bb..e9833d9b297b 100644 --- a/src/relax/ir/dataflow_matcher.h +++ b/src/relax/ir/dataflow_matcher.h @@ -94,7 +94,7 @@ class DFPatternMatcher : public DFPatternFunctor memo_; var2val_t var2val_; std::vector matched_nodes_; - PrimExpr symbolic_expr_condition_{IntImm(DataType::Bool(), 1)}; + PrimExpr symbolic_expr_condition_{IntImm::Bool(true)}; arith::Analyzer analyzer_; bool memoize_ = true; }; diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index f5b0c4474d33..e8b99a21ddcd 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -49,7 +49,7 @@ te::Tensor TETensor(Expr value, ffi::Map tir_var_map, std:: ffi::Array shape; shape.reserve(ndim); for (int i = 0; i < ndim; ++i) { - shape.push_back(IntImm(DataType::Int(64), shape_tuple[i])); + shape.push_back(IntImm::Int64(shape_tuple[i])); } n->shape = std::move(shape); return te::PlaceholderOp(n).output(0); diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index cec5ae65fbc2..a6fd7636f15f 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -342,7 +342,7 @@ Constant::Constant(runtime::Tensor data, ffi::Optional struct_info_a ffi::Array values; auto shape_tuple = n->data.Shape(); for (size_t dim = 0; dim < shape_tuple.size(); ++dim) { - values.push_back(IntImm(DataType::Int(64), shape_tuple[dim])); + values.push_back(IntImm::Int64(shape_tuple[dim])); } if (struct_info_annotation.defined()) { n->struct_info_ = struct_info_annotation.value(); @@ -371,7 +371,7 @@ PrimValue::PrimValue(PrimExpr value, Span span) { } PrimValue PrimValue::Int64(int64_t value, Span span) { - return PrimValue(IntImm(DataType::Int(64), value), span); + return PrimValue(IntImm::Int64(value), span); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index b69f58ebb7af..e9995fa31d08 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -799,7 +799,7 @@ Expr ExprMutator::VisitWithNewScope(const Expr& expr, ffi::OptionalIsInstance()) << "Normal form requires all new scope is stored as SeqExpr"; - PrimExpr constraint = IntImm(DataType::Bool(), 1); + PrimExpr constraint = IntImm::Bool(true); if (params.defined()) { auto non_negative_expressions = CollectNonNegativeExpressions(TupleStructInfo(params.value().Map(GetStructInfo))); diff --git a/src/relax/op/distributed/statistical.cc b/src/relax/op/distributed/statistical.cc index fe6439188b92..5384219f884b 100644 --- a/src/relax/op/distributed/statistical.cc +++ b/src/relax/op/distributed/statistical.cc @@ -68,7 +68,7 @@ StructInfo InferDistStructInfoStatistical(const Call& call, const BlockBuilder& if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == axes.end()) { out_shape.push_back(data_shape->values[i]); } else if (attrs->keepdims) { - out_shape.push_back(IntImm(DataType::Int(64), /*value=*/1)); + out_shape.push_back(IntImm::Int64(/*value=*/1)); } } TVM_FFI_ICHECK_EQ(static_cast(out_shape.size()), out_ndim); diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc index 91ae6bf5961c..1b84f3dfc8e3 100644 --- a/src/relax/op/image/resize.cc +++ b/src/relax/op/image/resize.cc @@ -414,10 +414,10 @@ StructInfo InferStructInfoAffineGrid(const Call& call, const BlockBuilder& ctx) // Output shape: [batch, 2, target_height, target_width] ffi::Array out_shape; - out_shape.push_back(data_shape->values[0]); // batch - out_shape.push_back(IntImm(DataType::Int(64), 2)); // 2 (spatial dimensions) - out_shape.push_back(size_value->values[0]); // target_height - out_shape.push_back(size_value->values[1]); // target_width + out_shape.push_back(data_shape->values[0]); // batch + out_shape.push_back(IntImm::Int64(2)); // 2 (spatial dimensions) + out_shape.push_back(size_value->values[0]); // target_height + out_shape.push_back(size_value->values[1]); // target_width return TensorStructInfo(ShapeExpr(out_shape), out_dtype, data_sinfo->vdevice); } diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index 04eba0dbe6fc..10b42f8002c2 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -129,7 +129,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { if (HasVoidStructInfo(arg_relative_byte_offset)) { // No byte offset is specified, so no change is applied. - return IntImm(DataType::Int(64), 0); + return IntImm::Int64(0); } else if (auto prim_sinfo = sinfo.as()) { TVM_FFI_CHECK_EQ(prim_sinfo->dtype, DataType::Int(64), TypeError) << "Operator " << call->op @@ -177,7 +177,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { return std::nullopt; } else { auto size_bits = dtype.bits() * dtype.lanes(); - return IntImm(DataType::Int(64), (size_bits + 7) / 8); + return IntImm::Int64((size_bits + 7) / 8); } }; @@ -189,7 +189,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { return std::nullopt; } - PrimExpr num_elements = IntImm(DataType::Int(32), 1); + PrimExpr num_elements = IntImm::Int32(1); for (const auto& dim : shape.value()) { num_elements *= dim; } diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index a1239f2c681a..9fe4a8e84da1 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -128,8 +128,7 @@ StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { PrimExpr input_w = data_NCW_shape[2]; PrimExpr kernel_w = weight_OIW_shape[2]; - PrimExpr padding_w = - IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[1]); + PrimExpr padding_w = IntImm::Int32(attrs->padding[0]) + IntImm::Int32(attrs->padding[1]); std::vector out_NCW_shape; out_NCW_shape.resize(3); @@ -137,9 +136,9 @@ StructInfo InferStructInfoConv1d(const Call& call, const BlockBuilder& ctx) { out_NCW_shape[1] = weight_OIW_shape[0]; PrimExpr numerator_w = - input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_w - 1) - 1; + input_w + padding_w - IntImm::Int32(attrs->dilation[0]) * (kernel_w - 1) - 1; out_NCW_shape[2] = - analyzer->Simplify(floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[0])) + 1); + analyzer->Simplify(floordiv(numerator_w, IntImm::Int32(attrs->strides[0])) + 1); ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); @@ -301,10 +300,8 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { PrimExpr input_w = data_NCHW_shape[3]; PrimExpr kernel_h = weight_OIHW_shape[2]; PrimExpr kernel_w = weight_OIHW_shape[3]; - PrimExpr padding_h = - IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[2]); - PrimExpr padding_w = - IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[3]); + PrimExpr padding_h = IntImm::Int32(attrs->padding[0]) + IntImm::Int32(attrs->padding[2]); + PrimExpr padding_w = IntImm::Int32(attrs->padding[1]) + IntImm::Int32(attrs->padding[3]); std::vector out_NCHW_shape; out_NCHW_shape.resize(4); @@ -312,13 +309,13 @@ StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) { out_NCHW_shape[1] = weight_OIHW_shape[0]; PrimExpr numerator_h = - input_h + padding_h - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_h - 1) - 1; + input_h + padding_h - IntImm::Int32(attrs->dilation[0]) * (kernel_h - 1) - 1; PrimExpr numerator_w = - input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_w - 1) - 1; + input_w + padding_w - IntImm::Int32(attrs->dilation[1]) * (kernel_w - 1) - 1; out_NCHW_shape[2] = - analyzer->Simplify(floordiv(numerator_h, IntImm(DataType::Int(32), attrs->strides[0])) + 1); + analyzer->Simplify(floordiv(numerator_h, IntImm::Int32(attrs->strides[0])) + 1); out_NCHW_shape[3] = - analyzer->Simplify(floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[1])) + 1); + analyzer->Simplify(floordiv(numerator_w, IntImm::Int32(attrs->strides[1])) + 1); ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); @@ -518,12 +515,9 @@ StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { PrimExpr kernel_d = weight_OIDHW_shape[2]; PrimExpr kernel_h = weight_OIDHW_shape[3]; PrimExpr kernel_w = weight_OIDHW_shape[4]; - PrimExpr padding_d = - IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[3]); - PrimExpr padding_h = - IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[4]); - PrimExpr padding_w = - IntImm(DataType::Int(32), attrs->padding[2]) + IntImm(DataType::Int(32), attrs->padding[5]); + PrimExpr padding_d = IntImm::Int32(attrs->padding[0]) + IntImm::Int32(attrs->padding[3]); + PrimExpr padding_h = IntImm::Int32(attrs->padding[1]) + IntImm::Int32(attrs->padding[4]); + PrimExpr padding_w = IntImm::Int32(attrs->padding[2]) + IntImm::Int32(attrs->padding[5]); std::vector out_NCDHW_shape; out_NCDHW_shape.resize(5); @@ -531,17 +525,17 @@ StructInfo InferStructInfoConv3d(const Call& call, const BlockBuilder& ctx) { out_NCDHW_shape[1] = weight_OIDHW_shape[0]; PrimExpr numerator_d = - input_d + padding_d - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_d - 1) - 1; + input_d + padding_d - IntImm::Int32(attrs->dilation[0]) * (kernel_d - 1) - 1; PrimExpr numerator_h = - input_h + padding_h - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_h - 1) - 1; + input_h + padding_h - IntImm::Int32(attrs->dilation[1]) * (kernel_h - 1) - 1; PrimExpr numerator_w = - input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[2]) * (kernel_w - 1) - 1; + input_w + padding_w - IntImm::Int32(attrs->dilation[2]) * (kernel_w - 1) - 1; out_NCDHW_shape[2] = - analyzer->Simplify(floordiv(numerator_d, IntImm(DataType::Int(32), attrs->strides[0])) + 1); + analyzer->Simplify(floordiv(numerator_d, IntImm::Int32(attrs->strides[0])) + 1); out_NCDHW_shape[3] = - analyzer->Simplify(floordiv(numerator_h, IntImm(DataType::Int(32), attrs->strides[1])) + 1); + analyzer->Simplify(floordiv(numerator_h, IntImm::Int32(attrs->strides[1])) + 1); out_NCDHW_shape[4] = - analyzer->Simplify(floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[2])) + 1); + analyzer->Simplify(floordiv(numerator_w, IntImm::Int32(attrs->strides[2])) + 1); ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice); @@ -714,17 +708,16 @@ StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& PrimExpr input_w = data_NCW_shape[2]; PrimExpr kernel_w = weight_IOW_shape[2]; - PrimExpr padding_w = - IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[1]); + PrimExpr padding_w = IntImm::Int32(attrs->padding[0]) + IntImm::Int32(attrs->padding[1]); std::vector out_NCW_shape; out_NCW_shape.resize(3); out_NCW_shape[0] = data_NCW_shape[0]; out_NCW_shape[1] = weight_IOW_shape[1] * attrs->groups; - PrimExpr out_w = (input_w - 1) * IntImm(DataType::Int(32), attrs->strides[0]) - padding_w + - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_w - 1) + - IntImm(DataType::Int(32), attrs->output_padding[0]) + 1; + PrimExpr out_w = (input_w - 1) * IntImm::Int32(attrs->strides[0]) - padding_w + + IntImm::Int32(attrs->dilation[0]) * (kernel_w - 1) + + IntImm::Int32(attrs->output_padding[0]) + 1; out_NCW_shape[2] = analyzer->Simplify(out_w); ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); @@ -907,22 +900,20 @@ StructInfo InferStructInfoConv2dTranspose(const Call& call, const BlockBuilder& PrimExpr input_w = data_NCHW_shape[3]; PrimExpr kernel_h = weight_IOHW_shape[2]; PrimExpr kernel_w = weight_IOHW_shape[3]; - PrimExpr padding_h = - IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[2]); - PrimExpr padding_w = - IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[3]); + PrimExpr padding_h = IntImm::Int32(attrs->padding[0]) + IntImm::Int32(attrs->padding[2]); + PrimExpr padding_w = IntImm::Int32(attrs->padding[1]) + IntImm::Int32(attrs->padding[3]); std::vector out_NCHW_shape; out_NCHW_shape.resize(4); out_NCHW_shape[0] = data_NCHW_shape[0]; out_NCHW_shape[1] = weight_IOHW_shape[1] * attrs->groups; - PrimExpr out_h = (input_h - 1) * IntImm(DataType::Int(32), attrs->strides[0]) - padding_h + - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_h - 1) + - IntImm(DataType::Int(32), attrs->output_padding[0]) + 1; - PrimExpr out_w = (input_w - 1) * IntImm(DataType::Int(32), attrs->strides[1]) - padding_w + - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_w - 1) + - IntImm(DataType::Int(32), attrs->output_padding[1]) + 1; + PrimExpr out_h = (input_h - 1) * IntImm::Int32(attrs->strides[0]) - padding_h + + IntImm::Int32(attrs->dilation[0]) * (kernel_h - 1) + + IntImm::Int32(attrs->output_padding[0]) + 1; + PrimExpr out_w = (input_w - 1) * IntImm::Int32(attrs->strides[1]) - padding_w + + IntImm::Int32(attrs->dilation[1]) * (kernel_w - 1) + + IntImm::Int32(attrs->output_padding[1]) + 1; out_NCHW_shape[2] = analyzer->Simplify(out_h); out_NCHW_shape[3] = analyzer->Simplify(out_w); @@ -1144,27 +1135,24 @@ StructInfo InferStructInfoConv3dTranspose(const Call& call, const BlockBuilder& PrimExpr kernel_d = weight_IODHW_shape[2]; PrimExpr kernel_h = weight_IODHW_shape[3]; PrimExpr kernel_w = weight_IODHW_shape[4]; - PrimExpr padding_d = - IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[3]); - PrimExpr padding_h = - IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[4]); - PrimExpr padding_w = - IntImm(DataType::Int(32), attrs->padding[2]) + IntImm(DataType::Int(32), attrs->padding[5]); + PrimExpr padding_d = IntImm::Int32(attrs->padding[0]) + IntImm::Int32(attrs->padding[3]); + PrimExpr padding_h = IntImm::Int32(attrs->padding[1]) + IntImm::Int32(attrs->padding[4]); + PrimExpr padding_w = IntImm::Int32(attrs->padding[2]) + IntImm::Int32(attrs->padding[5]); std::vector out_NCDHW_shape; out_NCDHW_shape.resize(5); out_NCDHW_shape[0] = data_NCDHW_shape[0]; out_NCDHW_shape[1] = weight_IODHW_shape[1] * attrs->groups; - PrimExpr out_d = (input_d - 1) * IntImm(DataType::Int(32), attrs->strides[0]) - padding_d + - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_d - 1) + - IntImm(DataType::Int(32), attrs->output_padding[0]) + 1; - PrimExpr out_h = (input_h - 1) * IntImm(DataType::Int(32), attrs->strides[1]) - padding_h + - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_h - 1) + - IntImm(DataType::Int(32), attrs->output_padding[1]) + 1; - PrimExpr out_w = (input_w - 1) * IntImm(DataType::Int(32), attrs->strides[2]) - padding_w + - IntImm(DataType::Int(32), attrs->dilation[2]) * (kernel_w - 1) + - IntImm(DataType::Int(32), attrs->output_padding[2]) + 1; + PrimExpr out_d = (input_d - 1) * IntImm::Int32(attrs->strides[0]) - padding_d + + IntImm::Int32(attrs->dilation[0]) * (kernel_d - 1) + + IntImm::Int32(attrs->output_padding[0]) + 1; + PrimExpr out_h = (input_h - 1) * IntImm::Int32(attrs->strides[1]) - padding_h + + IntImm::Int32(attrs->dilation[1]) * (kernel_h - 1) + + IntImm::Int32(attrs->output_padding[1]) + 1; + PrimExpr out_w = (input_w - 1) * IntImm::Int32(attrs->strides[2]) - padding_w + + IntImm::Int32(attrs->dilation[2]) * (kernel_w - 1) + + IntImm::Int32(attrs->output_padding[2]) + 1; out_NCDHW_shape[2] = analyzer->Simplify(out_d); out_NCDHW_shape[3] = analyzer->Simplify(out_h); out_NCDHW_shape[4] = analyzer->Simplify(out_w); diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index d036bff48ec5..0363da159b03 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -280,7 +280,7 @@ StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) { const auto* data_shape = input_sinfo[0]->shape.as(); for (int i = 0; i < ndim; i++) { // Sum pad width for this axis. - PrimExpr added_width = IntImm(DataType::Int(64), pad_width[2 * i] + pad_width[(2 * i) + 1]); + PrimExpr added_width = IntImm::Int64(pad_width[2 * i] + pad_width[(2 * i) + 1]); const PrimExpr current_width = data_shape->values[i]; out_shape.push_back(current_width + added_width); } @@ -337,7 +337,7 @@ StructInfo InferStructInfoPixelShuffle(const Call& call, const BlockBuilder& ctx PrimExpr h_in = in_shape[h_idx]; PrimExpr w_in = in_shape[w_idx]; - PrimExpr r_expr = IntImm(DataType::Int(32), r); + PrimExpr r_expr = IntImm::Int32(r); PrimExpr r_squared = r_expr * r_expr; const auto* c_in_imm = c_in.as(); @@ -1214,7 +1214,7 @@ StructInfo InferStructInfoBatchFlatten(const Call& call, const BlockBuilder& ctx } PrimExpr batch_dim = data_shape->values[0]; - PrimExpr flat_dim = IntImm(DataType::Int(64), 1); + PrimExpr flat_dim = IntImm::Int64(1); for (size_t i = 1; i < data_shape->values.size(); ++i) { flat_dim = flat_dim * data_shape->values[i]; } diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index dcf44eebba80..ca963010c3b6 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -99,9 +99,8 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); PrimExpr input_w = data_NCW_shape[2]; - PrimExpr kernel_w = IntImm(DataType::Int(32), attrs->pool_size[0]); - PrimExpr padding_w = - IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[1]); + PrimExpr kernel_w = IntImm::Int32(attrs->pool_size[0]); + PrimExpr padding_w = IntImm::Int32(attrs->padding[0]) + IntImm::Int32(attrs->padding[1]); arith::Analyzer analyzer = ctx->GetAnalyzer(); std::vector out_NCW_shape; @@ -110,14 +109,14 @@ StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) { out_NCW_shape[1] = data_NCW_shape[1]; PrimExpr numerator_w = - input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_w - 1) - 1; + input_w + padding_w - IntImm::Int32(attrs->dilation[0]) * (kernel_w - 1) - 1; if (attrs->ceil_mode) { - numerator_w += IntImm(DataType::Int(32), attrs->strides[0]) - 1; + numerator_w += IntImm::Int32(attrs->strides[0]) - 1; } - PrimExpr raw_out_w = floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[0])) + 1; + PrimExpr raw_out_w = floordiv(numerator_w, IntImm::Int32(attrs->strides[0])) + 1; if (attrs->ceil_mode) { - PrimExpr invalid_last_w = (raw_out_w - 1) * IntImm(DataType::Int(32), attrs->strides[0]) >= - input_w + IntImm(DataType::Int(32), attrs->padding[0]); + PrimExpr invalid_last_w = (raw_out_w - 1) * IntImm::Int32(attrs->strides[0]) >= + input_w + IntImm::Int32(attrs->padding[0]); out_NCW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_w, raw_out_w - 1, raw_out_w)); } else { out_NCW_shape[2] = analyzer->Simplify(raw_out_w); @@ -225,12 +224,10 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { PrimExpr input_h = data_NCHW_shape[2]; PrimExpr input_w = data_NCHW_shape[3]; - PrimExpr kernel_h = IntImm(DataType::Int(32), attrs->pool_size[0]); - PrimExpr kernel_w = IntImm(DataType::Int(32), attrs->pool_size[1]); - PrimExpr padding_h = - IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[2]); - PrimExpr padding_w = - IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[3]); + PrimExpr kernel_h = IntImm::Int32(attrs->pool_size[0]); + PrimExpr kernel_w = IntImm::Int32(attrs->pool_size[1]); + PrimExpr padding_h = IntImm::Int32(attrs->padding[0]) + IntImm::Int32(attrs->padding[2]); + PrimExpr padding_w = IntImm::Int32(attrs->padding[1]) + IntImm::Int32(attrs->padding[3]); arith::Analyzer analyzer = ctx->GetAnalyzer(); std::vector out_NCHW_shape; @@ -239,20 +236,20 @@ StructInfo InferStructInfoPool2D(const Call& call, const BlockBuilder& ctx) { out_NCHW_shape[1] = data_NCHW_shape[1]; PrimExpr numerator_h = - input_h + padding_h - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_h - 1) - 1; + input_h + padding_h - IntImm::Int32(attrs->dilation[0]) * (kernel_h - 1) - 1; PrimExpr numerator_w = - input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_w - 1) - 1; + input_w + padding_w - IntImm::Int32(attrs->dilation[1]) * (kernel_w - 1) - 1; if (attrs->ceil_mode) { - numerator_h += IntImm(DataType::Int(32), attrs->strides[0]) - 1; - numerator_w += IntImm(DataType::Int(32), attrs->strides[1]) - 1; + numerator_h += IntImm::Int32(attrs->strides[0]) - 1; + numerator_w += IntImm::Int32(attrs->strides[1]) - 1; } - PrimExpr raw_out_h = floordiv(numerator_h, IntImm(DataType::Int(32), attrs->strides[0])) + 1; - PrimExpr raw_out_w = floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[1])) + 1; + PrimExpr raw_out_h = floordiv(numerator_h, IntImm::Int32(attrs->strides[0])) + 1; + PrimExpr raw_out_w = floordiv(numerator_w, IntImm::Int32(attrs->strides[1])) + 1; if (attrs->ceil_mode) { - PrimExpr invalid_last_h = (raw_out_h - 1) * IntImm(DataType::Int(32), attrs->strides[0]) >= - input_h + IntImm(DataType::Int(32), attrs->padding[0]); - PrimExpr invalid_last_w = (raw_out_w - 1) * IntImm(DataType::Int(32), attrs->strides[1]) >= - input_w + IntImm(DataType::Int(32), attrs->padding[1]); + PrimExpr invalid_last_h = (raw_out_h - 1) * IntImm::Int32(attrs->strides[0]) >= + input_h + IntImm::Int32(attrs->padding[0]); + PrimExpr invalid_last_w = (raw_out_w - 1) * IntImm::Int32(attrs->strides[1]) >= + input_w + IntImm::Int32(attrs->padding[1]); out_NCHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_h, raw_out_h - 1, raw_out_h)); out_NCHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_w, raw_out_w - 1, raw_out_w)); } else { @@ -384,15 +381,12 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { PrimExpr input_d = data_NCDHW_shape[2]; PrimExpr input_h = data_NCDHW_shape[3]; PrimExpr input_w = data_NCDHW_shape[4]; - PrimExpr kernel_d = IntImm(DataType::Int(32), attrs->pool_size[0]); - PrimExpr kernel_h = IntImm(DataType::Int(32), attrs->pool_size[1]); - PrimExpr kernel_w = IntImm(DataType::Int(32), attrs->pool_size[2]); - PrimExpr padding_d = - IntImm(DataType::Int(32), attrs->padding[0]) + IntImm(DataType::Int(32), attrs->padding[3]); - PrimExpr padding_h = - IntImm(DataType::Int(32), attrs->padding[1]) + IntImm(DataType::Int(32), attrs->padding[4]); - PrimExpr padding_w = - IntImm(DataType::Int(32), attrs->padding[2]) + IntImm(DataType::Int(32), attrs->padding[5]); + PrimExpr kernel_d = IntImm::Int32(attrs->pool_size[0]); + PrimExpr kernel_h = IntImm::Int32(attrs->pool_size[1]); + PrimExpr kernel_w = IntImm::Int32(attrs->pool_size[2]); + PrimExpr padding_d = IntImm::Int32(attrs->padding[0]) + IntImm::Int32(attrs->padding[3]); + PrimExpr padding_h = IntImm::Int32(attrs->padding[1]) + IntImm::Int32(attrs->padding[4]); + PrimExpr padding_w = IntImm::Int32(attrs->padding[2]) + IntImm::Int32(attrs->padding[5]); arith::Analyzer analyzer = ctx->GetAnalyzer(); std::vector out_NCDHW_shape; @@ -401,26 +395,26 @@ StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) { out_NCDHW_shape[1] = data_NCDHW_shape[1]; PrimExpr numerator_d = - input_d + padding_d - IntImm(DataType::Int(32), attrs->dilation[0]) * (kernel_d - 1) - 1; + input_d + padding_d - IntImm::Int32(attrs->dilation[0]) * (kernel_d - 1) - 1; PrimExpr numerator_h = - input_h + padding_h - IntImm(DataType::Int(32), attrs->dilation[1]) * (kernel_h - 1) - 1; + input_h + padding_h - IntImm::Int32(attrs->dilation[1]) * (kernel_h - 1) - 1; PrimExpr numerator_w = - input_w + padding_w - IntImm(DataType::Int(32), attrs->dilation[2]) * (kernel_w - 1) - 1; + input_w + padding_w - IntImm::Int32(attrs->dilation[2]) * (kernel_w - 1) - 1; if (attrs->ceil_mode) { - numerator_d += IntImm(DataType::Int(32), attrs->strides[0]) - 1; - numerator_h += IntImm(DataType::Int(32), attrs->strides[1]) - 1; - numerator_w += IntImm(DataType::Int(32), attrs->strides[2]) - 1; + numerator_d += IntImm::Int32(attrs->strides[0]) - 1; + numerator_h += IntImm::Int32(attrs->strides[1]) - 1; + numerator_w += IntImm::Int32(attrs->strides[2]) - 1; } - PrimExpr raw_out_d = floordiv(numerator_d, IntImm(DataType::Int(32), attrs->strides[0])) + 1; - PrimExpr raw_out_h = floordiv(numerator_h, IntImm(DataType::Int(32), attrs->strides[1])) + 1; - PrimExpr raw_out_w = floordiv(numerator_w, IntImm(DataType::Int(32), attrs->strides[2])) + 1; + PrimExpr raw_out_d = floordiv(numerator_d, IntImm::Int32(attrs->strides[0])) + 1; + PrimExpr raw_out_h = floordiv(numerator_h, IntImm::Int32(attrs->strides[1])) + 1; + PrimExpr raw_out_w = floordiv(numerator_w, IntImm::Int32(attrs->strides[2])) + 1; if (attrs->ceil_mode) { - PrimExpr invalid_last_d = (raw_out_d - 1) * IntImm(DataType::Int(32), attrs->strides[0]) >= - input_d + IntImm(DataType::Int(32), attrs->padding[0]); - PrimExpr invalid_last_h = (raw_out_h - 1) * IntImm(DataType::Int(32), attrs->strides[1]) >= - input_h + IntImm(DataType::Int(32), attrs->padding[1]); - PrimExpr invalid_last_w = (raw_out_w - 1) * IntImm(DataType::Int(32), attrs->strides[2]) >= - input_w + IntImm(DataType::Int(32), attrs->padding[2]); + PrimExpr invalid_last_d = (raw_out_d - 1) * IntImm::Int32(attrs->strides[0]) >= + input_d + IntImm::Int32(attrs->padding[0]); + PrimExpr invalid_last_h = (raw_out_h - 1) * IntImm::Int32(attrs->strides[1]) >= + input_h + IntImm::Int32(attrs->padding[1]); + PrimExpr invalid_last_w = (raw_out_w - 1) * IntImm::Int32(attrs->strides[2]) >= + input_w + IntImm::Int32(attrs->padding[2]); out_NCDHW_shape[2] = analyzer->Simplify(if_then_else(invalid_last_d, raw_out_d - 1, raw_out_d)); out_NCDHW_shape[3] = analyzer->Simplify(if_then_else(invalid_last_h, raw_out_h - 1, raw_out_h)); out_NCDHW_shape[4] = analyzer->Simplify(if_then_else(invalid_last_w, raw_out_w - 1, raw_out_w)); @@ -575,7 +569,7 @@ StructInfo InferStructInfoAdaptiveAvgPool1D(const Call& call, const BlockBuilder ffi::Array data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values); ffi::Array out_NCW_shape(data_NCW_shape); if (attrs->output_size.defined()) { - out_NCW_shape.Set(2, IntImm(DataType::Int(32), attrs->output_size.value()[0])); + out_NCW_shape.Set(2, IntImm::Int32(attrs->output_size.value()[0])); } ffi::Array out_shape = out2NCW.BackwardShape(out_NCW_shape); @@ -660,8 +654,8 @@ StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const BlockBuilder ffi::Array data_NCHW_shape = data2NCHW.ForwardShape(data_shape.value()->values); ffi::Array out_NCHW_shape(data_NCHW_shape); if (attrs->output_size.defined()) { - out_NCHW_shape.Set(2, IntImm(DataType::Int(32), attrs->output_size.value()[0])); - out_NCHW_shape.Set(3, IntImm(DataType::Int(32), attrs->output_size.value()[1])); + out_NCHW_shape.Set(2, IntImm::Int32(attrs->output_size.value()[0])); + out_NCHW_shape.Set(3, IntImm::Int32(attrs->output_size.value()[1])); } ffi::Array out_shape = out2NCHW.BackwardShape(out_NCHW_shape); @@ -762,9 +756,9 @@ StructInfo InferStructInfoAdaptiveAvgPool3D(const Call& call, const BlockBuilder ffi::Array data_NCDHW_shape = data2NCDHW.ForwardShape(data_shape.value()->values); ffi::Array out_NCDHW_shape(data_NCDHW_shape); if (attrs->output_size.defined()) { - out_NCDHW_shape.Set(2, IntImm(DataType::Int(32), attrs->output_size.value()[0])); - out_NCDHW_shape.Set(3, IntImm(DataType::Int(32), attrs->output_size.value()[1])); - out_NCDHW_shape.Set(4, IntImm(DataType::Int(32), attrs->output_size.value()[2])); + out_NCDHW_shape.Set(2, IntImm::Int32(attrs->output_size.value()[0])); + out_NCDHW_shape.Set(3, IntImm::Int32(attrs->output_size.value()[1])); + out_NCDHW_shape.Set(4, IntImm::Int32(attrs->output_size.value()[2])); } ffi::Array out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape); diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 6470183baf4f..85c71641f4f1 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -91,7 +91,7 @@ StructInfo InferStructInfoBroadcast(const Call& call, const BlockBuilder& ctx, auto get_shape = [](const StructInfo& sinfo) -> ffi::Optional> { if (sinfo.as()) { - return ffi::Array{IntImm(DataType::Int(64), 1)}; + return ffi::Array{IntImm::Int64(1)}; } else if (const auto* tensor = sinfo.as()) { return tensor->GetShape(); } else { diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index 3088df64e158..c09dc1050107 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -387,7 +387,7 @@ StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx strides_tuple = opt_strides_tuple.value(); } else { - strides_tuple = ffi::Array(axes_tuple.size(), IntImm(DataType::Int(64), 1)); + strides_tuple = ffi::Array(axes_tuple.size(), IntImm::Int64(1)); } TVM_FFI_ICHECK_EQ(axes_tuple.size(), strides_tuple.size()) @@ -475,7 +475,7 @@ InferLayoutOutput InferLayoutStridedSlice( } return InferLayoutOutput({existing_layout}, {existing_layout}, call->attrs, - {{IntImm(DataType::Int(32), 1), relax::Tuple(new_axes)}}); + {{IntImm::Int32(1), relax::Tuple(new_axes)}}); } TVM_REGISTER_OP("relax.strided_slice") diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc index 3988e0ba2359..53ee3b18eafa 100644 --- a/src/relax/op/tensor/inspect.cc +++ b/src/relax/op/tensor/inspect.cc @@ -93,11 +93,10 @@ tirx::PrimFunc GetDLTensorField(tirx::builtin::TVMStructFieldKind field, DataTyp tirx::Var value("value", field_dtype); - tirx::Stmt body = - tirx::SeqStmt({tirx::Bind(value, tirx::Call(field_dtype, tirx::builtin::tvm_struct_get(), - {dlpack_handle, IntImm(DataType::Int(32), 0), - IntImm(DataType::Int(32), field)})), - tirx::Evaluate(tvm::ret(value))}); + tirx::Stmt body = tirx::SeqStmt( + {tirx::Bind(value, tirx::Call(field_dtype, tirx::builtin::tvm_struct_get(), + {dlpack_handle, IntImm::Int32(0), IntImm::Int32(field)})), + tirx::Evaluate(tvm::ret(value))}); DictAttrs attrs({{"tirx.is_scheduled", true}, {"tirx.is_host_func", true}}); @@ -309,19 +308,18 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { tirx::Stmt body = tirx::SeqStmt( {tirx::AssertStmt(0 <= axis, tirx::StringImm("RuntimeError"), {tirx::StringImm("Specified axis may not be negative")}), - tirx::Bind(ndim, tirx::Call(ndim->dtype, tirx::builtin::tvm_struct_get(), - {dlpack_handle, IntImm(DataType::Int(32), 0), - IntImm(DataType::Int(32), - tirx::builtin::TVMStructFieldKind::kDLTensorNDim)})), + tirx::Bind(ndim, + tirx::Call(ndim->dtype, tirx::builtin::tvm_struct_get(), + {dlpack_handle, IntImm::Int32(0), + IntImm::Int32(tirx::builtin::TVMStructFieldKind::kDLTensorNDim)})), tirx::AssertStmt( axis < tvm::cast(axis->dtype, ndim), tirx::StringImm("RuntimeError"), {tirx::StringImm( "Specified axis may not be larger than the tensor's dimensionality")}), tirx::Bind(shape_buffer->data, tirx::Call(DataType::Handle(), tirx::builtin::tvm_struct_get(), - {dlpack_handle, IntImm(DataType::Int(32), 0), - IntImm(DataType::Int(32), - tirx::builtin::TVMStructFieldKind::kDLTensorShape)})), + {dlpack_handle, IntImm::Int32(0), + IntImm::Int32(tirx::builtin::TVMStructFieldKind::kDLTensorShape)})), tirx::DeclBuffer(shape_buffer), tirx::Bind(extent, tirx::BufferLoad(shape_buffer, {axis})), tirx::Evaluate(tvm::ret(extent))}); @@ -379,7 +377,7 @@ StructInfo InferStructInfoTensorStride(const Call& call, const BlockBuilder&) { // striding of a tensor, it implicitly requires compact striding // for any legalizable Tensor. auto tensor_shape = opt_tensor_shape.value(); - PrimExpr stride = IntImm(DataType::Int(64), 1); + PrimExpr stride = IntImm::Int64(1); for (size_t axis = int_imm_axis->value + 1; axis < tensor_shape.size(); axis++) { stride = stride * tensor_shape[axis]; } diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 0d1adc939e4d..85936fae3fb2 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -168,11 +168,11 @@ ffi::Optional> CheckConcatOutputShape( return structural_equal(a[axis], first_concat_dim); }); if (all_same) { - return first_concat_dim * IntImm(DataType::Int(64), shape_values.size()); + return first_concat_dim * IntImm::Int64(shape_values.size()); } // General case, add up the dimensions along the specified axis. - PrimExpr concat_sum = IntImm(DataType::Int(64), 0); + PrimExpr concat_sum = IntImm::Int64(0); for (ffi::Array shape_value : shape_values) { concat_sum += shape_value[axis]; } @@ -439,7 +439,7 @@ StructInfo InferStructInfoExpandDims(const Call& call, const BlockBuilder& ctx) std::vector output_shape; output_shape.resize(output_ndim, PrimExpr()); for (int i = 0; i < n_new_dim; ++i) { - output_shape[axes[i]] = IntImm(DataType::Int(64), 1); + output_shape[axes[i]] = IntImm::Int64(1); } int i_data_shape = 0; @@ -507,7 +507,7 @@ TVM_REGISTER_OP("relax.expand_dims") // Helper function for flatten and reshape. PrimExpr ComputeShapeProduct(const ffi::Array& shape_values) { - PrimExpr shape_prod = IntImm(DataType::Int(64), 1); + PrimExpr shape_prod = IntImm::Int64(1); for (PrimExpr value : shape_values) { shape_prod *= value; } @@ -623,7 +623,7 @@ StructInfo InferStructInfoIndexTensor(const Call& call, const BlockBuilder& ctx) // initialise broadcast result with 1's ffi::Array out_shape; for (int i = 0; i < max_index_ndim; ++i) { - out_shape.push_back(IntImm(DataType::Int(64), 1)); + out_shape.push_back(IntImm::Int64(1)); } for (const auto& ishape : index_shapes) { @@ -973,7 +973,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, // Set any -1 dimensions to complete the number of appropriate elements. // Start by computing the shape product of all positive indices. - PrimExpr new_shape_prod = IntImm(DataType::Int(64), 1); + PrimExpr new_shape_prod = IntImm::Int64(1); for (int i = 0; i < static_cast(array_ref.size()); ++i) { PrimExpr new_dim = array_ref[i]; const auto* int_dim = new_dim.as(); @@ -1076,7 +1076,7 @@ Expr split(Expr x, ffi::Variant> indices_or_sections, << "Split op expects the input number of sections to be a " "positive integer. However, the given number of sections is " << n_section->value; - indices_or_sections_obj = IntImm(DataType::Int(64), n_section->value); + indices_or_sections_obj = IntImm::Int64(n_section->value); } else { TVM_FFI_THROW(InternalError) << "Split op expects the input indices_or_sections to be either an Array of " @@ -1478,7 +1478,7 @@ ffi::Optional> CheckStackOutputShape( for (int i = 0; i < axis; ++i) { output_shape.push_back(shape_values[0][i]); } - output_shape.push_back(IntImm(DataType::Int(64), shape_values.size())); // Stack dimension + output_shape.push_back(IntImm::Int64(shape_values.size())); // Stack dimension for (int i = axis; i < static_cast(shape_values[0].size()); ++i) { output_shape.push_back(shape_values[0][i]); } @@ -1920,11 +1920,10 @@ StructInfo InferStructInfoTile(const Call& call, const BlockBuilder& ctx) { if (i < l_delta) { out_shape.push_back(data_shape->values[i - ndim_delta]); } else if (i < ndim_delta) { - out_shape.push_back(IntImm(DataType::Int(64), attrs->repeats[i - l_delta])); + out_shape.push_back(IntImm::Int64(attrs->repeats[i - l_delta])); } else { - out_shape.push_back( - analyzer->Simplify(data_shape->values[i - ndim_delta] * - IntImm(DataType::Int(64), attrs->repeats[i - l_delta]))); + out_shape.push_back(analyzer->Simplify(data_shape->values[i - ndim_delta] * + IntImm::Int64(attrs->repeats[i - l_delta]))); } } diff --git a/src/relax/op/tensor/set.cc b/src/relax/op/tensor/set.cc index a2743ab574c6..edf2a385b429 100644 --- a/src/relax/op/tensor/set.cc +++ b/src/relax/op/tensor/set.cc @@ -93,7 +93,7 @@ StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { // unique values if (data_sinfo->ndim == 0) { - output_sinfo.push_back(TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), + output_sinfo.push_back(TensorStructInfo(ShapeExpr({IntImm::Int64(/*value=*/1)}), data_sinfo->dtype, data_sinfo->vdevice)); } else if (axis.defined()) { output_sinfo.push_back( @@ -107,8 +107,8 @@ StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { if (f_convert_to_int64(return_index->value)) { TensorStructInfo index_sinfo{nullptr}; if (data_sinfo->ndim == 0) { - index_sinfo = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), - DataType::Int(64), data_sinfo->vdevice); + index_sinfo = TensorStructInfo(ShapeExpr({IntImm::Int64(/*value=*/1)}), DataType::Int(64), + data_sinfo->vdevice); } else { index_sinfo = TensorStructInfo(DataType::Int(64), /*ndim=*/1, data_sinfo->vdevice); } @@ -119,8 +119,8 @@ StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { if (f_convert_to_int64(return_inverse->value)) { TensorStructInfo inverse_sinfo{nullptr}; if (data_sinfo->ndim == 0) { - inverse_sinfo = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), - DataType::Int(64), data_sinfo->vdevice); + inverse_sinfo = TensorStructInfo(ShapeExpr({IntImm::Int64(/*value=*/1)}), DataType::Int(64), + data_sinfo->vdevice); } else { inverse_sinfo = TensorStructInfo(DataType::Int(64), /*ndim=*/1, data_sinfo->vdevice); } @@ -131,8 +131,8 @@ StructInfo InferStructInfoUnique(const Call& call, const BlockBuilder& ctx) { if (f_convert_to_int64(return_counts->value)) { TensorStructInfo counts_sinfo{nullptr}; if (data_sinfo->ndim == 0) { - counts_sinfo = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(64), /*value=*/1)}), - DataType::Int(64), data_sinfo->vdevice); + counts_sinfo = TensorStructInfo(ShapeExpr({IntImm::Int64(/*value=*/1)}), DataType::Int(64), + data_sinfo->vdevice); } else { counts_sinfo = TensorStructInfo(DataType::Int(64), /*ndim=*/1, data_sinfo->vdevice); } diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index d6f3a15005f3..1da75b71309b 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -68,9 +68,8 @@ StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx) const auto* data_shape = data_sinfo->shape.as(); if (data_shape == nullptr) { if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) { - return TensorStructInfo( - ShapeExpr(ffi::Array(out_ndim, IntImm(DataType::Int(64), /*value=*/1))), - data_sinfo->dtype, data_sinfo->vdevice); + return TensorStructInfo(ShapeExpr(ffi::Array(out_ndim, IntImm::Int64(/*value=*/1))), + data_sinfo->dtype, data_sinfo->vdevice); } else { return out_ndim == 0 ? TensorStructInfo(ShapeExpr(ffi::Array()), data_sinfo->dtype, data_sinfo->vdevice) @@ -84,7 +83,7 @@ StructInfo InferStructInfoStatistical(const Call& call, const BlockBuilder& ctx) if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == axes.end()) { out_shape.push_back(data_shape->values[i]); } else if (attrs->keepdims) { - out_shape.push_back(IntImm(DataType::Int(64), /*value=*/1)); + out_shape.push_back(IntImm::Int64(/*value=*/1)); } } TVM_FFI_ICHECK_EQ(static_cast(out_shape.size()), out_ndim); @@ -211,9 +210,8 @@ StructInfo InferStructInfoStatisticalExtension(const Call& call, const BlockBuil const auto* data_shape = data_sinfo->shape.as(); if (data_shape == nullptr) { if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) { - return TensorStructInfo( - ShapeExpr(ffi::Array(out_ndim, IntImm(DataType::Int(64), /*value=*/1))), - data_sinfo->dtype, data_sinfo->vdevice); + return TensorStructInfo(ShapeExpr(ffi::Array(out_ndim, IntImm::Int64(/*value=*/1))), + data_sinfo->dtype, data_sinfo->vdevice); } if (out_ndim == 0) { return TensorStructInfo(ShapeExpr(ffi::Array()), data_sinfo->dtype, @@ -229,7 +227,7 @@ StructInfo InferStructInfoStatisticalExtension(const Call& call, const BlockBuil if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == axes.end()) { out_shape.push_back(data_shape->values[i]); } else if (attrs->keepdims) { - out_shape.push_back(IntImm(DataType::Int(64), /*value=*/1)); + out_shape.push_back(IntImm::Int64(/*value=*/1)); } } TVM_FFI_ICHECK_EQ(static_cast(out_shape.size()), out_ndim); diff --git a/src/relax/op/vision/multibox_transform_loc.cc b/src/relax/op/vision/multibox_transform_loc.cc index fd2f467671a2..e87c2d439caf 100644 --- a/src/relax/op/vision/multibox_transform_loc.cc +++ b/src/relax/op/vision/multibox_transform_loc.cc @@ -177,7 +177,7 @@ StructInfo InferStructInfoMultiboxTransformLoc(const Call& call, const BlockBuil } } - ffi::Array boxes_shape = {batch, num_anchors, IntImm(DataType::Int(32), 4)}; + ffi::Array boxes_shape = {batch, num_anchors, IntImm::Int32(4)}; ffi::Array scores_shape = {batch, num_classes, num_anchors}; ffi::Array fields = { TensorStructInfo(ShapeExpr(boxes_shape), cls_sinfo->dtype, vdev), diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc index a88e08c99258..88139c62fc5b 100644 --- a/src/relax/op/vision/nms.cc +++ b/src/relax/op/vision/nms.cc @@ -331,8 +331,7 @@ StructInfo InferStructInfoNMS(const Call& call, const BlockBuilder& ctx) { tvm::ffi::Array fields = { TensorStructInfo(ffi::GetRef(data_shape), data_sinfo->dtype, vdev), TensorStructInfo(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev), - TensorStructInfo(ShapeExpr({batch, IntImm(DataType::Int(64), 1)}), DataType::Int(32), - vdev)}; + TensorStructInfo(ShapeExpr({batch, IntImm::Int64(1)}), DataType::Int(32), vdev)}; return TupleStructInfo(fields); } @@ -346,8 +345,7 @@ StructInfo InferStructInfoNMS(const Call& call, const BlockBuilder& ctx) { auto num_anchors = data_shape->values[1]; tvm::ffi::Array fields = { TensorStructInfo(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev), - TensorStructInfo(ShapeExpr({batch, IntImm(DataType::Int(64), 1)}), DataType::Int(32), - vdev)}; + TensorStructInfo(ShapeExpr({batch, IntImm::Int64(1)}), DataType::Int(32), vdev)}; return TupleStructInfo(fields); } diff --git a/src/relax/op/vision/roi_align.cc b/src/relax/op/vision/roi_align.cc index aca4c17daeac..131c634bc46a 100644 --- a/src/relax/op/vision/roi_align.cc +++ b/src/relax/op/vision/roi_align.cc @@ -119,12 +119,11 @@ StructInfo InferStructInfoROIAlign(const Call& call, const BlockBuilder& ctx) { ffi::Array data_shape = data_sinfo->shape.as()->values; ffi::Array out_shape; if (attrs->layout == "NCHW") { - out_shape = {rois_shape->values[0], data_shape[1], - IntImm(DataType::Int(32), attrs->pooled_size[0]), - IntImm(DataType::Int(32), attrs->pooled_size[1])}; + out_shape = {rois_shape->values[0], data_shape[1], IntImm::Int32(attrs->pooled_size[0]), + IntImm::Int32(attrs->pooled_size[1])}; } else { - out_shape = {rois_shape->values[0], IntImm(DataType::Int(32), attrs->pooled_size[0]), - IntImm(DataType::Int(32), attrs->pooled_size[1]), data_shape[3]}; + out_shape = {rois_shape->values[0], IntImm::Int32(attrs->pooled_size[0]), + IntImm::Int32(attrs->pooled_size[1]), data_shape[3]}; } return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } diff --git a/src/relax/op/vision/roi_pool.cc b/src/relax/op/vision/roi_pool.cc index f0315289322a..ae8fc5d57bbb 100644 --- a/src/relax/op/vision/roi_pool.cc +++ b/src/relax/op/vision/roi_pool.cc @@ -111,8 +111,8 @@ StructInfo InferStructInfoROIPool(const Call& call, const BlockBuilder& ctx) { ffi::Array data_shape = data_sinfo->shape.as()->values; ffi::Array out_shape = {rois_shape->values[0], data_shape[1], - IntImm(DataType::Int(32), attrs->pooled_size[0]), - IntImm(DataType::Int(32), attrs->pooled_size[1])}; + IntImm::Int32(attrs->pooled_size[0]), + IntImm::Int32(attrs->pooled_size[1])}; return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); } diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index e97e423e9b78..5d13dc6773d4 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -49,7 +49,7 @@ ffi::Array GetBatchPrefix(const ffi::Array& shape) { } PrimExpr ProductDims(const ffi::Array& dims) { - PrimExpr product = IntImm(DataType::Int(64), 1); + PrimExpr product = IntImm::Int64(1); for (const auto& dim : dims) product = product * dim; return product; } @@ -95,7 +95,7 @@ std::tuple)>> auto pat = pat_matmul_on_lhs | pat_matmul_on_rhs | pat_permuted_matmul_on_lhs | pat_permuted_matmul_on_rhs; - PrimExpr symbolic_var_constraints = tirx::const_true(); + PrimExpr symbolic_var_constraints = IntImm::Bool(true); auto upper_bounds = func->GetAttr>("tir_var_upper_bound"); auto lower_bounds = func->GetAttr>("tir_var_lower_bound"); diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc index 6bbe86d148f9..13dd506a3fb5 100644 --- a/src/relax/transform/allocate_workspace.cc +++ b/src/relax/transform/allocate_workspace.cc @@ -61,8 +61,8 @@ class ExternFunctionRewriter : ExprMutator { // Append the workspace parameter to this function. ffi::Array new_params = func_node->params; - auto sinfo = TensorStructInfo(ShapeExpr({IntImm(DataType::Int(32), max_workspace_size_)}), - DataType::UInt(8)); + auto sinfo = + TensorStructInfo(ShapeExpr({IntImm::Int32(max_workspace_size_)}), DataType::UInt(8)); Var workspace_param(name_sup_->FreshName("workspace"), sinfo); if (func_node->GetAttr(attr::kCodegen)) { @@ -149,7 +149,7 @@ class WorkspaceProvider : ExprMutator { BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final { builder_->BeginDataflowBlock(); if (!workspace_var_main_.defined()) { - auto shape = ShapeExpr({IntImm(DataType::Int(32), max_workspace_size_)}); + auto shape = ShapeExpr({IntImm::Int32(max_workspace_size_)}); auto ty = DataTypeImm(DataType::UInt(8)); auto workspace = MakeAllocTensor(shape, ty, PrimValue::Int64(0)); workspace_var_main_ = builder_->Emit(workspace, "workspace_main"); diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index 16e492a80d0a..fa9db81f3aca 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -45,7 +45,7 @@ static constexpr const char* kOperatorName = "operator_name"; /*! \brief Construct ranges from shape dimensions */ static ffi::Array ConstructRangeFromShape(const ffi::Array& shape) { - return shape.Map([](const PrimExpr& dim) { return Range(tirx::make_zero(dim.dtype()), dim); }); + return shape.Map([](const PrimExpr& dim) { return Range(IntImm(dim.dtype(), 0), dim); }); } static ffi::Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index d55dacc0ff26..8e2591c0dea6 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -234,7 +234,7 @@ ffi::TypedFunction(ffi::Map, ffi::Mapvalue; - sections.push_back(IntImm(DataType::Int(64), split_index)); + sections.push_back(IntImm::Int64(split_index)); } int lhs_dim = GetTensorSInfo(lhs)->ndim; diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index c3ed7ef0b609..271eda0d499c 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -368,7 +368,7 @@ class AliasAnalyzer { // given a shape, return the number of elements corresponding to it (product of elements) PrimExpr NumElements(const ShapeExpr& shape) { - PrimExpr ret = IntImm(DataType::Int(64), 1); + PrimExpr ret = IntImm::Int64(1); for (auto dim : shape->values) { ret *= dim; } @@ -1063,7 +1063,7 @@ ffi::Array DataflowAliasAnalysis(const DataflowBlock& block, } elem_aliases.push_back(dim_aliases); } - new_tuple_map.Set(IntImm(DataType::Int(32), kv.first), elem_aliases); + new_tuple_map.Set(IntImm::Int32(kv.first), elem_aliases); } return {new_alias_sets, new_tuple_map}; } diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index ea2342589941..75b1b09bd48b 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -376,7 +376,7 @@ class ConstantFolder : public ExprMutator { int64_t num_elems = ndarray->shape[0]; ffi::Array shape_values; for (int64_t i = 0; i < num_elems; i++) { - shape_values.push_back(IntImm(DataType::Int(64), data[i])); + shape_values.push_back(IntImm::Int64(data[i])); } return ShapeExpr(shape_values); } diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 9e4c11ee707a..2b9320dcfd29 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -155,7 +155,7 @@ class SymbolicMatcher : ExprFunctor* var_remap_; - PrimExpr must_prove_ = const_true(); + PrimExpr must_prove_ = IntImm::Bool(true); }; /*! @@ -1021,7 +1021,7 @@ class FusedTIRConstructor : public ExprVisitor { body = subst.Substitute(body); body = tirx::SBlock({}, {}, {}, "root", std::move(body), std::nullopt, alloc_buffers); - body = tirx::SBlockRealize({}, IntImm(DataType::Bool(), 1), Downcast(body)); + body = tirx::SBlockRealize({}, IntImm::Bool(true), Downcast(body)); tirx::PrimFunc func(func_info_.params, body, VoidType(), func_info_.buffer_map, DictAttrs(attr_map)); // Renew function defs to prevent using the same symbolic vars in different functions diff --git a/src/relax/transform/infer_amp_utils.cc b/src/relax/transform/infer_amp_utils.cc index 94fe226146fc..01bd47d96073 100644 --- a/src/relax/transform/infer_amp_utils.cc +++ b/src/relax/transform/infer_amp_utils.cc @@ -54,11 +54,11 @@ NType NTypeMerge(const NType& a, const NType& b) { } ffi::Array InferMixedPrecisionFollow(const Call& call, const DataType& out_dtype) { - return {IntImm(DataType::Int(32), MixedPrecisionPolicyKind::kFollow), call}; + return {IntImm::Int32(MixedPrecisionPolicyKind::kFollow), call}; } ffi::Array InferMixedPrecisionNever(const Call& call, const DataType& out_dtype) { - return {IntImm(DataType::Int(32), MixedPrecisionPolicyKind::kNever), call}; + return {IntImm::Int32(MixedPrecisionPolicyKind::kNever), call}; } } // namespace relax diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index c2939f5a2a46..fb3b014b03df 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -100,13 +100,12 @@ class LazyInputMutator : public ExprMutator { if (plan_) { Var var = ffi::GetRef(op); if (auto it = plan_->param_lookup.find(var); it != plan_->param_lookup.end()) { - auto untyped = - builder_->Emit(relax::Call(plan_->fget_param, - { - PrimValue(IntImm(DataType::Int(64), it->second)), - StringImm(var->name_hint()), - }), - var->name_hint() + "_untyped"); + auto untyped = builder_->Emit(relax::Call(plan_->fget_param, + { + PrimValue(IntImm::Int64(it->second)), + StringImm(var->name_hint()), + }), + var->name_hint() + "_untyped"); return builder_->EmitMatchCast(untyped, GetStructInfo(var), var->name_hint()); } } @@ -173,8 +172,7 @@ class LazyOutputMutator : public ExprMutator { BindingBlock end_of_func = [&]() { ffi::Array propagated_params; for (const auto& [output_index, expr] : inline_outputs) { - Call fset_output_call(fset_output, - {PrimValue(IntImm(DataType::Int(64), output_index)), expr}); + Call fset_output_call(fset_output, {PrimValue(IntImm::Int64(output_index)), expr}); Var void_output("_void", TupleStructInfo(ffi::Array{})); propagated_params.push_back(VarBinding(void_output, fset_output_call)); } @@ -215,8 +213,7 @@ class LazyOutputMutator : public ExprMutator { if (plan_.has_value()) { if (auto it = plan_->output_lookup.find(var); it != plan_->output_lookup.end()) { for (auto output_index : it->second) { - callback( - Call(plan_->fset_output, {PrimValue(IntImm(DataType::Int(64), output_index)), var})); + callback(Call(plan_->fset_output, {PrimValue(IntImm::Int64(output_index)), var})); } } } diff --git a/src/relax/transform/lower_alloc_tensor.cc b/src/relax/transform/lower_alloc_tensor.cc index db5c0b24870d..793dbd3f3f43 100644 --- a/src/relax/transform/lower_alloc_tensor.cc +++ b/src/relax/transform/lower_alloc_tensor.cc @@ -72,7 +72,7 @@ class Mutator : public ExprMutator { }(); PrimExpr nbytes = [&]() -> PrimExpr { - PrimExpr nbytes = tirx::make_const(DataType::Int(64), dtype->value.bytes()); + PrimExpr nbytes = IntImm::Int64(dtype->value.bytes()); for (const auto& dim : shape) { nbytes *= dim; } @@ -89,7 +89,7 @@ class Mutator : public ExprMutator { if (vdevice.defined()) { std::string dev_kind = vdevice.value()->target->kind->name; - PrimExpr dev_size = tirx::make_const(DataType::Int(64), 1); + PrimExpr dev_size = IntImm::Int64(1); if (vdevice.value()->memory_scope != "global") { auto device_size_handler = tvm::ffi::Function::GetGlobal(std::string("DeviceGetMemSize.") + dev_kind); diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index 9cfb41d13e93..a7ec2a587615 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -785,10 +785,10 @@ class CUDAGraphRewriter : public ExprMutator { TVM_FFI_ICHECK(plan->inputs.empty()); auto gv_alloc = gv_global_alloc_.value(); auto ret_struct_info = Downcast(gv_alloc->struct_info_.value())->ret; - launch_subgraph = Call( - call_builtin_with_ctx_op, - {builtin_get_cached_alloc, Tuple({gv_alloc, PrimValue(IntImm(DataType::Int(64), 0))})}, - Attrs(), {ret_struct_info}); + launch_subgraph = + Call(call_builtin_with_ctx_op, + {builtin_get_cached_alloc, Tuple({gv_alloc, PrimValue(IntImm::Int64(0))})}, Attrs(), + {ret_struct_info}); } else { auto gv_func = builder_->AddFunction( plan->func, current_func_.value()->name_hint + "_cuda_graph_capture"); @@ -816,7 +816,7 @@ class CUDAGraphRewriter : public ExprMutator { } // Arguments of builtin_run_or_capture ffi::Array tuple_arg_fields{gv_func, Tuple(args), - PrimValue(IntImm(DataType::Int(64), index_capture_++))}; + PrimValue(IntImm::Int64(index_capture_++))}; if (plan->propogated_tir_vars.defined()) { // The shape expr is explicitly passed twice, one as the last argument of the lifted // function, one as the last argument of builtin_run_or_capture as the cache key. Explicitly diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc index da25f5f10d64..c4103429c30b 100644 --- a/src/relax/transform/split_layout_rewrite_preproc.cc +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -80,7 +80,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { : SeqStmt(layout_rewrite_preproc_stmts_); body = SBlockRealize( /*iter_values=*/ffi::Array(), - /*predicate=*/const_true(), + /*predicate=*/IntImm::Bool(true), /*block=*/ SBlock(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"root", body)); @@ -124,7 +124,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { body = SBlockRealize( /*iter_values=*/ffi::Array(), - /*predicate=*/const_true(), + /*predicate=*/IntImm::Bool(true), /*block=*/ SBlock(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"root", body, diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index ccce50274311..b69dce0155f0 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -139,7 +139,7 @@ class StorageToken : public ffi::ObjectRef { ffi::Optional vdevice = std::nullopt) { // Compute the tensor size from the shape. int64_t const_coeff = dtype.bytes() * dtype.lanes(); - PrimExpr size = tirx::make_const(DataType::Int(64), 1); + PrimExpr size = IntImm::Int64(1); bool size_computed = false; if (vdevice.defined()) { @@ -173,7 +173,7 @@ class StorageToken : public ffi::ObjectRef { } } - size = tirx::make_const(DataType::Int(64), const_coeff) * size; + size = IntImm::Int64(const_coeff) * size; ffi::ObjectPtr n = ffi::make_object(); n->bytes = size; @@ -259,7 +259,7 @@ class TokenAllocatorMixed { TVM_FFI_ICHECK_GE(available_size, 0); TVM_FFI_ICHECK_GE(size, available_size); // Enlarge the token size. - available_token->bytes = tirx::make_const(DataType::Int(64), size); + available_token->bytes = IntImm::Int64(size); available_token->ref_counter = prototype->ref_counter; pool.erase(mid); return available_token; @@ -447,8 +447,8 @@ void SetTIRVarRangeConstraints(Function func, arith::AnalyzerObj* ana, if (it_upper != var_upper_bound_attr.end()) { int64_t lower = (it_lower != var_lower_bound_attr.end()) ? it_lower->second->value : 0; int64_t upper = it_upper->second->value; - tvm::Range range = tvm::Range::FromMinExtent( - tvm::IntImm(DataType::Int(64), lower), tvm::IntImm(DataType::Int(64), upper - lower + 1)); + tvm::Range range = tvm::Range::FromMinExtent(tvm::IntImm::Int64(lower), + tvm::IntImm::Int64(upper - lower + 1)); ana->Bind(tir_var, range); dom_map->Set(tir_var, arith::IntSet::FromRange(range)); } else if (it_lower != var_lower_bound_attr.end() && it_lower->second->value >= 0) { @@ -483,7 +483,7 @@ ffi::Array GetUpperBoundShape(ffi::Array shape, arith::Analy upper_bounded_shape.push_back(dim_len); } } else { - upper_bounded_shape.push_back(tvm::IntImm(DataType::Int(64), max_bound)); + upper_bounded_shape.push_back(tvm::IntImm::Int64(max_bound)); } } return upper_bounded_shape; @@ -900,7 +900,7 @@ class StorageAllocationRewriter : public ExprMutator { } constexpr static const char* plan_dyn_attr_ = "relax.memory_plan_dynamic_func_output"; plan_dynamic_output_ = static_cast( - func_->GetAttr(plan_dyn_attr_).value_or(IntImm(DataType::Int(32), 0))->value); + func_->GetAttr(plan_dyn_attr_).value_or(IntImm::Int32(0))->value); if (plan_dynamic_output_) { SetTIRVarRangeConstraints(ffi::GetRef(func_), ana_.get(), &dom_map_); } @@ -1058,7 +1058,7 @@ PrimExpr GetTextureMemorySizeFromVDevice(ffi::Array pshape, DataType d size_t size = runtime::GetTextureMemorySize(shape, dtype.bytes() * 8, dtype.lanes(), vdevice->memory_scope, image_row_align); - return tirx::make_const(DataType::Int(64), size); + return IntImm::Int64(size); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/s_tir/analysis/identify_memcpy.cc b/src/s_tir/analysis/identify_memcpy.cc index e008f7e7ebc3..8a77c5b3eb89 100644 --- a/src/s_tir/analysis/identify_memcpy.cc +++ b/src/s_tir/analysis/identify_memcpy.cc @@ -107,7 +107,7 @@ std::variant IdentifyMemCpyImpl(const For& loop, // B[i] = A[T.abs(i-8)] arith::Analyzer analyzer_ref = ffi::GetRef(analyzer); - auto src_iter_map = arith::DetectIterMap({src_index}, loop_ranges, const_true(), + auto src_iter_map = arith::DetectIterMap({src_index}, loop_ranges, IntImm::Bool(true), arith::IterMapLevel::Bijective, analyzer_ref); if (src_iter_map->errors.size()) { return static_cast(std::stringstream() @@ -117,7 +117,7 @@ std::variant IdentifyMemCpyImpl(const For& loop, << " for src_index = " << src_index) .str(); } - auto dst_iter_map = arith::DetectIterMap({dst_index}, loop_ranges, const_true(), + auto dst_iter_map = arith::DetectIterMap({dst_index}, loop_ranges, IntImm::Bool(true), arith::IterMapLevel::Bijective, analyzer_ref); if (dst_iter_map->errors.size()) { return static_cast(std::stringstream() diff --git a/src/s_tir/analysis/sblock_access_region_detector.cc b/src/s_tir/analysis/sblock_access_region_detector.cc index 0eddf22d8506..18eef8e2fe01 100644 --- a/src/s_tir/analysis/sblock_access_region_detector.cc +++ b/src/s_tir/analysis/sblock_access_region_detector.cc @@ -348,7 +348,7 @@ ffi::Array BlockReadWriteDetector::CollectRegions( const tvm::arith::IntSet& range = regions[i][j]; if (range.CanProveSinglePoint(ana_)) { PrimExpr min = range.min(); - region.push_back(Range::FromMinExtent(min, make_const(min.dtype(), 1))); + region.push_back(Range::FromMinExtent(min, MakeConst(min.dtype(), 1))); } else { region.push_back(range.CoverRange(Range::FromMinExtent(0, buffers[i]->shape[j]))); } diff --git a/src/s_tir/backend/adreno/inject_texture_alloc.cc b/src/s_tir/backend/adreno/inject_texture_alloc.cc index 9b2b627dd49a..709e7c3336c9 100644 --- a/src/s_tir/backend/adreno/inject_texture_alloc.cc +++ b/src/s_tir/backend/adreno/inject_texture_alloc.cc @@ -78,10 +78,10 @@ class TextureAllocInjector : public arith::IRMutatorWithAnalyzer { auto texture = ApplyTexture2DFlattening(extents, extents.size(), axis); ffi::Array args; args.push_back(StringImm(storage_scope)); - args.push_back(IntImm(DataType::Int(64), 3)); + args.push_back(IntImm::Int64(3)); args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), {texture.width, texture.height, texture.depth})); - args.push_back(IntImm(DataType::Int(64), channel_size)); + args.push_back(IntImm::Int64(channel_size)); stmt = Bind(op->buffer->data, Call(op->buffer->data.dtype(), builtin::nd_mem_alloc_with_scope(), args)); } diff --git a/src/s_tir/backend/adreno/texture_flatten.cc b/src/s_tir/backend/adreno/texture_flatten.cc index 91cdc0b6e4bf..0dd939ad817a 100644 --- a/src/s_tir/backend/adreno/texture_flatten.cc +++ b/src/s_tir/backend/adreno/texture_flatten.cc @@ -58,7 +58,7 @@ class TextureLoweringBase : public StmtExprMutator { inline PrimExpr SimplifyOffset(const ffi::Array& shape, const ffi::Array& index) const { - PrimExpr base = make_const(DataType::Int(32), 0); + PrimExpr base = IntImm::Int32(0); TVM_FFI_ICHECK_EQ(shape.size(), index.size()); if (index.size() > 0) { PrimExpr offset = index[0]; diff --git a/src/s_tir/meta_schedule/database/json_database.cc b/src/s_tir/meta_schedule/database/json_database.cc index cc6ee009b471..94f02867da92 100644 --- a/src/s_tir/meta_schedule/database/json_database.cc +++ b/src/s_tir/meta_schedule/database/json_database.cc @@ -118,7 +118,7 @@ class JSONDatabaseNode : public DatabaseNode { JSONFileAppendLine( this->path_tuning_record, JSONDumps(ffi::Array{ - /*workload_index=*/IntImm(DataType::Int(32), this->workloads2idx_.at(record->workload)), + /*workload_index=*/IntImm::Int32(this->workloads2idx_.at(record->workload)), /*tuning_record=*/record->AsJSON() // })); } diff --git a/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc b/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc index b567ffa4eb1f..27a31def93b3 100644 --- a/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc @@ -273,12 +273,12 @@ Pass SimplifyForFeatureExtraction() { HasBufferLoad(node->condition)) { return ffi::GetRef