diff --git a/CMakeLists.txt b/CMakeLists.txt index edb74b9e24f2..0950db7b0ba4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -271,7 +271,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/ir/*.cc src/arith/*.cc src/te/*.cc - src/tir/*.cc + src/tirx/*.cc src/s_tir/*.cc src/topi/*.cc src/driver/*.cc diff --git a/docs/README.md b/docs/README.md index f708afd76861..7c59ed6bda69 100644 --- a/docs/README.md +++ b/docs/README.md @@ -15,6 +15,22 @@ + + + + + + + + + + + + + + + + # TVM Documentation This folder contains the source of TVM's documentation, hosted at https://tvm.apache.org/docs diff --git a/docs/arch/index.rst b/docs/arch/index.rst index 9677f994ce97..c3b296266f9b 100644 --- a/docs/arch/index.rst +++ b/docs/arch/index.rst @@ -64,10 +64,10 @@ contains a collection of functions. Currently, we support two primary variants o - **relax::Function** is a high-level functional program representation. A relax.Function represents high-level graph structure, usually corresponds to an end-to-end model or a sub-graph of the overall model. You can view a relax.Function as a computational graph with additional support for control-flow, and complex data structures. -- **tir::PrimFunc** is a low-level program representation that contains elements including loop-nest choices, multi-dimensional load/store, +- **tirx::PrimFunc** is a low-level program representation that contains elements including loop-nest choices, multi-dimensional load/store, threading, and vector/tensor instructions. It is usually used to represent an operator program that executes a (possibly-fused) layer in a model. -During the compilation and transformation, all relax operators are lowered to ``tir::PrimFunc`` or ``TVM PackedFunc``, which can be executed directly +During the compilation and transformation, all relax operators are lowered to ``tirx::PrimFunc`` or ``TVM PackedFunc``, which can be executed directly on the target device, while the calls to relax operators are lowered to calls to low-level functions (e.g. ``R.call_tir`` or ``R.call_dps``). Transformations @@ -83,15 +83,14 @@ relax transformations relax transformations contain a collection of passes that apply to relax functions. The optimizations include common graph-level optimizations such as constant folding and dead-code elimination for operators, and backend-specific optimizations such as library dispatch. -tir transformations -^^^^^^^^^^^^^^^^^^^ -tir transformations contain a collection of passes that apply to tir functions. There are two major types of transformations: +tirx transformations +^^^^^^^^^^^^^^^^^^^^ - **TensorIR schedule**: TensorIR schedules are designed to optimize the TensorIR functions for a specific target, with user-guided instructions and control how the target code is generated. - For CPU targets, TIR PrimFunc can generate valid code and execute on the target device without schedule but with very-low performance. However, for GPU targets, the schedule is essential - for generating valid code with thread bindings. For more details, please refer to the :ref:`TensorIR Transformation ` section. Additionally, we provides ``MetaSchedule`` to + For CPU targets, tirx PrimFunc can generate valid code and execute on the target device without schedule but with very-low performance. However, for GPU targets, the schedule is essential + for generating valid code with thread bindings. For more details, please refer to the :ref:`TensorIR Transformation ` section. Additionally, we provides ``MetaSchedule`` to automate the search of TensorIR schedule. -- **Lowering Passes**: These passes usually perform after the schedule is applied, transforming a TIR PrimFunc into another functionally equivalent PrimFunc, but closer to the +- **Lowering Passes**: These passes usually perform after the schedule is applied, transforming a tirx PrimFunc into another functionally equivalent PrimFunc, but closer to the target-specific representation. For example, there are passes to flatten multi-dimensional access to one-dimensional pointer access, to expand the intrinsics into target-specific ones, and to decorate the function entry to meet the runtime calling convention. @@ -102,12 +101,12 @@ focus on optimizations that are not covered by them. cross-level transformations ^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Apache TVM enables cross-level optimization of end-to-end models. As the IRModule includes both relax and tir functions, the cross-level transformations are designed to mutate +Apache TVM enables cross-level optimization of end-to-end models. As the IRModule includes both relax and tirx functions, the cross-level transformations are designed to mutate the IRModule by applying different transformations to these two types of functions. -For example, ``relax.LegalizeOps`` pass mutates the IRModule by lowering relax operators, adding corresponding TIR PrimFunc into the IRModule, and replacing the relax operators -with calls to the lowered TIR PrimFunc. Another example is operator fusion pipeline in relax (including ``relax.FuseOps`` and ``relax.FuseTIR``), which fuses multiple consecutive tensor operations -into one. Different from the previous implementations, relax fusion pipeline analyzes the pattern of TIR functions and detects the best fusion rules automatically rather +For example, ``relax.LegalizeOps`` pass mutates the IRModule by lowering relax operators, adding corresponding tirx PrimFunc into the IRModule, and replacing the relax operators +with calls to the lowered tirx PrimFunc. Another example is operator fusion pipeline in relax (including ``relax.FuseOps`` and ``relax.FuseTIR``), which fuses multiple consecutive tensor operations +into one. Different from the previous implementations, relax fusion pipeline analyzes the pattern of tirx functions and detects the best fusion rules automatically rather than human-defined operator fusion patterns. Target Translation @@ -172,12 +171,12 @@ Summary and Discussions In summary, the key data structures in the compilation flows are: -- IRModule: contains relax.Function and tir.PrimFunc +- IRModule: contains relax.Function and tirx.PrimFunc - runtime.Module: contains runtime.PackedFunc Most parts of the compilation are transformations among the key data structures. -- relax/transform and tir/transform are deterministic rule-based transformations +- relax/transform and tirx/transform are deterministic rule-based transformations - meta-schedule contains the search-based transformations Finally, the compilation flow example is only a typical use-case of the TVM stack. @@ -237,9 +236,9 @@ Thanks to the node module, we can directly access any field of the TVM's IRNode .. code-block:: python - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Add(x, x) - # a and b are fields of a tir.Add node + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Add(x, x) + # a and b are fields of a tirx.Add node # we can directly use the field name to access the IR structures assert y.a == x @@ -249,14 +248,14 @@ The ability to save/store, and inspect an IR node provides a foundation for maki tvm/ir ------ The `tvm/ir` folder contains the unified data structure and interfaces across all IR function variants. -The components in `tvm/ir` are shared by `tvm/relax` and `tvm/tir`, notable ones include +The components in `tvm/ir` are shared by `tvm/relax` and `tvm/tirx`, notable ones include - IRModule - Type - PassContext and Pass - Op -Different variants of functions(e.g. relax.Function and tir.PrimFunc) can co-exist in an IRModule. +Different variants of functions(e.g. relax.Function and tirx.PrimFunc) can co-exist in an IRModule. While these variants may not have the same content representation, they use the same data structure to represent types. As a consequence, we use the same data structure to represent function (type) signatures of these variants. The unified type system allows one function variant to call another function @@ -267,8 +266,8 @@ The following code snippet gives an example of PassContext configuration. .. code-block:: python - # configure the behavior of the tir.UnrollLoop pass - with tvm.transform.PassContext(config={"tir.UnrollLoop": { "auto_max_step": 10 }}): + # configure the behavior of the tirx.UnrollLoop pass + with tvm.transform.PassContext(config={"tirx.UnrollLoop": { "auto_max_step": 10 }}): # code affected by the pass context @@ -304,34 +303,34 @@ Relax is the high-level IR used to represent the computational graph of a model. Note that Relax usually works closely with the TensorIR IRModule, most of the transformations are applied on both Relax and TensorIR functions in the IRModule. Please refer to the :ref:`Relax Deep Dive ` for more details. -tvm/tir -------- +tvm/tirx +-------- -TIR contains the definition of the low-level program representations. We use `tir::PrimFunc` to represent functions that can be transformed by TIR passes. -Besides the IR data structures, the tir module also includes: +tirx contains the definition of the low-level program representations. We use `tirx::PrimFunc` to represent functions that can be transformed by tirx passes. +Besides the IR data structures, the tirx module also includes: -- A set of schedule primitives to control the generated code in ``tir/schedule``. -- A set of builtin intrinsics in ``tir/tensor_intrin``. -- A set of analysis passes to analyze the TIR functions in ``tir/analysis``. -- A set of transformation passes to lower or optimize the TIR functions in ``tir/transform``. +- A set of schedule primitives to control the generated code in ``tirx/schedule``. +- A set of builtin intrinsics in ``tirx/tensor_intrin``. +- A set of analysis passes to analyze the tirx functions in ``tirx/analysis``. +- A set of transformation passes to lower or optimize the tirx functions in ``tirx/transform``. Please refer to the :ref:`TensorIR Deep Dive ` for more details. tvm/arith --------- -This module is closely tied to the TIR. One of the key problems in the low-level code generation is the analysis of the indices' +This module is closely tied to tirx. One of the key problems in the low-level code generation is the analysis of the indices' arithmetic properties — the positiveness, variable bound, and the integer set that describes the iterator space. arith module provides -a collection of tools that do (primarily integer) analysis. A TIR pass can use these analyses to simplify and optimize the code. +a collection of tools that do (primarily integer) analysis. A tirx pass can use these analyses to simplify and optimize the code. tvm/te and tvm/topi ------------------- TE stands for Tensor Expression. TE is a domain-specific language (DSL) for describing tensor computations. Importantly, a tensor expression -itself is not a self-contained function that can be stored into IRModule. We can use ``te.create_prim_func`` to convert a tensor expression to a ``tir::PrimFunc`` +itself is not a self-contained function that can be stored into IRModule. We can use ``te.create_prim_func`` to convert a tensor expression to a ``tirx::PrimFunc`` and then integrate it into the IRModule. -While possible to construct operators directly via TIR or tensor expressions (TE) for each use case, it is tedious to do so. +While possible to construct operators directly via tirx or tensor expressions (TE) for each use case, it is tedious to do so. `topi` (Tensor operator inventory) provides a set of pre-defined operators defined by numpy and found in common deep learning workloads. tvm/s_tir/meta_schedule @@ -343,7 +342,7 @@ and can be used to optimize TensorIR schedules. Note that MetaSchedule only work tvm/dlight ---------- -DLight is a set of pre-defined, easy-to-use, and performant TIR schedules. DLight aims: +DLight is a set of pre-defined, easy-to-use, and performant tirx schedules. DLight aims: - Fully support **dynamic shape workloads**. - **Light weight**. DLight schedules provides tuning-free schedule with reasonable performance. diff --git a/docs/arch/pass_infra.rst b/docs/arch/pass_infra.rst index 0d2043a66cd3..047a0f48b396 100644 --- a/docs/arch/pass_infra.rst +++ b/docs/arch/pass_infra.rst @@ -31,7 +31,7 @@ transformation using the analysis result collected during and/or before traversa However, as TVM evolves quickly, the need for a more systematic and efficient way to manage these passes is becoming apparent. In addition, a generic framework that manages the passes across different layers of the TVM stack (e.g. -Relax and tir) paves the way for developers to quickly prototype and plug the +Relax and tirx) paves the way for developers to quickly prototype and plug the implemented passes into the system. This doc describes the design of such an infra that takes the advantage of the @@ -166,7 +166,7 @@ Pass Constructs ^^^^^^^^^^^^^^^ The pass infra is designed in a hierarchical manner, and it could work at -different granularities of Relax/tir programs. A pure virtual class ``PassNode`` is +different granularities of Relax/tirx programs. A pure virtual class ``PassNode`` is introduced to serve as the base of the different optimization passes. This class contains several virtual methods that must be implemented by the subclasses at the level of modules, functions, or sequences of passes. @@ -222,13 +222,13 @@ Function-Level Passes ^^^^^^^^^^^^^^^^^^^^^ Function-level passes are used to implement various intra-function level -optimizations for a given Relax/tir module. It fetches one function at a time from +optimizations for a given Relax/tirx module. It fetches one function at a time from the function list of a module for optimization and yields a rewritten Relax -``Function`` or tir ``PrimFunc``. Most of passes can be classified into this category, such as +``Function`` or tirx ``PrimFunc``. Most of passes can be classified into this category, such as common subexpression elimination and inference simplification in Relax as well as vectorization -and flattening storage in tir, etc. +and flattening storage in tirx, etc. -Note that the scope of passes at this level is either a Relax function or a tir primitive function. +Note that the scope of passes at this level is either a Relax function or a tirx primitive function. Therefore, we cannot add or delete a function through these passes as they are not aware of the global information. @@ -571,9 +571,9 @@ loop unrolling pass .. code:: c++ - TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig); + TVM_REGISTER_PASS_CONFIG_OPTION("tirx.UnrollLoop", UnrollLoopConfig); -Please refer to `src/tir/transform/unroll_loop.cc`_ for more details. +Please refer to `src/tirx/transform/unroll_loop.cc`_ for more details. .. _pass_instrument_py_frontend: @@ -656,7 +656,7 @@ new ``PassInstrument`` are called. .. _python/tvm/ir/instrument.py: https://github.com/apache/tvm/blob/main/python/tvm/ir/instrument.py -.. _src/tir/transform/unroll_loop.cc: https://github.com/apache/tvm/blob/main/src/tir/transform/unroll_loop.cc +.. _src/tirx/transform/unroll_loop.cc: https://github.com/apache/tvm/blob/main/src/tirx/transform/unroll_loop.cc .. _use pass infra: https://github.com/apache/tvm/blob/main/docs/how_to/tutorials/customize_opt.py diff --git a/docs/arch/runtime.rst b/docs/arch/runtime.rst index 25199b57ebc8..9839e559cf3f 100644 --- a/docs/arch/runtime.rst +++ b/docs/arch/runtime.rst @@ -241,7 +241,7 @@ For example, we can access the value field of the IntImmNode. import tvm - x = tvm.tir.IntImm("int32", 1) + x = tvm.tirx.IntImm("int32", 1) # access the value field of IntImmNode print(x.value) diff --git a/docs/arch/runtimes/vulkan.rst b/docs/arch/runtimes/vulkan.rst index 36b1508f5448..e60cc4092487 100644 --- a/docs/arch/runtimes/vulkan.rst +++ b/docs/arch/runtimes/vulkan.rst @@ -254,6 +254,6 @@ string are all false boolean flags. validated with `spvValidate`_. * ``TVM_VULKAN_DEBUG_SHADER_SAVEPATH`` - A path to a directory. If - set to a non-empty string, the Vulkan codegen will save tir, binary + set to a non-empty string, the Vulkan codegen will save tirx, binary SPIR-V, and disassembled SPIR-V shaders to this directory, to be used for debugging purposes. diff --git a/docs/conf.py b/docs/conf.py index 80b82f986f01..902ff6a657e5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -479,8 +479,8 @@ def force_gc(gallery_conf, fname): # Maps the original namespace to list of potential modules # that we can import alias from. tvm_alias_check_map = { - "tvm.te": ["tvm.tir"], - "tvm.tir": ["tvm.ir", "tvm.runtime"], + "tvm.te": ["tvm.tirx"], + "tvm.tirx": ["tvm.ir", "tvm.runtime"], } ## Setup header and other configs @@ -606,7 +606,7 @@ def update_alias_docstring(name, obj, lines): tvm_class_name_rewrite_map = { - "tvm.tir": ["Var", "Call"], + "tvm.tirx": ["Var", "Call"], "tvm.relax": ["Var", "Call"], "tvm.relax.frontend.nn": ["Module"], } @@ -616,7 +616,7 @@ def distinguish_class_name(name: str, lines: list[str]): """Distinguish the docstring of type annotations. In the whole TVM, there are many classes with the same name but in different modules, - e.g. ``tir.Var``, ``relax.Var``. This function is used to distinguish them in the docstring, + e.g. ``tirx.Var``, ``relax.Var``. This function is used to distinguish them in the docstring, by adding the module name as prefix. To be specific, this function will check the current object name, and if it in the specific diff --git a/docs/contribute/pull_request.rst b/docs/contribute/pull_request.rst index f8128ce297a4..7ae6706fd44e 100644 --- a/docs/contribute/pull_request.rst +++ b/docs/contribute/pull_request.rst @@ -207,7 +207,7 @@ each time (e.g. you can test a change in CPU and GPU while retaining incremental python tests/scripts/ci.py cpu --unittest # quickly iterate by running a specific test and skipping the rebuild each time - python tests/scripts/ci.py cpu --skip-build --tests tests/python/tir-transform/test_tir_transform_inject_rolling_buffer.py::test_upscale + python tests/scripts/ci.py cpu --skip-build --tests tests/python/tirx-transform/test_tir_transform_inject_rolling_buffer.py::test_upscale # run the CPU build and drop into a shell in the container python tests/scripts/ci.py cpu --interactive @@ -261,7 +261,7 @@ If you want to run a single test: export PYTHONPATH=python rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc - python -m pytest -v tests/python/tir-transform/test_tir_transform_storage_rewrite.py + python -m pytest -v tests/python/tirx-transform/test_tir_transform_storage_rewrite.py # Additionally if you want to run a single test, for example test_all_elemwise inside a file. python -m pytest -v -k "test_all_elemwise" tests/python/frontend/tflite/test_forward.py diff --git a/docs/deep_dive/relax/tutorials/relax_creation.py b/docs/deep_dive/relax/tutorials/relax_creation.py index a20ad7350442..d178279d4302 100644 --- a/docs/deep_dive/relax/tutorials/relax_creation.py +++ b/docs/deep_dive/relax/tutorials/relax_creation.py @@ -40,7 +40,7 @@ from tvm import relax, topi from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @I.ir_module diff --git a/docs/deep_dive/tensor_ir/abstraction.rst b/docs/deep_dive/tensor_ir/abstraction.rst index 86536b1dea6f..a36e15677eee 100644 --- a/docs/deep_dive/tensor_ir/abstraction.rst +++ b/docs/deep_dive/tensor_ir/abstraction.rst @@ -29,7 +29,7 @@ the compute statements themselves. .. code:: python - from tvm.script import tir as T + from tvm.script import tirx as T @T.prim_func def main( diff --git a/docs/deep_dive/tensor_ir/learning.rst b/docs/deep_dive/tensor_ir/learning.rst index 8a0227443167..229d6d9d69ca 100644 --- a/docs/deep_dive/tensor_ir/learning.rst +++ b/docs/deep_dive/tensor_ir/learning.rst @@ -15,7 +15,7 @@ specific language governing permissions and limitations under the License. -.. _tir-learning: +.. _tirx-learning: Understand TensorIR Abstraction =============================== diff --git a/docs/deep_dive/tensor_ir/tutorials/tir_creation.py b/docs/deep_dive/tensor_ir/tutorials/tir_creation.py index 48b6dcc627b5..973eac4c6d34 100644 --- a/docs/deep_dive/tensor_ir/tutorials/tir_creation.py +++ b/docs/deep_dive/tensor_ir/tutorials/tir_creation.py @@ -23,7 +23,7 @@ ----------------- In this section, we will introduce the methods to write a TensorIR function in Apache TVM. This tutorial presumes familiarity with the fundamental concepts of TensorIR. -If not already acquainted, please refer to :ref:`tir-learning` initially. +If not already acquainted, please refer to :ref:`tirx-learning` initially. .. note:: @@ -49,14 +49,14 @@ # # Standard Format # *************** -# Let's take an example of ``mm_relu`` from :ref:`tir-learning`. Here is the complete +# Let's take an example of ``mm_relu`` from :ref:`tirx-learning`. Here is the complete # format of the ir_module and in TVMScript: import numpy as np import tvm from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T @I.ir_module diff --git a/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py b/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py index 41fc3c43a0f4..4e59c6c1a7f6 100644 --- a/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py +++ b/docs/deep_dive/tensor_ir/tutorials/tir_transformation.py @@ -17,7 +17,7 @@ # ruff: noqa: E402 """ -.. _tir-transform: +.. _tirx-transform: Transformation -------------- @@ -26,7 +26,7 @@ """ ###################################################################### -# In the :ref:`previous section `, we have given an example of how to write +# In the :ref:`previous section `, we have given an example of how to write # ``mm_relu`` using TensorIR. In practice, there can be multiple ways to implement # the same functionality, and each implementation can result in different performance. # @@ -38,7 +38,7 @@ import tvm from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T @I.ir_module @@ -49,7 +49,7 @@ def main( B: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -82,7 +82,7 @@ def main( def evaluate(mod: tvm.IRModule): - lib = tvm.tir.build(mod, target="llvm") + lib = tvm.tirx.build(mod, target="llvm") # check correctness lib(a_nd, b_nd, c_nd) np.testing.assert_allclose(c_nd.numpy(), c_np, rtol=1e-5) diff --git a/docs/get_started/overview.rst b/docs/get_started/overview.rst index 6d775b5de10b..3b7e7767a083 100644 --- a/docs/get_started/overview.rst +++ b/docs/get_started/overview.rst @@ -1,4 +1,20 @@ -.. Licensed to the Apache Software Foundation (ASF) under one +.. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file diff --git a/docs/how_to/tutorials/optimize_llm.py b/docs/how_to/tutorials/optimize_llm.py index f698a37f3c09..c5341266be35 100644 --- a/docs/how_to/tutorials/optimize_llm.py +++ b/docs/how_to/tutorials/optimize_llm.py @@ -62,7 +62,7 @@ from pprint import pprint import tvm -from tvm import relax, te, tir +from tvm import relax, te, tirx from tvm.relax import register_pipeline from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op @@ -281,10 +281,10 @@ def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache): def create_tir_paged_kv_cache( self, - max_batch_size: tir.Var, - max_total_seq_len: tir.Var, - prefill_chunk_size: tir.Var, - page_size: tir.Var, + max_batch_size: tirx.Var, + max_total_seq_len: tirx.Var, + prefill_chunk_size: tirx.Var, + page_size: tirx.Var, ) -> PagedKVCache: return TIRPagedKVCache( attn_kind="mha", diff --git a/docs/reference/api/python/index.rst b/docs/reference/api/python/index.rst index a89938367e7d..dec8167a1ab3 100644 --- a/docs/reference/api/python/index.rst +++ b/docs/reference/api/python/index.rst @@ -51,12 +51,12 @@ Python API .. toctree:: :maxdepth: 1 - :caption: tvm.tir + :caption: tvm.tirx - tir/tir - tir/analysis - tir/stmt_functor - tir/transform + tirx/tirx + tirx/analysis + tirx/stmt_functor + tirx/transform .. toctree:: :maxdepth: 1 diff --git a/docs/reference/api/python/te.rst b/docs/reference/api/python/te.rst index 363dae675d84..6ea96680a13b 100644 --- a/docs/reference/api/python/te.rst +++ b/docs/reference/api/python/te.rst @@ -17,7 +17,7 @@ tvm.te ------ -.. Exclude the ops imported from tir. +.. Exclude the ops imported from tirx. .. automodule:: tvm.te :members: diff --git a/docs/reference/api/python/tir/analysis.rst b/docs/reference/api/python/tirx/analysis.rst similarity index 91% rename from docs/reference/api/python/tir/analysis.rst rename to docs/reference/api/python/tirx/analysis.rst index aa777358bcf2..67ac1e919190 100644 --- a/docs/reference/api/python/tir/analysis.rst +++ b/docs/reference/api/python/tirx/analysis.rst @@ -15,7 +15,7 @@ specific language governing permissions and limitations under the License. -tvm.tir.analysis ----------------- -.. automodule:: tvm.tir.analysis.analysis +tvm.tirx.analysis +================= +.. automodule:: tvm.tirx.analysis.analysis :members: diff --git a/docs/reference/api/python/tir/stmt_functor.rst b/docs/reference/api/python/tirx/stmt_functor.rst similarity index 90% rename from docs/reference/api/python/tir/stmt_functor.rst rename to docs/reference/api/python/tirx/stmt_functor.rst index 3b6c9bb64a89..f0f6a45df038 100644 --- a/docs/reference/api/python/tir/stmt_functor.rst +++ b/docs/reference/api/python/tirx/stmt_functor.rst @@ -15,7 +15,7 @@ specific language governing permissions and limitations under the License. -tvm.tir.stmt_functor --------------------- -.. automodule:: tvm.tir.stmt_functor +tvm.tirx.stmt_functor +--------------------- +.. automodule:: tvm.tirx.stmt_functor :members: diff --git a/docs/reference/api/python/tir/tir.rst b/docs/reference/api/python/tirx/tirx.rst similarity index 95% rename from docs/reference/api/python/tir/tir.rst rename to docs/reference/api/python/tirx/tirx.rst index 14a64d5592d2..cf55ad5588b1 100644 --- a/docs/reference/api/python/tir/tir.rst +++ b/docs/reference/api/python/tirx/tirx.rst @@ -15,9 +15,9 @@ specific language governing permissions and limitations under the License. -tvm.tir -------- -.. automodule:: tvm.tir +tvm.tirx +-------- +.. automodule:: tvm.tirx :members: :imported-members: :exclude-members: PrimExpr, const, StmtSRef, SBlockScope, ScheduleState, Schedule, ScheduleError diff --git a/docs/reference/api/python/tir/transform.rst b/docs/reference/api/python/tirx/transform.rst similarity index 92% rename from docs/reference/api/python/tir/transform.rst rename to docs/reference/api/python/tirx/transform.rst index 29f1bcbbf036..6f476c1b3000 100644 --- a/docs/reference/api/python/tir/transform.rst +++ b/docs/reference/api/python/tirx/transform.rst @@ -16,9 +16,9 @@ under the License. -tvm.tir.transform ------------------ -.. automodule:: tvm.tir.transform +tvm.tirx.transform +------------------ +.. automodule:: tvm.tirx.transform :members: :exclude-members: Attrs :imported-members: diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index b77f2ee5dbac..76881a2cd087 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -51,7 +51,7 @@ namespace arith { // Forward declare Analyzer class Analyzer; -using tir::Var; +using tirx::Var; enum DivMode { /*! \brief Truncated division. */ diff --git a/include/tvm/arith/bound.h b/include/tvm/arith/bound.h index 6cde90b0b8e5..7ae36fb289f0 100644 --- a/include/tvm/arith/bound.h +++ b/include/tvm/arith/bound.h @@ -25,18 +25,18 @@ #include #include -#include -#include +#include +#include #include namespace tvm { namespace arith { -using tir::Region; -using tir::Stmt; -using tir::Var; -using tir::VarNode; +using tirx::Region; +using tirx::Stmt; +using tirx::Var; +using tirx::VarNode; /*! * \brief Deduce the bound of the target variable in a expression, @@ -77,7 +77,7 @@ IntSet DeduceBound(PrimExpr v, PrimExpr cond, * \param consider_stores If stores are considered. * \return The domain that covers all the calls or provides within the given statement. */ -Region DomainTouched(const Stmt& body, const tir::Buffer& buffer, bool consider_loads, +Region DomainTouched(const Stmt& body, const tirx::Buffer& buffer, bool consider_loads, bool consider_stores); } // namespace arith diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index d1e8f9475750..9b8afefb7d53 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -25,16 +25,16 @@ #define TVM_ARITH_INT_SET_H_ #include -#include +#include #include namespace tvm { namespace arith { -using tir::IterVar; -using tir::Var; -using tir::VarNode; +using tirx::IterVar; +using tirx::Var; +using tirx::VarNode; class Analyzer; @@ -199,7 +199,7 @@ IntSet EvalSet(PrimExpr e, const ffi::Map& dom_map); * \param dom_map The domain of each variable. * \return An integer set that can cover all the possible values of e. */ -IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map); +IntSet EvalSet(PrimExpr e, const std::unordered_map& dom_map); /*! * \brief Find an symbolic integer set that contains is union over * all the possible conditional values in dom_map. diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index b8f0ac6d4327..a3be4c422bc3 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -25,8 +25,8 @@ #define TVM_ARITH_INT_SOLVER_H_ #include -#include -#include +#include +#include #include #include @@ -37,9 +37,9 @@ namespace tvm { namespace arith { -using tir::IterVar; -using tir::Var; -using tir::VarNode; +using tirx::IterVar; +using tirx::Var; +using tirx::VarNode; // According to experiments two best simplifications orders were can->rw and rw->can->rw, // but rw->can->rw is better for a couple of cases. diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 223fb3509571..e91c4d0b37fb 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -52,7 +52,7 @@ #include #include #include -#include +#include namespace tvm { namespace arith { diff --git a/include/tvm/arith/pattern.h b/include/tvm/arith/pattern.h index 254c1d0933ec..c476e2f01ab0 100644 --- a/include/tvm/arith/pattern.h +++ b/include/tvm/arith/pattern.h @@ -25,7 +25,7 @@ #define TVM_ARITH_PATTERN_H_ #include -#include +#include namespace tvm { namespace arith { @@ -37,7 +37,7 @@ namespace arith { * \param vars List of variables to be used in detection. * \return [coeff[i]] if it is possible, empty array if it is not. */ -ffi::Array DetectLinearEquation(const PrimExpr& e, const ffi::Array& vars); +ffi::Array DetectLinearEquation(const PrimExpr& e, const ffi::Array& vars); /*! * \brief Detect if expression corresponds to clip bound of the vars @@ -47,7 +47,7 @@ ffi::Array DetectLinearEquation(const PrimExpr& e, const ffi::Array DetectClipBound(const PrimExpr& e, const ffi::Array& vars); +ffi::Array DetectClipBound(const PrimExpr& e, const ffi::Array& vars); } // namespace arith } // namespace tvm diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index c440e6fc9e17..8778ace5cebc 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -83,7 +83,7 @@ enum class LinkageType : int { /*! * \brief Generic attribute names that can be attached to any function. * - * \sa tvm::tir::attr, tvm::relax::attr + * \sa tvm::tirx::attr, tvm::relax::attr */ namespace attr { /*! diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index 092a7a53f103..518109e6805c 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -30,8 +30,8 @@ #include #include #include -#include -#include +#include +#include #include #include @@ -118,7 +118,7 @@ TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Ca * * @R.function * def f(x: R.Tensor[(n, m)]): - * k = tir.Var("k", "int64") + * k = tirx.Var("k", "int64") * v0 = opaque_fn(x) * v1 = match_cast(v0, R.Tensor[(n, k)]) * v2 : R.Tensor[(n + 1, k + 2)] = pad(v1) @@ -158,7 +158,7 @@ TVM_DLL StructInfo DeriveCallRetStructInfo(const FuncStructInfo& finfo, const Ca */ TVM_DLL StructInfo EraseToWellDefined( const StructInfo& info, - std::function(const tir::Var& var)> f_shape_var_map = nullptr, + std::function(const tirx::Var& var)> f_shape_var_map = nullptr, std::function(const Var& var)> f_var_map = nullptr, arith::Analyzer* ana = nullptr); @@ -176,7 +176,7 @@ TVM_DLL StructInfo EraseToWellDefined( * \return the corresponding erased struct info. */ TVM_DLL StructInfo EraseToWellDefined(const StructInfo& info, - ffi::Map shape_var_map, + ffi::Map shape_var_map, ffi::Map var_map, arith::Analyzer* ana = nullptr); /*! @@ -291,7 +291,7 @@ TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, * \param sinfo The struct info object to be analyzed. * \return The list of TIR variables that appear in the input struct info. */ -TVM_DLL ffi::Array TIRVarsInStructInfo(const StructInfo& sinfo); +TVM_DLL ffi::Array TIRVarsInStructInfo(const StructInfo& sinfo); /*! * \brief Get the TIR variables that appear in the input struct info. @@ -305,7 +305,7 @@ TVM_DLL ffi::Array TIRVarsInStructInfo(const StructInfo& sinfo); * deduplicated, each TIR variable will appear at most once, and in * order of occurrence. */ -TVM_DLL ffi::Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo); +TVM_DLL ffi::Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo); /*! \brief Collect expressions whose usage requires them to be non-negative * @@ -326,7 +326,7 @@ TVM_DLL ffi::Array CollectNonNegativeExpressions(const StructInfo& sin * \param expr The relax expression (e.g. a Function) to be analyzed. * \return The list of TIR variables that are defined in the input function. */ -TVM_DLL ffi::Array DefinedSymbolicVars(const Expr& expr); +TVM_DLL ffi::Array DefinedSymbolicVars(const Expr& expr); /*! * \brief Get the TIR variables that are used but not defined in the input function. @@ -334,7 +334,7 @@ TVM_DLL ffi::Array DefinedSymbolicVars(const Expr& expr); * \param expr The relax expression (e.g. a Function) to be analyzed. * \return The list of TIR variables that are used but not defined in the input function. */ -TVM_DLL ffi::Array FreeSymbolicVars(const Expr& expr); +TVM_DLL ffi::Array FreeSymbolicVars(const Expr& expr); //----------------------------------- // General IR analysis //----------------------------------- @@ -525,7 +525,7 @@ TVM_DLL Expr RemoveAllUnused(Expr expr); * \note This analysis applies on TIR function but is primarily used by relax passes. * As a result we place it under the relax namespace. */ -TVM_DLL OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func); +TVM_DLL OpPatternKind AnalyzeOpPatternKind(const tirx::PrimFunc& func); /*! * \brief Check if the given PrimFunc is essentially doing a reshape operation. @@ -540,7 +540,7 @@ TVM_DLL OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func); * cannot be false-positive, since whenever we cannot prove the equality, we return false. This * property guarantees the safety of this function. */ -TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func); +TVM_DLL bool HasReshapePattern(const tirx::PrimFunc& func); /*! * \brief Check if the given expression (likely a function body) contains any impure calls. @@ -594,8 +594,8 @@ TVM_DLL bool WellFormed(ffi::Variant obj, bool check_struct_ * from the object (block or buffer) to it's index map transformation. */ -TVM_DLL ffi::Map> SuggestLayoutTransforms( - const Function& fn, ffi::Array write_buffer_transformations); +TVM_DLL ffi::Map> SuggestLayoutTransforms( + const Function& fn, ffi::Array write_buffer_transformations); /* \brief Collect variables whose value can be computed at compile-time * diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index 21184848e3c7..8a91b3348723 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -25,7 +25,7 @@ #define TVM_RELAX_ATTRS_MANIPULATE_H_ #include -#include +#include namespace tvm { namespace relax { @@ -60,7 +60,7 @@ struct ExpandDimsAttrs : public AttrsNodeReflAdapter { /*! \brief Attributes used in layout_transform operator */ struct LayoutTransformAttrs : public AttrsNodeReflAdapter { - tir::IndexMap index_map; + tirx::IndexMap index_map; // pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This // needs to be revisited in case PrimValue is evolved to represent symbolic expression in future. ffi::Optional pad_value; diff --git a/include/tvm/relax/attrs/sorting.h b/include/tvm/relax/attrs/sorting.h index 0731c6cf4f6d..354b77047272 100644 --- a/include/tvm/relax/attrs/sorting.h +++ b/include/tvm/relax/attrs/sorting.h @@ -25,7 +25,7 @@ #define TVM_RELAX_ATTRS_SORTING_H_ #include -#include +#include namespace tvm { namespace relax { diff --git a/include/tvm/relax/distributed/axis_group_graph.h b/include/tvm/relax/distributed/axis_group_graph.h index 6dc2022d5f19..68a466bfadbc 100644 --- a/include/tvm/relax/distributed/axis_group_graph.h +++ b/include/tvm/relax/distributed/axis_group_graph.h @@ -23,8 +23,8 @@ #include #include #include -#include -#include +#include +#include #include #include @@ -36,7 +36,7 @@ #include namespace tvm { -namespace tir { +namespace tirx { // (var, axis) using TIRVarAxis = std::pair; // (buffer, axis) @@ -198,7 +198,7 @@ class BufferAxisGraphExtractor : public StmtExprVisitor { ffi::Map iter_var_range_; std::string func_name; }; -} // namespace tir +} // namespace tirx } // namespace tvm namespace tvm { @@ -470,7 +470,7 @@ void BuildAxisGraphPermuteDims(const Var& output_var, const Call& call, distributed::AxisGroupGraph* axis_group_graph); void BuildAxisGraphReshape(const Var& output_var, const Call& call, distributed::AxisGroupGraph* axis_group_graph); -void BuildAxisGraphCallTIR(const Var& output_var, const Call& call, const tir::PrimFunc& func, +void BuildAxisGraphCallTIR(const Var& output_var, const Call& call, const tirx::PrimFunc& func, distributed::AxisGroupGraph* axis_group_graph); } // namespace distributed diff --git a/include/tvm/relax/distributed/transform.h b/include/tvm/relax/distributed/transform.h index 99b23331f70a..12a87614cffd 100644 --- a/include/tvm/relax/distributed/transform.h +++ b/include/tvm/relax/distributed/transform.h @@ -28,8 +28,8 @@ #include #include #include -#include -#include +#include +#include namespace tvm { namespace relax { diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index f8cebafa551c..0407e3c604c5 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -28,8 +28,8 @@ #include #include #include -#include -#include +#include +#include #include diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 4bb0c52eb0a7..e633603ff364 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -30,7 +30,7 @@ #include #include #include -#include +#include #include #include diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h index c51a6db5a2a0..0d9658d8cffc 100644 --- a/include/tvm/relax/struct_info.h +++ b/include/tvm/relax/struct_info.h @@ -319,7 +319,7 @@ class FuncStructInfo : public StructInfo { * \param purity The purity of the function (true by default). * \param span The span of the AST. * - * \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from + * \note If the ret contains variables(tirx::Var and relax::Var), they must be deducible from * params. If you are unsure, you can always erase ret to static. */ TVM_DLL FuncStructInfo(ffi::Array params, StructInfo ret, bool purity = true, diff --git a/include/tvm/relax/tir_pattern.h b/include/tvm/relax/tir_pattern.h index 6bd36560a6ac..d962c64ace82 100644 --- a/include/tvm/relax/tir_pattern.h +++ b/include/tvm/relax/tir_pattern.h @@ -26,24 +26,24 @@ #define TVM_RELAX_TIR_PATTERN_H_ #include -#include +#include namespace tvm { namespace relax { -using TIRPattern = tir::PrimFunc; +using TIRPattern = tirx::PrimFunc; /* * \brief The match result of a TIR pattern. */ class MatchResultNode : public Object { public: - /*! The matched tir pattern*/ + /*! The matched tirx pattern*/ TIRPattern pattern; /*! \brief The evaluated values of symbolic vars. */ ffi::Array symbol_values; /*! \brief The matched buffers of input and output. */ - ffi::Array matched_buffers; + ffi::Array matched_buffers; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -62,12 +62,12 @@ class MatchResult : public ObjectRef { public: /*! * \brief Constructor - * \param pattern The matched tir pattern. + * \param pattern The matched tirx pattern. * \param symbol_values The evaluated values of symbolic vars. * \param matched_buffers The matched buffers of input and output. */ TVM_DLL explicit MatchResult(TIRPattern pattern, ffi::Array symbol_values, - ffi::Array matched_buffers); + ffi::Array matched_buffers); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(MatchResult, ObjectRef, MatchResultNode); }; diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 9ffeb05f8f0b..04b6d80be236 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -28,8 +28,8 @@ #include #include #include -#include -#include +#include +#include namespace tvm { namespace relax { @@ -205,7 +205,7 @@ TVM_DLL Pass BindParams(ffi::String func_name, ffi::Map params); * * \param binding_map The dictionary of symbolic variables and their * constant shape values. Dictionary keys may be either a - * `tir.Var` or a string name of the variable. If the variables + * `tirx.Var` or a string name of the variable. If the variables * are referred to by name, the name must uniquely identify a * symbolic variable in each function where it is used. * @@ -215,7 +215,7 @@ TVM_DLL Pass BindParams(ffi::String func_name, ffi::Map params); * * \return The Pass. */ -TVM_DLL Pass BindSymbolicVars(ffi::Map, PrimExpr> binding_map, +TVM_DLL Pass BindSymbolicVars(ffi::Map, PrimExpr> binding_map, ffi::Optional func_name = std::nullopt); /*! @@ -262,9 +262,9 @@ TVM_DLL Pass LegalizeOps(ffi::Optional> cma TVM_DLL Pass RealizeVDevice(); /*! - * \brief Attach layout free buffers to the tir::PrimFunc. + * \brief Attach layout free buffers to the tirx::PrimFunc. * - * This pass is used to attach layout free buffers to the tir::PrimFunc according to + * This pass is used to attach layout free buffers to the tirx::PrimFunc according to * the function usage in the relax function. Currently, the layout free buffers are the model * weights and relax constants. * @@ -274,7 +274,7 @@ TVM_DLL Pass RealizeVDevice(); TVM_DLL Pass AttachAttrLayoutFreeBuffers(); /*! - * \brief Split the layout rewrite preproc block to a separate tir::PrimFunc. + * \brief Split the layout rewrite preproc block to a separate tirx::PrimFunc. * * This pass is used in the prepack weight after meta_schedule tuning. * @@ -598,8 +598,8 @@ TVM_DLL Pass DecomposeOpsForTraining(ffi::Optional func_name); * \return The Pass. */ TVM_DLL Pass AlterOpImpl( - const ffi::Map& op_impl_map, - const ffi::Map>& op_buffer_transforms, + const ffi::Map& op_impl_map, + const ffi::Map>& op_buffer_transforms, const ffi::Map>>>& axis_separators, const ffi::Map>>>& input_axis_separators); diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h index 32ec0f0f8f05..b70a2756b71f 100644 --- a/include/tvm/relax/type.h +++ b/include/tvm/relax/type.h @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h index 70ecbe4855ac..792f7dd11f90 100644 --- a/include/tvm/relax/utils.h +++ b/include/tvm/relax/utils.h @@ -48,14 +48,14 @@ namespace relax { * \return The updated expression. */ TVM_DLL Expr Bind(const Expr& expr, const tvm::ffi::Map& binds, - const tvm::ffi::Map& symbolic_var_map = {}); + const tvm::ffi::Map& symbolic_var_map = {}); /*! * \brief Bind the symbolic variables to a StructInfo. This is a helper function usually called by * other pass functions to help optimizations. */ TVM_DLL StructInfo Bind(const StructInfo& sinfo, - const tvm::ffi::Map& symbolic_var_map); + const tvm::ffi::Map& symbolic_var_map); /*! * \brief Infer a binding map for symbolic variables @@ -74,7 +74,7 @@ TVM_DLL StructInfo Bind(const StructInfo& sinfo, * * \return A map of TIR variables to TIR expressions */ -TVM_DLL tvm::ffi::Map InferSymbolicVarMap( +TVM_DLL tvm::ffi::Map InferSymbolicVarMap( const tvm::ffi::Map& binds, arith::Analyzer* analyzer); /*! diff --git a/include/tvm/s_tir/analysis.h b/include/tvm/s_tir/analysis.h index 844bf20a3592..c5f4bd90f465 100644 --- a/include/tvm/s_tir/analysis.h +++ b/include/tvm/s_tir/analysis.h @@ -27,13 +27,13 @@ #include #include #include -#include -#include +#include +#include #include namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Auto detect the block access region according to its body stmt @@ -85,9 +85,9 @@ TVM_DLL ffi::Map> DetectBufferAccessLCA(const PrimFu * \param mod The input TIR module. * \return The anchor block if found, nullptr otherwise. */ -const tir::SBlockNode* FindAnchorBlock(const IRModule& mod); +const tirx::SBlockNode* FindAnchorBlock(const IRModule& mod); -} // namespace tir +} // namespace tirx namespace arith { class Analyzer; @@ -95,7 +95,7 @@ class Analyzer; namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief Estimate the FLOPs of a TIR fragment. diff --git a/include/tvm/s_tir/backend/adreno/transform.h b/include/tvm/s_tir/backend/adreno/transform.h index 5d373b20add4..0db71c56ff7f 100644 --- a/include/tvm/s_tir/backend/adreno/transform.h +++ b/include/tvm/s_tir/backend/adreno/transform.h @@ -27,8 +27,8 @@ #include #include #include -#include -#include +#include +#include #include #include @@ -39,7 +39,7 @@ namespace backend { namespace adreno { namespace transform { -using tir::transform::CreatePrimFuncPass; +using tirx::transform::CreatePrimFuncPass; using tvm::transform::Pass; using tvm::transform::PassContext; diff --git a/include/tvm/s_tir/data_layout.h b/include/tvm/s_tir/data_layout.h index e57ed9bc965e..7a0ac5d6fba5 100644 --- a/include/tvm/s_tir/data_layout.h +++ b/include/tvm/s_tir/data_layout.h @@ -26,8 +26,8 @@ #define TVM_S_TIR_DATA_LAYOUT_H_ #include -#include -#include +#include +#include #include #include @@ -35,10 +35,10 @@ #include #include -#include "tvm/tir/var.h" +#include "tvm/tirx/var.h" namespace tvm { -namespace tir { +namespace tirx { class Layout; @@ -47,7 +47,7 @@ class LayoutAxis { static const LayoutAxis& Get(const char name); // Get the singleton LayoutAxis using itvar->var->name_hint - static const LayoutAxis& Get(const tir::IterVar& itvar); + static const LayoutAxis& Get(const tirx::IterVar& itvar); // Get the singleton LayoutAxis using name[0] (size of name must be 1). static const LayoutAxis& Get(const std::string& name); @@ -108,7 +108,7 @@ class LayoutNode : public Object { * it is a variable for a primal axis, but a constant for a subordinate axis. * Empty for scalar's layout. */ - ffi::Array axes; + ffi::Array axes; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -125,7 +125,7 @@ class LayoutNode : public Object { */ class Layout : public ObjectRef { public: - explicit Layout(const ffi::Array& axes); + explicit Layout(const ffi::Array& axes); /*! \brief construct from a string */ Layout(const tvm::ffi::String& name) : Layout(name.operator std::string()) {} // NOLINT(*) @@ -273,7 +273,7 @@ class Layout : public ObjectRef { * \param iter the input iter var. * \return the index or -1 if not found. */ - inline int32_t IndexOf(const tir::IterVar& iter) const { return IndexOf(iter->var->name_hint); } + inline int32_t IndexOf(const tirx::IterVar& iter) const { return IndexOf(iter->var->name_hint); } /*! * \brief Get the factor size of the subordinate axis. @@ -291,7 +291,7 @@ class Layout : public ObjectRef { */ bool Contains(const LayoutAxis& axis) const { if (!defined()) return false; - for (const tir::IterVar packed_var : operator->()->axes) { + for (const tirx::IterVar packed_var : operator->()->axes) { auto iter_vars = UnpackIterVar(packed_var); for (auto var : iter_vars) { if (var->var->name_hint == axis.name()) { @@ -306,7 +306,7 @@ class Layout : public ObjectRef { TVM_FFI_ICHECK(defined()) << "Try to access axis from an undefined layout."; int32_t index = i < 0 ? static_cast(ndim() + i) : i; TVM_FFI_ICHECK(index >= 0 && static_cast(index) < ndim()) << "Invalid index " << i; - const tir::IterVar axis = operator->()->axes[index]; + const tirx::IterVar axis = operator->()->axes[index]; return LayoutAxis::Get(axis); } @@ -314,7 +314,7 @@ class Layout : public ObjectRef { TVM_FFI_ICHECK(defined()) << "Try to access axis from an undefined layout."; int32_t index = i < 0 ? static_cast(ndim() + i) : i; TVM_FFI_ICHECK(index >= 0 && static_cast(index) < ndim()) << "Invalid index " << i; - const tir::IterVar axis = operator->()->axes[index]; + const tirx::IterVar axis = operator->()->axes[index]; return axis; } @@ -404,7 +404,7 @@ class BijectiveLayout : public ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(BijectiveLayout, ObjectRef, BijectiveLayoutNode); }; -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_S_TIR_DATA_LAYOUT_H_ diff --git a/include/tvm/s_tir/meta_schedule/arg_info.h b/include/tvm/s_tir/meta_schedule/arg_info.h index ae2c3c9057df..b5477882405c 100644 --- a/include/tvm/s_tir/meta_schedule/arg_info.h +++ b/include/tvm/s_tir/meta_schedule/arg_info.h @@ -24,7 +24,7 @@ #include #include #include -#include +#include namespace tvm { namespace s_tir { @@ -59,7 +59,7 @@ class ArgInfo : public runtime::ObjectRef { * \param func The PrimFunc to get argument information from. * \return An array of the argument information derived. */ - TVM_DLL static ffi::Array FromPrimFunc(const tir::PrimFunc& func); + TVM_DLL static ffi::Array FromPrimFunc(const tirx::PrimFunc& func); /*! * \brief Extract a list of the argument information from the entry func of an IRModule * \param mod The IRModule to extract argument information from. diff --git a/include/tvm/s_tir/meta_schedule/database.h b/include/tvm/s_tir/meta_schedule/database.h index c6947e573473..b23d991bbb69 100644 --- a/include/tvm/s_tir/meta_schedule/database.h +++ b/include/tvm/s_tir/meta_schedule/database.h @@ -196,7 +196,7 @@ class DatabaseNode : public runtime::Object { * - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a * given module. The "ignore-tensor" varint is used for the extracted blocks * or in case no anchor block is found. - * For the definition of the anchor block, see tvm/tir/analysis.h. + * For the definition of the anchor block, see tvm/tirx/analysis.h. */ explicit DatabaseNode(ffi::String mod_eq_name = "structural"); @@ -295,7 +295,7 @@ class PyDatabaseNode : public DatabaseNode { * - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a * given module. The "ignore-tensor" varint is used for the extracted blocks * or in case no anchor block is found. - * For the definition of the anchor block, see tvm/tir/analysis.h. + * For the definition of the anchor block, see tvm/tirx/analysis.h. */ explicit PyDatabaseNode(ffi::String mod_eq_name = "structural"); diff --git a/include/tvm/s_tir/meta_schedule/extracted_task.h b/include/tvm/s_tir/meta_schedule/extracted_task.h index be55a9db94e2..6ec20ba81f0f 100644 --- a/include/tvm/s_tir/meta_schedule/extracted_task.h +++ b/include/tvm/s_tir/meta_schedule/extracted_task.h @@ -27,9 +27,9 @@ #include namespace tvm { -namespace tir { +namespace tirx { class PrimFunc; -} // namespace tir +} // namespace tirx namespace te { class Tensor; } // namespace te diff --git a/include/tvm/s_tir/sblock_dependence_info.h b/include/tvm/s_tir/sblock_dependence_info.h index 408afb639bb0..56d2c88469b9 100644 --- a/include/tvm/s_tir/sblock_dependence_info.h +++ b/include/tvm/s_tir/sblock_dependence_info.h @@ -37,7 +37,7 @@ #include namespace tvm { -namespace tir { +namespace tirx { /** * @brief An object that helps build and query block level dependences using the 2 core objects @@ -99,6 +99,6 @@ class SBlockDependenceInfo : public ObjectRef { SBlockDependenceInfoNode); }; -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_S_TIR_SBLOCK_DEPENDENCE_INFO_H_ diff --git a/include/tvm/s_tir/sblock_scope.h b/include/tvm/s_tir/sblock_scope.h index a302cab26019..f923b2ab907a 100644 --- a/include/tvm/s_tir/sblock_scope.h +++ b/include/tvm/s_tir/sblock_scope.h @@ -26,16 +26,16 @@ #define TVM_S_TIR_SBLOCK_SCOPE_H_ #include -#include -#include -#include +#include +#include +#include #include #include #include namespace tvm { -namespace tir { +namespace tirx { /*! * \brief An object that refers to schedulable elements (block/for-loop) in TensorIR, aka "sref". @@ -86,7 +86,7 @@ class StmtSRefNode : public Object { * \brief Get the referenced statement with proper type checking. * It serves the same purpose as `ObjectRef::as`, but does not acquire strong reference to `stmt` * \tparam StmtType The type that `this->stmt` to be downcasted to. Presumably - * tvm::tir::SBlockNode or tvm::tir::ForNode + * tvm::tirx::SBlockNode or tvm::tirx::ForNode * \return nullptr if type check fails, otherwise the casted result for `this->stmt` */ template @@ -311,7 +311,7 @@ class SBlockScope : public ObjectRef { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SBlockScope, ObjectRef, SBlockScopeNode); }; -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_S_TIR_SBLOCK_SCOPE_H_ diff --git a/include/tvm/s_tir/schedule/instruction.h b/include/tvm/s_tir/schedule/instruction.h index b422fe953567..cfd719632985 100644 --- a/include/tvm/s_tir/schedule/instruction.h +++ b/include/tvm/s_tir/schedule/instruction.h @@ -30,7 +30,7 @@ template class AttrRegistry; namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; // Forward declaration class Schedule; diff --git a/include/tvm/s_tir/schedule/schedule.h b/include/tvm/s_tir/schedule/schedule.h index 4a1bbe207d4e..be903e10cdc6 100644 --- a/include/tvm/s_tir/schedule/schedule.h +++ b/include/tvm/s_tir/schedule/schedule.h @@ -22,11 +22,11 @@ #include #include #include -#include +#include namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! \brief The level of detailed error message rendering */ enum class ScheduleErrorRenderLevel : int32_t { diff --git a/include/tvm/s_tir/schedule/state.h b/include/tvm/s_tir/schedule/state.h index 03d7ddbdb36b..d0b855d14bc0 100644 --- a/include/tvm/s_tir/schedule/state.h +++ b/include/tvm/s_tir/schedule/state.h @@ -26,14 +26,14 @@ #include #include #include -#include +#include #include #include namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief The information about a TensorIR block, it contains two categories of information @@ -147,7 +147,7 @@ class ScheduleStateNode : public Object { * that points to the old block will point to the new one * \note The reuse of loop srefs are detected automatically according to the reuse of loop vars. */ - TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt, + TVM_DLL void Replace(const tirx::StmtSRef& src_sref, const Stmt& tgt_stmt, const ffi::Map& block_sref_reuse); /*! * \brief Trigger the verification according to the `debug_mask` bitmask. diff --git a/include/tvm/s_tir/schedule/trace.h b/include/tvm/s_tir/schedule/trace.h index 5640c4b7f50e..0100f00c891c 100644 --- a/include/tvm/s_tir/schedule/trace.h +++ b/include/tvm/s_tir/schedule/trace.h @@ -23,7 +23,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; // Forward declaration class Trace; diff --git a/include/tvm/s_tir/stmt.h b/include/tvm/s_tir/stmt.h index 86435f94b698..c0dcf942cf28 100644 --- a/include/tvm/s_tir/stmt.h +++ b/include/tvm/s_tir/stmt.h @@ -153,7 +153,7 @@ constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule"; * if (mask & 1) the read region should be detected, * if (mask & 2) the write region should be detected. */ -constexpr const char* script_parsing_detect_access = "tir.script_parsing_detect_access"; +constexpr const char* script_parsing_detect_access = "tirx.script_parsing_detect_access"; /*! * \brief Mark that the block need to add predicate for block var bounds during lowering @@ -176,7 +176,8 @@ constexpr const char* software_pipeline_async_stages = "software_pipeline_async_ constexpr const char* layout_free_buffers = "layout_free_buffers"; /*! \brief Mark the local stage for the shared memory access should be added. */ -constexpr const char* manifest_shared_memory_local_stage = "tir.manifest_shared_memory_local_stage"; +constexpr const char* manifest_shared_memory_local_stage = + "tirx.manifest_shared_memory_local_stage"; /*! * \brief Mark alignment of buffer dimension diff --git a/include/tvm/s_tir/transform.h b/include/tvm/s_tir/transform.h index 564e879c2e69..36424ab6cc19 100644 --- a/include/tvm/s_tir/transform.h +++ b/include/tvm/s_tir/transform.h @@ -26,7 +26,7 @@ #include #include -#include +#include #include #include @@ -41,11 +41,11 @@ namespace s_tir { * \param func The input PrimFunc. * \return The renewed func. */ -TVM_DLL tir::PrimFunc RenewDefs(const tir::PrimFunc& func); +TVM_DLL tirx::PrimFunc RenewDefs(const tirx::PrimFunc& func); namespace transform { -using tir::transform::CreatePrimFuncPass; +using tirx::transform::CreatePrimFuncPass; using tvm::transform::Pass; using tvm::transform::PassContext; @@ -350,7 +350,7 @@ TVM_DLL Pass DefaultGPUSchedule(); TVM_DLL Pass RemoveWeightLayoutRewriteBlock(bool skip_tensor_rewrite = false); /*! - * \brief Remove stores of tir::builtin::undef. + * \brief Remove stores of tirx::builtin::undef. * \return The pass. */ TVM_DLL Pass RemoveStoreUndef(); diff --git a/include/tvm/s_tir/utils.h b/include/tvm/s_tir/utils.h index 621efba737b8..066c9fc38d62 100644 --- a/include/tvm/s_tir/utils.h +++ b/include/tvm/s_tir/utils.h @@ -20,12 +20,12 @@ #define TVM_S_TIR_UTILS_H_ #include -#include +#include #include namespace tvm { -namespace tir { +namespace tirx { /*! * \brief A helper macro to convert an sref to the statement it points to, @@ -48,7 +48,7 @@ namespace tir { */ #define TVM_SREF_TO_SBLOCK(SRef) \ [&]() { \ - auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::SBlockNode) \ + auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tirx::SBlockNode) \ << "Expects StmtSRef `" << #SRef << "` points to `Block`, but gets: " \ << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \ return result; \ @@ -64,7 +64,7 @@ namespace tir { */ #define TVM_SREF_TO_FOR(SRef) \ [&]() { \ - auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::ForNode) \ + auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tirx::ForNode) \ << "Expects StmtSRef `" << #SRef << "` points to `Loop`, but gets: " \ << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \ return result; \ @@ -135,7 +135,7 @@ inline void SetSeqIndexInChildren( } } -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_S_TIR_UTILS_H_ diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index e5679c6064ac..86888dea1c8b 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -43,7 +43,7 @@ namespace ir_builder { * * \code {.cpp} * - * using T = tvm::script::ir_builder::tir; + * using T = tvm::script::ir_builder::tirx; * With _(...); * Buffer buffer = T::MatchBuffer(...); * @@ -53,7 +53,7 @@ namespace ir_builder { * * \code {.cpp} * - * using T = tvm::script::ir_builder::tir; + * using T = tvm::script::ir_builder::tirx; * With _(...); * { * With _2(...); @@ -141,7 +141,7 @@ class IRBuilderFrame : public runtime::ObjectRef { * * PrimFunc ConstructPrimFunc() { * using tvm::script::ir_builder::IRBuilder; - * using T = tvm::script::ir_builder::tir; + * using T = tvm::script::ir_builder::tirx; * IRBuilder builder; * // Step 1. Place IRBuilder inside the with-scope. * { diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tirx/frame.h similarity index 87% rename from include/tvm/script/ir_builder/tir/frame.h rename to include/tvm/script/ir_builder/tirx/frame.h index b608338d9ea2..3b4f9cf2edab 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tirx/frame.h @@ -21,14 +21,14 @@ #include #include -#include +#include #include namespace tvm { namespace script { namespace ir_builder { -namespace tir { +namespace tirx { /*! * \brief A base frame that represents the TIR fame with body of statements. @@ -38,13 +38,13 @@ namespace tir { class TIRFrameNode : public IRBuilderFrameNode { public: /*! \brief The Stmt within in this frame. */ - ffi::Array stmts; + ffi::Array stmts; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("stmts", &TIRFrameNode::stmts); } - TVM_FFI_DECLARE_OBJECT_INFO("script.ir_builder.tir.TIRFrame", TIRFrameNode, IRBuilderFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO("script.ir_builder.tirx.TIRFrame", TIRFrameNode, IRBuilderFrameNode); }; /*! @@ -71,19 +71,19 @@ class PrimFuncFrameNode : public TIRFrameNode { /*! \brief The name of the block. */ ffi::Optional name; /*! \brief Function parameters. */ - ffi::Array args; + ffi::Array args; /*! \brief Whether the PrimFunc is annotated as private. */ bool is_private; /*! \brief The return type of the function. */ ffi::Optional ret_type; /*! \brief Maps some parameters to specific Buffer data structures. */ - ffi::Map buffer_map; + ffi::Map buffer_map; /*! \brief Additional attributes storing the meta-data */ ffi::Map attrs; /*! \brief The variable map bound to thread env. */ - ffi::Map env_threads; + ffi::Map env_threads; /*! \brief The buffer allocated in root block. */ - ffi::Array root_alloc_buffers; + ffi::Array root_alloc_buffers; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -97,7 +97,7 @@ class PrimFuncFrameNode : public TIRFrameNode { .def_ro("env_threads", &PrimFuncFrameNode::env_threads) .def_ro("root_alloc_buffers", &PrimFuncFrameNode::root_alloc_buffers); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.PrimFuncFrame", PrimFuncFrameNode, + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tirx.PrimFuncFrame", PrimFuncFrameNode, TIRFrameNode); public: @@ -132,17 +132,17 @@ class SBlockFrameNode : public TIRFrameNode { /*! \brief The name of the block. */ ffi::String name; /*! \brief The variables of the block. */ - ffi::Array iter_vars; + ffi::Array iter_vars; /*! \brief The read buffer regions of the block. */ - ffi::Optional> reads; + ffi::Optional> reads; /*! \brief The write buffer regions of the block. */ - ffi::Optional> writes; + ffi::Optional> writes; /*! \brief The init statement of the bolck. */ - ffi::Optional init; + ffi::Optional init; /*! \brief The buffer allocated in the block. */ - ffi::Array alloc_buffers; + ffi::Array alloc_buffers; /*! \brief The match buffer regions. */ - ffi::Array match_buffers; + ffi::Array match_buffers; /*! \brief The annotation of the block. */ ffi::Optional> annotations; /*! \brief The corresponding values of the iter vars. */ @@ -170,7 +170,7 @@ class SBlockFrameNode : public TIRFrameNode { .def_ro("predicate", &SBlockFrameNode::predicate) .def_ro("no_realize", &SBlockFrameNode::no_realize); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.SSBlockFrame", SBlockFrameNode, + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tirx.SSBlockFrame", SBlockFrameNode, TIRFrameNode); public: @@ -207,7 +207,7 @@ class BlockInitFrameNode : public TIRFrameNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.SBlockInitFrame", BlockInitFrameNode, + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tirx.SBlockInitFrame", BlockInitFrameNode, TIRFrameNode); public: @@ -251,11 +251,11 @@ class ForFrameNode : public TIRFrameNode { * \param loop_body The loop body * \return A stmt, the loop nest */ - using FMakeForLoop = ffi::TypedFunction loop_vars, ffi::Array loop_extents, - ffi::Array> loop_steps, tvm::tir::Stmt loop_body)>; + using FMakeForLoop = ffi::TypedFunction loop_vars, ffi::Array loop_extents, + ffi::Array> loop_steps, tvm::tirx::Stmt loop_body)>; /*! \brief The loop variable. */ - ffi::Array vars; + ffi::Array vars; /*! \brief The domains of iteration. */ ffi::Array doms; /*! \brief The optional steps of iteration. */ @@ -270,7 +270,7 @@ class ForFrameNode : public TIRFrameNode { .def_ro("doms", &ForFrameNode::doms); // `f_make_for_loop` is not registered as it's not visited. } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.ForFrame", ForFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tirx.ForFrame", ForFrameNode, TIRFrameNode); public: /*! @@ -305,9 +305,9 @@ class AssertFrameNode : public TIRFrameNode { /*! \brief The PrimExpr to test. */ PrimExpr condition; /*! \brief The error kind, e.g. "RuntimeError", "TypeError", "ValueError". */ - tvm::tir::StringImm error_kind; + tvm::tirx::StringImm error_kind; /*! \brief Error message fragments, concatenated at runtime when assertion fails. */ - ffi::Array message_parts; + ffi::Array message_parts; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -316,7 +316,7 @@ class AssertFrameNode : public TIRFrameNode { .def_ro("error_kind", &AssertFrameNode::error_kind) .def_ro("message_parts", &AssertFrameNode::message_parts); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.AssertFrame", AssertFrameNode, + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tirx.AssertFrame", AssertFrameNode, TIRFrameNode); public: @@ -352,7 +352,7 @@ class LaunchThreadFrameNode : public TIRFrameNode { /*! \brief The attribute key, could be either virtual_thread or thread_extent. */ ffi::String attr_key; /*! \brief The iteration variable. */ - tvm::tir::IterVar iter_var; + tvm::tirx::IterVar iter_var; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -361,7 +361,7 @@ class LaunchThreadFrameNode : public TIRFrameNode { .def_ro("attr_key", &LaunchThreadFrameNode::attr_key) .def_ro("iter_var", &LaunchThreadFrameNode::iter_var); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.LaunchThreadFrame", + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tirx.LaunchThreadFrame", LaunchThreadFrameNode, TIRFrameNode); public: @@ -407,7 +407,8 @@ class AttrFrameNode : public TIRFrameNode { .def_ro("attr_key", &AttrFrameNode::attr_key) .def_ro("value", &AttrFrameNode::value); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.AttrFrame", AttrFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tirx.AttrFrame", AttrFrameNode, + TIRFrameNode); public: /*! @@ -445,7 +446,7 @@ class WhileFrameNode : public TIRFrameNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("condition", &WhileFrameNode::condition); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.WhileFrame", WhileFrameNode, + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tirx.WhileFrame", WhileFrameNode, TIRFrameNode); public: @@ -480,9 +481,9 @@ class IfFrameNode : public TIRFrameNode { /*! \brief The condition of the if statement. */ PrimExpr condition; /*! \brief The statements in the true branch. */ - ffi::Optional> then_stmts; + ffi::Optional> then_stmts; /*! \brief The stetements in the false branch. */ - ffi::Optional> else_stmts; + ffi::Optional> else_stmts; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -491,7 +492,7 @@ class IfFrameNode : public TIRFrameNode { .def_ro("then_stmts", &IfFrameNode::then_stmts) .def_ro("else_stmts", &IfFrameNode::else_stmts); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.IfFrame", IfFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tirx.IfFrame", IfFrameNode, TIRFrameNode); public: /*! @@ -525,7 +526,8 @@ class ThenFrameNode : public TIRFrameNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.ThenFrame", ThenFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tirx.ThenFrame", ThenFrameNode, + TIRFrameNode); public: /*! @@ -564,7 +566,8 @@ class ElseFrameNode : public TIRFrameNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tir.ElseFrame", ElseFrameNode, TIRFrameNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.tirx.ElseFrame", ElseFrameNode, + TIRFrameNode); public: /*! @@ -593,7 +596,7 @@ class ElseFrame : public TIRFrame { TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ElseFrame, TIRFrame, ElseFrameNode); }; -} // namespace tir +} // namespace tirx } // namespace ir_builder } // namespace script } // namespace tvm diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tirx/ir.h similarity index 96% rename from include/tvm/script/ir_builder/tir/ir.h rename to include/tvm/script/ir_builder/tirx/ir.h index 959bf330ca63..1bf6ce8ffe2d 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tirx/ir.h @@ -20,16 +20,16 @@ #define TVM_SCRIPT_IR_BUILDER_TIR_IR_H_ #include -#include -#include +#include +#include namespace tvm { namespace script { namespace ir_builder { -namespace tir { +namespace tirx { -using tvm::tir::Buffer; -using tvm::tir::Var; +using tvm::tirx::Buffer; +using tvm::tirx::Var; /*! * \brief The buffer declaration function. @@ -443,18 +443,19 @@ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), } else { type_annotation = PointerType(PrimType(dtype), storage_scope); } - return is_size_var ? tvm::tir::SizeVar("", type_annotation) : tvm::tir::Var("", type_annotation); + return is_size_var ? tvm::tirx::SizeVar("", type_annotation) + : tvm::tirx::Var("", type_annotation); } -inline Var TensormapHandle() { return tvm::tir::Var("", PointerType(TensorMapType())); } +inline Var TensormapHandle() { return tvm::tirx::Var("", PointerType(TensorMapType())); } -#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ - inline PrimExpr FuncName(ffi::Optional expr = std::nullopt, \ - bool is_size_var = false) { \ - DataType dtype = DType; \ - return expr.defined() \ - ? tvm::cast(dtype, expr.value()) \ - : (is_size_var ? tvm::tir::SizeVar("", dtype) : tvm::tir::Var("", dtype)); \ +#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ + inline PrimExpr FuncName(ffi::Optional expr = std::nullopt, \ + bool is_size_var = false) { \ + DataType dtype = DType; \ + return expr.defined() \ + ? tvm::cast(dtype, expr.value()) \ + : (is_size_var ? tvm::tirx::SizeVar("", dtype) : tvm::tirx::Var("", dtype)); \ } #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \ @@ -513,7 +514,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void()); #undef TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST -} // namespace tir +} // namespace tirx } // namespace ir_builder } // namespace script } // namespace tvm diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h index d92ef674f12e..59db02f8b8e9 100644 --- a/include/tvm/target/codegen.h +++ b/include/tvm/target/codegen.h @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 17de92c8be36..8dc1e01f705e 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -27,9 +27,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include diff --git a/include/tvm/te/tensor.h b/include/tvm/te/tensor.h index 501b5b062b52..c93ac2b3799d 100644 --- a/include/tvm/te/tensor.h +++ b/include/tvm/te/tensor.h @@ -26,8 +26,8 @@ #include #include -#include -#include +#include +#include #include #include @@ -38,7 +38,7 @@ namespace tvm { namespace te { using arith::IntSet; -using namespace tvm::tir; +using namespace tvm::tirx; // internal node container for Operation class OperationNode; diff --git a/include/tvm/tir/analysis.h b/include/tvm/tirx/analysis.h similarity index 95% rename from include/tvm/tir/analysis.h rename to include/tvm/tirx/analysis.h index 5a543f852b7f..83e235ea1684 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tirx/analysis.h @@ -18,7 +18,7 @@ */ /*! - * \file tvm/tir/analysis.h + * \file tvm/tirx/analysis.h * \brief Analysis utilities and passes for TIR. */ #ifndef TVM_TIR_ANALYSIS_H_ @@ -28,10 +28,10 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include @@ -42,7 +42,7 @@ namespace arith { class Analyzer; } -namespace tir { +namespace tirx { /*! * \brief Compare two expressions recursively and check if they are equal @@ -181,7 +181,7 @@ TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func, * * - Each variable has a single point of definition. * - * - Expressions within a tir::SBlock may not reference variables + * - Expressions within a tirx::SBlock may not reference variables * defined outside the block. For example, for a block with iter * vars `vi, vj = T.axis.remap('SS', [i,j])`, the statement * `B[i,j] = A[i,j]` would be ill-formed, because it uses the loop @@ -210,7 +210,7 @@ TVM_DLL bool VerifyWellFormed(const IRModule& mod, bool assert_mode = true); /*! * \brief Find the entry function of the given IRModule, i.e, functions marked by - * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc. + * `tirx::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc. * \param mod The IRModule to find the entry function. * \param result_g_var The result GlobalVar of the entry function. * \return The entry function. @@ -228,7 +228,7 @@ using tvm::transform::PassContext; * \brief Pass variant of VerifySSA. * * \returns The pass. - * \sa tvm::tir::VerifySSA + * \sa tvm::tirx::VerifySSA */ TVM_DLL Pass VerifySSA(); @@ -236,11 +236,11 @@ TVM_DLL Pass VerifySSA(); * \brief Pass variant of VerifyMemory. * * \returns The pass. - * \sa tvm::tir::VerifyMemory + * \sa tvm::tirx::VerifyMemory */ TVM_DLL Pass VerifyMemory(); } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_ANALYSIS_H_ diff --git a/include/tvm/tir/buffer.h b/include/tvm/tirx/buffer.h similarity index 94% rename from include/tvm/tir/buffer.h rename to include/tvm/tirx/buffer.h index 1075693bb541..8f5c916a5c11 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tirx/buffer.h @@ -18,7 +18,7 @@ */ /*! - * \file tvm/tir/buffer.h + * \file tvm/tirx/buffer.h * \brief Symbolic n-dimensional array, to represent a memory buffer. */ #ifndef TVM_TIR_BUFFER_H_ @@ -29,12 +29,12 @@ #include #include #include -#include +#include #include namespace tvm { -namespace tir { +namespace tirx { #ifndef TVM_INDEX_DEFAULT_I64 #define TVM_INDEX_DEFAULT_I64 1 @@ -131,7 +131,7 @@ class BufferNode : public Object { /*! \return preferred index type for this buffer node */ DataType DefaultIndexType() const { - return shape.size() != 0 ? shape[0].dtype() : tvm::tir::DefaultIndexType(); + return shape.size() != 0 ? shape[0].dtype() : tvm::tirx::DefaultIndexType(); } /*! \brief Determine the offset in the buffer of the given index. @@ -144,7 +144,7 @@ class BufferNode : public Object { static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Buffer", BufferNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Buffer", BufferNode, Object); TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); }; @@ -276,7 +276,7 @@ class DataProducerNode : public PrimExprConvertibleNode { * \return The data type. */ virtual ffi::String GetNameHint() const = 0; - TVM_FFI_DECLARE_OBJECT_INFO("tir.DataProducer", DataProducerNode, PrimExprConvertibleNode); + TVM_FFI_DECLARE_OBJECT_INFO("tirx.DataProducer", DataProducerNode, PrimExprConvertibleNode); }; /*! @@ -301,10 +301,10 @@ class DataProducer : public PrimExprConvertible { * \param compact If the statement has already bound to a compact buffer. * \param memory_scope memory scope of the buffer */ -TVM_DLL tir::Buffer BufferWithOffsetAlignment(ffi::Array shape, DataType dtype, - std::string name, int data_alignment, - int offset_factor, bool compact, - std::string memory_scope = ""); -} // namespace tir +TVM_DLL tirx::Buffer BufferWithOffsetAlignment(ffi::Array shape, DataType dtype, + std::string name, int data_alignment, + int offset_factor, bool compact, + std::string memory_scope = ""); +} // namespace tirx } // namespace tvm #endif // TVM_TIR_BUFFER_H_ diff --git a/include/tvm/tir/builtin.h b/include/tvm/tirx/builtin.h similarity index 99% rename from include/tvm/tir/builtin.h rename to include/tvm/tirx/builtin.h index 4650edc05483..d0d5b3d57e27 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tirx/builtin.h @@ -18,7 +18,7 @@ */ /*! - * \file tvm/tir/builtin.h + * \file tvm/tirx/builtin.h * \brief TIR builtin intrinsics. * * TIR builtin intrinsics are stored as tvm:Op. @@ -34,10 +34,10 @@ #define TVM_TIR_BUILTIN_H_ #include -#include +#include namespace tvm { -namespace tir { +namespace tirx { /*! \brief Collection of builtin intrinsics as ops */ namespace builtin { @@ -986,6 +986,6 @@ enum TVMStructFieldKind : int { kInt64ArrayElem, }; } // namespace builtin -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_BUILTIN_H_ diff --git a/include/tvm/tir/expr.h b/include/tvm/tirx/expr.h similarity index 90% rename from include/tvm/tir/expr.h rename to include/tvm/tirx/expr.h index 34c11bdd3ed8..ebd318d82288 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tirx/expr.h @@ -18,7 +18,7 @@ */ /*! - * \file tvm/tir/expr.h + * \file tvm/tirx/expr.h * \brief TIR expressions. */ // Acknowledgement: Many low-level IR nodes originate from Halide. @@ -32,8 +32,8 @@ #include #include #include -#include -#include +#include +#include #include #include @@ -43,7 +43,7 @@ #include namespace tvm { -namespace tir { +namespace tirx { using IntImmNode = tvm::IntImmNode; using FloatImmNode = tvm::FloatImmNode; @@ -58,7 +58,7 @@ class StringImmNode : public PrimExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &StringImmNode::value); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.StringImm", StringImmNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.StringImm", StringImmNode, PrimExprNode); }; /*! @@ -85,7 +85,7 @@ class CastNode : public PrimExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &CastNode::value); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Cast", CastNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Cast", CastNode, PrimExprNode); }; /*! @@ -124,7 +124,7 @@ class BinaryOpNode : public PrimExprNode { /*! \brief a + b */ class AddNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "tir.Add"; + static constexpr const char* _type_key = "tirx.Add"; }; /*! @@ -141,7 +141,7 @@ class Add : public PrimExpr { /*! \brief a - b */ class SubNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "tir.Sub"; + static constexpr const char* _type_key = "tirx.Sub"; }; /*! @@ -159,7 +159,7 @@ class Sub : public PrimExpr { /*! \brief a * b */ class MulNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "tir.Mul"; + static constexpr const char* _type_key = "tirx.Mul"; }; /*! @@ -179,7 +179,7 @@ class Mul : public PrimExpr { */ class DivNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "tir.Div"; + static constexpr const char* _type_key = "tirx.Div"; }; /*! @@ -199,7 +199,7 @@ class Div : public PrimExpr { */ class ModNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "tir.Mod"; + static constexpr const char* _type_key = "tirx.Mod"; }; /*! @@ -216,7 +216,7 @@ class Mod : public PrimExpr { /*! \brief Floor division, floor(a/b) */ class FloorDivNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "tir.FloorDiv"; + static constexpr const char* _type_key = "tirx.FloorDiv"; }; /*! @@ -233,7 +233,7 @@ class FloorDiv : public PrimExpr { /*! \brief The remainder of the floordiv */ class FloorModNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "tir.FloorMod"; + static constexpr const char* _type_key = "tirx.FloorMod"; }; /*! @@ -250,7 +250,7 @@ class FloorMod : public PrimExpr { /*! \brief min(a, b) */ class MinNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "tir.Min"; + static constexpr const char* _type_key = "tirx.Min"; }; /*! @@ -267,7 +267,7 @@ class Min : public PrimExpr { /*! \brief max(a, b) */ class MaxNode : public BinaryOpNode { public: - static constexpr const char* _type_key = "tir.Max"; + static constexpr const char* _type_key = "tirx.Max"; }; /*! @@ -306,7 +306,7 @@ class CmpOpNode : public PrimExprNode { /*! \brief a == b */ class EQNode : public CmpOpNode { public: - static constexpr const char* _type_key = "tir.EQ"; + static constexpr const char* _type_key = "tirx.EQ"; }; /*! @@ -323,7 +323,7 @@ class EQ : public PrimExpr { /*! \brief a != b */ class NENode : public CmpOpNode { public: - static constexpr const char* _type_key = "tir.NE"; + static constexpr const char* _type_key = "tirx.NE"; }; /*! @@ -340,7 +340,7 @@ class NE : public PrimExpr { /*! \brief a < b */ class LTNode : public CmpOpNode { public: - static constexpr const char* _type_key = "tir.LT"; + static constexpr const char* _type_key = "tirx.LT"; }; /*! @@ -357,7 +357,7 @@ class LT : public PrimExpr { /*! \brief a <= b */ struct LENode : public CmpOpNode { public: - static constexpr const char* _type_key = "tir.LE"; + static constexpr const char* _type_key = "tirx.LE"; }; /*! @@ -374,7 +374,7 @@ class LE : public PrimExpr { /*! \brief a > b */ class GTNode : public CmpOpNode { public: - static constexpr const char* _type_key = "tir.GT"; + static constexpr const char* _type_key = "tirx.GT"; }; /*! @@ -391,7 +391,7 @@ class GT : public PrimExpr { /*! \brief a >= b */ class GENode : public CmpOpNode { public: - static constexpr const char* _type_key = "tir.GE"; + static constexpr const char* _type_key = "tirx.GE"; }; /*! @@ -417,7 +417,7 @@ class AndNode : public PrimExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("a", &AndNode::a).def_ro("b", &AndNode::b); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.And", AndNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.And", AndNode, PrimExprNode); }; /*! @@ -443,7 +443,7 @@ class OrNode : public PrimExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("a", &OrNode::a).def_ro("b", &OrNode::b); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Or", OrNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Or", OrNode, PrimExprNode); }; /*! @@ -467,7 +467,7 @@ class NotNode : public PrimExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("a", &NotNode::a); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Not", NotNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Not", NotNode, PrimExprNode); }; /*! @@ -504,7 +504,7 @@ class SelectNode : public PrimExprNode { .def_ro("true_value", &SelectNode::true_value) .def_ro("false_value", &SelectNode::false_value); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Select", SelectNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Select", SelectNode, PrimExprNode); }; /*! @@ -545,7 +545,7 @@ class BufferLoadNode : public PrimExprNode { .def_ro("indices", &BufferLoadNode::indices) .def_ro("predicate", &BufferLoadNode::predicate); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BufferLoad", BufferLoadNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.BufferLoad", BufferLoadNode, PrimExprNode); private: /*! \brief Set the dtype based on the buffer/indices @@ -598,7 +598,7 @@ class ProducerLoadNode : public PrimExprNode { .def_ro("producer", &ProducerLoadNode::producer) .def_ro("indices", &ProducerLoadNode::indices); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.ProducerLoad", ProducerLoadNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.ProducerLoad", ProducerLoadNode, PrimExprNode); }; /*! @@ -639,7 +639,7 @@ class RampNode : public PrimExprNode { .def_ro("stride", &RampNode::stride) .def_ro("lanes", &RampNode::lanes); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Ramp", RampNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Ramp", RampNode, PrimExprNode); }; /*! @@ -667,7 +667,7 @@ class BroadcastNode : public PrimExprNode { .def_ro("value", &BroadcastNode::value) .def_ro("lanes", &BroadcastNode::lanes); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Broadcast", BroadcastNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Broadcast", BroadcastNode, PrimExprNode); }; /*! @@ -700,7 +700,7 @@ class LetNode : public PrimExprNode { .def_ro("value", &LetNode::value) .def_ro("body", &LetNode::body); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Let", LetNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Let", LetNode, PrimExprNode); }; /*! @@ -734,7 +734,7 @@ class CallNode : public PrimExprNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("op", &CallNode::op).def_ro("args", &CallNode::args); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Call", CallNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Call", CallNode, PrimExprNode); }; /*! @@ -766,7 +766,7 @@ class ShuffleNode : public PrimExprNode { .def_ro("vectors", &ShuffleNode::vectors) .def_ro("indices", &ShuffleNode::indices); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Shuffle", ShuffleNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Shuffle", ShuffleNode, PrimExprNode); }; /*! @@ -821,7 +821,7 @@ class CommReducerNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.CommReducer", CommReducerNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.CommReducer", CommReducerNode, Object); }; /*! @@ -865,7 +865,7 @@ class ReduceNode : public PrimExprNode { .def_ro("condition", &ReduceNode::condition) .def_ro("value_index", &ReduceNode::value_index); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Reduce", ReduceNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Reduce", ReduceNode, PrimExprNode); }; /*! @@ -898,18 +898,18 @@ inline std::unordered_map as_unordered_map(const ffi::Map& dmap) { } return ret; } -} // namespace tir +} // namespace tirx namespace ffi { template <> -inline constexpr bool use_default_type_traits_v = false; +inline constexpr bool use_default_type_traits_v = false; template <> -struct TypeTraits - : public ObjectRefWithFallbackTraitsBase { - TVM_FFI_INLINE static tvm::tir::StringImm ConvertFallbackValue(ffi::String value) { - return tvm::tir::StringImm(value); +struct TypeTraits + : public ObjectRefWithFallbackTraitsBase { + TVM_FFI_INLINE static tvm::tirx::StringImm ConvertFallbackValue(ffi::String value) { + return tvm::tirx::StringImm(value); } }; } // namespace ffi @@ -917,6 +917,6 @@ struct TypeTraits namespace std { template <> -struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {}; +struct hash<::tvm::tirx::IterVar> : public ::tvm::ObjectPtrHash {}; } // namespace std #endif // TVM_TIR_EXPR_H_ diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tirx/expr_functor.h similarity index 98% rename from include/tvm/tir/expr_functor.h rename to include/tvm/tirx/expr_functor.h index 02cb74a5d87b..78e80769323f 100644 --- a/include/tvm/tir/expr_functor.h +++ b/include/tvm/tirx/expr_functor.h @@ -18,20 +18,20 @@ */ /*! - * \file tvm/tir/expr_functor.h + * \file tvm/tirx/expr_functor.h * - * \brief Functors for tir expressions. + * \brief Functors for tirx expressions. */ #ifndef TVM_TIR_EXPR_FUNCTOR_H_ #define TVM_TIR_EXPR_FUNCTOR_H_ #include -#include +#include #include namespace tvm { -namespace tir { +namespace tirx { /*! * \brief A dynamical functor that dispatches on in the first Expr argument. @@ -44,7 +44,7 @@ namespace tir { * \code * // A functor that set variable to b. and calculate results. * class MyExprFunctor - * : public tir::ExprFunctor { + * : public tirx::ExprFunctor { * public: * int VisitExpr_(const Variable* op, int b) final { * return b; @@ -292,6 +292,6 @@ class TVM_DLL ExprMutator : protected ExprFunctor { PrimExpr VisitExpr_(const StringImmNode* op) override; }; -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_EXPR_FUNCTOR_H_ diff --git a/include/tvm/tir/function.h b/include/tvm/tirx/function.h similarity index 90% rename from include/tvm/tir/function.h rename to include/tvm/tirx/function.h index 31c6e3bc5ccd..0c98deb8b309 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tirx/function.h @@ -18,7 +18,7 @@ */ /*! - * \file tvm/tir/function.h + * \file tvm/tirx/function.h * \brief TIR Function. */ #ifndef TVM_TIR_FUNCTION_H_ @@ -29,14 +29,14 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Primitive functions that contains TIR statements. @@ -49,7 +49,7 @@ namespace tir { class PrimFuncNode : public BaseFuncNode { public: /*! \brief Function parameters */ - ffi::Array params; + ffi::Array params; /*! \brief The return type of the function. */ Type ret_type; /*! @@ -97,9 +97,9 @@ class PrimFuncNode : public BaseFuncNode { * all usage in the body of the function is done through a * flattened alias of the buffer. */ - ffi::Map buffer_map; + ffi::Map buffer_map; /*! \brief The body of the function */ - tir::Stmt body; + tirx::Stmt body; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -120,7 +120,7 @@ class PrimFuncNode : public BaseFuncNode { TVM_DLL FuncType func_type_annotation() const; TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.PrimFunc", PrimFuncNode, BaseFuncNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.PrimFunc", PrimFuncNode, BaseFuncNode); }; /*! @@ -147,8 +147,8 @@ class PrimFunc : public BaseFunc { * * \param span The location of this object in the source code. */ - TVM_DLL PrimFunc(ffi::Array params, Stmt body, Type ret_type = VoidType(), - ffi::Map buffer_map = ffi::Map(), + TVM_DLL PrimFunc(ffi::Array params, Stmt body, Type ret_type = VoidType(), + ffi::Map buffer_map = ffi::Map(), DictAttrs attrs = DictAttrs(), Span span = Span()); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrimFunc, BaseFunc, PrimFuncNode); @@ -171,7 +171,7 @@ class TensorIntrinNode : public Object { .def_ro("desc", &TensorIntrinNode::desc) .def_ro("impl", &TensorIntrinNode::impl); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.TensorIntrin", TensorIntrinNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.TensorIntrin", TensorIntrinNode, Object); }; /*! @@ -233,7 +233,7 @@ class TensorIntrin : public ObjectRef { * * \code{.py} * a, _, m, n = mem_copy.params - * func = mem_copy.specialize({a: tir.decl_buffer((16, 16))}) + * func = mem_copy.specialize({a: tirx.decl_buffer((16, 16))}) * # or * func = mem_copy.specialize({n: 16, m: 16}) * \endcode @@ -300,18 +300,18 @@ namespace attr { * CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES attribute in * CUDA. * - * Defined as "tir.use_dyn_shared_memory". + * Defined as "tirx.use_dyn_shared_memory". * * \sa tvm::CallingConv::kDeviceKernelLaunch */ -constexpr const char* kKernelLaunchParams = "tir.kernel_launch_params"; +constexpr const char* kKernelLaunchParams = "tirx.kernel_launch_params"; /*! * \brief Whether to set noalias rule on the function arguments. * * Type: Integer */ -constexpr const char* kNoAlias = "tir.noalias"; +constexpr const char* kNoAlias = "tirx.noalias"; /*! * \brief Mark the function as the entry function of @@ -321,30 +321,30 @@ constexpr const char* kNoAlias = "tir.noalias"; * * \note There can only be one entry function per module. */ -constexpr const char* kIsEntryFunc = "tir.is_entry_func"; +constexpr const char* kIsEntryFunc = "tirx.is_entry_func"; /*! * \brief Mark the function as the global function called from the host. * * Type: Integer */ -constexpr const char* kIsGlobalFunc = "tir.is_global_func"; +constexpr const char* kIsGlobalFunc = "tirx.is_global_func"; /*! * \brief Mark the function as run on the host, mutually exclusive with kTarget. * * Type: Integer */ -constexpr const char* kIsHostFunc = "tir.is_host_func"; +constexpr const char* kIsHostFunc = "tirx.is_host_func"; /*! * \brief Mark the function as scheduled, so the default schedule will pass will skip it. * * Type: Integer */ -constexpr const char* kIsScheduled = "tir.is_scheduled"; +constexpr const char* kIsScheduled = "tirx.is_scheduled"; } // namespace attr -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_FUNCTION_H_ diff --git a/include/tvm/tir/index_map.h b/include/tvm/tirx/index_map.h similarity index 97% rename from include/tvm/tir/index_map.h rename to include/tvm/tirx/index_map.h index c4e716cbe166..61134ec53a3f 100644 --- a/include/tvm/tir/index_map.h +++ b/include/tvm/tirx/index_map.h @@ -18,10 +18,10 @@ */ /*! - * \file tvm/tir/index_map.h + * \file tvm/tirx/index_map.h * \brief Defines a remapping of buffer indices * - * For use with tvm::tir::Buffer. + * For use with tvm::tirx::Buffer. */ #ifndef TVM_TIR_INDEX_MAP_H_ #define TVM_TIR_INDEX_MAP_H_ @@ -30,7 +30,7 @@ #include #include #include -#include +#include #include @@ -41,7 +41,7 @@ class Analyzer; } // namespace tvm namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Defines a mapping between two representations of indices @@ -164,7 +164,7 @@ class IndexMapNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.IndexMap", IndexMapNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.IndexMap", IndexMapNode, Object); }; class IndexMap : public ObjectRef { @@ -232,7 +232,7 @@ class IndexMap : public ObjectRef { IndexMap Substitute(const IndexMap& index_map, std::function(const Var& var)> f_subst); -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_INDEX_MAP_H_ diff --git a/include/tvm/tir/op.h b/include/tvm/tirx/op.h similarity index 91% rename from include/tvm/tir/op.h rename to include/tvm/tirx/op.h index 59f04e76a3de..66d9d932b3fa 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tirx/op.h @@ -18,7 +18,7 @@ */ /*! - * \file tvm/tir/op.h + * \file tvm/tirx/op.h * \brief Common operators defined for Expr. * * \note Most of the operator defined here perform simple constant folding @@ -31,9 +31,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include @@ -42,13 +42,13 @@ namespace tvm { #define TVM_TIR_REGISTER_OP(OpName) \ - TVM_REGISTER_OP("tir." OpName).set_attr("TScriptPrinterName", OpName) + TVM_REGISTER_OP("tirx." OpName).set_attr("TScriptPrinterName", OpName) // Most common operators can be overloaded by argument type(PrimExpr). // So we put them under the root namespace. // -// We put more developer oriented APIs -- make_const and is_const under tir -// as they are more specific to the tir namespace. +// We put more developer oriented APIs -- make_const and is_const under tirx +// as they are more specific to the tirx namespace. /*! * \brief Get the type of the expression under the unified type system. @@ -580,8 +580,8 @@ TVM_DLL PrimExpr isinf(PrimExpr x, Span span = Span()); * \param span The location of this operation in the source. * \return The result. */ -TVM_DLL PrimExpr sum(PrimExpr source, ffi::Array axis, ffi::Array init = {}, - Span span = Span()); +TVM_DLL PrimExpr sum(PrimExpr source, ffi::Array axis, + ffi::Array init = {}, Span span = Span()); /*! * \brief logical And of source expression over axis @@ -590,8 +590,8 @@ TVM_DLL PrimExpr sum(PrimExpr source, ffi::Array axis, ffi::Array< * \param init The value with which to initialize the output. * \param span The location of this operation in the source. */ -TVM_DLL PrimExpr all(PrimExpr source, ffi::Array axis, ffi::Array init = {}, - Span span = Span()); +TVM_DLL PrimExpr all(PrimExpr source, ffi::Array axis, + ffi::Array init = {}, Span span = Span()); /*! * \brief logical Or of source expression over axis @@ -601,8 +601,8 @@ TVM_DLL PrimExpr all(PrimExpr source, ffi::Array axis, ffi::Array< * \param span The location of this operation in the source. * \return The result. */ -TVM_DLL PrimExpr any(PrimExpr source, ffi::Array axis, ffi::Array init = {}, - Span span = Span()); +TVM_DLL PrimExpr any(PrimExpr source, ffi::Array axis, + ffi::Array init = {}, Span span = Span()); /*! * \brief max of source expression over axis @@ -612,8 +612,8 @@ TVM_DLL PrimExpr any(PrimExpr source, ffi::Array axis, ffi::Array< * \param span The location of this operation in the source. * \return The result. */ -TVM_DLL PrimExpr max(PrimExpr source, ffi::Array axis, ffi::Array init = {}, - Span span = Span()); +TVM_DLL PrimExpr max(PrimExpr source, ffi::Array axis, + ffi::Array init = {}, Span span = Span()); /*! * \brief max of source expression over axis @@ -623,8 +623,8 @@ TVM_DLL PrimExpr max(PrimExpr source, ffi::Array axis, ffi::Array< * \param span The location of this operation in the source. * \return The result. */ -TVM_DLL PrimExpr min(PrimExpr source, ffi::Array axis, ffi::Array init = {}, - Span span = Span()); +TVM_DLL PrimExpr min(PrimExpr source, ffi::Array axis, + ffi::Array init = {}, Span span = Span()); /*! * \brief product of source expression over axis @@ -634,7 +634,7 @@ TVM_DLL PrimExpr min(PrimExpr source, ffi::Array axis, ffi::Array< * \param span The location of this operation in the source. * \return The result. */ -TVM_DLL PrimExpr prod(PrimExpr source, ffi::Array axis, +TVM_DLL PrimExpr prod(PrimExpr source, ffi::Array axis, ffi::Array init = {}, Span span = Span()); /*! @@ -722,23 +722,23 @@ TVM_DLL PrimExpr fast_erf_float_expr(PrimExpr arg, int bits); inline void CheckMathUnaryOpInputDType(const char* op_name, DataType dtype) { TVM_FFI_CHECK(dtype.is_float() || dtype.is_bfloat16(), TypeError) - << "tir." << op_name << " only supports floating-point inputs, but got " << dtype; + << "tirx." << op_name << " only supports floating-point inputs, but got " << dtype; } // Intrinsic operators -#define TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckInputDType) \ - inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ - static const Op& op = Op::Get("tir." #OpName); \ - CheckInputDType(#OpName, x.dtype()); \ - if (x.dtype().is_bfloat16()) { \ - DataType bf16_dtype = x.dtype(); \ - DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \ - PrimExpr x_fp32 = tir::Cast(fp32_dtype, {x}, span); \ - PrimExpr result_fp32 = tir::Call(fp32_dtype, op, {x_fp32}, span); \ - return tir::Cast(bf16_dtype, {result_fp32}, span); \ - } else { \ - return tir::Call(x.dtype(), op, {x}, span); \ - } \ +#define TVM_DECLARE_INTRIN_UNARY_WITH_CHECK(OpName, CheckInputDType) \ + inline PrimExpr OpName(PrimExpr x, Span span = Span()) { \ + static const Op& op = Op::Get("tirx." #OpName); \ + CheckInputDType(#OpName, x.dtype()); \ + if (x.dtype().is_bfloat16()) { \ + DataType bf16_dtype = x.dtype(); \ + DataType fp32_dtype(kDLFloat, 32, bf16_dtype.lanes()); \ + PrimExpr x_fp32 = tirx::Cast(fp32_dtype, {x}, span); \ + PrimExpr result_fp32 = tirx::Call(fp32_dtype, op, {x_fp32}, span); \ + return tirx::Cast(bf16_dtype, {result_fp32}, span); \ + } else { \ + return tirx::Call(x.dtype(), op, {x}, span); \ + } \ } #define TVM_DECLARE_INTRIN_UNARY(OpName) \ @@ -775,8 +775,8 @@ TVM_DECLARE_INTRIN_UNARY(clz); #define TVM_DECLARE_INTRIN_BINARY(OpName) \ inline PrimExpr OpName(PrimExpr x, PrimExpr y, Span span = Span()) { \ - static const Op& op = Op::Get("tir." #OpName); \ - return tir::Call(x.dtype(), op, {x, y}, span); \ + static const Op& op = Op::Get("tirx." #OpName); \ + return tirx::Call(x.dtype(), op, {x, y}, span); \ } TVM_DECLARE_INTRIN_BINARY(atan2); @@ -785,7 +785,7 @@ TVM_DECLARE_INTRIN_BINARY(copysign); TVM_DECLARE_INTRIN_BINARY(hypot); TVM_DECLARE_INTRIN_BINARY(ldexp); -namespace tir { +namespace tirx { /*! * \brief Check if type is a pointer to a runtime element type. @@ -847,7 +847,7 @@ inline PrimExpr const_false(int lanes = 1, Span span = Span()) { */ inline const int64_t* as_const_int(const PrimExpr& x) { if (!x.defined()) return nullptr; - if (const tir::IntImmNode* op = x.as()) { + if (const tirx::IntImmNode* op = x.as()) { return &(op->value); } @@ -867,7 +867,7 @@ inline bool is_const_int(const PrimExpr& x, int64_t value); * \param stmt The input statement * \return whether stmt is nop */ -inline bool is_no_op(const tir::Stmt& stmt); +inline bool is_no_op(const tirx::Stmt& stmt); /*! * \brief Check whether x is a constant integer 1 @@ -931,12 +931,13 @@ TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift); inline bool is_const_int(const PrimExpr& x) { return as_const_int(x); } inline bool is_const_number(const PrimExpr& x) { - if (x.as()) { + if (x.as()) { return true; - } else if (x.as()) { + } else if (x.as()) { return true; - } else if (const auto* op = x.as()) { - return (op->value->IsInstance() || op->value->IsInstance()); + } else if (const auto* op = x.as()) { + return (op->value->IsInstance() || + op->value->IsInstance()); } return false; } @@ -956,12 +957,12 @@ inline bool is_const_int(const PrimExpr& x, int64_t value) { return as_int && (*as_int == value); } -inline bool is_no_op(const tir::Stmt& stmt) { +inline bool is_no_op(const tirx::Stmt& stmt) { if (!stmt.defined()) return true; - if (const auto* op = stmt.as()) { + if (const auto* op = stmt.as()) { return is_const_int(op->value); } - if (const auto* op = stmt.as()) { + if (const auto* op = stmt.as()) { return op->seq.size() == 0; } return false; @@ -1008,11 +1009,11 @@ inline PrimExpr make_const(DataType t, ValueType value, Span span) { return MakeConstScalar(t, value, span); } else { if (t.is_fixed_length_vector()) { - return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span); + return tirx::Broadcast(MakeConstScalar(t.element_of(), value, span), t.lanes(), span); } else { PrimExpr lanes = - tir::Mul(tir::Call(DataType::Int(32), tir::builtin::vscale(), {}), t.vscale_factor()); - return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), lanes, span); + tirx::Mul(tirx::Call(DataType::Int(32), tirx::builtin::vscale(), {}), t.vscale_factor()); + return tirx::Broadcast(MakeConstScalar(t.element_of(), value, span), lanes, span); } } } @@ -1024,7 +1025,7 @@ inline PrimExpr make_zero(DataType t, Span span) { return make_const(t, 0, span); } -} // namespace tir +} // namespace tirx // additional const expression overloading #define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ @@ -1037,13 +1038,13 @@ inline PrimExpr make_zero(DataType t, Span span) { inline PrimExpr Name(const PrimExpr& a, float b) { return Name(a, PrimExpr(b)); } \ inline PrimExpr Name(float a, const PrimExpr& b) { return Name(PrimExpr(a), b); } \ inline PrimExpr Name(int a, const PrimExpr& b) { \ - return Name(tir::make_const(b.dtype(), a), b); \ + return Name(tirx::make_const(b.dtype(), a), b); \ } \ inline PrimExpr Name(const PrimExpr& a, int b) { \ - return Name(a, tir::make_const(a.dtype(), b)); \ + return Name(a, tirx::make_const(a.dtype(), b)); \ } \ inline PrimExpr Name(const PrimExpr& a, double b) { \ - return Name(a, tir::make_const(DataType::Float(64), b)); \ + return Name(a, tirx::make_const(DataType::Float(64), b)); \ } #define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD_SPANNED(Name) \ @@ -1054,13 +1055,13 @@ inline PrimExpr make_zero(DataType t, Span span) { return Name(PrimExpr(a), b, span); \ } \ inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \ - return Name(tir::make_const(b.dtype(), a), b, span); \ + return Name(tirx::make_const(b.dtype(), a), b, span); \ } \ inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \ - return Name(a, tir::make_const(a.dtype(), b), span); \ + return Name(a, tirx::make_const(a.dtype(), b), span); \ } \ inline PrimExpr Name(const PrimExpr& a, double b, Span span = Span()) { \ - return Name(a, tir::make_const(DataType::Float(64), b), span); \ + return Name(a, tirx::make_const(DataType::Float(64), b), span); \ } #define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ @@ -1075,18 +1076,18 @@ inline PrimExpr make_zero(DataType t, Span span) { return Name(PrimExpr(a), b, span); \ } -#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ - inline PrimExpr Name(const PrimExpr& a, int b) { \ - return Name(a, tir::make_const(a.dtype(), b)); \ - } \ - inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tir::make_const(b.dtype(), a), b); } +#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ + inline PrimExpr Name(const PrimExpr& a, int b) { \ + return Name(a, tirx::make_const(a.dtype(), b)); \ + } \ + inline PrimExpr Name(int a, const PrimExpr& b) { return Name(tirx::make_const(b.dtype(), a), b); } #define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD_SPANNED(Name) \ inline PrimExpr Name(const PrimExpr& a, int b, Span span = Span()) { \ - return Name(a, tir::make_const(a.dtype(), b), span); \ + return Name(a, tirx::make_const(a.dtype(), b), span); \ } \ inline PrimExpr Name(int a, const PrimExpr& b, Span span = Span()) { \ - return Name(tir::make_const(b.dtype(), a), b, span); \ + return Name(tirx::make_const(b.dtype(), a), b, span); \ } TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+); @@ -1146,7 +1147,7 @@ inline void DivAmbiguityError(const TA& a) { "please call div, indexdiv/indexmod, " "floordiv/floormod or truncdiv/truncmod directly " "to avoid ambiguity in the code. " - "Checkout these functions in tir/op.h."); + "Checkout these functions in tirx/op.h."); } // The following code are not intended to be used in the codebase. diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tirx/op_attr_types.h similarity index 96% rename from include/tvm/tir/op_attr_types.h rename to include/tvm/tirx/op_attr_types.h index e9727f7ab3d8..9d0173bfd49f 100644 --- a/include/tvm/tir/op_attr_types.h +++ b/include/tvm/tirx/op_attr_types.h @@ -18,7 +18,7 @@ */ /*! - * \file tvm/tir/op_attr_types.h + * \file tvm/tirx/op_attr_types.h * \brief Attribute types in the Op registry for TIR ops. * * These attributes can be set via OpRegEntry::set_attr @@ -35,7 +35,7 @@ #include namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Global symbol of the op after lowering. */ @@ -52,7 +52,7 @@ using TVectorizable = bool; using FLowerIntrinsic = ffi::TypedFunction; /*! - * \brief The legalization function for given tir op. + * \brief The legalization function for given tirx op. */ using FLegalize = ffi::TypedFunction; @@ -151,6 +151,6 @@ inline std::ostream& operator<<(std::ostream& os, CallEffectKind side_effect) { /*! \brief Use integer to record the kind. */ using TCallEffectKind = Integer; -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_OP_ATTR_TYPES_H_ diff --git a/include/tvm/tir/stmt.h b/include/tvm/tirx/stmt.h similarity index 95% rename from include/tvm/tir/stmt.h rename to include/tvm/tirx/stmt.h index 2888a26da3f9..c191c4e6bf1f 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tirx/stmt.h @@ -17,7 +17,7 @@ * under the License. */ /*! - * \file tvm/tir/stmt.h + * \file tvm/tirx/stmt.h * \brief TIR statements. */ // Acknowledgement: Many low-level stmts originate from Halide. @@ -26,7 +26,7 @@ #include #include -#include +#include #include #include @@ -34,7 +34,7 @@ #include namespace tvm { -namespace tir { +namespace tirx { /*! \brief Base node of all statements. */ class StmtNode : public Object { @@ -58,7 +58,7 @@ class StmtNode : public Object { static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; static constexpr const uint32_t _type_child_slots = 15; - TVM_FFI_DECLARE_OBJECT_INFO("tir.Stmt", StmtNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("tirx.Stmt", StmtNode, Object); }; /*! \brief Container of all statements */ @@ -87,7 +87,7 @@ class BindNode : public StmtNode { .def_ro("var", &BindNode::var, refl::AttachFieldFlag::SEqHashDef()) .def_ro("value", &BindNode::value); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Bind", BindNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Bind", BindNode, StmtNode); }; /*! @@ -131,7 +131,7 @@ class AttrStmtNode : public StmtNode { .def_ro("value", &AttrStmtNode::value) .def_ro("body", &AttrStmtNode::body); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.AttrStmt", AttrStmtNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.AttrStmt", AttrStmtNode, StmtNode); }; /*! @@ -172,7 +172,7 @@ class AssertStmtNode : public StmtNode { .def_ro("error_kind", &AssertStmtNode::error_kind) .def_ro("message_parts", &AssertStmtNode::message_parts); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.AssertStmt", AssertStmtNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.AssertStmt", AssertStmtNode, StmtNode); }; /*! @@ -217,7 +217,7 @@ class BufferStoreNode : public StmtNode { .def_ro("indices", &BufferStoreNode::indices) .def_ro("predicate", &BufferStoreNode::predicate); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BufferStore", BufferStoreNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.BufferStore", BufferStoreNode, StmtNode); }; /*! @@ -244,7 +244,7 @@ class DeclBufferNode : public StmtNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("buffer", &DeclBufferNode::buffer); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.DeclBuffer", DeclBufferNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.DeclBuffer", DeclBufferNode, StmtNode); }; /*! \brief Managed reference to DeclBufferNode */ @@ -274,7 +274,7 @@ class AllocBufferNode : public StmtNode { .def_ro("buffer", &AllocBufferNode::buffer, refl::AttachFieldFlag::SEqHashDef()) .def_ro("annotations", &AllocBufferNode::annotations); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.AllocBuffer", AllocBufferNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.AllocBuffer", AllocBufferNode, StmtNode); }; /*! \brief Managed reference to AllocBufferNode */ @@ -324,7 +324,7 @@ class SeqStmtNode : public StmtNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("seq", &SeqStmtNode::seq); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SeqStmt", SeqStmtNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.SeqStmt", SeqStmtNode, StmtNode); }; /*! @@ -342,7 +342,7 @@ class EvaluateNode : public StmtNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("value", &EvaluateNode::value); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Evaluate", EvaluateNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.Evaluate", EvaluateNode, StmtNode); }; /*! @@ -529,7 +529,7 @@ class IfThenElseNode : public StmtNode { .def_ro("then_case", &IfThenElseNode::then_case) .def_ro("else_case", &IfThenElseNode::else_case); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.IfThenElse", IfThenElseNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.IfThenElse", IfThenElseNode, StmtNode); }; /*! @@ -630,7 +630,7 @@ class ForNode : public StmtNode { /*! \brief Check it is a loop without nontrivial loop step. */ bool HasTrivialStep() const; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.For", ForNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.For", ForNode, StmtNode); }; /*! @@ -671,7 +671,7 @@ class WhileNode : public StmtNode { .def_ro("condition", &WhileNode::condition) .def_ro("body", &WhileNode::body); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.While", WhileNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.While", WhileNode, StmtNode); }; /*! @@ -706,7 +706,7 @@ class BufferRegionNode : public PrimExprConvertibleNode { TVM_DLL PrimExpr ToPrimExpr() const final; static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BufferRegion", BufferRegionNode, PrimExprConvertibleNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.BufferRegion", BufferRegionNode, PrimExprConvertibleNode); }; /*! @@ -760,7 +760,7 @@ class MatchBufferRegionNode : public Object { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.MatchBufferRegion", MatchBufferRegionNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.MatchBufferRegion", MatchBufferRegionNode, Object); }; /*! @@ -836,7 +836,7 @@ class SBlockNode : public StmtNode { .def_ro("init", &SBlockNode::init) .def_ro("body", &SBlockNode::body); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SBlock", SBlockNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.SBlock", SBlockNode, StmtNode); }; /*! @@ -880,7 +880,7 @@ class SBlockRealizeNode : public StmtNode { .def_ro("predicate", &SBlockRealizeNode::predicate) .def_ro("block", &SBlockRealizeNode::block); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SBlockRealize", SBlockRealizeNode, StmtNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.SBlockRealize", SBlockRealizeNode, StmtNode); }; /*! @@ -934,7 +934,7 @@ constexpr const char* storage_alignment = "storage_alignment"; /*! \brief Mark launching extent of thread, used by device API. */ constexpr const char* thread_extent = "thread_extent"; /*! \brief Annotation key on AllocBuffer marking the allocation as volatile. */ -constexpr const char* kVolatile = "tir.volatile"; +constexpr const char* kVolatile = "tirx.volatile"; /*! * \brief Check if attr_key is a pragma key extension @@ -975,6 +975,6 @@ inline const char* ForKind2String(ForKind t) { TVM_FFI_UNREACHABLE(); } -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_STMT_H_ diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tirx/stmt_functor.h similarity index 98% rename from include/tvm/tir/stmt_functor.h rename to include/tvm/tirx/stmt_functor.h index d99cdfb84e59..d883c3f6db4c 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tirx/stmt_functor.h @@ -18,25 +18,25 @@ */ /*! - * \file tvm/tir/stmt_functor.h + * \file tvm/tirx/stmt_functor.h * - * \brief Functors for tir stmts + * \brief Functors for tirx stmts * utility functions to call common functors. */ #ifndef TVM_TIR_STMT_FUNCTOR_H_ #define TVM_TIR_STMT_FUNCTOR_H_ #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Same as ExprFunctor except it is applied on statements * \tparam FType The function signature. @@ -582,7 +582,7 @@ bool ContainsNode(const Stmt& stmt) { return visitor.contains_node; } -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_STMT_FUNCTOR_H_ diff --git a/include/tvm/tir/transform.h b/include/tvm/tirx/transform.h similarity index 97% rename from include/tvm/tir/transform.h rename to include/tvm/tirx/transform.h index 83a103ac8570..4d1267e97bb9 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tirx/transform.h @@ -18,7 +18,7 @@ */ /*! - * \file tvm/tir/transform.h + * \file tvm/tirx/transform.h * \brief TIR specific transformation passes. */ #ifndef TVM_TIR_TRANSFORM_H_ @@ -26,14 +26,14 @@ #include #include -#include -#include +#include +#include #include #include namespace tvm { -namespace tir { +namespace tirx { namespace transform { using tvm::transform::CreateModulePass; @@ -103,10 +103,10 @@ TVM_DLL Pass Simplify(); /*! * \brief Convert an IRModule to be SSA form. * - * This pass handles cases where the same tir::Var appears in + * This pass handles cases where the same tirx::Var appears in * multiple functions within the same module. For example, after * extracting a fragment from one function into another, where the - * same `tir::Var` may be defined both as within the body of the + * same `tirx::Var` may be defined both as within the body of the * original function, and as a parameter within the hoisted function. * * \return The pass. @@ -355,7 +355,7 @@ TVM_DLL Pass Filter(ffi::TypedFunction fcond); * \return The pass. */ } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_TRANSFORM_H_ diff --git a/include/tvm/tir/var.h b/include/tvm/tirx/var.h similarity index 94% rename from include/tvm/tir/var.h rename to include/tvm/tirx/var.h index b4106f2d2e9f..7da5cb31c152 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tirx/var.h @@ -18,7 +18,7 @@ */ /*! - * \file tvm/tir/var.h + * \file tvm/tirx/var.h * \brief Variables in the TIR. */ #ifndef TVM_TIR_VAR_H_ @@ -31,7 +31,7 @@ #include namespace tvm { -namespace tir { +namespace tirx { /*! * \brief A variable node in the IR. @@ -69,7 +69,7 @@ class VarNode : public PrimExprNode { static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindFreeVar; static constexpr const uint32_t _type_child_slots = 1; - TVM_FFI_DECLARE_OBJECT_INFO("tir.Var", VarNode, PrimExprNode); + TVM_FFI_DECLARE_OBJECT_INFO("tirx.Var", VarNode, PrimExprNode); }; /*! \brief a named variable in TIR */ @@ -135,7 +135,7 @@ class SizeVarNode : public VarNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef(); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SizeVar", SizeVarNode, VarNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.SizeVar", SizeVarNode, VarNode); }; /*! \brief a named variable represents a tensor index size */ @@ -284,7 +284,7 @@ class IterVarNode : public PrimExprConvertibleNode { } static constexpr TVMFFISEqHashKind _type_s_eq_hash_kind = kTVMFFISEqHashKindTreeNode; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.IterVar", IterVarNode, PrimExprConvertibleNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.IterVar", IterVarNode, PrimExprConvertibleNode); }; /*! @@ -332,10 +332,10 @@ inline const char* IterVarType2String(IterVarType t) { } return "Unknown"; } -} // namespace tir +} // namespace tirx } // namespace tvm -/* \brief Allow tir.Var as key in STL tables +/* \brief Allow tirx.Var as key in STL tables * * For most TIR expressions, it would be ambiguous whether the * expression should follow reference equality or structural equality. @@ -344,21 +344,21 @@ inline const char* IterVarType2String(IterVarType t) { * tables. * * Providing `std::hash` and `std::equal_to` specializations for - * `tir::Var` allows it to be used as a key in STL tables. For + * `tirx::Var` allows it to be used as a key in STL tables. For * `PrimExpr`, the user must specify the type of equality used * (e.g. `std::unordered_set` or * `std::unordered_set`). */ template <> -struct std::hash { - std::size_t operator()(const tvm::tir::Var& var) const { +struct std::hash { + std::size_t operator()(const tvm::tirx::Var& var) const { return tvm::runtime::ObjectPtrHash()(var); } }; template <> -struct std::equal_to { - bool operator()(const tvm::tir::Var& var_a, const tvm::tir::Var& var_b) const { +struct std::equal_to { + bool operator()(const tvm::tirx::Var& var_a, const tvm::tirx::Var& var_b) const { return tvm::runtime::ObjectPtrEqual()(var_a, var_b); } }; diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h index 41a6ed6ca5c8..f8ef2edc39d1 100644 --- a/include/tvm/topi/broadcast.h +++ b/include/tvm/topi/broadcast.h @@ -56,39 +56,39 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, TVM_FFI_ICHECK_EQ(output_shape.size(), bh.common_shape.size()); ffi::Array oshape; for (size_t i = 0; i < output_shape.size(); ++i) { - if (output_shape[i].as() == nullptr) { + if (output_shape[i].as() == nullptr) { oshape.push_back(output_shape[i]); } else { TVM_FFI_ICHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i])); oshape.push_back(bh.common_shape[i]); } } - auto l = [&](tvm::ffi::Array ovars) { + auto l = [&](tvm::ffi::Array ovars) { return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars)); }; return tvm::te::compute(oshape, l, name, tag); } -#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ - inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \ - std::string name = "T_" #Name, std::string tag = kBroadcast) { \ - auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return detail::WithBroadcast(l, A, B, name, tag); \ - } \ - inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \ - std::string name = "T_" #Name, std::string tag = kElementWise) { \ - auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return tvm::te::compute( \ - A->shape, [&](const ::tvm::ffi::Array<::tvm::tir::Var>& i) { return l(A(i), B); }, name, \ - tag); \ - } \ - inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \ - std::string name = "T_" #Name, std::string tag = kElementWise) { \ - auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ - return tvm::te::compute( \ - B->shape, [&](const ::tvm::ffi::Array<::tvm::tir::Var>& i) { return l(A, B(i)); }, name, \ - tag); \ +#define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ + inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \ + std::string name = "T_" #Name, std::string tag = kBroadcast) { \ + auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return detail::WithBroadcast(l, A, B, name, tag); \ + } \ + inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::PrimExpr& B, \ + std::string name = "T_" #Name, std::string tag = kElementWise) { \ + auto l = [](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return tvm::te::compute( \ + A->shape, [&](const ::tvm::ffi::Array<::tvm::tirx::Var>& i) { return l(A(i), B); }, name, \ + tag); \ + } \ + inline tvm::te::Tensor Name(const tvm::PrimExpr& A, const tvm::te::Tensor& B, \ + std::string name = "T_" #Name, std::string tag = kElementWise) { \ + auto l = [&](tvm::PrimExpr a, tvm::PrimExpr b) { ComputeRule; }; \ + return tvm::te::compute( \ + B->shape, [&](const ::tvm::ffi::Array<::tvm::tirx::Var>& i) { return l(A, B(i)); }, name, \ + tag); \ } #define TOPI_DEFINE_OP_OVERLOAD(Name, OpName) \ diff --git a/include/tvm/topi/detail/broadcast.h b/include/tvm/topi/detail/broadcast.h index 9d9f78ba022e..9a4c7f9339ab 100644 --- a/include/tvm/topi/detail/broadcast.h +++ b/include/tvm/topi/detail/broadcast.h @@ -37,9 +37,9 @@ namespace detail { struct BroadcastHelper { std::deque common_shape; - std::deque all_vars; - std::deque vars1; - std::deque vars2; + std::deque all_vars; + std::deque vars1; + std::deque vars2; }; static inline DataType CommonType(DataType type1, DataType type2) { @@ -66,7 +66,7 @@ inline BroadcastHelper BroadcastShape(const tvm::ffi::Array& shap const IntImmNode* static_size2 = shape2[s2_size - i].as(); DataType common_type = CommonType(shape1[s1_size - i].dtype(), shape2[s2_size - i].dtype()); - bh.all_vars.push_front(tvm::tir::Var("dim", common_type)); + bh.all_vars.push_front(tvm::tirx::Var("dim", common_type)); if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) { bh.common_shape.push_front(cast_if_needed(common_type, shape1[s1_size - i])); bh.vars1.push_front(bh.all_vars[0]); @@ -104,7 +104,7 @@ inline BroadcastHelper BroadcastShape(const tvm::ffi::Array& shap auto& shape = (s1_size > s2_size) ? shape1 : shape2; auto& vars = (s1_size > s2_size) ? bh.vars1 : bh.vars2; for (; i <= max_size; ++i) { - bh.all_vars.push_front(tvm::tir::Var("v", shape[max_size - 1].dtype())); + bh.all_vars.push_front(tvm::tirx::Var("v", shape[max_size - 1].dtype())); bh.common_shape.push_front(shape[max_size - i]); vars.push_front(bh.all_vars[0]); } @@ -112,8 +112,8 @@ inline BroadcastHelper BroadcastShape(const tvm::ffi::Array& shap } inline tvm::ffi::Array InputIndexFromBroadcast( - const tvm::ffi::Array& ovars, const tvm::te::Tensor& T, - const std::deque& my_vars, const std::deque& all_vars) { + const tvm::ffi::Array& ovars, const tvm::te::Tensor& T, + const std::deque& my_vars, const std::deque& all_vars) { tvm::ffi::Array ivars; TVM_FFI_ICHECK_EQ(ovars.size(), all_vars.size()); // N^2, could use a map but NBD. @@ -130,7 +130,7 @@ inline tvm::ffi::Array InputIndexFromBroadcast( // Only inject 0 here if we have not yet reached the dimension of I // (i.e. this must be a 1) if (!found && (ovars.size() - i) <= expected_dims) { - ivars.push_back(tvm::tir::make_zero(ovars[i].dtype())); + ivars.push_back(tvm::tirx::make_zero(ovars[i].dtype())); } } TVM_FFI_ICHECK(expected_dims == ivars.size()); @@ -142,7 +142,7 @@ inline tvm::te::Tensor WithBroadcast(FBinaryExpr op, const tvm::te::Tensor& A, const tvm::te::Tensor& B, const std::string& name = "tensor", const std::string& tag = "") { auto bh = BroadcastShape(A->shape, B->shape); - auto l = [&](tvm::ffi::Array ovars) { + auto l = [&](tvm::ffi::Array ovars) { return op(A(InputIndexFromBroadcast(ovars, A, bh.vars1, bh.all_vars)), B(InputIndexFromBroadcast(ovars, B, bh.vars2, bh.all_vars))); }; diff --git a/include/tvm/topi/detail/constant_utils.h b/include/tvm/topi/detail/constant_utils.h index 5bcc64ba125c..a77177984734 100644 --- a/include/tvm/topi/detail/constant_utils.h +++ b/include/tvm/topi/detail/constant_utils.h @@ -26,8 +26,8 @@ #include #include -#include -#include +#include +#include #include #include @@ -45,7 +45,7 @@ using namespace tvm::te; * * \return true if the given expr is a constant int or uint, false otherwise. */ -inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance(); } +inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance(); } /*! * \brief Test whether the given Array has every element as constant integer. @@ -58,7 +58,7 @@ inline bool IsConstInt(PrimExpr expr) { return expr->IsInstance array) { bool is_const_int = true; for (auto const& elem : array) { - is_const_int &= !elem.defined() || elem->IsInstance(); + is_const_int &= !elem.defined() || elem->IsInstance(); } return is_const_int; } @@ -123,13 +123,13 @@ inline std::vector GetConstInt64Values(ffi::Array exprs, /*! * \brief Check whether the two expressions are equal or not, if not simplify the expressions and * check again - * \note This is stronger equality check than tvm::tir::Equal + * \note This is stronger equality check than tvm::tirx::Equal * \param lhs First expression * \param rhs Second expression * \return result True if both expressions are equal, else false */ inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) { - tvm::tir::ExprDeepEqual expr_equal; + tvm::tirx::ExprDeepEqual expr_equal; bool result = expr_equal(lhs, rhs); if (!result) { PrimExpr t = tvm::arith::Analyzer().Simplify(lhs - rhs); diff --git a/include/tvm/topi/detail/extern.h b/include/tvm/topi/detail/extern.h index 674cdabdab6d..14eb54c3ed65 100644 --- a/include/tvm/topi/detail/extern.h +++ b/include/tvm/topi/detail/extern.h @@ -25,7 +25,7 @@ #define TVM_TOPI_DETAIL_EXTERN_H_ #include -#include +#include #include #include @@ -70,15 +70,15 @@ inline ffi::Array make_extern(const ffi::Array>& ou ffi::Array input_placeholders; for (auto t : inputs) { - input_placeholders.push_back(tvm::tir::decl_buffer(t->shape, t->dtype, t->op->name)); + input_placeholders.push_back(tvm::tirx::decl_buffer(t->shape, t->dtype, t->op->name)); } ffi::Array output_placeholders; for (size_t i = 0; i < out_shapes.size(); ++i) { - output_placeholders.push_back(tvm::tir::decl_buffer(out_shapes[i], out_types[i], name)); + output_placeholders.push_back(tvm::tirx::decl_buffer(out_shapes[i], out_types[i], name)); } auto body = fextern(input_placeholders, output_placeholders); - auto body_stmt = tvm::tir::Evaluate(body); + auto body_stmt = tvm::tirx::Evaluate(body); auto op = ExternOp(name, tag, attrs, inputs, input_placeholders, output_placeholders, body_stmt); @@ -100,11 +100,11 @@ inline ffi::Array make_extern(const ffi::Array>& ou inline PrimExpr pack_buffer(Buffer buf) { TVM_FFI_ICHECK_GT(buf->shape.size(), 0) << "buf shape must have at least one element"; auto shape = - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), buf->shape); + tvm::tirx::Call(DataType::Handle(), tvm::tirx::builtin::tvm_stack_make_shape(), buf->shape); PrimExpr strides; if (buf->strides.size() > 0) { - strides = - tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_shape(), buf->strides); + strides = tvm::tirx::Call(DataType::Handle(), tvm::tirx::builtin::tvm_stack_make_shape(), + buf->strides); } else { strides = 0; } @@ -115,7 +115,7 @@ inline PrimExpr pack_buffer(Buffer buf) { make_const(DataType::Int(32), static_cast(buf->shape.size())), make_const(buf->dtype, 0), buf->elem_offset}; - return tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_stack_make_array(), pack_args); + return tvm::tirx::Call(DataType::Handle(), tvm::tirx::builtin::tvm_stack_make_array(), pack_args); } /*! @@ -128,7 +128,7 @@ inline PrimExpr pack_buffer(Buffer buf) { * \return An expression representing the invocation */ inline PrimExpr call_packed(ffi::Array args) { - return tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args); + return tvm::tirx::Call(DataType::Int(32), tvm::tirx::builtin::tvm_call_packed(), args); } } // namespace detail diff --git a/include/tvm/topi/detail/pad_utils.h b/include/tvm/topi/detail/pad_utils.h index dfb9542e7655..d5d672af685a 100644 --- a/include/tvm/topi/detail/pad_utils.h +++ b/include/tvm/topi/detail/pad_utils.h @@ -25,8 +25,8 @@ #define TVM_TOPI_DETAIL_PAD_UTILS_H_ #include -#include -#include +#include +#include #include diff --git a/include/tvm/topi/detail/strided_slice.h b/include/tvm/topi/detail/strided_slice.h index 9f59828a8f43..e70b1542d4a4 100644 --- a/include/tvm/topi/detail/strided_slice.h +++ b/include/tvm/topi/detail/strided_slice.h @@ -24,7 +24,7 @@ #ifndef TVM_TOPI_DETAIL_STRIDED_SLICE_H_ #define TVM_TOPI_DETAIL_STRIDED_SLICE_H_ -#include +#include #include #include @@ -141,7 +141,7 @@ inline ffi::Array StridedSliceOutputShape( << ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i; out_shape.Set(axes[i].IntValue(), cast(out_shape[i].dtype(), PrimExpr(slice_size))); } else { - out_shape.Set(axes[i].IntValue(), tvm::tir::Var("dim", out_shape[i]->dtype)); + out_shape.Set(axes[i].IntValue(), tvm::tirx::Var("dim", out_shape[i]->dtype)); } } diff --git a/include/tvm/topi/elemwise.h b/include/tvm/topi/elemwise.h index 67fa3ade1065..940b1f149ead 100644 --- a/include/tvm/topi/elemwise.h +++ b/include/tvm/topi/elemwise.h @@ -24,9 +24,9 @@ #ifndef TVM_TOPI_ELEMWISE_H_ #define TVM_TOPI_ELEMWISE_H_ -#include -#include -#include +#include +#include +#include #include #include @@ -212,8 +212,8 @@ inline Tensor sign(const Tensor& x, std::string name = "T_sign", std::string tag PrimExpr zero = make_zero(x->dtype); PrimExpr one = make_const(x->dtype, 1); PrimExpr minus_one = make_const(x->dtype, -1); - auto s1 = tvm::tir::Select((x(i) < zero), minus_one, zero); - auto s2 = tvm::tir::Select((x(i) > zero), one, s1); + auto s1 = tvm::tirx::Select((x(i) < zero), minus_one, zero); + auto s2 = tvm::tirx::Select((x(i) > zero), one, s1); return s2; }, name, tag); @@ -284,7 +284,7 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", if (expr.dtype().lanes() == type.lanes()) { return expr; } else if (expr.dtype().lanes() == 1 && type.is_vector()) { - return tvm::tir::Broadcast(expr, type.lanes()); + return tvm::tirx::Broadcast(expr, type.lanes()); } } diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 01b82cb3f648..979cb2148c63 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -26,8 +26,8 @@ #include #include -#include -#include +#include +#include #include #include #include @@ -56,8 +56,8 @@ inline tvm::te::Tensor relu(const tvm::te::Tensor& t, T threshold = static_cast< std::string name = "T_relu", std::string tag = kElementWise) { return tvm::te::compute( t->shape, - [&](const tvm::ffi::Array& i) { - auto threshold_const = tvm::tir::make_const(t->dtype, threshold); + [&](const tvm::ffi::Array& i) { + auto threshold_const = tvm::tirx::make_const(t->dtype, threshold); return tvm::max(t(i), threshold_const); }, name, tag); @@ -78,10 +78,10 @@ inline tvm::te::Tensor leaky_relu(const tvm::te::Tensor& t, double alpha = 0.1, std::string tag = kElementWise) { return tvm::te::compute( t->shape, - [&](const tvm::ffi::Array& i) { + [&](const tvm::ffi::Array& i) { auto value = t(i); - auto calpha = tvm::tir::make_const(value.dtype(), alpha); - return tvm::tir::Select(value > 0, value, value * calpha); + auto calpha = tvm::tirx::make_const(value.dtype(), alpha); + return tvm::tirx::Select(value > 0, value, value * calpha); }, name, tag); } @@ -107,9 +107,9 @@ inline tvm::te::Tensor prelu(const tvm::te::Tensor& x, const tvm::te::Tensor& sl return tvm::te::compute( x->shape, - [&](const tvm::ffi::Array& indices) { + [&](const tvm::ffi::Array& indices) { auto xval = x(indices); - return tvm::tir::Select(xval > 0, xval, xval * slope(indices[axis])); + return tvm::tirx::Select(xval > 0, xval, xval * slope(indices[axis])); }, name, tag); } @@ -194,10 +194,10 @@ inline tvm::te::Tensor pad( } if (!pad_value.defined()) { - pad_value = tvm::tir::make_const(t->dtype, 0); + pad_value = tvm::tirx::make_const(t->dtype, 0); } - auto l = [&](tvm::ffi::Array ovars) { + auto l = [&](tvm::ffi::Array ovars) { tvm::ffi::Array indices; tvm::ffi::Array sel; tvm::ffi::Array pad_idx; @@ -285,7 +285,7 @@ inline tvm::te::Tensor conv2d_nchw(const tvm::te::Tensor& I, const tvm::te::Tens auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw"); auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); - auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { + auto l = [&](tvm::tirx::Var b, tvm::tirx::Var o, tvm::tirx::Var h, tvm::tirx::Var w) { return tvm::sum(T(b, i, stride_h * h + kh, stride_w * w + kw) * W(o, i, kh, kw), {i, kh, kw}); }; return tvm::te::compute(output_shape, l, name, tag); @@ -328,7 +328,7 @@ inline tvm::te::Tensor conv2d_hwcn(const tvm::te::Tensor& I, const tvm::te::Tens auto kh = tvm::te::reduce_axis(tvm::Range{0, W->shape[0]}, "kh"); auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw"); auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {pad_h, pad_w}); - auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { + auto l = [&](tvm::tirx::Var b, tvm::tirx::Var o, tvm::tirx::Var h, tvm::tirx::Var w) { return tvm::sum(T(stride_h * h + kh, stride_w * w + kw, i, b) * W(kh, kw, i, o), {i, kh, kw}); }; return tvm::te::compute(output_shape, l, name, tag); @@ -375,7 +375,7 @@ inline tvm::te::Tensor depthwise_conv2d_nchw(const tvm::te::Tensor& I, const tvm auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[3]}, "kw"); auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); - auto l = [&](tvm::tir::Var b, tvm::tir::Var o, tvm::tir::Var h, tvm::tir::Var w) { + auto l = [&](tvm::tirx::Var b, tvm::tirx::Var o, tvm::tirx::Var h, tvm::tirx::Var w) { return tvm::sum(T(b, indexdiv(i, pCM), stride_h * h + kh, stride_w * w + kw) * W(indexdiv(i, pCM), indexmod(o, pCM), kh, kw), {i, kh, kw}); @@ -404,7 +404,7 @@ inline tvm::te::Tensor depthwise_conv2d_nhwc(const tvm::te::Tensor& I, const tvm auto kw = tvm::te::reduce_axis(tvm::Range{0, W->shape[1]}, "kw"); auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), pad_h, pad_w, tvm::PrimExpr(0)}); - auto l = [&](tvm::tir::Var b, tvm::tir::Var h, tvm::tir::Var w, tvm::tir::Var o) { + auto l = [&](tvm::tirx::Var b, tvm::tirx::Var h, tvm::tirx::Var w, tvm::tirx::Var o) { return tvm::sum(T(b, stride_h * h + kh, stride_w * w + kw, indexdiv(i, pCM)) * W(kh, kw, indexdiv(i, pCM), indexmod(o, pCM)), {kh, kw, i}); @@ -455,12 +455,12 @@ inline tvm::te::Tensor group_conv2d_ngchw(const tvm::te::Tensor& I, const tvm::t auto T = (pad_h == 0 && pad_w == 0) ? I : pad(I, {tvm::PrimExpr(0), tvm::PrimExpr(0), tvm::PrimExpr(0), pad_h, pad_w}); - auto l = [&](tvm::ffi::Array args) { - tvm::tir::Var b = args[0]; - tvm::tir::Var g = args[1]; - tvm::tir::Var o = args[2]; - tvm::tir::Var h = args[3]; - tvm::tir::Var w = args[4]; + auto l = [&](tvm::ffi::Array args) { + tvm::tirx::Var b = args[0]; + tvm::tirx::Var g = args[1]; + tvm::tirx::Var o = args[2]; + tvm::tirx::Var h = args[3]; + tvm::tirx::Var w = args[4]; return tvm::sum(I(b, g, i, stride_h * h + kh, stride_w * w + kw) * W(g, i, o, kh, kw), {i, kh, kw}); }; @@ -507,7 +507,7 @@ inline tvm::te::Tensor space_to_batch_nd(const tvm::te::Tensor& data, // pad the input with paddings provided if (!pad_value.defined()) { - pad_value = tvm::tir::make_const(data->dtype, 0); + pad_value = tvm::tirx::make_const(data->dtype, 0); } padded_t = pad(data, pad_before_int32, pad_after_int32, pad_value); @@ -666,19 +666,19 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T // prediction->shape = (C,), targets->shape = (), weights->shape = (C,) auto T = tvm::te::compute( {}, - [&](const tvm::ffi::Array& target_indices) { + [&](const tvm::ffi::Array& target_indices) { auto c = targets(); - return tvm::tir::Select(c != ignore_index, -predictions(c) * weights(c), - tvm::tir::make_const(predictions->dtype, 0)); + return tvm::tirx::Select(c != ignore_index, -predictions(c) * weights(c), + tvm::tirx::make_const(predictions->dtype, 0)); }, name, tag); if (reduction == "mean") { auto W = tvm::te::compute( {}, - [&](const tvm::ffi::Array& target_indices) { + [&](const tvm::ffi::Array& target_indices) { auto c = targets(); - return tvm::tir::Select(c != ignore_index, weights(c), - tvm::tir::make_const(predictions->dtype, 0)); + return tvm::tirx::Select(c != ignore_index, weights(c), + tvm::tirx::make_const(predictions->dtype, 0)); }, name, tag); return topi::divide(T, W); @@ -688,7 +688,7 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T } auto T = tvm::te::compute( targets->shape, - [&](const tvm::ffi::Array& target_indices) { + [&](const tvm::ffi::Array& target_indices) { auto c = targets(target_indices); tvm::ffi::Array pred_indices; pred_indices.push_back(target_indices[0]); // batch index @@ -696,18 +696,18 @@ inline Tensor nll_loss(const Tensor& predictions, const Tensor& targets, const T for (size_t i = 1; i < target_indices.size(); i++) { pred_indices.push_back(target_indices[i]); // indices for multidimensional loss } - return tvm::tir::Select(c != ignore_index, -predictions(pred_indices) * weights(c), - tvm::tir::make_const(predictions->dtype, 0)); + return tvm::tirx::Select(c != ignore_index, -predictions(pred_indices) * weights(c), + tvm::tirx::make_const(predictions->dtype, 0)); }, name, tag); TVM_FFI_ICHECK(T->shape.size() != 0); if (reduction == "mean") { auto W = tvm::te::compute( targets->shape, - [&](const tvm::ffi::Array& target_indices) { + [&](const tvm::ffi::Array& target_indices) { auto c = targets(target_indices); - return tvm::tir::Select(c != ignore_index, weights(c), - tvm::tir::make_const(predictions->dtype, 0)); + return tvm::tirx::Select(c != ignore_index, weights(c), + tvm::tirx::make_const(predictions->dtype, 0)); }, name, tag); return topi::divide(topi::sum(T, tvm::ffi::Array(nullptr)), diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h index 970d74b1c612..69e9aae4840e 100644 --- a/include/tvm/topi/nn/pooling.h +++ b/include/tvm/topi/nn/pooling.h @@ -144,17 +144,17 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh); out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww); - PrimExpr out_idx_lower_h = tir::Select( + PrimExpr out_idx_lower_h = tirx::Select( pad_inds[height_axis] < kernel_height, make_const(pad_inds[height_axis].dtype(), 0), (pad_inds[height_axis] - kernel_height) / stride_height + 1); - PrimExpr out_idx_lower_w = tir::Select( + PrimExpr out_idx_lower_w = tirx::Select( pad_inds[width_axis] < kernel_width, make_const(pad_inds[width_axis].dtype(), 0), (pad_inds[width_axis] - kernel_width) / stride_width + 1); return tvm::sum( - tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h, - out_idx[width_axis] >= out_idx_lower_w), - mp_inds(out_idx) == idx), + tvm::if_then_else(tirx::And(tirx::And(out_idx[height_axis] >= out_idx_lower_h, + out_idx[width_axis] >= out_idx_lower_w), + mp_inds(out_idx) == idx), out_grad(out_idx), make_const(x->dtype, 0)), {windowh, windoww}); }, @@ -176,11 +176,11 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww)); PrimExpr out_idx_lower_h = - tir::Select(pad_h_idx < kernel_height, make_const(pad_h_idx.dtype(), 0), - (pad_h_idx - kernel_height) / stride_height + 1); + tirx::Select(pad_h_idx < kernel_height, make_const(pad_h_idx.dtype(), 0), + (pad_h_idx - kernel_height) / stride_height + 1); PrimExpr out_idx_lower_w = - tir::Select(pad_w_idx < kernel_width, make_const(pad_w_idx.dtype(), 0), - (pad_w_idx - kernel_width) / stride_width + 1); + tirx::Select(pad_w_idx < kernel_width, make_const(pad_w_idx.dtype(), 0), + (pad_w_idx - kernel_width) / stride_width + 1); PrimExpr divide_factor; // number of pooled elements if (count_include_pad) { @@ -197,10 +197,10 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, max((h_end - h_start) * (w_end - w_start), make_const(h_end.dtype(), 1)); } return tvm::sum( - tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h, - out_idx[height_axis] < out_height), - tir::And(out_idx[width_axis] >= out_idx_lower_w, - out_idx[width_axis] < out_width)), + tvm::if_then_else(tirx::And(tirx::And(out_idx[height_axis] >= out_idx_lower_h, + out_idx[height_axis] < out_height), + tirx::And(out_idx[width_axis] >= out_idx_lower_w, + out_idx[width_axis] < out_width)), out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)), {windowh, windoww}); }, @@ -310,7 +310,7 @@ inline PrimExpr start_index(const Var& out_index, const PrimExpr& odim, const Pr inline PrimExpr end_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) { PrimExpr tmp = indexdiv((out_index + 1) * idim, odim); - return tvm::tir::Select(indexmod((out_index + 1) * idim, odim) == 0, tmp, tmp + 1); + return tvm::tirx::Select(indexmod((out_index + 1) * idim, odim) == 0, tmp, tmp + 1); } /*! @@ -340,7 +340,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const ffi::Array& ou auto get_iter_vars = [=](const ffi::Array& output, bool reduce_indices) { ffi::Array indices; for (size_t i = 0; i < output.size(); ++i) indices.push_back(output[i]); - ffi::Array reduce_axes; + ffi::Array reduce_axes; for (size_t i = 0; i < n_dim; ++i) { auto i_start = start_index(output[axes[i]], out_size[i], in_size[i]); auto i_end = end_index(output[axes[i]], out_size[i], in_size[i]); @@ -361,7 +361,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const ffi::Array& ou out_shape, [&](const ffi::Array& output) { ffi::Array indices; - ffi::Array reduce_axes; + ffi::Array reduce_axes; std::tie(indices, reduce_axes) = get_iter_vars(output, true); return tvm::max(x(indices), reduce_axes); // NOLINT(*) }, @@ -372,7 +372,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const ffi::Array& ou out_shape, [&](const ffi::Array& output) { ffi::Array indices; - ffi::Array reduce_axes; + ffi::Array reduce_axes; std::tie(indices, reduce_axes) = get_iter_vars(output, true); return tvm::sum(x(indices), reduce_axes); }, @@ -382,7 +382,7 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const ffi::Array& ou out_shape, [&](const ffi::Array& output) { ffi::Array indices; - ffi::Array reduce_axes; + ffi::Array reduce_axes; std::tie(indices, reduce_axes) = get_iter_vars(output, false); PrimExpr divide_factor = tvm::cast(x->dtype, 1); diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 5345cc8e0ea9..73c5fc31ce77 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -283,12 +283,12 @@ inline FCommReduce MakeCommReducer(FCombine fcombine, FIdentity fidentity, auto result = fcombine(lhs, rhs); auto id_elem = fidentity(dtypes); - auto cond = condition != nullptr ? *condition : tir::const_true(); + auto cond = condition != nullptr ? *condition : tirx::const_true(); - auto combiner = tvm::tir::CommReducer(lhs, rhs, result, id_elem); + auto combiner = tvm::tirx::CommReducer(lhs, rhs, result, id_elem); ffi::Array outputs; for (size_t i = 0; i < exprs.size(); ++i) { - outputs.push_back(tvm::tir::Reduce(combiner, exprs, axis, cond, static_cast(i), {})); + outputs.push_back(tvm::tirx::Reduce(combiner, exprs, axis, cond, static_cast(i), {})); } return outputs; }; @@ -470,14 +470,14 @@ inline FCommReduce MakeArgminReducer(bool select_last_index = false) { } PrimExpr update_index = is_smaller || (is_same && proper_index); - result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::Select(is_smaller, lhs[1], rhs[1])); // val + result.push_back(tvm::tirx::Select(update_index, lhs[0], rhs[0])); // idx + result.push_back(tvm::tirx::Select(is_smaller, lhs[1], rhs[1])); // val return result; }; auto fidentity = [&](std::vector types) { ffi::Array result; - result.push_back(tvm::tir::make_const(types[0], -1)); // idx - result.push_back(tvm::max_value(types[1])); // val + result.push_back(tvm::tirx::make_const(types[0], -1)); // idx + result.push_back(tvm::max_value(types[1])); // val return result; }; return MakeCommReducer(fcombine, fidentity, "argmin"); @@ -532,14 +532,14 @@ inline FCommReduce MakeArgmaxReducer(bool select_last_index = false) { } PrimExpr update_index = is_bigger || (is_same && proper_index); - result.push_back(tvm::tir::Select(update_index, lhs[0], rhs[0])); // idx - result.push_back(tvm::tir::Select(is_bigger, lhs[1], rhs[1])); // val + result.push_back(tvm::tirx::Select(update_index, lhs[0], rhs[0])); // idx + result.push_back(tvm::tirx::Select(is_bigger, lhs[1], rhs[1])); // val return result; }; auto fidentity = [&](std::vector types) { ffi::Array result; - result.push_back(tvm::tir::make_const(types[0], -1)); // idx - result.push_back(tvm::min_value(types[1])); // val + result.push_back(tvm::tirx::make_const(types[0], -1)); // idx + result.push_back(tvm::min_value(types[1])); // val return result; }; return MakeCommReducer(fcombine, fidentity, "argmax"); @@ -601,7 +601,7 @@ inline FCommReduce MakeTupleSumReducer() { auto fidentity = [](std::vector types) { ffi::Array result; for (size_t i = 0; i < types.size(); ++i) { - result.push_back(tvm::tir::make_const(types[i], 0)); + result.push_back(tvm::tirx::make_const(types[i], 0)); } return result; }; diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 24a1521c1e96..93938c601dd7 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include #include @@ -46,9 +46,9 @@ #include "tvm/ir/expr.h" #include "tvm/runtime/data_type.h" -#include "tvm/tir/expr.h" -#include "tvm/tir/op.h" -#include "tvm/tir/var.h" +#include "tvm/tirx/expr.h" +#include "tvm/tirx/op.h" +#include "tvm/tirx/var.h" namespace tvm { namespace topi { @@ -650,8 +650,8 @@ inline ffi::Array split_indices_array(const Tensor& x, ffi::Array(); - auto extent_var = extent.as(); + auto idx_var = index.as(); + auto extent_var = extent.as(); if (idx_var && extent_var && idx_var->name_hint == extent_var->name_hint) { return index; @@ -741,7 +741,7 @@ inline te::Tensor dynamic_strided_slice_with_axes( return te::compute( out_shape, - [&](const ffi::Array& indices) { + [&](const ffi::Array& indices) { ffi::Array real_indices = indices.Map([](const auto& var) -> PrimExpr { return var; }); @@ -793,7 +793,7 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const ffi::Array& out_shape.push_back( analyzer.Simplify(GetLength(begin[i], end[i], strides[i], x->shape[i], assume_inbound))); } else { - out_shape.push_back(tvm::tir::Var("dim")); + out_shape.push_back(tvm::tirx::Var("dim")); } } @@ -803,7 +803,7 @@ inline Tensor dynamic_strided_slice(const Tensor& x, const ffi::Array& return te::compute( out_shape, - [&](const ffi::Array& indices) { + [&](const ffi::Array& indices) { ffi::Array real_indices; for (size_t i = 0; i < num_slice_axes; ++i) { real_indices.push_back(indices[i] * strides[i] + tvm::min(begin[i], x->shape[i] - 1)); @@ -919,7 +919,7 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const ffi::Array return te::compute( out_shape, - [&](const ffi::Array& indices) { + [&](const ffi::Array& indices) { ffi::Array real_indices; for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); for (size_t i = 0; i < axes.size(); ++i) { @@ -1102,7 +1102,7 @@ inline Tensor sequence_mask(const Tensor& data, const Tensor& valid_length, doub len_index.push_back(bid); PrimExpr ret = tvm::if_then_else(tvm::cast(valid_length->dtype, tid) >= valid_length(len_index), - tvm::tir::make_const(data->dtype, mask_value), data(out_index)); + tvm::tirx::make_const(data->dtype, mask_value), data(out_index)); return ret; }, name, tag); @@ -1277,7 +1277,7 @@ inline Tensor take(const Tensor& a, ffi::Variant indices, int PrimExpr in_bounds = idx >= 0 && idx < axis_dim; return tvm::if_then_else( in_bounds, a(real_indices), - tvm::tir::make_const(a->dtype, std::numeric_limits::quiet_NaN())); + tvm::tirx::make_const(a->dtype, std::numeric_limits::quiet_NaN())); }, name, tag); } else { // mode == "wrap" @@ -1332,11 +1332,11 @@ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, auto x_bh = detail::BroadcastShape(x->shape, oshape); auto y_bh = detail::BroadcastShape(y->shape, oshape); - auto select = [&](tvm::ffi::Array ovars) { + auto select = [&](tvm::ffi::Array ovars) { auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars)); auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars)); auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars)); - return tvm::tir::Select(c != 0, true_val, false_val); + return tvm::tirx::Select(c != 0, true_val, false_val); }; return compute(oshape, select, name, tag); @@ -1614,7 +1614,7 @@ inline tvm::te::Tensor matmul(const tvm::te::Tensor& A, const tvm::te::Tensor& B std::string name = "T_matmul", std::string tag = kMatMul) { tvm::ffi::Array output_shape{A->shape[trans_a ? 1 : 0], B->shape[trans_b ? 0 : 1]}; auto k = tvm::te::reduce_axis(tvm::Range{0, A->shape[trans_a ? 0 : 1]}, "k"); - auto l = [&](tvm::tir::Var i, tvm::tir::Var j) { + auto l = [&](tvm::tirx::Var i, tvm::tirx::Var j) { return tvm::sum((trans_a ? A[k][i] : A[i][k]) * (trans_b ? B[j][k] : B[k][j]), {k}); }; return tvm::te::compute(output_shape, l, name, tag); @@ -1808,7 +1808,7 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, TVM_FFI_ICHECK(src_layout_struct.defined() && dst_layout_struct.defined()) << "cannot convert from/to undefined layout"; - auto layout_converter = tir::BijectiveLayout(src_layout_struct, dst_layout_struct); + auto layout_converter = tirx::BijectiveLayout(src_layout_struct, dst_layout_struct); TVM_FFI_ICHECK(layout_converter.defined()) << "cannot convert from " << src_layout << " to " << dst_layout; @@ -1938,7 +1938,7 @@ inline Tensor auto_scheduler_layout_transform( * A'[a, b, c, d] = A[a * 4 + c, b * 16 + d] */ inline Tensor meta_schedule_layout_transform( - const Tensor& src, const tir::IndexMap& index_map, + const Tensor& src, const tirx::IndexMap& index_map, const ffi::String name = "T_meta_schedule_layout_trans", const ffi::String tag = kInjective) { arith::Analyzer analyzer; ffi::Array iter_domain; @@ -2053,7 +2053,7 @@ inline Tensor one_hot(const Tensor& indices, const PrimExpr on_value, const Prim } auto idx = iter_vars[true_axis]; - return tir::Select(indices(indices_indices) == idx, on_value_cast, off_value_cast); + return tirx::Select(indices(indices_indices) == idx, on_value_cast, off_value_cast); }, name, tag); } @@ -2250,7 +2250,7 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b return te::compute( output_shape, - [&](const ffi::Array& indices) { + [&](const ffi::Array& indices) { ffi::Array real_indices; for (size_t i = 0; i < num_dynamic_axes; ++i) { auto ind = make_const(DataType::Int(64), i); diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index bba21cb690cc..dfcc8e20ab3a 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -45,8 +45,8 @@ from .ir import container from . import ir -# tvm.tir -from . import tir +# tvm.tirx +from . import tirx # tvm.s_tir from . import s_tir diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index f7d0ab82aad5..fc5d3c9aea04 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -21,7 +21,7 @@ import tvm_ffi -from tvm import ir, tir +from tvm import ir, tirx from tvm.arith import IntSet from tvm.runtime import Object @@ -126,7 +126,7 @@ def __init__(self): self._get_enabled_extensions = _mod("get_enabled_extensions") self._set_enabled_extensions = _mod("set_enabled_extensions") - def const_int_bound(self, expr: tir.PrimExpr) -> ConstIntBound: + def const_int_bound(self, expr: tirx.PrimExpr) -> ConstIntBound: """Find constant integer bound for expr. Parameters @@ -141,12 +141,12 @@ def const_int_bound(self, expr: tir.PrimExpr) -> ConstIntBound: """ return self._const_int_bound(expr) - def const_int_bound_is_bound(self, var: tir.Var) -> bool: + def const_int_bound_is_bound(self, var: tirx.Var) -> bool: """Check if a variable is bound to a range. Parameters ---------- - var : tvm.tir.Var + var : tvm.tirx.Var The variable. Returns @@ -156,7 +156,7 @@ def const_int_bound_is_bound(self, var: tir.Var) -> bool: """ return self._const_int_bound_is_bound(var) - def modular_set(self, expr: tir.PrimExpr) -> ModularSet: + def modular_set(self, expr: tirx.PrimExpr) -> ModularSet: """Find a modular set that expr belongs to. Parameters @@ -171,7 +171,7 @@ def modular_set(self, expr: tir.PrimExpr) -> ModularSet: """ return self._modular_set(expr) - def simplify(self, expr: tir.PrimExpr, steps: int = 2) -> tir.PrimExpr: + def simplify(self, expr: tirx.PrimExpr, steps: int = 2) -> tirx.PrimExpr: """Simplify expression via both rewrite and canonicalization. Parameters @@ -191,7 +191,7 @@ def simplify(self, expr: tir.PrimExpr, steps: int = 2) -> tir.PrimExpr: """ return self._simplify(expr, steps) - def rewrite_simplify(self, expr: tir.PrimExpr) -> tir.PrimExpr: + def rewrite_simplify(self, expr: tirx.PrimExpr) -> tirx.PrimExpr: """Simplify expression via rewriting rules. Parameters @@ -213,7 +213,7 @@ def rewrite_simplify_stats(self): def reset_rewrite_simplify_stats(self): self._reset_rewrite_simplify_stats() - def canonical_simplify(self, expr: tir.PrimExpr) -> tir.PrimExpr: + def canonical_simplify(self, expr: tirx.PrimExpr) -> tirx.PrimExpr: """Simplify expression via canonicalization. Parameters @@ -228,7 +228,7 @@ def canonical_simplify(self, expr: tir.PrimExpr) -> tir.PrimExpr: """ return self._canonical_simplify(expr) - def int_set(self, expr: tir.PrimExpr, dom_map: dict[tir.Var, IntSet]) -> IntSet: + def int_set(self, expr: tirx.PrimExpr, dom_map: dict[tirx.Var, IntSet]) -> IntSet: """Compute a symbolic IntSet that covers expr for all values in dom_map. Parameters @@ -236,7 +236,7 @@ def int_set(self, expr: tir.PrimExpr, dom_map: dict[tir.Var, IntSet]) -> IntSet: expr : PrimExpr The expression. - dom_map : Dict[tvm.tir.Var, tvm.arith.IntSet] + dom_map : Dict[tvm.tirx.Var, tvm.arith.IntSet] The domain for variables to be relaxed. Returns @@ -247,7 +247,7 @@ def int_set(self, expr: tir.PrimExpr, dom_map: dict[tir.Var, IntSet]) -> IntSet: return self._int_set(expr, dom_map) def can_prove( - self, expr: tir.PrimExpr, strength: ProofStrength = ProofStrength.DEFAULT + self, expr: tirx.PrimExpr, strength: ProofStrength = ProofStrength.DEFAULT ) -> bool: """Check whether we can prove expr to be true. @@ -266,20 +266,20 @@ def can_prove( """ return self._can_prove(expr, strength) - def bind(self, var: tir.Var, expr: tir.PrimExpr | ir.Range) -> None: + def bind(self, var: tirx.Var, expr: tirx.PrimExpr | ir.Range) -> None: """Bind a variable to the expression. Parameters ---------- - var : tvm.tir.Var + var : tvm.tirx.Var The variable. - expr : Union[tir.PrimExpr, ir.Range] + expr : Union[tirx.PrimExpr, ir.Range] The expression or the range to bind to. """ return self._bind(var, expr) - def constraint_scope(self, constraint: tir.PrimExpr) -> ConstraintScope: + def constraint_scope(self, constraint: tirx.PrimExpr) -> ConstraintScope: """Create a constraint scope. Parameters @@ -310,12 +310,12 @@ def _fenter(): return ConstraintScope(_fenter) - def update(self, var: tir.Var, info: ConstIntBound, override: bool = False) -> None: + def update(self, var: tirx.Var, info: ConstIntBound, override: bool = False) -> None: """Update infomation about var Parameters ---------- - var : tvm.tir.Var + var : tvm.tirx.Var The variable. info : tvm.Object @@ -329,7 +329,7 @@ def update(self, var: tir.Var, info: ConstIntBound, override: bool = False) -> N else: raise TypeError(f"Do not know how to handle type {type(info)}") - def can_prove_equal(self, lhs: tir.PrimExpr, rhs: tir.PrimExpr) -> bool: + def can_prove_equal(self, lhs: tirx.PrimExpr, rhs: tirx.PrimExpr) -> bool: """Whether we can prove that lhs == rhs Parameters diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index 08421b79d90b..9b114ce810ed 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -76,9 +76,9 @@ class IntConstraints(Object): Parameters ---------- - variables : List[tvm.tir.Var] + variables : List[tvm.tirx.Var] The variables in the constraints. Must be integers - ranges : Map[tvm.tir.Var, tvm.ir.Range] + ranges : Map[tvm.tirx.Var, tvm.ir.Range] The ranges of the variables. relations : List[tvm.ir.PrimExpr] The relations between the variables (either equations or inequalities) @@ -108,10 +108,10 @@ class IntConstraintsTransform(Object): source integer constraints, e.g., {a + b = 0 | a >= 0, b >= 0} dst : arith.IntConstraints integer constraints equivalent to the source, e.g., {m - n = 0 | m >= 0, n <= 0} - src_to_dst : Map[tvm.tir.Var, tvm.ir.PrimExpr] + src_to_dst : Map[tvm.tirx.Var, tvm.ir.PrimExpr] mapping from variables in the src to the variables in the dst, e.g., {a -> m, b -> -n} - dst_to_src : Map[tvm.tir.Var, tvm.ir.PrimExpr] + dst_to_src : Map[tvm.tirx.Var, tvm.ir.PrimExpr] mapping from variables in the dst to the variables in the src, e.g., {m -> a, n -> -b} """ @@ -129,9 +129,9 @@ def solve_linear_equations(equations, variables=None, ranges=None): ---------- equations: List[tvm.ir.PrimExpr] or IntConstraints The equations of the variables - variables : Optional[List[tvm.tir.Var]] + variables : Optional[List[tvm.tirx.Var]] The variables in the system. - ranges : Optional[Map[tvm.tir.Var, tvm.ir.Range]] + ranges : Optional[Map[tvm.tirx.Var, tvm.ir.Range]] The ranges of the variables. Returns @@ -157,9 +157,9 @@ def solve_linear_inequalities(equations, variables=None, ranges=None, deskew_ran ---------- equations : List[tvm.ir.PrimExpr] or IntConstraints The inequalities of the variables - variables : Optional[List[tvm.tir.Var]] + variables : Optional[List[tvm.tirx.Var]] The variables in the system. - ranges : Optional[Map[tvm.tir.Var, tvm.ir.Range]] + ranges : Optional[Map[tvm.tirx.Var, tvm.ir.Range]] The ranges of the variables. deskew_range: Optional[bool] Whether deskew the result ranges to be started from zero. diff --git a/python/tvm/arith/pattern.py b/python/tvm/arith/pattern.py index da2c66c5e844..efaf1b72e73c 100644 --- a/python/tvm/arith/pattern.py +++ b/python/tvm/arith/pattern.py @@ -29,7 +29,7 @@ def detect_linear_equation(expr, var_list): expr : PrimExpr The expression to be matched. - var_list : List[tvm.tir.Var] + var_list : List[tvm.tirx.Var] A list of variables. Returns @@ -49,7 +49,7 @@ def detect_clip_bound(expr, var_list): expr : PrimExpr The expression to be matched. - var_list : List[tvm.tir.Var] + var_list : List[tvm.tirx.Var] A list of variables. Returns diff --git a/python/tvm/contrib/cblas.py b/python/tvm/contrib/cblas.py index 37302277dea4..d2dd8012a8fc 100644 --- a/python/tvm/contrib/cblas.py +++ b/python/tvm/contrib/cblas.py @@ -45,7 +45,7 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs): return te.extern( (n, m), [lhs, rhs], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], transa, transb ), name="C", @@ -79,7 +79,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs return te.extern( (b, n, m), [lhs, rhs], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.cblas.batch_matmul" if not iterative else "tvm.contrib.cblas.batch_matmul_iterative", diff --git a/python/tvm/contrib/cublas.py b/python/tvm/contrib/cublas.py index b66a36143da8..a999d70f2a99 100644 --- a/python/tvm/contrib/cublas.py +++ b/python/tvm/contrib/cublas.py @@ -45,7 +45,7 @@ def matmul(lhs, rhs, transa=False, transb=False, dtype=None): return te.extern( (n, m), [lhs, rhs], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.cublas.matmul", ins[0], ins[1], outs[0], transa, transb ), dtype=dtype, @@ -79,7 +79,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None): return te.extern( (b, n, m), [lhs, rhs], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.cublas.batch_matmul", ins[0], ins[1], outs[0], transa, transb ), dtype=dtype, diff --git a/python/tvm/contrib/cublaslt.py b/python/tvm/contrib/cublaslt.py index b59ebc737039..44cb9c9f0bee 100644 --- a/python/tvm/contrib/cublaslt.py +++ b/python/tvm/contrib/cublaslt.py @@ -47,7 +47,7 @@ def matmul(lhs, rhs, transa=False, transb=False, n=0, m=0, dtype=None): return te.extern( (n, m), [lhs, rhs], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.cublaslt.matmul", ins[0], ins[1], outs[0], transa, transb ), dtype=dtype, diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index d77b4afe0ea2..9820bc8a29b6 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -603,7 +603,7 @@ def conv_forward( x_shape = list(x.shape) - if isinstance(x.shape[0], tvm.tir.expr.IntImm): + if isinstance(x.shape[0], tvm.tirx.expr.IntImm): oshape = conv_output_shape( tensor_format, pad, @@ -658,7 +658,7 @@ def conv_forward( return te.extern( oshape, [x, w], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.cudnn.conv2d.forward", conv_mode, tensor_format, @@ -681,7 +681,7 @@ def conv_forward( return te.extern( oshape, [x, w], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.cudnn.conv3d.forward", conv_mode, tensor_format, @@ -753,7 +753,7 @@ def conv_backward_data( conv_dtype = dy.dtype if conv_dtype is None else conv_dtype pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation) - assert isinstance(dy.shape[0], tvm.tir.expr.IntImm), ( + assert isinstance(dy.shape[0], tvm.tirx.expr.IntImm), ( "Dynamic batch is not supported for cudnn conv2d backwad data yet." ) @@ -782,7 +782,7 @@ def conv_backward_data( return te.extern( dx_shape, [dy, w], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.cudnn.conv2d.backward_data", conv_mode, tensor_format, @@ -847,7 +847,7 @@ def conv_backward_filter( x_shape = list(x.shape) - assert isinstance(x.shape[0], tvm.tir.expr.IntImm), ( + assert isinstance(x.shape[0], tvm.tirx.expr.IntImm), ( "Dynamic batch is not supported for cudnn conv2d backwad filter yet." ) @@ -883,7 +883,7 @@ def conv_backward_filter( return te.extern( dw_shape, [dy, x], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.cudnn.conv2d.backward_filter", conv_mode, tensor_format, @@ -923,7 +923,7 @@ def softmax(x, axis=-1): return te.extern( x.shape, [x], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.cudnn.softmax.forward", ins[0], outs[0], axis ), name="y", @@ -949,7 +949,7 @@ def log_softmax(x, axis=-1): return te.extern( x.shape, [x], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.cudnn.log_softmax.forward", ins[0], outs[0], axis ), name="y", diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index f7e02a09852f..ce9a46ba7004 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -108,7 +108,7 @@ def select_gemm_kernel( ): """Run CUTLASS profiler to select the best kernel, or return the default one for dynamic workloads.""" - if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]): + if any(isinstance(s, tvm.tirx.Any) for s in [MM, KK, NN]): out = cutlass_profiler.get_default( op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32, batched=batched ) @@ -249,7 +249,7 @@ def handle_conv2d( else: conv_kind = ConvKind.Fprop - if any(isinstance(s, tvm.tir.Any) for s in d_shape): + if any(isinstance(s, tvm.tirx.Any) for s in d_shape): out = cutlass_profiler.get_default( op_type, out_dtype, data_dtype, weight_dtype, use_3xtf32, conv_kind, strides ) @@ -417,7 +417,7 @@ def is_shape_valid_for_cutlass_matmul( as well as ND x 2D and 2D x ND. For example, it cannot handle matmul with shape (2, 1, 4, 8) x (2, 3, 8, 16), because the batch stride of lhs is not constant. """ - if not isinstance(lhs_shape[-1], tvm.tir.expr.IntImm | int): + if not isinstance(lhs_shape[-1], tvm.tirx.expr.IntImm | int): # Reduction axis must be constant return False diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index bcddffa5b8a4..eab4c46c5b77 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -29,7 +29,7 @@ import tvm_ffi from tvm.runtime import Object -from tvm.tir import IntImm +from tvm.tirx import IntImm from . import _ffi_api as ffi from .attention_operation import ( @@ -539,7 +539,7 @@ def get_flattened_batch_dim(arg_name, batch_rank): attrs["M"] = annotations["M"] attrs["group_size"] = annotations["group_size"] - if not isinstance(attrs["M"], tvm.tir.IntImm): + if not isinstance(attrs["M"], tvm.tirx.IntImm): attrs["M"] = get_flattened_batch_dim( func_args[lhs_arg_idx], int(annotations["batch_rank"]) ) @@ -888,7 +888,7 @@ def get_batch_on_arg(arg_name, arg_shape): attrs = {"input": func_args[0], "gamma": func_args[1], "beta": func_args[2]} attrs.update(dict(annotations)) - if not isinstance(attrs["M"], tvm.tir.IntImm): + if not isinstance(attrs["M"], tvm.tirx.IntImm): attrs["M"] = get_flattened_batch_dim(func_args[0], int(attrs["batch_rank"])) code = instantiate_layer_norm_template(attrs) @@ -899,7 +899,7 @@ def get_batch_on_arg(arg_name, arg_shape): attrs = {"input": func_args[0], "weight": func_args[1]} attrs.update(dict(annotations)) - if not isinstance(attrs["M"], tvm.tir.IntImm): + if not isinstance(attrs["M"], tvm.tirx.IntImm): attrs["M"] = get_flattened_batch_dim(func_args[0], int(attrs["batch_rank"])) code = instantiate_rms_norm_template(attrs) diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index 62fb952a18eb..82a9b6e71fa6 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -22,7 +22,7 @@ import re from enum import auto as enum_auto -from tvm.tir.expr import FloatImm, IntImm +from tvm.tirx.expr import FloatImm, IntImm class GeneratorTarget(enum.Enum): diff --git a/python/tvm/contrib/dnnl.py b/python/tvm/contrib/dnnl.py index 444be67ffb8d..eebd7e166014 100644 --- a/python/tvm/contrib/dnnl.py +++ b/python/tvm/contrib/dnnl.py @@ -47,7 +47,7 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs): return te.extern( (n, m), [lhs, rhs], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.dnnl.matmul", ins[0], ins[1], outs[0], transa, transb ), name="C", @@ -142,7 +142,7 @@ def dnnl_conv2d( return te.extern( out_shape, [src, weights], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.dnnl.conv2d", ins[0], ins[1], diff --git a/python/tvm/contrib/hexagon/hexagon_profiler.py b/python/tvm/contrib/hexagon/hexagon_profiler.py index bbf360bc0ab4..aaec36688e37 100644 --- a/python/tvm/contrib/hexagon/hexagon_profiler.py +++ b/python/tvm/contrib/hexagon/hexagon_profiler.py @@ -52,7 +52,7 @@ def __init__( if self._android_serial_number is None: raise RuntimeError("ANDROID_SERIAL_NUMBER must be set for profiling") - if ("tir.instrument_lwp", True) in config.items(): + if ("tirx.instrument_lwp", True) in config.items(): # Set profiling mode self._profiling_mode = "lwp" diff --git a/python/tvm/contrib/hexagon/tools.py b/python/tvm/contrib/hexagon/tools.py index c1e1191b5e61..bf7913cf6235 100644 --- a/python/tvm/contrib/hexagon/tools.py +++ b/python/tvm/contrib/hexagon/tools.py @@ -128,9 +128,9 @@ def link_shared(so_name, objs, extra_args=None): """ # The list of object files can be passed as built-in Python strings, - # or as tvm.tir.StringImm's. + # or as tvm.tirx.StringImm's. def to_str(s): - if isinstance(s, tvm.tir.StringImm): + if isinstance(s, tvm.tirx.StringImm): return s.value assert isinstance(s, str), 'argument "' + str(s) + '" should be a string or StrImm' return s @@ -199,9 +199,9 @@ def link_shared_macos(so_name, objs, extra_args=None): """ # The list of object files can be passed as built-in Python strings, - # or as tvm.tir.StringImm's. + # or as tvm.tirx.StringImm's. def to_str(s): - if isinstance(s, tvm.tir.StringImm): + if isinstance(s, tvm.tirx.StringImm): return s.value assert isinstance(s, str), 'argument "' + str(s) + '" should be a string or StrImm' return s diff --git a/python/tvm/contrib/hipblas.py b/python/tvm/contrib/hipblas.py index da71ca64c37a..078d6fef6e65 100644 --- a/python/tvm/contrib/hipblas.py +++ b/python/tvm/contrib/hipblas.py @@ -45,7 +45,7 @@ def matmul(lhs, rhs, transa=False, transb=False, dtype=None): return te.extern( (n, m), [lhs, rhs], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.hipblas.matmul", ins[0], ins[1], outs[0], transa, transb ), dtype=dtype, @@ -79,7 +79,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None): return te.extern( (b, n, m), [lhs, rhs], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.hipblas.batch_matmul", ins[0], ins[1], outs[0], transa, transb ), dtype=dtype, diff --git a/python/tvm/contrib/mkl.py b/python/tvm/contrib/mkl.py index 303ae62bf651..cc223b5f08ff 100644 --- a/python/tvm/contrib/mkl.py +++ b/python/tvm/contrib/mkl.py @@ -45,7 +45,7 @@ def matmul(lhs, rhs, transa=False, transb=False, **kwargs): return te.extern( (n, m), [lhs, rhs], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.mkl.matmul", ins[0], ins[1], outs[0], transa, transb ), name="C", @@ -78,7 +78,7 @@ def matmul_u8s8s32(lhs, rhs, transa=False, transb=False, **kwargs): return te.extern( (n, m), [lhs, rhs], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.mkl.matmul_u8s8s32", ins[0], ins[1], outs[0], transa, transb ), name="C", @@ -112,7 +112,7 @@ def batch_matmul(lhs, rhs, transa=False, transb=False, iterative=False, **kwargs return te.extern( (b, n, m), [lhs, rhs], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.mkl.batch_matmul" if not iterative else "tvm.contrib.mkl.batch_matmul_iterative", diff --git a/python/tvm/contrib/nnpack.py b/python/tvm/contrib/nnpack.py index 9e4d69efdd2f..e6068db7c94c 100644 --- a/python/tvm/contrib/nnpack.py +++ b/python/tvm/contrib/nnpack.py @@ -50,7 +50,7 @@ def fully_connected_inference(lhs, rhs, nthreads=1): return te.extern( (m,), [lhs, rhs], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.nnpack.fully_connected_inference", ins[0], ins[1], outs[0], nthreads ), name="C", @@ -114,7 +114,7 @@ def convolution_inference( return te.extern( (batch, output_channels, output_height, output_width), [data, kernel, bias] if bias is not None else [data, kernel], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.nnpack.convolution_inference", ins[0], ins[1], @@ -177,7 +177,7 @@ def convolution_inference_without_weight_transform( return te.extern( (batch, output_channels, output_height, output_width), [data, transformed_kernel, bias] if bias is not None else [data, transformed_kernel], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.nnpack.convolution_inference_without_weight_transform", ins[0], ins[1], @@ -223,7 +223,7 @@ def convolution_inference_weight_transform( return te.extern( (output_channels, input_channels, transform_tile_size, transform_tile_size), [kernel], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.nnpack.convolution_inference_weight_transform", ins[0], outs[0], diff --git a/python/tvm/contrib/random.py b/python/tvm/contrib/random.py index 66d7426ef204..0fcfc567a438 100644 --- a/python/tvm/contrib/random.py +++ b/python/tvm/contrib/random.py @@ -43,7 +43,7 @@ def randint(low, high, size, dtype="int32"): return te.extern( size, [], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.random.randint", int(low), int(high), outs[0] ), dtype=dtype, @@ -77,7 +77,7 @@ def uniform(low, high, size): return te.extern( size, [], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.random.uniform", float(low), float(high), outs[0] ), dtype="float32", @@ -107,7 +107,7 @@ def normal(loc, scale, size): return te.extern( size, [], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.random.normal", float(loc), float(scale), outs[0] ), dtype="float32", diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 8847b9031d60..a7ff16f8f880 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -25,7 +25,7 @@ from tvm.ir.module import IRModule from tvm.runtime import Executable from tvm.target import Target -from tvm.tir import PrimFunc +from tvm.tirx import PrimFunc def build( @@ -37,7 +37,7 @@ def build( Build a function with a signature, generating code for devices coupled with target information. - This function is deprecated. Use `tvm.compile` or `tvm.tir.build` instead. + This function is deprecated. Use `tvm.compile` or `tvm.tirx.build` instead. Parameters ---------- @@ -54,10 +54,10 @@ def build( A module combining both host and device code. """ warnings.warn( - "build is deprecated. Use `tvm.compile` or `tvm.tir.build` instead.", + "build is deprecated. Use `tvm.compile` or `tvm.tirx.build` instead.", DeprecationWarning, ) - return tvm.tir.build(mod, target, pipeline) + return tvm.tirx.build(mod, target, pipeline) def _contains_relax(mod: PrimFunc | IRModule) -> bool: @@ -108,5 +108,5 @@ def compile( # pylint: disable=redefined-builtin relax_pipeline=relax_pipeline, tir_pipeline=tir_pipeline, ) - lib = tvm.tir.build(mod, target, pipeline=tir_pipeline) + lib = tvm.tirx.build(mod, target, pipeline=tir_pipeline) return Executable(lib) diff --git a/python/tvm/exec/popen_worker.py b/python/tvm/exec/popen_worker.py index 0fd77ccb17e2..5d63abd4668d 100644 --- a/python/tvm/exec/popen_worker.py +++ b/python/tvm/exec/popen_worker.py @@ -1,4 +1,20 @@ # Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index 01b6fe20c1d5..5451b65ef6c0 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -146,7 +146,7 @@ def make_node(type_key, **kwargs): .. code-block:: python x = tvm.ir.make_node("ir.IntImm", dtype="int32", value=10, span=None) - assert isinstance(x, tvm.tir.IntImm) + assert isinstance(x, tvm.tirx.IntImm) assert x.value == 10 """ if type_key == "ir.DictAttrs": diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 225cc5b59d0c..e466ef850043 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -163,9 +163,9 @@ def structural_equal(lhs, rhs, map_free_vars=False): - Normal node: equality is recursively defined without the restriction of graph nodes. - Vars(tir::Var, relax::Var) are graph nodes. + Vars(tirx::Var, relax::Var) are graph nodes. - A var-type node(e.g. tir::Var) can be mapped as equal to another var + A var-type node(e.g. tirx::Var) can be mapped as equal to another var with the same type if one of the following condition holds: - They appear in a same definition point(e.g. function argument). diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index b534213f8164..3dab9d02d54e 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -102,7 +102,7 @@ def __call__(self, *args: RelaxExpr) -> BaseExpr: return relax.Call(self, args) elif all(isinstance(x, Number | PrimExpr) for x in args): - return tvm.tir.call_tir(self, *args) + return tvm.tirx.call_tir(self, *args) arg_types = [type(x) for x in args] raise RuntimeError(f"Do not know how to handle GlobalVar.__call__ for types {arg_types}") diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index d3bf850740dc..51e951957c3e 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -25,11 +25,11 @@ from enum import IntEnum import tvm -from tvm import IRModule, tir +from tvm import IRModule, tirx from tvm.relax.expr import Binding, Call, DataflowBlock, Expr, Function, GlobalVar, Var from tvm.relax.struct_info import FuncStructInfo, StructInfo from tvm.relax.ty import Type -from tvm.tir import Buffer, IndexMap, PrimFunc, SBlock +from tvm.tirx import Buffer, IndexMap, PrimFunc, SBlock from . import _ffi_api @@ -52,7 +52,7 @@ def get_static_type(sinfo: StructInfo) -> Type: def erase_to_well_defined( sinfo: StructInfo, - shape_var_map: dict[tir.Var, tir.PrimExpr] | None = None, + shape_var_map: dict[tirx.Var, tirx.PrimExpr] | None = None, var_map: dict[Var, Expr] | None = None, ) -> StructInfo: """Erase sinfo into a well defined form. @@ -65,7 +65,7 @@ def erase_to_well_defined( sinfo : StructInfo The input struct info. - shape_var_map : Dict[tir.Var, tir.PrimExpr] + shape_var_map : Dict[tirx.Var, tirx.PrimExpr] Specifies the defined shape vars and the values they should map to. var_map : Dict[Var, Expr] @@ -167,7 +167,7 @@ def struct_info_lca(lhs: StructInfo, rhs: StructInfo) -> StructInfo: return _ffi_api.StructInfoLCA(lhs, rhs) # type: ignore -def tir_vars_in_struct_info(sinfo: StructInfo) -> list[tir.Var]: +def tir_vars_in_struct_info(sinfo: StructInfo) -> list[tirx.Var]: """Get the TIR variables that appear in the input struct info. The returned list is deduplicated - each TIR variable will appear at most once. @@ -178,13 +178,13 @@ def tir_vars_in_struct_info(sinfo: StructInfo) -> list[tir.Var]: Returns ------- - ret : List[tir.Var] + ret : List[tirx.Var] The list of TIR variables that appear in the input struct info. """ return _ffi_api.TIRVarsInStructInfo(sinfo) # type: ignore -def definable_tir_vars_in_struct_info(sinfo: StructInfo) -> list[tir.Var]: +def definable_tir_vars_in_struct_info(sinfo: StructInfo) -> list[tirx.Var]: """Get the TIR variables that may be defined from input struct info. The returned list is deduplicated - each TIR variable will appear at most once. @@ -195,14 +195,14 @@ def definable_tir_vars_in_struct_info(sinfo: StructInfo) -> list[tir.Var]: Returns ------- - ret : List[tir.Var] + ret : List[tirx.Var] The list of TIR variables that can be defined from the StructInfo """ return _ffi_api.DefinableTIRVarsInStructInfo(sinfo) # type: ignore -def collect_non_negative_expressions(sinfo: StructInfo) -> list[tir.PrimExpr]: +def collect_non_negative_expressions(sinfo: StructInfo) -> list[tirx.PrimExpr]: """Collect TIR expressions used in non-negative contexts Get TIR variables that are non-negative within the context where @@ -220,7 +220,7 @@ def collect_non_negative_expressions(sinfo: StructInfo) -> list[tir.PrimExpr]: Returns ------- - ret : List[tir.Var] + ret : List[tirx.Var] The list of TIR variables that can be defined from the StructInfo @@ -363,7 +363,7 @@ def post_order_visit(expr, fvisit): return _ffi_api.post_order_visit(expr, fvisit) # type: ignore -def has_reshape_pattern(func: tir.PrimFunc) -> bool: +def has_reshape_pattern(func: tirx.PrimFunc) -> bool: """Check if the given PrimFunc is essentially doing a reshape operation. The reshape operation also includes expand_dims, squeeze, flatten, etc. @@ -374,7 +374,7 @@ def has_reshape_pattern(func: tir.PrimFunc) -> bool: Parameters ---------- - func : tir.PrimFunc + func : tirx.PrimFunc The function to be examined. Returns diff --git a/python/tvm/relax/analysis/estimate_memory_usage.py b/python/tvm/relax/analysis/estimate_memory_usage.py index 0de0c7ea344f..8dac1905140b 100644 --- a/python/tvm/relax/analysis/estimate_memory_usage.py +++ b/python/tvm/relax/analysis/estimate_memory_usage.py @@ -124,7 +124,7 @@ def calculate_size(self, shape: Expr, dtype_str: str) -> int: ) size: int = 1 for dim_len in shape.values: - if not isinstance(dim_len, tvm.tir.IntImm): + if not isinstance(dim_len, tvm.tirx.IntImm): self.total_dyn_size_tensor_num += 1 return -1 size *= dim_len.value diff --git a/python/tvm/relax/backend/adreno/clml.py b/python/tvm/relax/backend/adreno/clml.py index 002d976b9cb3..02baae58b111 100644 --- a/python/tvm/relax/backend/adreno/clml.py +++ b/python/tvm/relax/backend/adreno/clml.py @@ -18,7 +18,7 @@ """Pattern table for CLML backend""" import tvm -from tvm import IRModule, relax, tir +from tvm import IRModule, relax, tirx from tvm.ir.transform import PassContext, module_pass from tvm.relax import transform from tvm.relax.dpl.pattern import ( @@ -631,7 +631,7 @@ def _check_dequantize_matmul(ctx: relax.transform.PatternCheckContext) -> bool: if not ( (len(root.struct_info.shape) == 3) - and isinstance(root.struct_info.shape[0], tir.IntImm) + and isinstance(root.struct_info.shape[0], tirx.IntImm) and (root.struct_info.dtype == "float16") and (root.struct_info.shape[0] == 1) ): diff --git a/python/tvm/relax/backend/adreno/pipeline.py b/python/tvm/relax/backend/adreno/pipeline.py index caf84c0e7a0f..dfc504d368c2 100644 --- a/python/tvm/relax/backend/adreno/pipeline.py +++ b/python/tvm/relax/backend/adreno/pipeline.py @@ -44,7 +44,7 @@ def legalize_passes(target: tvm.target.Target): # pylint: disable=unused-argume pass_list.extend( [ - tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False)), + tvm.tirx.transform.BindTarget(tvm.target.Target.current(allow_none=False)), relax.transform.DecomposeOpsForInference(), ] ) diff --git a/python/tvm/relax/backend/cuda/cublas.py b/python/tvm/relax/backend/cuda/cublas.py index 006f2153af9d..a0907a76784c 100644 --- a/python/tvm/relax/backend/cuda/cublas.py +++ b/python/tvm/relax/backend/cuda/cublas.py @@ -74,7 +74,7 @@ def _check_matmul(context: PatternCheckContext) -> bool: lhs_shape = lhs.struct_info.shape.values rhs_shape = rhs.struct_info.shape.values - if not isinstance(lhs_shape[-1], tvm.tir.expr.IntImm | int): + if not isinstance(lhs_shape[-1], tvm.tirx.expr.IntImm | int): # Reduction axis must be constant return False @@ -82,7 +82,7 @@ def _check_matmul(context: PatternCheckContext) -> bool: if lhs_shape[-1] % 4 != 0: # Reduction axis must be multiples of 4 for IGEMM return False - if not isinstance(rhs_shape[-1], tvm.tir.expr.IntImm | int) or rhs_shape[-1] % 4 != 0: + if not isinstance(rhs_shape[-1], tvm.tirx.expr.IntImm | int) or rhs_shape[-1] % 4 != 0: # Rows number must be multiples of 4 for IGEMM return False elif lhs_dtype == "float8_e4m3fn" and rhs_dtype == "float8_e4m3fn": @@ -102,12 +102,12 @@ def _check_matmul(context: PatternCheckContext) -> bool: # cuBLAS FP8 operations require all tensors being aligned to 16 bytes. if ( - not isinstance(rhs_shape[-1], tvm.tir.expr.IntImm | int) + not isinstance(rhs_shape[-1], tvm.tirx.expr.IntImm | int) or rhs_shape[-1] % (16 // DataType(lhs_dtype).itemsize) != 0 ): return False if ( - not isinstance(rhs_shape[-2], tvm.tir.expr.IntImm | int) + not isinstance(rhs_shape[-2], tvm.tirx.expr.IntImm | int) or rhs_shape[-2] % (16 // DataType(out_dtype).itemsize) != 0 ): return False @@ -122,7 +122,7 @@ def _check_matmul(context: PatternCheckContext) -> bool: bias = context.annotated_expr["bias"] bias_shape = bias.struct_info.shape.values bias_batches = reduce(operator.mul, bias_shape[:-1], 1) - if not isinstance(bias_batches, tvm.tir.expr.IntImm | int) or int(bias_batches) > 1: + if not isinstance(bias_batches, tvm.tirx.expr.IntImm | int) or int(bias_batches) > 1: # cuBLAS only supports bias vector return False @@ -133,8 +133,8 @@ def _check_matmul(context: PatternCheckContext) -> bool: # must be equal. If lhs is batched but rhs is not, we can use the regular GEMM by # flattening all batch axes into the M axis. return ( - isinstance(lhs_batches, tvm.tir.Var) - or isinstance(rhs_batches, tvm.tir.Var) + isinstance(lhs_batches, tvm.tirx.Var) + or isinstance(rhs_batches, tvm.tirx.Var) or (analyzer.can_prove_equal(lhs_batches, rhs_batches)) or (analyzer.can_prove(lhs_batches >= 1) and analyzer.can_prove(rhs_batches == 1)) ) diff --git a/python/tvm/relax/backend/cuda/cudnn.py b/python/tvm/relax/backend/cuda/cudnn.py index 3cc8720f4269..2be11fd04d47 100644 --- a/python/tvm/relax/backend/cuda/cudnn.py +++ b/python/tvm/relax/backend/cuda/cudnn.py @@ -182,7 +182,7 @@ def visit_function_(self, f): out_size_1d = _shape_1d(f.ret_struct_info.shape) # This needs to be in sync with the actual value that the kernel expects. workspace_size_bytes = out_size_1d * {"float16": 2, "float32": 4}[out_dtype] - if not isinstance(workspace_size_bytes, int | tvm.tir.expr.IntImm): + if not isinstance(workspace_size_bytes, int | tvm.tirx.expr.IntImm): # Tempororay workaround for dynamic shape workload. Will be removed when # workspace for dynamic shape workload is implemented. workspace_size_bytes = 8 diff --git a/python/tvm/relax/backend/cuda/cutlass.py b/python/tvm/relax/backend/cuda/cutlass.py index c656a6ae77f9..b15b3b698508 100644 --- a/python/tvm/relax/backend/cuda/cutlass.py +++ b/python/tvm/relax/backend/cuda/cutlass.py @@ -136,7 +136,7 @@ def _check_conv2d(context: PatternCheckContext) -> bool: # Check if any dimensions are symbolic. for dim in data.struct_info.shape.values: - if isinstance(dim, tvm.tir.Var): + if isinstance(dim, tvm.tirx.Var): return False # pylint: disable=invalid-name @@ -551,7 +551,7 @@ def visit_function_(self, f): out_size_1d = _shape_1d(f.ret_struct_info.shape) # This needs to be in sync with the actual value that the kernel expects. workspace_size_bytes = out_size_1d * {"float16": 2, "float32": 4}[out_dtype] - if not isinstance(workspace_size_bytes, int | tvm.tir.expr.IntImm): + if not isinstance(workspace_size_bytes, int | tvm.tirx.expr.IntImm): # Tempororay workaround for dynamic shape workload. Will be removed when # workspace for dynamic shape workload is implemented. workspace_size_bytes = 8 diff --git a/python/tvm/relax/backend/dispatch_sort_scan.py b/python/tvm/relax/backend/dispatch_sort_scan.py index b0d74769d9a8..c901aceed3e3 100644 --- a/python/tvm/relax/backend/dispatch_sort_scan.py +++ b/python/tvm/relax/backend/dispatch_sort_scan.py @@ -58,7 +58,7 @@ def apply_dlight_gpu_fallback( if sch is not None: assert len(sch) == 1 self.builder_.update_func( - gvar, sch[0].mod["main"].with_attr("tir.is_scheduled", True) + gvar, sch[0].mod["main"].with_attr("tirx.is_scheduled", True) ) def _append_calls_to_update(self, tir_call: relax.Call, target: Target) -> None: diff --git a/python/tvm/relax/backend/gpu_generic/cumsum.py b/python/tvm/relax/backend/gpu_generic/cumsum.py index ae2060175c8f..bd2cec3bcd50 100644 --- a/python/tvm/relax/backend/gpu_generic/cumsum.py +++ b/python/tvm/relax/backend/gpu_generic/cumsum.py @@ -19,8 +19,8 @@ import math -from tvm.script import tir as T -from tvm.tir import PrimFunc +from tvm.script import tirx as T +from tvm.tirx import PrimFunc def _is_power_of_two(n: int): @@ -154,7 +154,7 @@ def update_cross_block( @T.prim_func(private=True) def cumsum(var_a: T.handle, var_out: T.handle): - T.func_attr({"tir.is_scheduled": True}) # prevent further scheduling + T.func_attr({"tirx.is_scheduled": True}) # prevent further scheduling m, n = T.int64(), T.int64() A = T.match_buffer(var_a, [m, n], dtype=in_dtype) Out = T.match_buffer(var_out, [m, n], dtype=out_dtype) diff --git a/python/tvm/relax/backend/gpu_generic/sampling.py b/python/tvm/relax/backend/gpu_generic/sampling.py index b21e488218bc..1e039ac19405 100644 --- a/python/tvm/relax/backend/gpu_generic/sampling.py +++ b/python/tvm/relax/backend/gpu_generic/sampling.py @@ -21,8 +21,8 @@ from collections.abc import Callable import tvm -from tvm.script import tir as T -from tvm.tir import PrimFunc +from tvm.script import tirx as T +from tvm.tirx import PrimFunc def _is_power_of_two(n: int): @@ -265,7 +265,7 @@ def parallel_sampling_from_prob( var_row_indices: T.handle, var_sampled_token_ids: T.handle, ): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) n, vocab_size, batch_size = T.int64(), T.int64(), T.int64() # match buffers prob = T.match_buffer(var_prob, (n, vocab_size), prob_dtype) diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index c9c853d702de..a328684b7c7e 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -29,7 +29,7 @@ wildcard, ) from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def _with_bias_activation_pattern( diff --git a/python/tvm/relax/backend/rocm/hipblas.py b/python/tvm/relax/backend/rocm/hipblas.py index 1c43777e23fd..15eaf80ab0c4 100644 --- a/python/tvm/relax/backend/rocm/hipblas.py +++ b/python/tvm/relax/backend/rocm/hipblas.py @@ -56,7 +56,7 @@ def _check_matmul(context: PatternCheckContext) -> bool: lhs_shape = lhs.struct_info.shape.values rhs_shape = rhs.struct_info.shape.values - if not isinstance(lhs_shape[-1], tvm.tir.expr.IntImm | int): + if not isinstance(lhs_shape[-1], tvm.tirx.expr.IntImm | int): # Reduction axis must be constant return False @@ -75,7 +75,7 @@ def _check_matmul(context: PatternCheckContext) -> bool: bias = context.annotated_expr["bias"] bias_shape = bias.struct_info.shape.values bias_batches = reduce(operator.mul, bias_shape[:-1], 1) - if not isinstance(bias_batches, tvm.tir.expr.IntImm | int) or int(bias_batches) > 1: + if not isinstance(bias_batches, tvm.tirx.expr.IntImm | int) or int(bias_batches) > 1: # hipblas only supports bias vector return False @@ -84,8 +84,8 @@ def _check_matmul(context: PatternCheckContext) -> bool: # must be equal. If lhs is batched but rhs is not, we can use the regular GEMM by # flattening all batch axes into the M axis. return ( - isinstance(lhs_batches, tvm.tir.Var) - or isinstance(rhs_batches, tvm.tir.Var) + isinstance(lhs_batches, tvm.tirx.Var) + or isinstance(rhs_batches, tvm.tirx.Var) or (int(lhs_batches) == int(rhs_batches)) or (lhs_batches >= 1 and rhs_batches == 1) ) diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py index 016bd27034d1..7840e5b3b4c6 100644 --- a/python/tvm/relax/base_py_module.py +++ b/python/tvm/relax/base_py_module.py @@ -24,7 +24,7 @@ import numpy as np import tvm -from tvm import relax, tir +from tvm import relax, tirx from tvm.ir import IRModule from tvm.runtime import Device, PackedFunc, Tensor from tvm.target import Target @@ -122,7 +122,7 @@ def _getattr_python_function(name: str) -> Any: def _collect_function_names(self): """Collect names of TIR and Relax functions from IRModule.""" for global_var, func in self.ir_mod.functions_items(): - if isinstance(func, tir.PrimFunc): + if isinstance(func, tirx.PrimFunc): self.tir_func_names.append(global_var.name_hint) elif isinstance(func, relax.Function): self.relax_func_names.append(global_var.name_hint) @@ -134,7 +134,7 @@ def _compile_functions(self): { gv: func for gv, func in self.ir_mod.functions_items() - if isinstance(func, tir.PrimFunc) + if isinstance(func, tirx.PrimFunc) } ) if tir_mod: @@ -315,7 +315,7 @@ def _infer_concrete_shape_from_args(self, shape, in_args): for idx, dim in enumerate(shape): if isinstance(dim, int | np.integer): concrete.append(int(dim)) - elif isinstance(dim, tir.IntImm): + elif isinstance(dim, tirx.IntImm): concrete.append(int(dim.value)) else: concrete.append(None) @@ -475,7 +475,7 @@ def get_function(self, name: str) -> PackedFunc | None: def list_functions(self) -> dict[str, list[str]]: """List all available functions.""" return { - "tir": self.tir_func_names, + "tirx": self.tir_func_names, "relax": self.relax_func_names, "extern": list(self.extern_funcs.keys()), } diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index b10ddf5ebd02..1d6057cec500 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -25,7 +25,7 @@ import tvm from tvm import relax as rx -from tvm import tir +from tvm import tirx from tvm.ir.module import IRModule from tvm.runtime import Object @@ -87,10 +87,10 @@ def __init__(self, block_builder, def_vars): self._bb = block_builder shape_vars = [] for var in def_vars: - if isinstance(var, tvm.tir.Var): + if isinstance(var, tvm.tirx.Var): shape_vars.append(var) else: - raise ValueError("def_vars only can take tir.Var") + raise ValueError("def_vars only can take tirx.Var") # setup a dummy var so shape is in scope. sparam = rx.Var("sparam", rx.ShapeStructInfo(shape_vars)) self._scope_params = [sparam] @@ -112,8 +112,8 @@ class BlockBuilder(Object): -------- .. code-block:: python - m = tir.Var("m", "int32") - n = tir.Var("n", "int32") + m = tirx.Var("m", "int32") + n = tirx.Var("n", "int32") x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) y = rx.Var("y", rx.TensorStructInfo([n], "float16") bb = rx.BlockBuilder() @@ -131,7 +131,7 @@ class BlockBuilder(Object): from tvm.relax.testing import nn - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") input_size = 784 hidden_sizes = [128, 32] output_size = 10 @@ -263,12 +263,12 @@ def function( return FunctionScope(self, name, params, attrs, is_pure=pure) - def testing_scope(self, def_vars: list[tir.Var]) -> TestingScope: + def testing_scope(self, def_vars: list[tirx.Var]) -> TestingScope: """Start a scope for unit-testing purposes. Parameters ---------- - def_vars: List[tir.Var] + def_vars: List[tirx.Var] List of symbolic variables that are marked as defined in scope. Returns @@ -453,7 +453,7 @@ def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var: .. code-block:: python bb = rx.BlockBuilder() - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) @@ -476,7 +476,7 @@ class Module: def te_func(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_compute: T.handle) -> None: # function attr dict - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int64() n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], dtype="float32") @@ -503,7 +503,7 @@ def rx_func(x: Tensor((n, m), "float32"), y: Tensor((n, m), "float32")) -> Tenso .. code-block:: python bb = relax.BlockBuilder() - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") x = relax.Var("x", relax.TensorStructInfo([n], "float32")) y = relax.Var("y", relax.TensorStructInfo([n + 1], "float32")) diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index f485d8bbbdf1..fefff21432e2 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -893,7 +893,7 @@ def is_call_tir( args : Union[List[DFPattern], Tuple[DFPattern]], optional Arguments in expected call_packed, by default None meaning arbitrary (number of) arguments tir_vars : Optional[DFPattern] - Pattern to match the tuple of integers that are unpacked when calling the tir func. + Pattern to match the tuple of integers that are unpacked when calling the tirx func. Returns ------- CallPattern diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index a6ea733ae9dd..7c7bcc2aeaa7 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -33,7 +33,7 @@ from ..ir import BaseFunc, Node, Span from ..runtime import Scriptable, String -from ..tir import PrimExpr +from ..tirx import PrimExpr from . import _ffi_api # It is a workaround for mypy: https://github.com/python/mypy/issues/7866#issuecomment-549454370 @@ -301,7 +301,7 @@ def elem_offset(self) -> "Expr": This parameter is not stored in the DLTensor, but is instead derived from the DLTensor's byte offset and datatype. This is exposed in Relax for ease of use, and for translation into the - `tir::BufferNode::elem_offset` field when interacting with TIR + `tirx::BufferNode::elem_offset` field when interacting with TIR buffers. """ self._check_for_tensor_struct_info() @@ -861,7 +861,7 @@ class PrimValue(Expr, Scriptable): def __init__(self, value: PrimExpr | int, span: Span | None = None) -> None: if isinstance(value, int): - value = tvm.tir.IntImm("int64", value) + value = tvm.tirx.IntImm("int64", value) self.__init_handle_by_constructor__(_ffi_api.PrimValue, value, span) # type: ignore @@ -1040,15 +1040,15 @@ def __call__(self, *args): """ return Call(self, args, None, None) - def bind_symbolic_vars(self, binding_map: Mapping[str | tvm.tir.Var, PrimExpr]) -> "Function": + def bind_symbolic_vars(self, binding_map: Mapping[str | tvm.tirx.Var, PrimExpr]) -> "Function": """Return a new function with updated symbolic variable Parameters ---------- - binding_map: Mapping[str | tvm.tir.Var, PrimExpr] + binding_map: Mapping[str | tvm.tirx.Var, PrimExpr] The mapping of values to be replaced. Keys may be either - a `tir.Var` or a string name of the variable. If the + a `tirx.Var` or a string name of the variable. If the variables are referred to by name, the name must uniquely identify a symbolic variable in the function. @@ -1062,7 +1062,7 @@ def bind_symbolic_vars(self, binding_map: Mapping[str | tvm.tir.Var, PrimExpr]) # Relax uses int64 for symbolic variables, but the FFI # converts python integers into int32. binding_map = { - key: tvm.tir.const(value, "int64") if isinstance(value, int) else value + key: tvm.tirx.const(value, "int64") if isinstance(value, int) else value for key, value in binding_map.items() } @@ -1108,7 +1108,7 @@ def _normalize_value(value): if isinstance(value, int): # Relax uses int64 for symbolic variables, but the FFI # converts python integers into int32. - return tvm.tir.const(value, "int64") + return tvm.tirx.const(value, "int64") elif isinstance(value, _np.ndarray | tvm.runtime.Tensor): return tvm.relax.const(value) else: @@ -1205,7 +1205,7 @@ class TEPlaceholderOp(tvm.te.tensor.Operation): def te_tensor( - value: Expr, tir_var_map: dict[tvm.tir.Var, tvm.tir.PrimExpr], name: str = "rxplaceholder" + value: Expr, tir_var_map: dict[tvm.tirx.Var, tvm.tirx.PrimExpr], name: str = "rxplaceholder" ): """Create a TE tensor from relax expression, with TIR variables in the tensor shape substituted by the given mapping @@ -1215,7 +1215,7 @@ def te_tensor( value : Expr The relax expression, which is required to have TensorStructInfo. - tir_var_map : Dict[tvm.tir.Var, tvm.tir.PrimExpr] + tir_var_map : Dict[tvm.tirx.Var, tvm.tirx.PrimExpr] The mapping to substitute the TIR variables appeared in the shape of the input Expr. diff --git a/python/tvm/relax/frontend/nn/_tensor_op.py b/python/tvm/relax/frontend/nn/_tensor_op.py index d7acc48fd22a..6285f8008f7d 100644 --- a/python/tvm/relax/frontend/nn/_tensor_op.py +++ b/python/tvm/relax/frontend/nn/_tensor_op.py @@ -17,7 +17,7 @@ # ruff: noqa: F821 """Adding member operators to nn.Tensor.""" -from tvm import tir +from tvm import tirx def _op(): @@ -31,7 +31,7 @@ def _convert_scalar(scalar, ref) -> "Tensor": if isinstance(scalar, Tensor): return scalar - if isinstance(scalar, tir.FloatImm | tir.IntImm): + if isinstance(scalar, tirx.FloatImm | tirx.IntImm): return Tensor.from_scalar(scalar.value, dtype=ref.dtype) if isinstance(scalar, int | float): return Tensor.from_scalar(scalar, dtype=ref.dtype) diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 2ac461cbff89..40659e1623d7 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -35,7 +35,7 @@ import numpy as np # type: ignore import tvm.runtime -from tvm import tir +from tvm import tirx from tvm.ir import IRModule from tvm.ir.transform import Pass from tvm.runtime import Device @@ -133,7 +133,7 @@ def from_struct_info(struct_info: rx.TensorStructInfo, name: str = "tensor") -> @staticmethod def placeholder( - shape: Sequence[int | str | tir.PrimExpr], + shape: Sequence[int | str | tirx.PrimExpr], dtype: str, name: str = "tensor", ) -> "Tensor": @@ -141,20 +141,20 @@ def placeholder( never be created directly by users in usual cases, and the only exception is to indicate the shape/dtype of return values of an external function. - If shape is a string `name`, we create a symbolic shape `tvm.tir.Var(name, "int64")`. + If shape is a string `name`, we create a symbolic shape `tvm.tirx.Var(name, "int64")`. """ new_shape = [] for expr in shape: - if isinstance(expr, int | tir.IntImm): + if isinstance(expr, int | tirx.IntImm): expr = int(expr) assert expr >= 0 new_shape.append(expr) continue if isinstance(expr, str): - expr = tir.Var(expr, "int64") + expr = tirx.Var(expr, "int64") new_shape.append(expr) continue - if not isinstance(expr, tir.PrimExpr): + if not isinstance(expr, tirx.PrimExpr): raise TypeError(f"Invalid shape: {shape}") assert expr.dtype == "int64" new_shape.append(expr) @@ -169,21 +169,21 @@ def placeholder( ) @property - def shape(self) -> list[int | tir.PrimExpr]: + def shape(self) -> list[int | tirx.PrimExpr]: """Returns the shape of the tensor as a list of integers. - An integer can be a python int or tvm.tir.PrimExpr, depending on whether the shape is - fully static, for example, [1, 2, tvm.tir.Var("n")] is a valid shape where the last + An integer can be a python int or tvm.tirx.PrimExpr, depending on whether the shape is + fully static, for example, [1, 2, tvm.tirx.Var("n")] is a valid shape where the last dimension is dynamic while the first two dimensions are always static constants. Returns ------- - shape : List[Union[int, tir.PrimExpr]] + shape : List[Union[int, tirx.PrimExpr]] The shape of the tensor """ - def _simplify(expr: tir.PrimExpr): - return expr.value if isinstance(expr, tir.IntImm) else expr + def _simplify(expr: tirx.PrimExpr): + return expr.value if isinstance(expr, tirx.IntImm) else expr shape_sinfo: ShapeStructInfo = self._expr.struct_info.shape.struct_info return [_simplify(x) for x in shape_sinfo.values] @@ -225,7 +225,7 @@ class Parameter(Tensor): def __init__( self, - shape: Sequence[int | str | tir.PrimExpr], + shape: Sequence[int | str | tirx.PrimExpr], dtype: str | None = None, ) -> None: """Create a parameter with given shape and dtype. The parameter is not bound to any @@ -233,9 +233,9 @@ def __init__( Parameters ---------- - shape : Sequence[Union[int, str, tir.PrimExpr]] + shape : Sequence[Union[int, str, tirx.PrimExpr]] The shape of the parameter. If it is a string `name`, we create a symbolic shape - `tvm.tir.Var(name, "int64")`. + `tvm.tirx.Var(name, "int64")`. dtype : Optional[str] The data type of the parameter. If not specified, the default dtype will be used. """ diff --git a/python/tvm/relax/frontend/nn/exporter.py b/python/tvm/relax/frontend/nn/exporter.py index a6cf334d1acb..374d7059194b 100644 --- a/python/tvm/relax/frontend/nn/exporter.py +++ b/python/tvm/relax/frontend/nn/exporter.py @@ -21,7 +21,7 @@ import threading import typing -from tvm import tir +from tvm import tirx from tvm.ir import IRModule from .... import relax as rx @@ -164,8 +164,8 @@ def _emit_method( # pylint: disable=too-many-locals,too-many-branches,too-many- effects: list[tuple[str, core.Effect]] | None, ): # pylint: disable=protected-access - # symbolic shape's name mapping to its tir.Var for reuse - str2var_params: dict[str, tir.Var] = {} + # symbolic shape's name mapping to its tirx.Var for reuse + str2var_params: dict[str, tirx.Var] = {} def _unwrap_ret(expr: typing.Any) -> typing.Any: if isinstance(expr, core.Tensor | core.Object): @@ -177,7 +177,7 @@ def _unwrap_ret(expr: typing.Any) -> typing.Any: raise TypeError(f"Unsupported return type: {type(expr)}") def _convert_input(arg): - if isinstance(arg, tir.Var): + if isinstance(arg, tirx.Var): return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg])) if isinstance(arg, core.Tensor | core.Object): return arg._expr # pylint: disable=protected-access @@ -193,18 +193,18 @@ def _convert_input(arg): def _params(mode: str) -> list[rx.Var]: inputs: list[rx.Var] = [] - def _get_var(shape_var: tir.Var) -> tir.Var: + def _get_var(shape_var: tirx.Var) -> tirx.Var: name = shape_var.name if name in str2var_params: return str2var_params[name] - var = tir.Var(name, "int64") + var = tirx.Var(name, "int64") str2var_params[name] = var return var for name, param in params: # Make sure the a symbolic shape is not re-registered (same as _method_spec_to_inputs) # e.g. we do not see `vocab_size` for `lm_head` and `vocab_size_1` for `embed_tokens` - new_shape = [_get_var(x) if isinstance(x, tir.Var) else x for x in param.shape] + new_shape = [_get_var(x) if isinstance(x, tirx.Var) else x for x in param.shape] var = core.Tensor.placeholder(new_shape, param.dtype, name)._expr inputs.append(var) param._expr = var @@ -265,7 +265,7 @@ def _detuple(arg, var: rx.Var, builder: BlockBuilder): return type(arg.elements)(ret) if isinstance(arg, core.Tensor): return core.Tensor(_expr=var) - if isinstance(arg, tir.Var): + if isinstance(arg, tirx.Var): return arg raise TypeError(f"Unsupported input type: {type(arg)}") @@ -292,14 +292,14 @@ def _detuple(arg, var: rx.Var, builder: BlockBuilder): def _method_spec_to_inputs( spec: _spec.MethodSpec, -) -> list[tir.Var | core.Tensor]: +) -> list[tirx.Var | core.Tensor]: """Convert the MethodSpec to a list of inputs to Module's method.""" - str2var: dict[str, tir.Var] = {} + str2var: dict[str, tirx.Var] = {} - def _get_var(name: str) -> tir.Var: + def _get_var(name: str) -> tirx.Var: if name in str2var: return str2var[name] - var = tir.Var(name, "int64") + var = tirx.Var(name, "int64") str2var[name] = var return var diff --git a/python/tvm/relax/frontend/nn/extern.py b/python/tvm/relax/frontend/nn/extern.py index 50d8b390c9a7..f442d491dc57 100644 --- a/python/tvm/relax/frontend/nn/extern.py +++ b/python/tvm/relax/frontend/nn/extern.py @@ -24,7 +24,7 @@ from collections.abc import Callable from pathlib import Path -from tvm import tir +from tvm import tirx from tvm.contrib import cc as _cc from tvm.runtime import Module, load_static_library @@ -53,12 +53,12 @@ def _convert(arg, name: str): if isinstance(arg, core.Tensor): return arg._expr # pylint: disable=protected-access if isinstance(arg, int): - return rx.PrimValue(tir.IntImm("int64", arg)) + return rx.PrimValue(tirx.IntImm("int64", arg)) if isinstance(arg, float): - return rx.PrimValue(tir.FloatImm("float64", arg)) + return rx.PrimValue(tirx.FloatImm("float64", arg)) if isinstance(arg, str): return rx.StringImm(arg) - if isinstance(arg, tir.PrimExpr): + if isinstance(arg, tirx.PrimExpr): return rx.PrimValue(arg) if isinstance(arg, tuple | list): return rx.Tuple([_convert(e, f"{name}_{i}") for i, e in enumerate(arg)]) diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index d7fc500b948b..f66ff6b6636f 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -25,10 +25,10 @@ import tvm from tvm import relax as rx -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.relax.frontend.nn import Object, Tensor from tvm.runtime import DataType -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target from .position_embedding import llama_rope_with_position_map, switch_rope_freq_func @@ -270,13 +270,13 @@ def merge_attn_output_inplace( lse_self_attn = Tensor(_expr=bb.emit(rx.TupleGetItem(merge_results, 1))).reshape(b, s, h_qo) return o_self_attn, lse_self_attn - def get_query_positions(self, total_length: tir.PrimExpr) -> Tensor: + def get_query_positions(self, total_length: tirx.PrimExpr) -> Tensor: """Get the in-sequence positions of each slot in the query, which are needed for applying positional embeddings in some models. Parameters ---------- - total_length : tir.PrimExpr + total_length : tirx.PrimExpr The summed-up total sequence length of queries in the batch being forwarded. @@ -321,11 +321,11 @@ class FlashInferPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-me def __init__( # pylint: disable=too-many-locals self, attn_kind: Literal["mha", "mla"] | list[Literal["mha", "mla", "mha_sliding"]], - max_batch_size: tir.Var, - max_total_seq_len: tir.Var, - prefill_chunk_size: tir.Var, - page_size: tir.Var, - support_sliding_window: tir.Var, + max_batch_size: tirx.Var, + max_total_seq_len: tirx.Var, + prefill_chunk_size: tirx.Var, + page_size: tirx.Var, + support_sliding_window: tirx.Var, layer_partition: rx.ShapeExpr, num_hidden_layers: int, num_attention_heads: int, @@ -349,23 +349,23 @@ def __init__( # pylint: disable=too-many-locals Parameters ---------- - max_batch_size : tir.Var + max_batch_size : tirx.Var The maximum allowed batch size of the KV cache. It is a symbolic variable whose concrete value is specified at runtime. - max_total_seq_len : tir.Var + max_total_seq_len : tirx.Var The maximum allowed total sequence length of the KV cache. It is a symbolic variable whose concrete value is specified at runtime. - prefill_chunk_size : tir.Var + prefill_chunk_size : tirx.Var The maximum total sequence length in a prefill. It is a symbolic variable whose concrete value is specified at runtime. - page_size : tir.Var + page_size : tirx.Var The size (a.k.a. number of tokens) of each page. It is a symbolic variable whose concrete value is specified at runtime. - support_sliding_window : tir.Var + support_sliding_window : tirx.Var 0 or 1, denoting whether the KV cache supports sliding window. It is a symbolic variable whose concrete value is specified at runtime. @@ -438,10 +438,10 @@ def __init__( # pylint: disable=too-many-locals [ rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_prefill_paged_run"), rx.ExternFunc("batch_prefill_plan")]), rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_decode_run"), rx.ExternFunc("batch_decode_plan")]), - rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling, target), "tir_attention_prefill_sliding_window")]), - rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling, target), "tir_attention_decode_sliding_window")]), - rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), - rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling, target), "tir_attention_prefill_sliding_window")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling, target), "tir_attention_decode_sliding_window")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]), ] if attn_kind_single == "mha" else [rx.Tuple([]) for _ in range(6)] @@ -509,11 +509,11 @@ class TIRPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods def __init__( # pylint: disable=too-many-locals self, attn_kind: Literal["mha", "mla"] | list[Literal["mha", "mla", "mha_sliding"]], - max_batch_size: tir.Var, - max_total_seq_len: tir.Var, - prefill_chunk_size: tir.Var, - page_size: tir.Var, - support_sliding_window: tir.Var, + max_batch_size: tirx.Var, + max_total_seq_len: tirx.Var, + prefill_chunk_size: tirx.Var, + page_size: tirx.Var, + support_sliding_window: tirx.Var, layer_partition: rx.ShapeExpr, num_hidden_layers: int, num_attention_heads: int, @@ -537,23 +537,23 @@ def __init__( # pylint: disable=too-many-locals Parameters ---------- - max_batch_size : tir.Var + max_batch_size : tirx.Var The maximum allowed batch size of the KV cache. It is a symbolic variable whose concrete value is specified at runtime. - max_total_seq_len : tir.Var + max_total_seq_len : tirx.Var The maximum allowed total sequence length of the KV cache. It is a symbolic variable whose concrete value is specified at runtime. - prefill_chunk_size : tir.Var + prefill_chunk_size : tirx.Var The maximum total sequence length in a prefill. It is a symbolic variable whose concrete value is specified at runtime. - page_size : tir.Var + page_size : tirx.Var The size (a.k.a. number of tokens) of each page. It is a symbolic variable whose concrete value is specified at runtime. - support_sliding_window : tir.Var + support_sliding_window : tirx.Var 0 or 1, denoting whether the KV cache supports sliding window. It is a symbolic variable whose concrete value is specified at runtime. @@ -630,13 +630,13 @@ def __init__( # pylint: disable=too-many-locals # fmt: off args.extend( [ - rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, v_head_dim, dtype, rope_scaling), "tir_attention_prefill_ragged_cpu")]), - rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling), "tir_attention_prefill_cpu")]), - rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_decode_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling), "tir_attention_decode_cpu")]), - rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling), "tir_attention_prefill_cpu_sliding_window")]), - rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_decode_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling), "tir_attention_decode_cpu_sliding_window")]), - rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling), "tir_attention_prefill_with_tree_mask_cpu")]), - rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache_cpu")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(_attention_prefill_ragged_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, v_head_dim, dtype, rope_scaling), "tir_attention_prefill_ragged_cpu")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(_attention_prefill_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling), "tir_attention_prefill_cpu")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(_attention_decode_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling), "tir_attention_decode_cpu")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(_attention_prefill_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling), "tir_attention_prefill_cpu_sliding_window")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(_attention_decode_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling), "tir_attention_decode_cpu_sliding_window")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(tree_attn_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling), "tir_attention_prefill_with_tree_mask_cpu")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(tree_attn_with_paged_kv_cache_cpu(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache_cpu")]), rx.Tuple([]), # f_mla_prefill rx.Tuple([bb.add_func(_merge_state_inplace_cpu(dtype), "tir_attention_merge_state_cpu")]), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), @@ -652,20 +652,20 @@ def __init__( # pylint: disable=too-many-locals # fmt: off ragged_qk_head_dim = qk_head_dim if attn_kind_single == "mha" else mla_original_qk_head_dim ragged_v_head_dim = v_head_dim if attn_kind_single == "mha" else mla_original_v_head_dim - args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if attn_kind_single == "mha" else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")])) + args.append(rx.Tuple([rx.StringImm("tirx"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if attn_kind_single == "mha" else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")])) mha_functions = ( [ - rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill")]), - rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_decode")]), - rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling, target), "tir_attention_prefill_sliding_window")]), - rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling, target), "tir_attention_decode_sliding_window")]), - rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), - rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_decode")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling, target), "tir_attention_prefill_sliding_window")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(_attention_decode(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, True, rope_scaling, target), "tir_attention_decode_sliding_window")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), + rx.Tuple([rx.StringImm("tirx"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]), ] if attn_kind_single == "mha" else [rx.Tuple([]) for _ in range(6)] ) - mla_function = rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim - v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind_single == "mla" else []) + mla_function = rx.Tuple([rx.StringImm("tirx"), bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim - v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind_single == "mla" else []) attn_merge_functions = [ bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim, dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"), ] @@ -711,7 +711,7 @@ def tir_kv_cache_transpose_append( var_v_data: T.handle, var_position_map: T.handle, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) ntoken = T.SizeVar("num_tokens_excluding_cache", "int64") num_pages = T.int64() pages_elem_offset = T.int64() @@ -753,7 +753,7 @@ def tir_kv_cache_transpose_append_mla( var_kv_data: T.handle, var_position_map: T.handle, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) ntoken = T.SizeVar("num_tokens_excluding_cache", "int64") num_pages = T.int64() pages_elem_offset = T.int64() @@ -790,7 +790,7 @@ def tir_kv_cache_debug_get_kv( var_v_data: T.handle, layer_id: T.int64, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) seqlen = T.SizeVar("num_tokens_including_cache", "int64") page_size = T.SizeVar("page_size", "int64") num_pages = T.int64() @@ -828,7 +828,7 @@ def tir_kv_cache_debug_get_kv_mla( var_compressed_kv_with_k_pe_data: T.handle, layer_id: T.int64, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) seqlen = T.SizeVar("num_tokens_including_cache", "int64") page_size = T.SizeVar("page_size", "int64") num_pages = T.int64() @@ -854,11 +854,11 @@ def tir_kv_cache_debug_get_kv_mla( def _rope( buffer: T.Buffer, - offset: tir.Var, + offset: tirx.Var, rotary_dim: int, - theta: tir.Var, - scale: tir.Var, - indices: tuple[tir.Var, ...], + theta: tirx.Var, + scale: tirx.Var, + indices: tuple[tirx.Var, ...], qkv_dtype: str, rope_scaling: dict[str, Any], ): @@ -867,14 +867,14 @@ def _rope( offset * scale, d, rotary_dim, theta, "float32" ) cos = cos_freq * buffer[indices].astype("float32") - sin = sin_freq * tir.if_then_else( + sin = sin_freq * tirx.if_then_else( d < rotary_dim // 2, -buffer[indices[:-1] + (d + rotary_dim // 2,)], buffer[indices[:-1] + (d - rotary_dim // 2,)], ).astype("float32") expr = (cos + sin).astype(qkv_dtype) for var, value in var_map.items(): - expr = tir.Let(var, value, expr) + expr = tirx.Let(var, value, expr) return expr @@ -1520,7 +1520,7 @@ def batch_prefill_paged_kv( sch = _schedule_prefill_kernel( sch, LOAD_VEC, bdx, num_warps, tile_x, tile_y, tile_z, False, False ) - return sch.mod["main"].with_attr("tir.is_scheduled", True) + return sch.mod["main"].with_attr("tirx.is_scheduled", True) def _attention_decode_cpu( @@ -1559,7 +1559,7 @@ def batch_decode_paged_kv( rope_theta: T.float32, sm_scale: T.float32, ): - T.func_attr({"tir.is_scheduled": True, "global_symbol": global_symbol}) + T.func_attr({"tirx.is_scheduled": True, "global_symbol": global_symbol}) B = T.int32(is_size_var=True) nnz_pages = T.int32(is_size_var=True) max_num_pages = T.int32(is_size_var=True) @@ -1740,7 +1740,7 @@ def batch_decode_paged_kv( rope_theta: T.float32, sm_scale: T.float32, ): - T.func_attr({"tir.is_scheduled": True, "global_symbol": global_symbol}) + T.func_attr({"tirx.is_scheduled": True, "global_symbol": global_symbol}) B = T.int32(is_size_var=True) nnz_pages = T.int32(is_size_var=True) max_num_pages = T.int32(is_size_var=True) @@ -1943,7 +1943,7 @@ def merge_state_inplace_cpu( v_other: T.handle, s_other: T.handle, ): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) N = T.int32(is_size_var=True) H = T.int32(is_size_var=True) D = T.int32(is_size_var=True) @@ -1996,7 +1996,7 @@ def merge_state_inplace( v_other: T.handle, s_other: T.handle, ): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) N = T.int32(is_size_var=True) H = T.int32(is_size_var=True) D = T.int32(is_size_var=True) @@ -2302,7 +2302,7 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches sch = _schedule_prefill_kernel( sch, LOAD_VEC, bdx, num_warps, tile_x, tile_y, tile_z, False, False ) - return sch.mod["main"].with_attr("tir.is_scheduled", True) + return sch.mod["main"].with_attr("tirx.is_scheduled", True) def _attention_prefill_ragged_cpu(h_kv, h_q, d_qk, d_v, dtype, rope_scaling: dict[str, Any]): @@ -2694,7 +2694,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches # pylint: enable=line-too-long,too-many-branches sch = tvm.s_tir.Schedule(batch_prefill_ragged_kv) sch = _schedule_prefill_kernel(sch, LOAD_VEC, bdx, num_warps, tile_x, d_v, tile_z, True, False) - return sch.mod["main"].with_attr("tir.is_scheduled", True) + return sch.mod["main"].with_attr("tirx.is_scheduled", True) def _attention_prefill_mla( @@ -2963,7 +2963,7 @@ def batch_prefill_paged_kv_mla( sch = _schedule_prefill_kernel( sch, LOAD_VEC, bdx, num_warps, tile_x, d_latent, tile_z, False, True ) - return sch.mod["main"].with_attr("tir.is_scheduled", True) + return sch.mod["main"].with_attr("tirx.is_scheduled", True) def _copy_single_page(num_heads, page_size, head_dim, dtype, target: Target): @@ -2976,7 +2976,7 @@ def copy_single_page( tgt_page_id: T.int64, copy_length: T.int64, ): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) num_pages = T.int32() pages_elem_offset = T.int64() pages = T.match_buffer( @@ -3023,7 +3023,7 @@ def copy_single_page_mla( tgt_page_id: T.int64, copy_length: T.int64, ): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) num_pages = T.int32() pages_elem_offset = T.int64() pages = T.match_buffer( @@ -3051,7 +3051,7 @@ def copy_single_page_cpu( tgt_page_id: T.int64, copy_length: T.int64, ): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) num_pages = T.int32() pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype) @@ -3090,7 +3090,7 @@ def compact_kv_copy( var_copy_src_dst_pos: T.handle, batch_size: T.int32, ): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) num_pages = T.int32() total_copy_length = T.int32() copy_length_indptr_elem_offset = T.int32() @@ -3147,7 +3147,7 @@ def compact_kv_copy_cpu( var_copy_src_dst_pos: T.handle, batch_size: T.int32, ): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) num_pages = T.int32() total_copy_length = T.int32() copy_length_indptr_elem_offset = T.int32() diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index 2fc438cacd1d..dec80c50c270 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -22,22 +22,22 @@ from functools import partial from typing import Any -from tvm import tir +from tvm import tirx from tvm.relax.frontend.nn import Tensor, op -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=invalid-name -def rope_freq_default(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: str): +def rope_freq_default(s: tirx.Var, d: tirx.Var, d_range: int, theta: float, dtype: str): """Compute the inverse frequency of RoPE and then return the cosine and sine of it. Parameters ---------- - s : tir.Var + s : tirx.Var The position index. - d : tir.Var + d : tirx.Var The dimension index. d_range : int @@ -57,28 +57,28 @@ def rope_freq_default(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: sin_freq : Tensor The sine of the inverse frequency. - var_map: Dict[tir.Var, tir.PrimExpr] + var_map: Dict[tirx.Var, tirx.PrimExpr] The common expression map. """ - freq = s / tir.power(theta, d * 2 % d_range / tir.const(d_range, "float32")) - freq_var = tir.Var("freq", "float32") - cos_freq = tir.cos(freq_var).astype(dtype) - sin_freq = tir.sin(freq_var).astype(dtype) + freq = s / tirx.power(theta, d * 2 % d_range / tirx.const(d_range, "float32")) + freq_var = tirx.Var("freq", "float32") + cos_freq = tirx.cos(freq_var).astype(dtype) + sin_freq = tirx.sin(freq_var).astype(dtype) return cos_freq, sin_freq, {freq_var: freq} -def rope_freq_gptj(s: tir.Var, d: tir.Var, d_range: int, theta: float, dtype: str): +def rope_freq_gptj(s: tirx.Var, d: tirx.Var, d_range: int, theta: float, dtype: str): """Compute the inverse frequency of RoPE for gptj RoPE scaling.""" - freq = s / tir.power(theta, 2 * (d // 2) % d_range / tir.const(d_range, "float32")) - freq_var = tir.Var("freq", "float32") - cos_freq = tir.cos(freq_var).astype(dtype) - sin_freq = tir.sin(freq_var).astype(dtype) + freq = s / tirx.power(theta, 2 * (d // 2) % d_range / tirx.const(d_range, "float32")) + freq_var = tirx.Var("freq", "float32") + cos_freq = tirx.cos(freq_var).astype(dtype) + sin_freq = tirx.sin(freq_var).astype(dtype) return cos_freq, sin_freq, {freq_var: freq} def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals - s: tir.Var, - d: tir.Var, + s: tirx.Var, + d: tirx.Var, d_range: int, theta: float, dtype: str, @@ -88,18 +88,20 @@ def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals original_max_position_embeddings: float, ): """Compute the inverse frequency of RoPE for llama4 RoPE scaling.""" - orig_freq = tir.const(1, "float32") / tir.power( - theta, 2 * (d // 2) / tir.const(d_range, "float32") + orig_freq = tirx.const(1, "float32") / tirx.power( + theta, 2 * (d // 2) / tirx.const(d_range, "float32") ) - orig_freq_var = tir.Var("orig_freq", "float32") + orig_freq_var = tirx.Var("orig_freq", "float32") llama4_inv_scaling_factor = 1.0 / factor if high_freq_factor == low_freq_factor: - wavelength = tir.const(2 * math.pi, "float32") / orig_freq_var - threshold_wavelen = tir.const(original_max_position_embeddings / low_freq_factor, "float32") + wavelength = tirx.const(2 * math.pi, "float32") / orig_freq_var + threshold_wavelen = tirx.const( + original_max_position_embeddings / low_freq_factor, "float32" + ) - scaled_freq = tir.if_then_else( + scaled_freq = tirx.if_then_else( wavelength > threshold_wavelen, orig_freq_var / factor, orig_freq_var ) smoothed_freq = s * scaled_freq @@ -110,14 +112,14 @@ def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals llama4_alpha = original_max_position_embeddings / (2 * math.pi) * inv_diff_freq_factor llama4_beta = low_freq_factor * inv_diff_freq_factor - smooth = tir.max(0.0, tir.min(1.0, llama4_alpha * orig_freq_var - llama4_beta)) + smooth = tirx.max(0.0, tirx.min(1.0, llama4_alpha * orig_freq_var - llama4_beta)) smoothed_freq = s * ( (1.0 - smooth) * orig_freq_var * llama4_inv_scaling_factor + smooth * orig_freq_var ) - smoothed_freq_var = tir.Var("smoothed_freq", "float32") - cos_freq = tir.cos(smoothed_freq_var).astype(dtype) - sin_freq = tir.sin(smoothed_freq_var).astype(dtype) + smoothed_freq_var = tirx.Var("smoothed_freq", "float32") + cos_freq = tirx.cos(smoothed_freq_var).astype(dtype) + sin_freq = tirx.sin(smoothed_freq_var).astype(dtype) return ( cos_freq, sin_freq, @@ -126,8 +128,8 @@ def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals - s: tir.Var, - d: tir.Var, + s: tirx.Var, + d: tirx.Var, d_range: int, theta: float, dtype: str, @@ -137,21 +139,21 @@ def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals original_max_position_embeddings: float, ): """Compute the inverse frequency of RoPE for llama3 RoPE scaling.""" - orig_freq = tir.const(1, "float32") / tir.power( - theta, d * 2 % d_range / tir.const(d_range, "float32") + orig_freq = tirx.const(1, "float32") / tirx.power( + theta, d * 2 % d_range / tirx.const(d_range, "float32") ) - orig_freq_var = tir.Var("orig_freq", "float32") + orig_freq_var = tirx.Var("orig_freq", "float32") inv_diff_freq_factor = 1.0 / (high_freq_factor - low_freq_factor) llama3_inv_scaling_factor = 1.0 / factor llama3_alpha = original_max_position_embeddings / (2 * math.pi) * inv_diff_freq_factor llama3_beta = low_freq_factor * inv_diff_freq_factor - smooth = tir.max(0.0, tir.min(1.0, llama3_alpha * orig_freq_var - llama3_beta)) + smooth = tirx.max(0.0, tirx.min(1.0, llama3_alpha * orig_freq_var - llama3_beta)) smoothed_freq = s * ( (1.0 - smooth) * orig_freq_var * llama3_inv_scaling_factor + smooth * orig_freq_var ) - smoothed_freq_var = tir.Var("smoothed_freq", "float32") - cos_freq = tir.cos(smoothed_freq_var).astype(dtype) - sin_freq = tir.sin(smoothed_freq_var).astype(dtype) + smoothed_freq_var = tirx.Var("smoothed_freq", "float32") + cos_freq = tirx.cos(smoothed_freq_var).astype(dtype) + sin_freq = tirx.sin(smoothed_freq_var).astype(dtype) return ( cos_freq, sin_freq, @@ -160,8 +162,8 @@ def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals def rope_freq_longrope( # pylint: disable=too-many-arguments - s: tir.Var, - d: tir.Var, + s: tirx.Var, + d: tirx.Var, d_range: int, theta: float, dtype: str, @@ -176,21 +178,21 @@ def rope_freq_longrope( # pylint: disable=too-many-arguments if scale > 1.0 else 1.0 ) - divisor = tir.power(theta, d * 2 % d_range / tir.const(d_range, "float32")) + divisor = tirx.power(theta, d * 2 % d_range / tirx.const(d_range, "float32")) if ext_factors is not None: divisor = ext_factors[d % (d_range // 2)] * divisor freq = s / divisor - freq_var = tir.Var("freq", "float32") - cos_freq = (tir.cos(freq_var) * scaling_factor).astype(dtype) - sin_freq = (tir.sin(freq_var) * scaling_factor).astype(dtype) + freq_var = tirx.Var("freq", "float32") + cos_freq = (tirx.cos(freq_var) * scaling_factor).astype(dtype) + sin_freq = (tirx.sin(freq_var) * scaling_factor).astype(dtype) return cos_freq, sin_freq, {freq_var: freq} def yarn_find_correction_dim( num_rotations: int, - d: tir.Var, + d: tirx.Var, max_position_embeddings: int, - inv_theta_log_scale: float | tir.PrimExpr | None = None, + inv_theta_log_scale: float | tirx.PrimExpr | None = None, ): """Inverse dim formula to find dim based on number of rotations""" return ( @@ -201,9 +203,9 @@ def yarn_find_correction_dim( def yarn_find_correction_range( low_rot: int, high_rot: int, - d: tir.Var, + d: tirx.Var, max_position_embeddings: int, - inv_theta_log_scale: float | tir.PrimExpr | None = None, + inv_theta_log_scale: float | tirx.PrimExpr | None = None, ): """Find the correction range based on the number of rotations""" low = yarn_find_correction_dim( @@ -212,27 +214,27 @@ def yarn_find_correction_range( high = yarn_find_correction_dim( high_rot, d, max_position_embeddings, inv_theta_log_scale=inv_theta_log_scale ) - return tir.max(low, 0), tir.min(high, d - 1) + return tirx.max(low, 0), tirx.min(high, d - 1) def rope_freq_yarn( - s: tir.Var, - d: tir.Var, + s: tirx.Var, + d: tirx.Var, d_range: int, - theta: float | tir.PrimExpr, + theta: float | tirx.PrimExpr, dtype: str, original_max_position_embeddings: int, scaling_factor: float, beta_fast: int, beta_slow: int, - inv_theta_log_scale: float | tir.PrimExpr | None = None, + inv_theta_log_scale: float | tirx.PrimExpr | None = None, ): # pylint: disable=too-many-arguments, too-many-locals """Compute the inverse frequency of RoPE for yarn RoPE scaling.""" - exponent = d * 2 % d_range / tir.const(d_range, "float32") - freq_power = tir.power(theta, exponent) - freq_extra = tir.const(1, "float32") / freq_power - freq_inter = tir.const(1, "float32") / (scaling_factor * freq_power) + exponent = d * 2 % d_range / tirx.const(d_range, "float32") + freq_power = tirx.power(theta, exponent) + freq_extra = tirx.const(1, "float32") / freq_power + freq_inter = tirx.const(1, "float32") / (scaling_factor * freq_power) low, high = yarn_find_correction_range( beta_fast, @@ -241,15 +243,15 @@ def rope_freq_yarn( original_max_position_embeddings, inv_theta_log_scale=inv_theta_log_scale, ) - high = tir.if_then_else(low == high, high + 0.001, high) - inv_freq_mask = tir.const(1, "float32") - tir.max( - tir.min((d - low) / (high - low), 1.0), 0.0 + high = tirx.if_then_else(low == high, high + 0.001, high) + inv_freq_mask = tirx.const(1, "float32") - tirx.max( + tirx.min((d - low) / (high - low), 1.0), 0.0 ).astype("float32") inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask freq = s * inv_freq - freq_var = tir.Var("freq", "float32") - cos_freq = tir.cos(freq_var).astype(dtype) - sin_freq = tir.sin(freq_var).astype(dtype) + freq_var = tirx.Var("freq", "float32") + cos_freq = tirx.cos(freq_var).astype(dtype) + sin_freq = tirx.sin(freq_var).astype(dtype) return cos_freq, sin_freq, {freq_var: freq} @@ -302,7 +304,7 @@ def switch_rope_freq_func(rope_scaling: dict[str, Any]) -> Callable: def llama_rope( # pylint: disable=too-many-arguments qkv: Tensor, - total_seq_len: tir.Var, + total_seq_len: tirx.Var, theta: float, scale: float, num_q_heads: int, @@ -318,7 +320,7 @@ def llama_rope( # pylint: disable=too-many-arguments qkv : Tensor The fused QKV tensor of shape: [batch_size, seq_len, #q_heads + #kv_heads * 2, head_dim] - total_seq_len : tir.Var + total_seq_len : tirx.Var The total sequence length after being concatenated with KVCache. It is used to compute the offset of RoPE. @@ -357,35 +359,35 @@ def llama_rope( # pylint: disable=too-many-arguments if rotary_dim is None: rotary_dim = head_dim dtype = qkv.dtype - scale = tir.const(scale, dtype) + scale = tirx.const(scale, dtype) def _rope( # pylint: disable=too-many-arguments x: T.Buffer, - b: tir.Var, - s: tir.Var, - h: tir.Var, - d: tir.Var, - offset: tir.Var, + b: tirx.Var, + s: tirx.Var, + h: tirx.Var, + d: tirx.Var, + offset: tirx.Var, ): cos_freq, sin_freq, var_map = switch_rope_freq_func(rope_scaling)( (s + offset) * scale, d, rotary_dim, theta, dtype ) cos = cos_freq * x[b, s, h, d] if rope_scaling["rope_type"] == "gptj": - sin = sin_freq * tir.if_then_else( + sin = sin_freq * tirx.if_then_else( d % 2 == 0, -x[b, s, h, d + 1], x[b, s, h, d - 1], ) else: - sin = sin_freq * tir.if_then_else( + sin = sin_freq * tirx.if_then_else( d < rotary_dim // 2, -x[b, s, h, d + rotary_dim // 2], x[b, s, h, d - rotary_dim // 2], ) expr = cos + sin for var, value in var_map.items(): - expr = tir.Let(var, value, expr) + expr = tirx.Let(var, value, expr) return expr @T.prim_func(private=True) @@ -399,7 +401,7 @@ def fused_rope( # pylint: disable=too-many-locals T.func_attr( { "op_pattern": 8, # 2 means injective, 8 means opaque - "tir.noalias": True, + "tirx.noalias": True, } ) batch_size = T.int64() @@ -481,7 +483,7 @@ def llama_rope_with_position_map( # pylint: disable=too-many-arguments fused_heads = num_q_heads + num_kv_heads * 2 if rotary_dim is None: rotary_dim = head_dim - scale = tir.const(scale, "float32") + scale = tirx.const(scale, "float32") is_longrope_scaling = rope_scaling.get("rope_type") == "longrope" if is_longrope_scaling and "original_max_position_embeddings" in rope_scaling: original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] @@ -490,10 +492,10 @@ def llama_rope_with_position_map( # pylint: disable=too-many-arguments def _rope( # pylint: disable=too-many-arguments x: T.Buffer, - s: tir.Var, - h: tir.Var, - d: tir.Var, - pos: tir.Var, + s: tirx.Var, + h: tirx.Var, + d: tirx.Var, + pos: tirx.Var, ext_factors: T.Buffer | None = None, ): kwargs = {} @@ -504,20 +506,20 @@ def _rope( # pylint: disable=too-many-arguments ) cos = cos_freq * x[s, h, d].astype("float32") if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj": - sin = sin_freq * tir.if_then_else( + sin = sin_freq * tirx.if_then_else( d % 2 == 0, -x[s, h, d + 1], x[s, h, d - 1], ).astype("float32") else: - sin = sin_freq * tir.if_then_else( + sin = sin_freq * tirx.if_then_else( d < rotary_dim // 2, -x[s, h, d + rotary_dim // 2], x[s, h, d - rotary_dim // 2], ).astype("float32") expr = (cos + sin).astype(dtype) for var, value in var_map.items(): - expr = tir.Let(var, value, expr) + expr = tirx.Let(var, value, expr) return expr @T.prim_func @@ -532,7 +534,7 @@ def fused_rope( # pylint: disable=too-many-locals T.func_attr( { "op_pattern": 8, # 2 means injective, 8 means opaque - "tir.noalias": True, + "tirx.noalias": True, } ) seq_len = T.int32() @@ -574,7 +576,7 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals T.func_attr( { "op_pattern": 8, # 2 means injective, 8 means opaque - "tir.noalias": True, + "tirx.noalias": True, } ) seq_len = T.int64() @@ -707,7 +709,7 @@ def llama4_rope_with_position_map( # pylint: disable=too-many-arguments fused_heads = num_q_heads + num_kv_heads * 2 if rotary_dim is None: rotary_dim = head_dim - scale = tir.const(scale, "float32") + scale = tirx.const(scale, "float32") is_longrope_scaling = rope_scaling.get("rope_type") == "longrope" if is_longrope_scaling and "original_max_position_embeddings" in rope_scaling: original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] @@ -716,10 +718,10 @@ def llama4_rope_with_position_map( # pylint: disable=too-many-arguments def _rope( # pylint: disable=too-many-arguments x: T.Buffer, - s: tir.Var, - h: tir.Var, - d: tir.Var, - pos: tir.Var, + s: tirx.Var, + h: tirx.Var, + d: tirx.Var, + pos: tirx.Var, ext_factors: T.Buffer | None = None, ): kwargs = {} @@ -730,21 +732,21 @@ def _rope( # pylint: disable=too-many-arguments ) cos = cos_freq * x[s, h, d].astype("float32") if "rope_type" in rope_scaling and rope_scaling["rope_type"] == "gptj": - sin = sin_freq * tir.if_then_else( + sin = sin_freq * tirx.if_then_else( d % 2 == 0, -x[s, h, d + 1], x[s, h, d - 1], ).astype("float32") else: # Data layout is different for llama4 vs llama3 - sin = sin_freq * tir.if_then_else( + sin = sin_freq * tirx.if_then_else( d % 2 == 0, -x[s, h, d + 1], x[s, h, d - 1], ).astype("float32") expr = (cos + sin).astype(dtype) for var, value in var_map.items(): - expr = tir.Let(var, value, expr) + expr = tirx.Let(var, value, expr) return expr @T.prim_func(private=True) @@ -759,7 +761,7 @@ def fused_rope( # pylint: disable=too-many-locals T.func_attr( { "op_pattern": 8, # 2 means injective, 8 means opaque - "tir.noalias": True, + "tirx.noalias": True, } ) seq_len = T.int32() @@ -801,7 +803,7 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals T.func_attr( { "op_pattern": 8, # 2 means injective, 8 means opaque - "tir.noalias": True, + "tirx.noalias": True, } ) seq_len = T.int64() diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py b/python/tvm/relax/frontend/nn/llm/tree_attn.py index d32669f9423a..c55aa3eceb22 100644 --- a/python/tvm/relax/frontend/nn/llm/tree_attn.py +++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py @@ -22,9 +22,9 @@ import math from typing import Any -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.runtime import DataType -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target from .position_embedding import switch_rope_freq_func @@ -39,11 +39,11 @@ def _var(dtype): def _rope( buffer: T.Buffer, - offset: tir.Var, + offset: tirx.Var, rotary_dim: int, - theta: tir.Var, - scale: tir.Var, - indices: tuple[tir.Var, ...], + theta: tirx.Var, + scale: tirx.Var, + indices: tuple[tirx.Var, ...], qkv_dtype: str, rope_scaling: dict[str, Any], ): @@ -52,14 +52,14 @@ def _rope( offset * scale, d, rotary_dim, theta, "float32" ) cos = cos_freq * buffer[indices].astype("float32") - sin = sin_freq * tir.if_then_else( + sin = sin_freq * tirx.if_then_else( d < rotary_dim // 2, -buffer[indices[:-1] + (d + rotary_dim // 2,)], buffer[indices[:-1] + (d - rotary_dim // 2,)], ).astype("float32") expr = (cos + sin).astype(qkv_dtype) for var, value in var_map.items(): - expr = tir.Let(var, value, expr) + expr = tirx.Let(var, value, expr) return expr @@ -69,11 +69,11 @@ def _check_tree_order(tree_order_indptr, tree_order, batch, row, col, kv_len, qo tree_start = kv_len - tree_order_len child_idx_in_tree = row + tree_order_len - qo_len parent_idx_in_tree = col - tree_start - return tir.all( + return tirx.all( col < kv_len, - tir.any( + tirx.any( col < tree_start, - tir.all( + tirx.all( tree_order[tree_order_indptr[batch] + child_idx_in_tree, 0] >= tree_order[tree_order_indptr[batch] + parent_idx_in_tree, 0], tree_order[tree_order_indptr[batch] + child_idx_in_tree, 0] @@ -648,7 +648,7 @@ def apply_to_md(sch, block): apply_to_qkv_load(sch, sch.get_sblock("KV_load")) apply_to_md(sch, sch.get_sblock("lse_store")) - return sch.mod["main"].with_attr("tir.is_scheduled", True) + return sch.mod["main"].with_attr("tirx.is_scheduled", True) def tree_attn_with_paged_kv_cache_cpu(h_kv, h_q, d, dtype, rope_scaling: dict[str, Any]): @@ -1341,4 +1341,4 @@ def apply_to_md(sch, block): apply_to_qkv_load(sch, sch.get_sblock("K_load")) apply_to_qkv_load(sch, sch.get_sblock("V_load")) apply_to_md(sch, sch.get_sblock("lse_store")) - return sch.mod["main"].with_attr("tir.is_scheduled", True) + return sch.mod["main"].with_attr("tirx.is_scheduled", True) diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index ee16f25d2881..cf6c827b9ef2 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -21,7 +21,7 @@ from collections.abc import Sequence from tvm import relax as rx -from tvm import tir +from tvm import tirx from . import op from .core import Effect, Module, ModuleList, Parameter, Tensor, get_default_dtype @@ -101,8 +101,8 @@ class Linear(Module): def __init__( self, - in_features: int | str | tir.PrimExpr, - out_features: int | str | tir.PrimExpr, + in_features: int | str | tirx.PrimExpr, + out_features: int | str | tirx.PrimExpr, bias: bool = True, dtype: str | None = None, out_dtype: str | None = None, @@ -242,7 +242,7 @@ def __init__( # pylint: disable=too-many-arguments if isinstance(self.in_channels, int): in_channels = int(self.in_channels / self.groups) else: - in_channels = tir.floordiv(self.in_channels, self.groups) + in_channels = tirx.floordiv(self.in_channels, self.groups) # Expand kernel size if provided an integer. if isinstance(kernel_size, int): @@ -316,7 +316,7 @@ def __init__( # pylint: disable=too-many-arguments if isinstance(self.in_channels, int): in_channels = int(self.in_channels / self.groups) else: - in_channels = tir.floordiv(self.in_channels, self.groups) + in_channels = tirx.floordiv(self.in_channels, self.groups) # Expand kernel size if given an integer. if isinstance(kernel_size, int): @@ -657,13 +657,13 @@ def to(self, dtype: str | None = None) -> None: if dtype is not None: self.dtype = dtype - def view(self, seq_len: tir.Var) -> Tensor: + def view(self, seq_len: tirx.Var) -> Tensor: """ View the last elements in KVCache. Parameters ---------- - seq_len : tir.Var + seq_len : tirx.Var The number of last elements to view. Returns @@ -714,8 +714,8 @@ class Embedding(Module): def __init__( self, - num: int | str | tir.PrimExpr, - dim: int | str | tir.PrimExpr, + num: int | str | tirx.PrimExpr, + dim: int | str | tirx.PrimExpr, dtype: str | None = None, ): self.num = num diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 3c1fa442b007..53c21ad56a35 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -26,8 +26,8 @@ import numpy as np from tvm import te -from tvm import tir as _tir -from tvm.script import tir as T +from tvm import tirx as _tir +from tvm.script import tirx as T from ... import expr as rx from ... import op as _op diff --git a/python/tvm/relax/frontend/nn/subroutine.py b/python/tvm/relax/frontend/nn/subroutine.py index 87f272071202..c62491be9ff5 100644 --- a/python/tvm/relax/frontend/nn/subroutine.py +++ b/python/tvm/relax/frontend/nn/subroutine.py @@ -168,7 +168,7 @@ def _get_subroutine( gvar = block_builder.emit_func_output(out) # The relax.Var instances in model_params, along with any - # tir.Var instances in the struct info, appear in both the + # tirx.Var instances in the struct info, appear in both the # calling scope and as parameters for the subroutine. To # maintain SSA, replace all relax and TIR variables in the # subroutine. diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 3dc575ae778c..edcb5e99addd 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -48,10 +48,10 @@ import onnx.onnx_ml_pb2 import tvm -from tvm import TVMError, relax, tir, topi +from tvm import TVMError, relax, tirx, topi from tvm.ir import IRModule from tvm.ir.supply import NameSupply -from tvm.tir.generic import cast +from tvm.tirx.generic import cast from tvm.topi.utils import get_const_tuple from ..common import autopad @@ -110,7 +110,7 @@ def get_constant( return var -def get_value(token, value_dict: dict[str, tvm.tir.SizeVar]) -> int | tvm.tir.SizeVar: +def get_value(token, value_dict: dict[str, tvm.tirx.SizeVar]) -> int | tvm.tirx.SizeVar: """Converts to token to an integer value if it a constant, otherwise it generates a SizeVar Parameters @@ -123,7 +123,7 @@ def get_value(token, value_dict: dict[str, tvm.tir.SizeVar]) -> int | tvm.tir.Si Returns ------- - Union[int, tvm.tir.SizeVar] + Union[int, tvm.tirx.SizeVar] The decoded token """ @@ -131,14 +131,14 @@ def get_value(token, value_dict: dict[str, tvm.tir.SizeVar]) -> int | tvm.tir.Si return int(token) except ValueError: if token not in value_dict or token == "?": - value_dict[token] = tvm.tir.SizeVar(token, "int64") + value_dict[token] = tvm.tirx.SizeVar(token, "int64") value = value_dict[token] return value def parse_shape_name( - name: str, value_dict: dict[str, tvm.tir.SizeVar] -) -> tir.PrimExpr | tvm.tir.SizeVar: + name: str, value_dict: dict[str, tvm.tirx.SizeVar] +) -> tirx.PrimExpr | tvm.tirx.SizeVar: """Converts expressions in the shape dimension name to prim expressions. Parameters @@ -151,7 +151,7 @@ def parse_shape_name( Returns ------- - Union[tir.PrimExpr, tvm.tir.SizeVar] + Union[tirx.PrimExpr, tvm.tirx.SizeVar] The expression of the shape dimension. """ @@ -188,7 +188,7 @@ def parse_shape_name( def get_info( - info_proto: onnx.onnx_ml_pb2.ValueInfoProto, value_dict: dict[str, tvm.tir.SizeVar] + info_proto: onnx.onnx_ml_pb2.ValueInfoProto, value_dict: dict[str, tvm.tirx.SizeVar] ) -> tuple[str, list, str, list, dict]: """Extract the shape from a ValueInfoProto. @@ -237,7 +237,7 @@ def get_numpy(tensor_proto: onnx.onnx_ml_pb2.TensorProto) -> _np.ndarray: def get_prim_expr_list( inputs: relax.Constant | relax.ShapeExpr, -) -> list[int | tir.PrimExpr]: +) -> list[int | tirx.PrimExpr]: """Attempt to convert a variable to list of PrimExpr if possible. Parameters @@ -247,7 +247,7 @@ def get_prim_expr_list( Returns ------- - ret : List[Union[int, tir.PrimExpr]] + ret : List[Union[int, tirx.PrimExpr]] The input value converted to a list of PrimExpr if possible. """ if isinstance(inputs, relax.Constant): @@ -320,7 +320,7 @@ def _impl_v13(cls, bb, inputs, attr, params): def _to_numpy(x): if isinstance(x, relax.PrimValue): x = x.value - if isinstance(x, tir.IntImm | tir.FloatImm): + if isinstance(x, tirx.IntImm | tirx.FloatImm): x = x.value return _np.array(x) else: @@ -635,8 +635,8 @@ def _impl_v13(cls, bb, inputs, attr, params): dtype = indices.struct_info.dtype axis_len = int(inputs[0].struct_info.shape[axis]) argmax = relax.op.argmax(indices, axis=axis) - on_value = relax.PrimValue(tvm.tir.const(1.0, dtype)) - off_value = relax.PrimValue(tvm.tir.const(0.0, dtype)) + on_value = relax.PrimValue(tvm.tirx.const(1.0, dtype)) + off_value = relax.PrimValue(tvm.tirx.const(0.0, dtype)) one_hot = relax.op.one_hot(argmax, on_value, off_value, axis_len, axis) return one_hot @@ -771,7 +771,7 @@ def _impl_v13(cls, bb, inputs, attr, params): to_type = get_type(attr["to"]) if isinstance(inputs[0], relax.ShapeExpr): shape = inputs[0] - if all([isinstance(x, tir.IntImm) for x in shape]): + if all([isinstance(x, tirx.IntImm) for x in shape]): shape = [int(x) for x in shape] return relax.const(shape, to_type) if isinstance(inputs[0], relax.Constant): @@ -896,7 +896,7 @@ def _impl_v11(cls, bb, inputs, attr, params): if condition.struct_info.ndim != 1: raise ValueError("Condition tensor is expected to be a 1D boolean tensor") indices = relax.op.nonzero(condition) - num_nonzero = tir.Var("num_nonzero", "int64") + num_nonzero = tirx.Var("num_nonzero", "int64") indices = bb.match_cast(indices, relax.TensorStructInfo([1, num_nonzero], "int64")) indices = relax.op.reshape(indices, [-1]) @@ -1486,7 +1486,7 @@ def _impl_v9(cls, bb, inputs, attr, params): shape = relax.ShapeExpr(list(shape.data.numpy())) # Special case where requested shape are constant - if len(shape) == 1 and all([isinstance(x, tir.IntImm) for x in shape]): + if len(shape) == 1 and all([isinstance(x, tirx.IntImm) for x in shape]): shape = [int(x) for x in shape] return relax.const(_np.full(shape, value, dtype), dtype) @@ -1887,7 +1887,7 @@ def _impl_v13(cls, bb, inputs, attr, params): assert all(len(i) == 1 for i in [starts, ends, steps]) sliced_values = shape_data[starts[0] : ends[0] : steps[0]] - if all([isinstance(val, tir.IntImm | int) for val in sliced_values]): + if all([isinstance(val, tirx.IntImm | int) for val in sliced_values]): return relax.const([x.value for x in sliced_values], "int64") else: return relax.ShapeExpr(sliced_values) @@ -1895,7 +1895,7 @@ def _impl_v13(cls, bb, inputs, attr, params): # If all `starts`, `ends`, and `steps` are constant, use strict mode # Otherwise, we assume the slice is inbound. assume_inbound = not all( - [isinstance(param, tir.IntImm | int) for param in [*starts, *ends, *steps]] + [isinstance(param, tirx.IntImm | int) for param in [*starts, *ends, *steps]] ) # Converting PrimExpr to PrimValue since relax.op.strided_slice does not accept PrimExpr @@ -1981,7 +1981,7 @@ def _tensor_length(expr): return None length = shape.values[0] - if not isinstance(length, tir.IntImm): + if not isinstance(length, tirx.IntImm): return None return length.value @@ -2021,7 +2021,7 @@ def _impl_v13(cls, bb, inputs, attr, params): ) output_shape = bb.normalize(relax.op.tensor_to_shape(output_shape_tensor)) output_shape_vars = [ - tir.Var(f"tile_dim_{i}", "int64") for i in range(max(data_ndim, reps_len)) + tirx.Var(f"tile_dim_{i}", "int64") for i in range(max(data_ndim, reps_len)) ] bb.match_cast(output_shape, relax.ShapeStructInfo(output_shape_vars)) return bb.emit_te(topi.dyn_tile, data, output_shape_vars, reps_len) @@ -2047,11 +2047,11 @@ def _impl_v13(cls, bb, inputs, attr, params): assert len(data_shape) == len(target_shape) # Apply ONNX v13 Expand broadcasting rules for i, s in enumerate(target_shape): - if isinstance(s, tvm.tir.IntImm): + if isinstance(s, tvm.tirx.IntImm): if s.value == -1: # -1 means preserve the input dimension target_shape[i] = data_shape[i] - elif isinstance(data_shape[i], tvm.tir.IntImm) and data_shape[i].value == 1: + elif isinstance(data_shape[i], tvm.tirx.IntImm) and data_shape[i].value == 1: # Input dimension is 1, can broadcast to any target dimension >= 1 if s.value < 1: raise ValueError( @@ -2059,7 +2059,8 @@ def _impl_v13(cls, bb, inputs, attr, params): f"at possition {i}. Target dimensions must be >= 1." ) elif ( - isinstance(data_shape[i], tvm.tir.IntImm) and s.value == data_shape[i].value + isinstance(data_shape[i], tvm.tirx.IntImm) + and s.value == data_shape[i].value ): # Dimensions match, no change needed pass @@ -2068,7 +2069,7 @@ def _impl_v13(cls, bb, inputs, attr, params): # This would "squeeze" the dimension - preserve input for safety target_shape[i] = data_shape[i] else: - if isinstance(data_shape[i], tvm.tir.IntImm): + if isinstance(data_shape[i], tvm.tirx.IntImm): raise ValueError( f"ONNX Expand: Cannot broadcast input shape {original_data_shape} " f"to target shape {original_target_shape}. " @@ -2139,7 +2140,7 @@ def _impl_v13(cls, bb, inputs, attr, params): shape_vars = [] for i in range(shape_ndim): - shape_vars.append(tvm.tir.Var(f"x_{i}", "int64")) + shape_vars.append(tvm.tirx.Var(f"x_{i}", "int64")) bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars)) # Applying broadcasting rules for dynamic shapes @@ -2149,7 +2150,7 @@ def _impl_v13(cls, bb, inputs, attr, params): padded_data = data if target_ndim > data_ndim: - padded_data_shape = [tir.IntImm("int64", 1)] * (target_ndim - data_ndim) + data_shape + padded_data_shape = [tirx.IntImm("int64", 1)] * (target_ndim - data_ndim) + data_shape padded_data = bb.normalize(relax.op.reshape(data, relax.ShapeExpr(padded_data_shape))) return bb.normalize(relax.op.broadcast_to(padded_data, relax.ShapeExpr(shape_vars))) @@ -2781,7 +2782,7 @@ def _impl_v13(cls, bb, inputs, attr, params): if axis == 0: new_shape = (1, -1) else: - shape_flags = [isinstance(x, tvm.script.tir.IntImm) for x in data_shape[0:axis]] + shape_flags = [isinstance(x, tvm.script.tirx.IntImm) for x in data_shape[0:axis]] if all(shape_flags): data_shape = [x.value for x in data_shape[0:axis]] @@ -3398,7 +3399,7 @@ def _impl_v11(cls, bb, inputs, attr, params): axis=axis, ) - unique_numbers = tir.Var("unique_numbers", "int64") + unique_numbers = tirx.Var("unique_numbers", "int64") input_shape = data.struct_info.shape dtype = data.struct_info.dtype @@ -3428,7 +3429,7 @@ def _impl_v11(cls, bb, inputs, attr, params): # ONNX spec: inverse_indices is always 1D # When axis is None: shape is [X.size] # When axis is specified: shape is [X.shape[axis]] - inverse_shape = (tir.Var("inverse_numbers", "int64"),) + inverse_shape = (tirx.Var("inverse_numbers", "int64"),) inverse_sinfo = relax.TensorStructInfo(inverse_shape, "int64") outputs.append(bb.match_cast(unique[tuple_idx], inverse_sinfo)) tuple_idx += 1 @@ -3448,7 +3449,7 @@ class NonZero(OnnxOpConverter): def _impl_v9(cls, bb, inputs, attr, params): ndim = inputs[0].struct_info.ndim ndim = 1 if ndim == 0 else ndim - nonzero_numbers = tir.Var("nonzero_numbers", "int64") + nonzero_numbers = tirx.Var("nonzero_numbers", "int64") return bb.match_cast( relax.op.nonzero(inputs[0]), relax.TensorStructInfo((ndim, nonzero_numbers), "int64") ) diff --git a/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py b/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py index 45f2989539b5..a5196095eee0 100644 --- a/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py +++ b/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py @@ -22,7 +22,7 @@ from typing import Any import tvm -from tvm import relax, tir +from tvm import relax, tirx class StableHLOImporter: @@ -130,7 +130,7 @@ def get_shape(self, inpt_type) -> list[Any]: for i in range(shape_type.rank): # get_dim_size if shape_type.is_dynamic_dim(i): - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") ret.append(n) else: ret.append(shape_type.get_dim_size(i)) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 3a7a62ba391d..803b4b7e11be 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -26,7 +26,7 @@ from functools import reduce import tvm -from tvm import relax, tir +from tvm import relax, tirx class BaseFXGraphImporter(metaclass=abc.ABCMeta): @@ -1870,7 +1870,7 @@ def is_squeezable(idx): for d in shape: if isinstance(d, int) and d == 1: return True - # Check for tir.IntImm + # Check for tirx.IntImm if hasattr(d, "value") and d.value == 1: return True return False @@ -1942,13 +1942,13 @@ def _normalize_bound(bound): max_index_val = 9223372036854775807 def _adjust(val): - if isinstance(val, int | tir.IntImm): + if isinstance(val, int | tirx.IntImm): int_val = int(val) if int_val >= max_index_val: return input_shape[axis] if int_val < 0: return input_shape[axis] + int_val - if isinstance(input_shape[axis], int | tir.IntImm) and int_val > int( + if isinstance(input_shape[axis], int | tirx.IntImm) and int_val > int( input_shape[axis] ): return input_shape[axis] @@ -1998,7 +1998,7 @@ def _roll(self, node: fx.Node) -> relax.Var: original_shape = self.shape_of(input_tensor) def to_int(val): - if isinstance(val, tir.IntImm): + if isinstance(val, tirx.IntImm): return int(val.value) elif isinstance(val, int): return val diff --git a/python/tvm/relax/frontend/torch/dynamo.py b/python/tvm/relax/frontend/torch/dynamo.py index d21a74acf1c7..9a0b6e1c58df 100644 --- a/python/tvm/relax/frontend/torch/dynamo.py +++ b/python/tvm/relax/frontend/torch/dynamo.py @@ -92,7 +92,7 @@ def to_torch_tensor(nd_tensor): for s in tensor.shape: if isinstance(s, torch.SymInt): if str(s) not in shape_vars: - shape_vars[str(s)] = tvm.tir.Var(str(s), "int64") + shape_vars[str(s)] = tvm.tirx.Var(str(s), "int64") shape.append(shape_vars[str(s)]) else: shape.append(s) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index fe69b6f76863..2487b904c6f1 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -480,9 +480,9 @@ def _lstm(self, node: fx.Node) -> relax.Var: else: seq_len, batch_size, input_size = input_shape - seq_len = int(seq_len) if isinstance(seq_len, tvm.tir.IntImm) else seq_len - batch_size = int(batch_size) if isinstance(batch_size, tvm.tir.IntImm) else batch_size - input_size = int(input_size) if isinstance(input_size, tvm.tir.IntImm) else input_size + seq_len = int(seq_len) if isinstance(seq_len, tvm.tirx.IntImm) else seq_len + batch_size = int(batch_size) if isinstance(batch_size, tvm.tirx.IntImm) else batch_size + input_size = int(input_size) if isinstance(input_size, tvm.tirx.IntImm) else input_size # Extract hidden size from the LSTM parameters # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh] # weight_ih shape: (4 * hidden_size, input_size) @@ -784,9 +784,9 @@ def _gru(self, node: fx.Node) -> relax.Var: else: seq_len, batch_size, input_size = input_shape - seq_len = int(seq_len) if isinstance(seq_len, tvm.tir.IntImm) else seq_len - batch_size = int(batch_size) if isinstance(batch_size, tvm.tir.IntImm) else batch_size - input_size = int(input_size) if isinstance(input_size, tvm.tir.IntImm) else input_size + seq_len = int(seq_len) if isinstance(seq_len, tvm.tirx.IntImm) else seq_len + batch_size = int(batch_size) if isinstance(batch_size, tvm.tirx.IntImm) else batch_size + input_size = int(input_size) if isinstance(input_size, tvm.tirx.IntImm) else input_size # Extract hidden size from parameters # For bidirectional: params has weights for both directions @@ -1181,8 +1181,8 @@ def _as_strided(self, node: fx.Node) -> relax.Var: assert storage_offset == 0, "as_strided with non-zero storage_offset is not supported yet" # Only handle view-like cases where the provided strides align with a contiguous layout. - can_check = all(isinstance(dim, int | tvm.tir.IntImm) for dim in size) and all( - isinstance(st, int | tvm.tir.IntImm) for st in stride + can_check = all(isinstance(dim, int | tvm.tirx.IntImm) for dim in size) and all( + isinstance(st, int | tvm.tirx.IntImm) for st in stride ) if can_check: expected_stride = [] @@ -1313,8 +1313,8 @@ def _import_branch_subgraph( # Create fresh SizeVars to avoid sharing with the caller function. if orig_si.shape is not None: new_shape = [ - tvm.tir.SizeVar(s.name, s.dtype) - if isinstance(s, tvm.tir.SizeVar) + tvm.tirx.SizeVar(s.name, s.dtype) + if isinstance(s, tvm.tirx.SizeVar) else s for s in orig_si.shape ] @@ -1734,8 +1734,8 @@ def create_convert_map( } def _process_derived_symbol( - self, symbol, torch_symbol_to_relax_var: dict[str, tvm.tir.Var] - ) -> tuple[str, tvm.tir.PrimExpr | None]: + self, symbol, torch_symbol_to_relax_var: dict[str, tvm.tirx.Var] + ) -> tuple[str, tvm.tirx.PrimExpr | None]: """Process a sympy symbol to generate a descriptive name and TIR expression.""" import sympy @@ -1748,10 +1748,10 @@ def _process_derived_symbol( tir_expr = None for arg in symbol.args: if isinstance(arg, sympy.Integer): - term = tvm.tir.IntImm("int64", int(arg)) + term = tvm.tirx.IntImm("int64", int(arg)) elif isinstance(arg, sympy.Symbol): term = torch_symbol_to_relax_var.setdefault( - str(arg), tvm.tir.SizeVar(str(arg), "int64") + str(arg), tvm.tirx.SizeVar(str(arg), "int64") ) else: _, term = self._process_derived_symbol(arg, torch_symbol_to_relax_var) @@ -1766,14 +1766,14 @@ def _process_derived_symbol( elif isinstance(symbol, sympy.Add): tir_expr = tir_expr + term - if isinstance(tir_expr, tvm.tir.Add): + if isinstance(tir_expr, tvm.tirx.Add): for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, tir_expr.a)]: - if isinstance(const, tvm.tir.IntImm) and isinstance(var, tvm.tir.Var): + if isinstance(const, tvm.tirx.IntImm) and isinstance(var, tvm.tirx.Var): return f"{var.name}___{const.value}", tir_expr - if isinstance(tir_expr, tvm.tir.Mul): + if isinstance(tir_expr, tvm.tirx.Mul): for const, var in [(tir_expr.a, tir_expr.b), (tir_expr.b, tir_expr.a)]: - if isinstance(const, tvm.tir.IntImm) and isinstance(var, tvm.tir.Var): + if isinstance(const, tvm.tirx.IntImm) and isinstance(var, tvm.tirx.Var): return f"{var.name}_{const.value}", tir_expr return str(symbol), tir_expr @@ -1784,7 +1784,7 @@ def create_input_vars( """Create relax input vars.""" parameters_buffers_constants = OrderedDict() user_inputs = OrderedDict() - torch_symbol_to_relax_var: dict[str, tvm.tir.Var] = {} + torch_symbol_to_relax_var: dict[str, tvm.tirx.Var] = {} range_constraints = {} if hasattr(exported_program, "range_constraints"): @@ -1837,7 +1837,7 @@ def create_input_vars( ) size_var = torch_symbol_to_relax_var.setdefault( - symbol_name, tvm.tir.SizeVar(symbol_name, "int64") + symbol_name, tvm.tirx.SizeVar(symbol_name, "int64") ) relax_shape.append(size_var) else: diff --git a/python/tvm/relax/op/_op_gradient.py b/python/tvm/relax/op/_op_gradient.py index 3cab1ece9e96..377874f705e9 100644 --- a/python/tvm/relax/op/_op_gradient.py +++ b/python/tvm/relax/op/_op_gradient.py @@ -25,7 +25,7 @@ from tvm.base import TVMError from tvm.relax.struct_info import ShapeStructInfo -from ...tir import PrimExpr +from ...tirx import PrimExpr from ..block_builder import BlockBuilder from ..expr import Call, Expr, ShapeExpr, Var from .base import register_gradient diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 28a7aa897a50..d257a8c8e6d4 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -98,12 +98,12 @@ def call_tir( tir_vars: ShapeExpr | tuple[PrimExpr] | list[PrimExpr] | None = None, ) -> Call: """ - Call a tir.prim_func and return the output. + Call a tirx.prim_func and return the output. Parameters ---------- gvar : GlobalVar - The GlobalVar referring to a tir PrimFunc. + The GlobalVar referring to a tirx PrimFunc. args : Expr The input arguments. @@ -141,14 +141,14 @@ def call_tir_with_grad( tir_vars: ShapeExpr | tuple[PrimExpr] | list[PrimExpr] | None = None, ) -> Call: """ - Call a tir.prim_func and return the output. This intrinsic will bind a te gradient function + Call a tirx.prim_func and return the output. This intrinsic will bind a te gradient function (refered by te_grad_name) to the call_tir_with_grad node. The te gradient function will be called by the Gradient pass. Parameters ---------- gvar : GlobalVar - The GlobalVar referring to a tir PrimFunc. + The GlobalVar referring to a tirx PrimFunc. args : Expr The input arguments. diff --git a/python/tvm/relax/op/distributed/distributed.py b/python/tvm/relax/op/distributed/distributed.py index b09f8686ac48..07ff674dd09e 100644 --- a/python/tvm/relax/op/distributed/distributed.py +++ b/python/tvm/relax/op/distributed/distributed.py @@ -73,14 +73,14 @@ def call_tir_local_view( tir_vars: ShapeExpr | tuple[PrimExpr] | list[PrimExpr] | None = None, ) -> Call: """ - Call a tir.prim_func and return the output. The prim_func should be a worker-local function + Call a tirx.prim_func and return the output. The prim_func should be a worker-local function that is actually executed on each worker, instead of the unpartitioned function. The output of this operator is DTensor or a tuple of DTensors. Parameters ---------- gvar : GlobalVar - The GlobalVar referring to a tir PrimFunc. + The GlobalVar referring to a tirx PrimFunc. args : Expr The input arguments. diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 25333bf20d54..3ce70fc545fb 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -19,7 +19,7 @@ from collections.abc import Callable from tvm.ir.expr import PrimExpr -from tvm.tir import FloatImm, IndexMap, IntImm +from tvm.tirx import FloatImm, IndexMap, IntImm from ..expr import Expr, PrimValue, ShapeExpr from ..expr import Tuple as RxTuple diff --git a/python/tvm/relax/op/memory/view.py b/python/tvm/relax/op/memory/view.py index e4fdd71df4b7..645ab94a1ade 100644 --- a/python/tvm/relax/op/memory/view.py +++ b/python/tvm/relax/op/memory/view.py @@ -29,7 +29,7 @@ from collections.abc import Sequence from tvm.relax import DataTypeImm, Expr, PrimValue, ShapeExpr -from tvm.tir import PrimExpr +from tvm.tirx import PrimExpr from . import _ffi_api diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index e467494f2b71..e30ba550c74b 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -17,7 +17,7 @@ """Relax Neural Network (NN) operators""" from tvm import DataType, relax -from tvm.tir import FloatImm +from tvm.tirx import FloatImm from ...expr import Expr from . import _ffi_api diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py index d1bbd7897aeb..4bff72f17f90 100644 --- a/python/tvm/relax/struct_info.py +++ b/python/tvm/relax/struct_info.py @@ -25,7 +25,7 @@ import tvm from tvm.ir import Array, EnvFunc, Span, VDevice from tvm.runtime import DataType -from tvm.tir import PrimExpr +from tvm.tirx import PrimExpr from . import _ffi_api, expr, ty from .expr import Expr, ShapeExpr, StructInfo @@ -93,11 +93,11 @@ def __init__( "but the specified dtype was {dtype}." ) elif isinstance(value, int | float): - value = tvm.tir.const(value, dtype) + value = tvm.tirx.const(value, dtype) # Use relax's default integer type if not otherwise specified. if isinstance(value, int): - value = tvm.tir.IntImm("int64", value) + value = tvm.tirx.IntImm("int64", value) if value is None: self.__init_handle_by_constructor__(_ffi_api.PrimStructInfoFromDtype, dtype, span) # type: ignore diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py index 636dea360d8e..4aa930ba6e1e 100644 --- a/python/tvm/relax/testing/ast_printer.py +++ b/python/tvm/relax/testing/ast_printer.py @@ -175,7 +175,7 @@ def display_attrs(attr_key): # we want to wrap strings in quotes # (__repr__ would work but it uses single quotes) attr_val = wrap_quotes(attr_val) - elif isinstance(attr_val, tvm.tir.IntImm): + elif isinstance(attr_val, tvm.tirx.IntImm): if attr_val.dtype == "bool": attr_val = bool(attr_val.value) else: diff --git a/python/tvm/relax/testing/attention.py b/python/tvm/relax/testing/attention.py index 4e72fa22f19e..a4157cabb6e1 100644 --- a/python/tvm/relax/testing/attention.py +++ b/python/tvm/relax/testing/attention.py @@ -19,7 +19,7 @@ import tvm from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import relax as relax_builder diff --git a/python/tvm/relax/testing/nn.py b/python/tvm/relax/testing/nn.py index 8e36680c95a4..97a3b83eb77d 100644 --- a/python/tvm/relax/testing/nn.py +++ b/python/tvm/relax/testing/nn.py @@ -24,7 +24,7 @@ import numpy as np # type: ignore import tvm -from tvm import relax, tir, topi +from tvm import relax, tirx, topi from tvm.relax.op.grad.grad import end_checkpoint, start_checkpoint @@ -287,7 +287,7 @@ def init_params(mod: tvm.IRModule) -> list[tvm.runtime.Tensor]: if isinstance(v, relax.ShapeExpr): shape = [] for i in v: - if isinstance(i, tir.IntImm): + if isinstance(i, tirx.IntImm): shape.append(int(i)) else: raise TypeError("cannot initialize for unknown-shape parameters.") diff --git a/python/tvm/relax/testing/vm.py b/python/tvm/relax/testing/vm.py index 2537d8e97da1..cbffb92912ba 100644 --- a/python/tvm/relax/testing/vm.py +++ b/python/tvm/relax/testing/vm.py @@ -88,5 +88,5 @@ def check_saved_func(vm: relax.VirtualMachine, func_name: str, *inputs: list[Any @tvm.register_global_func("test.vm.check_if_defined") -def check_if_defined(obj: tvm.Object) -> tvm.tir.IntImm: +def check_if_defined(obj: tvm.Object) -> tvm.tirx.IntImm: return tvm.runtime.convert(obj is not None) diff --git a/python/tvm/relax/training/setup_trainer.py b/python/tvm/relax/training/setup_trainer.py index a32279c181ec..eb6b6f488a75 100644 --- a/python/tvm/relax/training/setup_trainer.py +++ b/python/tvm/relax/training/setup_trainer.py @@ -20,7 +20,7 @@ import tvm from tvm import TVMError from tvm.ir.module import IRModule -from tvm.tir.expr import IntImm +from tvm.tirx.expr import IntImm from ..analysis import well_formed from ..expr import Tuple diff --git a/python/tvm/relax/transform/fold_batch_norm_to_conv2d_for_inference.py b/python/tvm/relax/transform/fold_batch_norm_to_conv2d_for_inference.py index c6485bc6391c..2aaebe527efb 100644 --- a/python/tvm/relax/transform/fold_batch_norm_to_conv2d_for_inference.py +++ b/python/tvm/relax/transform/fold_batch_norm_to_conv2d_for_inference.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name, unused-argument, redefined-argument-from-local """Relax Fold Batchnorm into Conv2D.""" -from tvm import relax, tir +from tvm import relax, tirx from tvm.ir.module import IRModule from tvm.ir.transform import PassContext from tvm.relax import Expr @@ -81,7 +81,7 @@ def rewriter(expr, matches): bn_attrs = bn_op.attrs bn_variance = relax.op.add( - bn_variance, relax.PrimValue(tir.FloatImm("float32", bn_attrs["epsilon"])) + bn_variance, relax.PrimValue(tirx.FloatImm("float32", bn_attrs["epsilon"])) ) dino = relax.op.sqrt(bn_variance) wt = relax.op.divide(bn_weight, dino) diff --git a/python/tvm/relax/transform/fuse_transpose_matmul.py b/python/tvm/relax/transform/fuse_transpose_matmul.py index 141f926cd3f8..ecefa876120f 100644 --- a/python/tvm/relax/transform/fuse_transpose_matmul.py +++ b/python/tvm/relax/transform/fuse_transpose_matmul.py @@ -24,7 +24,7 @@ """ import tvm -from tvm import IRModule, relax, te, tir +from tvm import IRModule, relax, te, tirx from tvm.relax.dpl.pattern import is_op, wildcard from tvm.relax.expr_functor import PyExprMutator, mutator @@ -130,9 +130,9 @@ def multiply_compute(idx_reduce): a_dim = a_shape[i if is_a_larger else i - offset] b_dim = b_shape[i if not is_a_larger else i - offset] dim_equal = a_dim == b_dim - if not isinstance(dim_equal, tir.IntImm) or dim_equal == 0: - a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1 - b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1 + if not isinstance(dim_equal, tirx.IntImm) or dim_equal == 0: + a_dim_is_one = isinstance(a_dim, tirx.IntImm) and a_dim == 1 + b_dim_is_one = isinstance(b_dim, tirx.IntImm) and b_dim == 1 a_indices.append(0 if a_dim_is_one else idx_spatial[i]) b_indices.append(0 if b_dim_is_one else idx_spatial[i]) else: diff --git a/python/tvm/relax/transform/legalize_ops/ccl.py b/python/tvm/relax/transform/legalize_ops/ccl.py index c3d1dee3e3e5..104d155bf1c4 100644 --- a/python/tvm/relax/transform/legalize_ops/ccl.py +++ b/python/tvm/relax/transform/legalize_ops/ccl.py @@ -18,7 +18,7 @@ # ruff: noqa: RUF005 """Default legalization function for ccl operators.""" -from tvm import arith, tir, topi +from tvm import arith, tirx, topi from ...block_builder import BlockBuilder from ...expr import Call, Expr, ShapeExpr @@ -101,7 +101,7 @@ def _transpose_for_ccl(_bb: BlockBuilder, expr: Expr, axis: int, num_workers: in f"is {shape_value} while num_workers is {num_workers}" ) new_shape.append(num_workers) - new_shape.append(tir.div(shape_value, num_workers)) + new_shape.append(tirx.div(shape_value, num_workers)) else: new_shape.append(shape_value) reshape_var = _bb.emit_te(topi.reshape, expr, new_shape) diff --git a/python/tvm/relax/transform/legalize_ops/common.py b/python/tvm/relax/transform/legalize_ops/common.py index 1d76d8614b41..6cb50d70ab74 100644 --- a/python/tvm/relax/transform/legalize_ops/common.py +++ b/python/tvm/relax/transform/legalize_ops/common.py @@ -20,7 +20,7 @@ import tvm from tvm import te -from tvm.tir import FloatImm, IntImm +from tvm.tirx import FloatImm, IntImm from ...block_builder import BlockBuilder from ...expr import Call, Constant, Expr @@ -73,9 +73,9 @@ def _try_convert_to_scalar_const( return value # preserve the data type of the constant if dtype.startswith("float"): - return tvm.tir.FloatImm(dtype, value) + return tvm.tirx.FloatImm(dtype, value) elif dtype.startswith("int") or dtype.startswith("uint") or dtype.startswith("bool"): - return tvm.tir.IntImm(dtype, value) + return tvm.tirx.IntImm(dtype, value) return expr diff --git a/python/tvm/relax/transform/legalize_ops/create.py b/python/tvm/relax/transform/legalize_ops/create.py index 3ab432056f2d..99b4449ebf79 100644 --- a/python/tvm/relax/transform/legalize_ops/create.py +++ b/python/tvm/relax/transform/legalize_ops/create.py @@ -20,7 +20,7 @@ import numpy as np -from tvm import tir, topi +from tvm import tirx, topi from ...block_builder import BlockBuilder from ...expr import Call, Expr, PrimValue, const @@ -108,7 +108,7 @@ def _arange(bb: BlockBuilder, call: Call) -> Expr: dtype = call.attrs.dtype def is_const_scalar(x: PrimValue): - return isinstance(x.value, tir.IntImm | tir.FloatImm) + return isinstance(x.value, tirx.IntImm | tirx.FloatImm) if all([is_const_scalar(x) for x in call.args]): return const(np.arange(start.value, end.value, step.value, dtype=dtype), dtype=dtype) diff --git a/python/tvm/relax/transform/legalize_ops/distributed.py b/python/tvm/relax/transform/legalize_ops/distributed.py index d92d7e0049cf..acd6bd4a4514 100644 --- a/python/tvm/relax/transform/legalize_ops/distributed.py +++ b/python/tvm/relax/transform/legalize_ops/distributed.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name """Default legalization function for distir-related operators.""" -from tvm import relax, tir +from tvm import relax, tirx from ...block_builder import BlockBuilder from ...expr import Call, Expr @@ -30,7 +30,7 @@ def _redistribute_replica_to_shard(_bb: BlockBuilder, call: Call) -> Expr: num_workers = call.attrs.num_workers axis = call.attrs.axis - worker_id_symbol = tir.Var("worker_id", "int64") + worker_id_symbol = tirx.Var("worker_id", "int64") worker_id_var = _bb.emit( call_pure_packed("runtime.disco.worker_id", sinfo_args=[ShapeStructInfo(None)]) ) diff --git a/python/tvm/relax/transform/legalize_ops/grad.py b/python/tvm/relax/transform/legalize_ops/grad.py index 029b106539f3..53222b5d50a7 100644 --- a/python/tvm/relax/transform/legalize_ops/grad.py +++ b/python/tvm/relax/transform/legalize_ops/grad.py @@ -19,10 +19,10 @@ import logging -from tvm import te, tir, topi +from tvm import te, tirx, topi from tvm.script.ir_builder import IRBuilder -from tvm.script.ir_builder import tir as T -from tvm.script.ir_builder.tir.utils import buffer_proxy +from tvm.script.ir_builder import tirx as T +from tvm.script.ir_builder.tirx.utils import buffer_proxy from ...block_builder import BlockBuilder from ...expr import Call, Expr @@ -56,7 +56,7 @@ def te_nll_loss_backward(output_grad, predictions, targets, weights, reduction, if ignore_index >= 0: weights = te.compute( weights.shape, - lambda i: tir.Select(i == ignore_index, tir.const(0, weights.dtype), weights(i)), + lambda i: tirx.Select(i == ignore_index, tirx.const(0, weights.dtype), weights(i)), "weights_new", ) @@ -73,18 +73,18 @@ def te_nll_loss_backward(output_grad, predictions, targets, weights, reduction, if predictions.ndim == 1: return te.compute( predictions.shape, - lambda i: tir.Select( - i == targets(), -all_weights() * output_grad(), tir.const(0, predictions.dtype) + lambda i: tirx.Select( + i == targets(), -all_weights() * output_grad(), tirx.const(0, predictions.dtype) ), "pred_grad", ) return te.compute( predictions.shape, - lambda *i: tir.Select( + lambda *i: tirx.Select( i[1] == targets(*i[:1], *i[2:]), -all_weights(*i[:1], *i[2:]) * output_grad(*i[:1], *i[2:]), - tir.const(0, predictions.dtype), + tirx.const(0, predictions.dtype), ), "pred_grad", ) @@ -181,7 +181,7 @@ def gen_ir(output_grad_ptr, x_ptr, indices_ptr, out_ptr): with T.seq_scope(): # Init loop (zero-fill output buffer) with T.serial(fused_shape) as i: - out[i] = tir.const(0, dtype=x_ptr.dtype) + out[i] = tirx.const(0, dtype=x_ptr.dtype) # Accumulation loop if axis is not None: @@ -219,7 +219,7 @@ def gen_ir(output_grad_ptr, x_ptr, indices_ptr, out_ptr): return ib.get() shape = x.shape - out_buf = tir.decl_buffer(shape, x.dtype, "out_buf") + out_buf = tirx.decl_buffer(shape, x.dtype, "out_buf") return te.extern( [shape], diff --git a/python/tvm/relax/transform/legalize_ops/index.py b/python/tvm/relax/transform/legalize_ops/index.py index 1d521bae4fb0..9bc47bd676e5 100644 --- a/python/tvm/relax/transform/legalize_ops/index.py +++ b/python/tvm/relax/transform/legalize_ops/index.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name """Default legalization function for index operators.""" -from tvm import te, tir, topi +from tvm import te, tirx, topi from ...block_builder import BlockBuilder from ...expr import Call, Expr @@ -45,7 +45,7 @@ def _relax_tuple_to_tir(relax_tuple): if len(call.args) == 4: data, axes, begin, end = call.args - strides = [tir.IntImm("int64", 1)] * len(axes.struct_info.fields) + strides = [tirx.IntImm("int64", 1)] * len(axes.struct_info.fields) elif len(call.args) == 5: data, axes, begin, end, strides = call.args strides = _relax_tuple_to_tir(strides) @@ -80,21 +80,23 @@ def _dynamic_strided_slice(bb: BlockBuilder, call: Call) -> Expr: def shape_func(data, begin, end, strides): def _compute(i): def canonicalize_index(index, extent, strides): - begin_range = tir.Select(strides < 0, tir.const(-1, "int64"), tir.const(0, "int64")) - end_range = tir.Select(strides < 0, extent - 1, extent) - index = tir.Select(index < 0, index + extent, index) - return tir.Min(tir.Max(index, begin_range), end_range) + begin_range = tirx.Select( + strides < 0, tirx.const(-1, "int64"), tirx.const(0, "int64") + ) + end_range = tirx.Select(strides < 0, extent - 1, extent) + index = tirx.Select(index < 0, index + extent, index) + return tirx.Min(tirx.Max(index, begin_range), end_range) def get_length(begin, end, strides, length): begin = canonicalize_index(begin, length, strides) end = canonicalize_index(end, length, strides) - len1 = tir.ceildiv(begin - end, -strides) - len2 = tir.ceildiv(end - begin, strides) - return tir.Select(strides < 0, len1, len2) + len1 = tirx.ceildiv(begin - end, -strides) + len2 = tirx.ceildiv(end - begin, strides) + return tirx.Select(strides < 0, len1, len2) - length = tir.const(-1, "int64") + length = tirx.const(-1, "int64") for idx in range(data.ndim): - length = tir.Select(i == tir.const(idx, "int64"), data.shape[idx], length) + length = tirx.Select(i == tirx.const(idx, "int64"), data.shape[idx], length) return get_length(begin[i], end[i], strides[i], length) @@ -113,7 +115,7 @@ def get_length(begin, end, strides, length): # 2. Convert tensor to shape and match cast with new symbolic vars ndim = int(output_shape.struct_info.shape[0]) output_shape = bb.emit(tensor_to_shape(output_shape)) - output_shape_vars = [tir.Var("s", "int64") for i in range(ndim)] + output_shape_vars = [tirx.Var("s", "int64") for i in range(ndim)] bb.match_cast(output_shape, ShapeStructInfo(output_shape_vars)) # 3. Pass the output shape vars to TOPI diff --git a/python/tvm/relax/transform/legalize_ops/inspect_op.py b/python/tvm/relax/transform/legalize_ops/inspect_op.py index 6edba59bc4e0..1bbdc5d7a1b0 100644 --- a/python/tvm/relax/transform/legalize_ops/inspect_op.py +++ b/python/tvm/relax/transform/legalize_ops/inspect_op.py @@ -19,7 +19,7 @@ import enum -from tvm.script import tir as T +from tvm.script import tirx as T from ... import op from ...block_builder import BlockBuilder @@ -28,7 +28,7 @@ class TVMStructFieldKind(enum.IntEnum): - """Equivalent to tvm::tir::builtin::TVMStructFieldKind + """Equivalent to tvm::tirx::builtin::TVMStructFieldKind This does not use `enum.auto()` to define the values, because `enum.auto()` starts from 1, and this must match the C++ @@ -55,7 +55,7 @@ class TVMStructFieldKind(enum.IntEnum): def _tensor_stride_i(bb: BlockBuilder, call: Call) -> Expr: @T.prim_func(private=True) def _get_tensor_stride_i(dlpack_handle: T.handle, axis: T.int64) -> T.int64: - T.func_attr({"tir.is_host": True, "tir.is_scheduled": True}) + T.func_attr({"tirx.is_host": True, "tirx.is_scheduled": True}) assert T.int64(0) <= axis, "Specified axis may not be negative" ndim: T.int32 = T.tvm_struct_get( dlpack_handle, 0, int(TVMStructFieldKind.kDLTensorNDim), "int32" @@ -97,7 +97,7 @@ def _get_tensor_stride_i(dlpack_handle: T.handle, axis: T.int64) -> T.int64: def _tensor_byte_offset(bb: BlockBuilder, call: Call) -> Expr: @T.prim_func(private=True) def _get_tensor_byte_offset(dlpack_handle: T.handle) -> T.int64: - T.func_attr({"tir.is_host": True, "tir.is_scheduled": True}) + T.func_attr({"tirx.is_host": True, "tirx.is_scheduled": True}) byte_offset: T.uint64 = T.tvm_struct_get( dlpack_handle, 0, int(TVMStructFieldKind.kDLTensorByteOffset), "uint64" ) @@ -111,7 +111,7 @@ def _get_tensor_byte_offset(dlpack_handle: T.handle) -> T.int64: def _tensor_elem_offset(bb: BlockBuilder, call: Call) -> Expr: @T.prim_func(private=True) def _get_tensor_elem_offset(dlpack_handle: T.handle) -> T.int64: - T.func_attr({"tir.is_host": True, "tir.is_scheduled": True}) + T.func_attr({"tirx.is_host": True, "tirx.is_scheduled": True}) byte_offset: T.uint64 = T.tvm_struct_get( dlpack_handle, 0, int(TVMStructFieldKind.kDLTensorByteOffset), "uint64" ) diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py b/python/tvm/relax/transform/legalize_ops/linear_algebra.py index f8dcab0a752a..d8dd8aa3b0cc 100644 --- a/python/tvm/relax/transform/legalize_ops/linear_algebra.py +++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name """Default legalization function for linear algebra operators.""" -from tvm import relax, te, tir, topi +from tvm import relax, te, tirx, topi from ...block_builder import BlockBuilder from ...expr import Call, Expr, Tuple, TupleGetItem, Var @@ -62,9 +62,9 @@ def multiply_compute(idx_reduce): a_dim = a_shape[i if is_a_larger else i - offset] b_dim = b_shape[i if not is_a_larger else i - offset] dim_equal = a_dim == b_dim - if not isinstance(dim_equal, tir.IntImm) or dim_equal == 0: - a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1 - b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1 + if not isinstance(dim_equal, tirx.IntImm) or dim_equal == 0: + a_dim_is_one = isinstance(a_dim, tirx.IntImm) and a_dim == 1 + b_dim_is_one = isinstance(b_dim, tirx.IntImm) and b_dim == 1 a_indices.append(0 if a_dim_is_one else idx_spatial[i]) b_indices.append(0 if b_dim_is_one else idx_spatial[i]) else: diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 41b191a3bf8b..2a1d249ef737 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -19,11 +19,11 @@ """Default legalization function for manipulate operators.""" import tvm -from tvm import relax, s_tir, te, tir, topi +from tvm import relax, s_tir, te, tirx, topi from tvm.relax.op.base import call_tir from tvm.relax.struct_info import TensorStructInfo from tvm.relax.utils import gen_call_tir_inputs -from tvm.tir.expr import IntImm +from tvm.tirx.expr import IntImm from ...block_builder import BlockBuilder from ...expr import Call, Expr, ShapeExpr, Tuple, TupleGetItem, Var @@ -108,7 +108,7 @@ def _permute_dims(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.split") def _split(bb: BlockBuilder, call: Call) -> Expr: - if isinstance(call.attrs.indices_or_sections, tir.IntImm): + if isinstance(call.attrs.indices_or_sections, tirx.IntImm): indices_or_sections = call.attrs.indices_or_sections.value else: indices_or_sections = call.attrs.indices_or_sections @@ -312,7 +312,7 @@ def te_layout_transform(data, name): def set_axis_sep(axis_sep: list, sch: s_tir.schedule, buffer_type: str): sch.set_axis_separator(primfunc_name, (buffer_type, 0), axis_separators=axis_sep) - index_map: tvm.tir.IndexMap = call.attrs.index_map + index_map: tvm.tirx.IndexMap = call.attrs.index_map pad_value = call.attrs.pad_value if pad_value is not None: pad_value = pad_value.value @@ -322,14 +322,14 @@ def set_axis_sep(axis_sep: list, sch: s_tir.schedule, buffer_type: str): else: pad_value = 0.0 - axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR = call.attrs.axis_separators - input_axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR = call.attrs.input_axis_separators + axis_separators: tvm.tirx.IndexMap.AXIS_SEPARATOR = call.attrs.axis_separators + input_axis_separators: tvm.tirx.IndexMap.AXIS_SEPARATOR = call.attrs.input_axis_separators # Convert to list from array axis_separators = [int(sep) for sep in axis_separators] primfunc_name = "te_layout_transform" _, padding_predicate = index_map.non_surjective_inverse(call.args[0].struct_info.shape) - if not isinstance(padding_predicate, tvm.tir.expr.IntImm): + if not isinstance(padding_predicate, tvm.tirx.expr.IntImm): primfunc_name += "_with_pad" if len(axis_separators) != 0: primfunc_name += "_axis_separator" diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index fe01ec62410d..4234aa831e4b 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -20,7 +20,7 @@ import logging import math -from tvm import s_tir, te, tir, topi +from tvm import s_tir, te, tirx, topi from ...block_builder import BlockBuilder from ...expr import Call, Expr @@ -46,7 +46,7 @@ def _nn_conv1d(bb: BlockBuilder, call: Call) -> Expr: kernel_layout = s_tir.layout(call.attrs.kernel_layout) ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")] oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")] - if not isinstance(ic, tir.IntImm) or not isinstance(oc, tir.IntImm): + if not isinstance(ic, tirx.IntImm) or not isinstance(oc, tirx.IntImm): logging.info( "Conv1D where number of groups is more than one and input or output " "channel size is symbolic cannot be legalized by TOPI at this moment." @@ -87,7 +87,7 @@ def _nn_conv2d(bb: BlockBuilder, call: Call) -> Expr: kernel_layout = s_tir.layout(call.attrs.kernel_layout) ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")] oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")] - if not isinstance(ic, tir.IntImm) or not isinstance(oc, tir.IntImm): + if not isinstance(ic, tirx.IntImm) or not isinstance(oc, tirx.IntImm): logging.info( "Conv2D where number of groups is more than one and input or output " "channel size is symbolic cannot be legalized by TOPI at this moment." @@ -128,7 +128,7 @@ def _nn_conv3d(bb: BlockBuilder, call: Call) -> Expr: kernel_layout = s_tir.layout(call.attrs.kernel_layout) ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")] oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")] - if not isinstance(ic, tir.IntImm) or not isinstance(oc, tir.IntImm): + if not isinstance(ic, tirx.IntImm) or not isinstance(oc, tirx.IntImm): logging.info( "Conv3D where number of groups is more than one and input or output " "channel size is symbolic cannot be legalized by TOPI at this moment." @@ -497,14 +497,14 @@ def _nn_prelu(bb: BlockBuilder, call: Call) -> Expr: def _nn_gelu(bb: BlockBuilder, call: Call) -> Expr: def te_gelu(x: te.Tensor): dtype = x.dtype - erf_inp = x * tir.const(0.5**0.5, dtype) + erf_inp = x * tirx.const(0.5**0.5, dtype) if dtype == "float16": erf = topi.math.cast(topi.erf(topi.math.cast(erf_inp, "float32")), "float16") else: erf = topi.erf(erf_inp) - return x * (tir.const(0.5, dtype) + erf * tir.const(0.5, dtype)) + return x * (tirx.const(0.5, dtype) + erf * tirx.const(0.5, dtype)) return bb.call_te(te_gelu, call.args[0], primfunc_name_hint="gelu") @@ -514,14 +514,14 @@ def _nn_gelu_tanh(bb: BlockBuilder, call: Call) -> Expr: def te_gelu_tanh(x: te.Tensor): dtype = x.dtype return ( - tir.const(0.5, dtype) + tirx.const(0.5, dtype) * x * ( - tir.const(1.0, dtype) + tirx.const(1.0, dtype) + topi.tanh( - tir.const(math.sqrt(2.0 / math.pi), dtype) + tirx.const(math.sqrt(2.0 / math.pi), dtype) * x - * (1 + tir.const(0.044715, dtype) * x * x) + * (1 + tirx.const(0.044715, dtype) * x * x) ) ) ) @@ -533,14 +533,14 @@ def te_gelu_tanh(x: te.Tensor): def _nn_selu(bb: BlockBuilder, call: Call) -> Expr: def te_selu(x: te.Tensor): dtype = x.dtype - alpha = tir.const(1.6732632423543772848170429916717, dtype) - scale = tir.const(1.0507009873554804934193349852946, dtype) + alpha = tirx.const(1.6732632423543772848170429916717, dtype) + scale = tirx.const(1.0507009873554804934193349852946, dtype) # Compute SELU # SELU(x) = scale*(max(0,x)+min(0,a*(exp(x)-1))) - positive_part = topi.maximum(x, tir.const(0, dtype)) + positive_part = topi.maximum(x, tirx.const(0, dtype)) negative_part = topi.minimum( - tir.const(0, dtype), alpha * (topi.exp(x) - tir.const(1, dtype)) + tirx.const(0, dtype), alpha * (topi.exp(x) - tirx.const(1, dtype)) ) return scale * (positive_part + negative_part) @@ -669,7 +669,7 @@ def _te_attention( k: te.Tensor, v: te.Tensor, bias: te.Tensor, - scale: tir.FloatImm, + scale: tirx.FloatImm, causal_mask: str | None, ) -> te.Tensor: batch_size, seq_len, num_head, head_dim = q.shape @@ -684,7 +684,7 @@ def _te_attention( if scale is not None: p = topi.multiply(p, scale) else: - p = topi.divide(p, tir.sqrt(tir.Cast(p.dtype, head_dim))) + p = topi.divide(p, tirx.sqrt(tirx.Cast(p.dtype, head_dim))) if bias is not None: p = topi.reshape(p, [batch_size, num_head, seq_len, seq_len_kv]) p = topi.add(p, bias) @@ -693,9 +693,9 @@ def _te_attention( s = topi.nn.softmax(p) else: if causal_mask == "TopLeft": - offset = tir.IntImm("int32", 0) + offset = tirx.IntImm("int32", 0) elif causal_mask == "BottomRight": - offset = tir.abs(seq_len - seq_len_kv).astype("int32") + offset = tirx.abs(seq_len - seq_len_kv).astype("int32") else: raise NotImplementedError() p_masked = topi.trilu(p, k=offset, upper=False) diff --git a/python/tvm/relax/transform/legalize_ops/qdq.py b/python/tvm/relax/transform/legalize_ops/qdq.py index 206417c6d0a1..caec63ffa8c0 100644 --- a/python/tvm/relax/transform/legalize_ops/qdq.py +++ b/python/tvm/relax/transform/legalize_ops/qdq.py @@ -18,7 +18,7 @@ """Default legalization function for quantize/dequantize operators.""" import tvm -from tvm import te, tir +from tvm import te, tirx from ...block_builder import BlockBuilder from ...expr import Call, Expr @@ -26,13 +26,13 @@ def clip_cast(val, dtype): - const_min = tvm.tir.min_value(dtype) - const_max = tvm.tir.max_value(dtype) + const_min = tvm.tirx.min_value(dtype) + const_max = tvm.tirx.max_value(dtype) return te.max(te.min(val, const_max), const_min).astype(dtype) def is_const_scalar(x): - return isinstance(x, tvm.tir.IntImm | tvm.tir.FloatImm) + return isinstance(x, tvm.tirx.IntImm | tvm.tirx.FloatImm) @register_legalize("relax.quantize") @@ -46,8 +46,8 @@ def _quantize(bb: BlockBuilder, call: Call) -> Expr: def te_quantize( data: te.Tensor, - scale: te.Tensor | tir.IntImm | tir.FloatImm, - zp: te.Tensor | tir.IntImm | tir.FloatImm, + scale: te.Tensor | tirx.IntImm | tirx.FloatImm, + zp: te.Tensor | tirx.IntImm | tirx.FloatImm, ): def quantize_compute(*indices): scale_value = scale if is_const_scalar(scale) else scale[indices[axis]] @@ -94,8 +94,8 @@ def _dequantize(bb: BlockBuilder, call: Call) -> Expr: def te_dequantize( data: te.Tensor, - scale: te.Tensor | tir.IntImm | tir.FloatImm, - zp: te.Tensor | tir.IntImm | tir.FloatImm, + scale: te.Tensor | tirx.IntImm | tirx.FloatImm, + zp: te.Tensor | tirx.IntImm | tirx.FloatImm, ): def dequantize_compute(*indices): scale_value = scale if is_const_scalar(scale) else scale[indices[axis]] diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py b/python/tvm/relax/transform/legalize_ops/statistical.py index f9967e5f7bdd..4db7e6b49281 100644 --- a/python/tvm/relax/transform/legalize_ops/statistical.py +++ b/python/tvm/relax/transform/legalize_ops/statistical.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name """Default legalization function for statistical operators.""" -from tvm import te, tir, topi +from tvm import te, tirx, topi from ...block_builder import BlockBuilder from ...expr import Call, Expr @@ -31,21 +31,21 @@ def statistical_call_te(bb: BlockBuilder, call: Call) -> Expr: return statistical_call_te -def _compute_shape_prod(x: te.Tensor, axis: list[tir.IntImm]) -> tir.PrimExpr: - shape_prod = tir.const(1, "int32") +def _compute_shape_prod(x: te.Tensor, axis: list[tirx.IntImm]) -> tirx.PrimExpr: + shape_prod = tirx.const(1, "int32") axes = [_axis.value for _axis in axis] if axis is not None else range(0, len(x.shape)) for dim in axes: shape_prod = shape_prod * x.shape[dim] return shape_prod -def _te_mean(x: te.Tensor, axis: list[tir.IntImm], keepdims: bool) -> te.Tensor: +def _te_mean(x: te.Tensor, axis: list[tirx.IntImm], keepdims: bool) -> te.Tensor: shape_prod = _compute_shape_prod(x, axis) res_sum = topi.sum(x, axis, keepdims) return topi.divide(res_sum, shape_prod) -def _te_variance(x: te.Tensor, axis: list[tir.IntImm], keepdims: bool) -> te.Tensor: +def _te_variance(x: te.Tensor, axis: list[tirx.IntImm], keepdims: bool) -> te.Tensor: dev = x - _te_mean(x, axis, True) return _te_mean(dev * dev, axis, keepdims) # This version has better memory locality and performance @@ -55,7 +55,7 @@ def _te_variance(x: te.Tensor, axis: list[tir.IntImm], keepdims: bool) -> te.Ten def _te_median( - x: te.Tensor, axis: list[tir.IntImm], keepdims: bool + x: te.Tensor, axis: list[tirx.IntImm], keepdims: bool ) -> te.Tensor | tuple[te.Tensor, te.Tensor]: # currently only supports one axis or no axis ~ same pytorch # todo: support multiple axis ~ same numpy @@ -97,7 +97,7 @@ def _mean(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.std") def _std(bb: BlockBuilder, call: Call) -> Expr: - def te_std(x: te.Tensor, axis: list[tir.IntImm], keepdims: bool) -> te.Tensor: + def te_std(x: te.Tensor, axis: list[tirx.IntImm], keepdims: bool) -> te.Tensor: return topi.sqrt(_te_variance(x, axis, keepdims)) return bb.call_te( diff --git a/python/tvm/relax/transform/legalize_ops/vision.py b/python/tvm/relax/transform/legalize_ops/vision.py index a3303fe4ad14..f95dfa35b642 100644 --- a/python/tvm/relax/transform/legalize_ops/vision.py +++ b/python/tvm/relax/transform/legalize_ops/vision.py @@ -16,7 +16,7 @@ # under the License. """Default legalization function for vision network related operators.""" -from tvm import relax, te, tir, topi +from tvm import relax, te, tirx, topi from ...block_builder import BlockBuilder from ...expr import Call, Expr, TupleGetItem @@ -76,18 +76,18 @@ def _all_class_non_max_suppression(block_builder: BlockBuilder, call: Call) -> E # Build slicing parameters using TE to avoid high-level Relax ops during legalization def build_begin(): - return te.compute((2,), lambda i: tir.const(0, "int64"), name="begin") + return te.compute((2,), lambda i: tirx.const(0, "int64"), name="begin") def build_strides(): - return te.compute((2,), lambda i: tir.const(1, "int64"), name="strides") + return te.compute((2,), lambda i: tirx.const(1, "int64"), name="strides") def build_end(count_tensor): # end = [count_tensor[0], 3] def compute_end(i): - return tir.if_then_else( + return tirx.if_then_else( i == 0, - tir.Cast("int64", count_tensor[0]), - tir.const(3, "int64"), + tirx.Cast("int64", count_tensor[0]), + tirx.const(3, "int64"), ) return te.compute((2,), compute_end, name="end") diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index e70392e88d58..a65a59c53c4f 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -33,7 +33,7 @@ from tvm.relax import Expr, StructInfo, Var from tvm.relax.dpl import DFPattern from tvm.runtime import Object, Tensor -from tvm.tir import IndexMap, PrimFunc +from tvm.tirx import IndexMap, PrimFunc from ..expr import Var from . import _ffi_api @@ -672,14 +672,14 @@ def BindParams( def BindSymbolicVars( - binding_map: Mapping[str | tvm.tir.Var, tvm.tir.PrimExpr], + binding_map: Mapping[str | tvm.tirx.Var, tvm.tirx.PrimExpr], func_name: str | None = None, ) -> tvm.ir.transform.Pass: """Bind params of function of the module to constant tensors. Parameters ---------- - binding_map : Mapping[Union[str, tvm.tir.Var], tvm.tir.PrimExpr] + binding_map : Mapping[Union[str, tvm.tirx.Var], tvm.tirx.PrimExpr] The map from symbolic varname to integer. func_name : Optional[str] @@ -693,7 +693,7 @@ def BindSymbolicVars( # Relax uses int64 for symbolic variables, but the FFI # converts python integers into int32. binding_map = { - key: tvm.tir.const(value, "int64") if isinstance(value, int) else value + key: tvm.tirx.const(value, "int64") if isinstance(value, int) else value for key, value in binding_map.items() } return _ffi_api.BindSymbolicVars(binding_map, func_name) # type: ignore @@ -819,7 +819,7 @@ def FuseTIR() -> tvm.ir.transform.Pass: Returns ------- ret : tvm.transform.Pass - The registered pass for tir fusion. + The registered pass for tirx fusion. """ return _ffi_api.FuseTIR() # type: ignore @@ -976,9 +976,9 @@ def MergeCompositeFunctions() -> tvm.ir.transform.Pass: def AttachAttrLayoutFreeBuffers() -> tvm.ir.transform.Pass: - """Attach layout free buffers to the tir::PrimFunc. + """Attach layout free buffers to the tirx::PrimFunc. - This pass is used to attach layout free buffers to the tir::PrimFunc according to + This pass is used to attach layout free buffers to the tirx::PrimFunc according to the function usage in the relax function. Currently, the layout free buffers are the model weights and relax constants. @@ -1153,7 +1153,7 @@ def add( B: T.Buffer((2, 3), "float32"), T_add: T.Buffer((2, 3), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(2, 3): with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -1167,7 +1167,7 @@ def multiply( B: T.Buffer((2, 3), "float32"), T_multiply: T.Buffer((2, 3), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(2, 3): with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -1839,7 +1839,7 @@ class TestReplaceBinding: def __init__(self): # create a new VarBinding - m, n = tir.Var("m", "int64"), tir.Var("n", "int64") + m, n = tirx.Var("m", "int64"), tirx.Var("n", "int64") lv0 = relax.Var("lv1", relax.TensorStructInfo([m, n], "float32")) val = relax.const(np.random.rand(24, 56)) self.new_binding = relax.VarBinding(lv0, val) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 75f62c525d4b..c0445a1b4914 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -29,11 +29,11 @@ import tvm -from .. import tir +from .. import tirx from ..ir import Array, Attrs, Map, Type, VDevice from ..te import Tensor as te_Tensor from ..te import create_prim_func -from ..tir import PrimExpr +from ..tirx import PrimExpr from . import _ffi_api from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm, te_tensor from .expr import Tuple as rx_Tuple @@ -91,23 +91,23 @@ def convert_to_expr(value: Any) -> Expr: Notes ----- - 1. `tvm.tir.StringImm` is not allowed because of ambiguity, + 1. `tvm.tirx.StringImm` is not allowed because of ambiguity, which can be either `relax.StringImm` or `relax.PrimValue`. """ if isinstance(value, int): - return PrimValue(tir.IntImm("int64", value)) + return PrimValue(tirx.IntImm("int64", value)) if isinstance(value, float): - return PrimValue(tir.FloatImm("float64", value)) + return PrimValue(tirx.FloatImm("float64", value)) tvm_value = tvm_ffi.convert(value) # Case 1 if isinstance(tvm_value, Expr): # type: ignore return tvm_value # Note`` 1 - if isinstance(tvm_value, tir.StringImm): + if isinstance(tvm_value, tirx.StringImm): raise TypeError( - "Cannot convert `tir.StringImm` to `relax.Expr` because of ambiguity," + "Cannot convert `tirx.StringImm` to `relax.Expr` because of ambiguity," "which can be either `relax.StringImm` or `relax.PrimValue` " ) # Case 2 @@ -145,7 +145,7 @@ def copy_with_new_vars(func: Function) -> Function: def gen_call_tir_inputs( func: Callable, *args: Any, **kwargs: Any -) -> tuple[tir.PrimFunc, Expr, list[TensorStructInfo], ShapeExpr | None]: +) -> tuple[tirx.PrimFunc, Expr, list[TensorStructInfo], ShapeExpr | None]: """Generate the inputs for call_tir according to the te function. This function converts arguments from relax expression to te tensor, The callback func should return a te tensor or a list of te tensors. @@ -165,26 +165,26 @@ def gen_call_tir_inputs( Returns ------- - ret : Tuple[tir.PrimFunc, Expr, List[TensorStructInfo], Optional[ShapeExpr]] - ret contains the inputs for call_tir, including a tir prim_func, args, + ret : Tuple[tirx.PrimFunc, Expr, List[TensorStructInfo], Optional[ShapeExpr]] + ret contains the inputs for call_tir, including a tirx prim_func, args, out_sinfo, and tir_vars. """ - tir_var_map: dict[tir.Var, tir.PrimExpr] = {} + tir_var_map: dict[tirx.Var, tirx.PrimExpr] = {} call_tir_args = [] create_primfunc_args = [] - # extra list of tir expression arguments + # extra list of tirx expression arguments # that are not covered by Tensor extra_tir_args_list = [] - def _copy_undefined_var(expr: tir.PrimExpr): - def _visit_expr(e: tir.PrimExpr): - if isinstance(e, tir.Var) and e not in tir_var_map: - new_var = tir.Var(e.name, e.dtype) + def _copy_undefined_var(expr: tirx.PrimExpr): + def _visit_expr(e: tirx.PrimExpr): + if isinstance(e, tirx.Var) and e not in tir_var_map: + new_var = tirx.Var(e.name, e.dtype) tir_var_map[e] = new_var - tir.stmt_functor.post_order_visit(expr, _visit_expr) + tirx.stmt_functor.post_order_visit(expr, _visit_expr) def _convert_te_arg(te_args: Any) -> Any: """Helper function used to convert Relax expressions to TE tensor. @@ -207,7 +207,7 @@ def _convert_te_arg(te_args: Any) -> Any: te_args : Any Argument to convert to TE - tir_var_map : Dict[tir.Var, tir.PrimExpr] + tir_var_map : Dict[tirx.Var, tirx.PrimExpr] The TIR variable mapping, which maps TIR variables on the Relax function side to the new set of variables used on the PrimFunc side. @@ -258,7 +258,7 @@ def _convert_te_arg_helper(arg): else: name = f"scalar_input_{n_args}" - tir_param = tir.Var(name, arg.struct_info.dtype) + tir_param = tirx.Var(name, arg.struct_info.dtype) call_tir_args.append(arg) create_primfunc_args.append(tir_param) @@ -277,9 +277,9 @@ def _convert_te_arg_helper(arg): "emit_te only supports dict with string as the key currently" ) return {k: _convert_te_arg_helper(arg[k]) for k in arg} - elif isinstance(arg, tir.PrimExpr): + elif isinstance(arg, tirx.PrimExpr): _copy_undefined_var(arg) - new_arg = tir.stmt_functor.substitute(arg, tir_var_map) + new_arg = tirx.stmt_functor.substitute(arg, tir_var_map) extra_tir_args_list.append(new_arg) return new_arg elif isinstance(arg, int | float | str | Type | Attrs) or arg is None: @@ -291,7 +291,7 @@ def _convert_te_arg_helper(arg): def _get_unbound_tir_vars( args: list[te_Tensor], extra_tir_args: list[PrimExpr] - ) -> list[tir.Var]: + ) -> list[tirx.Var]: """get unbound TIR vars (i.e TIR vars used in the shape but is not itself a dimension of a shape)""" @@ -302,15 +302,15 @@ def _populate_bound_vars(expr): if isinstance(expr, te_Tensor): for dim in expr.shape: _populate_bound_vars(dim) - elif isinstance(expr, tir.Var): + elif isinstance(expr, tirx.Var): bound_vars.add(expr) def _populate_used_vars(expr): if isinstance(expr, te_Tensor): for dim in expr.shape: _populate_used_vars(dim) - elif isinstance(expr, tir.PrimExpr): - used_vars.update(tir.analysis.undefined_vars(expr)) + elif isinstance(expr, tirx.PrimExpr): + used_vars.update(tirx.analysis.undefined_vars(expr)) for arg in itertools.chain(args, extra_tir_args): _populate_used_vars(arg) @@ -340,10 +340,10 @@ def _get_vdevice(arg: Any) -> VDevice | None: return vdevice def _shape_with_old_tir_var( - shape_values: list[tir.PrimExpr], tir_var_inverse_map: dict[tir.Var, tir.PrimExpr] + shape_values: list[tirx.PrimExpr], tir_var_inverse_map: dict[tirx.Var, tirx.PrimExpr] ): return ShapeExpr( - [tir.stmt_functor.substitute(value, tir_var_inverse_map) for value in shape_values] + [tirx.stmt_functor.substitute(value, tir_var_inverse_map) for value in shape_values] ) primfunc_attrs = kwargs.pop("primfunc_attrs", None) diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index ffe92a1daf8c..68592d67f870 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -21,7 +21,7 @@ from tvm import relax from tvm.ir.module import IRModule from tvm.runtime import Executable -from tvm.tir.function import PrimFunc +from tvm.tirx.function import PrimFunc from . import _ffi_api @@ -154,7 +154,7 @@ def _vmlink( tir_ext_libs = [] if tir_mod is not None and len(tir_mod.get_global_vars()) > 0: tir_mod = _auto_attach_system_lib_prefix(tir_mod, target, system_lib) - lib = tvm.tir.build(tir_mod, target=target, pipeline=tir_pipeline) + lib = tvm.tirx.build(tir_mod, target=target, pipeline=tir_pipeline) for ext_mod in ext_libs: if _is_device_module(ext_mod): tir_ext_libs.append(ext_mod) diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index 431fdbee533c..6ba9abb032af 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -155,7 +155,7 @@ def script( ir_prefix : str = "I" The prefix of AST nodes from tvm.ir tir_prefix : str = "T" - The prefix of AST nodes from tvm.tir + The prefix of AST nodes from tvm.tirx relax_prefix : str = "R" The prefix of AST nodes from tvm.relax module_alias : str = "cls" @@ -332,7 +332,7 @@ def show( ir_prefix : str = "I" The prefix of AST nodes from tvm.ir tir_prefix : str = "T" - The prefix of AST nodes from tvm.tir + The prefix of AST nodes from tvm.tirx relax_prefix : str = "R" The prefix of AST nodes from tvm.relax module_alias : str = "cls" diff --git a/python/tvm/s_tir/__init__.py b/python/tvm/s_tir/__init__.py index ce91d4009ee7..bba0dbff9fcf 100644 --- a/python/tvm/s_tir/__init__.py +++ b/python/tvm/s_tir/__init__.py @@ -18,7 +18,7 @@ # pylint: disable=invalid-name """S-TIR namespace for scheduable TensorIR""" -from tvm.tir.function import TensorIntrin +from tvm.tirx.function import TensorIntrin # dlight depends on compiler-only C++ functions (e.g. s_tir.schedule.GetSBlockRealize), # so skip it in runtime-only builds. diff --git a/python/tvm/s_tir/analysis/__init__.py b/python/tvm/s_tir/analysis/__init__.py index 6956b515e8e6..c1d4667e6247 100644 --- a/python/tvm/s_tir/analysis/__init__.py +++ b/python/tvm/s_tir/analysis/__init__.py @@ -22,11 +22,11 @@ import tvm from tvm.ir import IRModule -from tvm.tir.expr import Var -from tvm.tir.stmt import SBlock, BufferRegion +from tvm.tirx.expr import Var +from tvm.tirx.stmt import SBlock, BufferRegion -from tvm.tir import Buffer, Stmt -from tvm.tir.function import PrimFunc +from tvm.tirx import Buffer, Stmt +from tvm.tirx.function import PrimFunc from . import _ffi_api @@ -38,7 +38,7 @@ def get_sblock_access_region( Parameters ---------- - block: tvm.tir.SBlock + block: tvm.tirx.SBlock The block in which we are detecting read/write regions. buffer_var_map : Dict[Var, Buffer] @@ -63,7 +63,7 @@ def get_sblock_read_write_region( Parameters ---------- - block: tvm.tir.SBlock + block: tvm.tirx.SBlock The block in which we are detecting read/write regions. buffer_var_map : Dict[Var, Buffer] @@ -85,7 +85,7 @@ def detect_buffer_access_lca(func: PrimFunc) -> dict[Buffer, Stmt]: Parameters ---------- - func: tvm.tir.PrimFunc + func: tvm.tirx.PrimFunc The function to be detected. Returns @@ -128,7 +128,7 @@ def verify_gpu_code(func: PrimFunc, constraints: dict[str, int]) -> bool: Parameters ---------- - func: tvm.tir.PrimFunc + func: tvm.tirx.PrimFunc The module to be verified. constraints : Dict[str, int] diff --git a/python/tvm/s_tir/backend/adreno/pipeline.py b/python/tvm/s_tir/backend/adreno/pipeline.py index 8cce273d9422..a63fb4346d34 100644 --- a/python/tvm/s_tir/backend/adreno/pipeline.py +++ b/python/tvm/s_tir/backend/adreno/pipeline.py @@ -19,12 +19,12 @@ """The TIR backend compilation pipeline for Adreno""" import tvm -from tvm import s_tir, tir -from tvm.tir import pipeline as tir_pipeline +from tvm import s_tir, tirx +from tvm.tirx import pipeline as tir_pipeline def default_tir_pipeline(): - """The default tir pipeline used in tvm.tir.build""" + """The default tirx pipeline used in tvm.tirx.build""" @tvm.transform.module_pass(opt_level=0) def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: @@ -44,56 +44,56 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I s_tir.transform.LowerAutoCopy(), s_tir.transform.UnifyThreadBinding(), s_tir.transform.LowerMatchBuffer(), - tir.transform.Simplify(), + tirx.transform.Simplify(), s_tir.transform.InjectPermutedLayout(), s_tir.transform.AnnotateIrregularLoop(), s_tir.transform.InjectSoftwarePipeline(), s_tir.transform.TransformMmaBufferLayout(), s_tir.transform.LowerOpaqueBlock(), s_tir.backend.adreno.transform.InjectTextureAlloc(), - tir.transform.FlattenBuffer(), - tir.transform.BF16ComputeLegalize(), - tir.transform.NarrowDataType(32), + tirx.transform.FlattenBuffer(), + tirx.transform.BF16ComputeLegalize(), + tirx.transform.NarrowDataType(32), s_tir.transform.LoopPartition(), - tir.transform.VectorizeLoop(not bool(config.get("tir.disable_vectorize", False))), + tirx.transform.VectorizeLoop(not bool(config.get("tirx.disable_vectorize", False))), s_tir.transform.InjectVirtualThread(), s_tir.transform.InjectDoubleBuffer(), ] - if not bool(config.get("tir.disable_storage_rewrite", False)): - passes.append(tir.transform.StorageRewrite()) - if config.get("tir.use_async_copy", False): + if not bool(config.get("tirx.disable_storage_rewrite", False)): + passes.append(tirx.transform.StorageRewrite()) + if config.get("tirx.use_async_copy", False): passes.append(s_tir.transform.LowerAsyncDMA()) passes.extend( [ s_tir.transform.HoistIfThenElse(), - tir.transform.UnrollLoop(), + tirx.transform.UnrollLoop(), s_tir.transform.RenormalizeSplitPattern(), - tir.transform.Simplify(), - tir.transform.RemoveNoOp(), + tirx.transform.Simplify(), + tirx.transform.RemoveNoOp(), s_tir.transform.RewriteUnsafeSelect(), ] ) # Additional passes based on configuration. - if bool(config.get("tir.instrument_bound_checkers", False)): + if bool(config.get("tirx.instrument_bound_checkers", False)): passes.append(s_tir.transform.InstrumentBoundCheckers()) - if bool(config.get("tir.ptx_ldg32", False)): + if bool(config.get("tirx.ptx_ldg32", False)): passes.append(s_tir.transform.InjectPTXLDG32(True)) - if not bool(config.get("tir.disable_cse_tir", False)): - passes.append(tir.transform.CommonSubexprElim()) - if bool(config.get("tir.instrument_lwp", False)): + if not bool(config.get("tirx.disable_cse_tir", False)): + passes.append(tirx.transform.CommonSubexprElim()) + if bool(config.get("tirx.instrument_lwp", False)): passes.append(s_tir.transform.InstrumentProfileIntrinsics()) passes.extend( [ # Bind the target first so that target-specific attributes are available. - tir.transform.FP8ComputeLegalize(), + tirx.transform.FP8ComputeLegalize(), # VerifyVTCMLimit must occur before LowerVtcmAlloc. s_tir.transform.VerifyVTCMLimit(), s_tir.transform.LowerVtcmAlloc(), - tir.transform.VerifyMemory(), - tir.transform.AnnotateEntryFunc(), + tirx.transform.VerifyMemory(), + tirx.transform.AnnotateEntryFunc(), ] ) - if bool(config.get("tir.detect_global_barrier", False)): + if bool(config.get("tirx.detect_global_barrier", False)): passes.append(s_tir.transform.ThreadSync("global")) passes.extend( [ @@ -104,20 +104,20 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I s_tir.transform.LowerThreadAllreduce(), ] ) - if bool(config.get("tir.use_async_copy", False)): + if bool(config.get("tirx.use_async_copy", False)): passes.append(s_tir.transform.InjectPTXAsyncCopy()) - if bool(config.get("tir.ptx_ldg32", False)): + if bool(config.get("tirx.ptx_ldg32", False)): passes.append(s_tir.transform.InjectPTXLDG32()) passes.extend( [ - tir.transform.AnnotateDeviceRegions(), - tir.transform.SplitHostDevice(), + tirx.transform.AnnotateDeviceRegions(), + tirx.transform.SplitHostDevice(), # MergeSharedMemoryAllocations must follow SplitHostDevice. s_tir.transform.MergeSharedMemoryAllocations(), - tir.transform.MakePackedAPI(), - tir.transform.FP8StorageLegalize(), - tir.transform.BF16StorageLegalize(), - tir.transform.LowerDeviceKernelLaunch(), + tirx.transform.MakePackedAPI(), + tirx.transform.FP8StorageLegalize(), + tirx.transform.BF16StorageLegalize(), + tirx.transform.LowerDeviceKernelLaunch(), ] ) mod = tvm.ir.transform.Sequential(passes)(mod) diff --git a/python/tvm/s_tir/dlight/adreno/convolution.py b/python/tvm/s_tir/dlight/adreno/convolution.py index 7a59c62b1a65..324b15007e5e 100644 --- a/python/tvm/s_tir/dlight/adreno/convolution.py +++ b/python/tvm/s_tir/dlight/adreno/convolution.py @@ -17,7 +17,7 @@ # pylint: disable=missing-docstring, invalid-name """A Conv2d schedule rule for Adreno GPU operators.""" -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.target import Target from .. import analysis @@ -62,16 +62,16 @@ def schedule_conv2d(sch: s_tir.Schedule, blk: s_tir.schedule.SBlockRV): def apply( # pylint: disable=too-many-locals,missing-docstring self, - func: tir.PrimFunc | s_tir.Schedule, + func: tirx.PrimFunc | s_tir.Schedule, target: Target, _: bool, ) -> s_tir.Schedule | None: - if not (isinstance(func, tir.PrimFunc | s_tir.Schedule)) or not self.is_target_available( + if not (isinstance(func, tirx.PrimFunc | s_tir.Schedule)) or not self.is_target_available( target ): return None - if isinstance(func, tir.PrimFunc): + if isinstance(func, tirx.PrimFunc): sch = s_tir.Schedule(func) sch.work_on("main") elif isinstance(func, s_tir.Schedule): diff --git a/python/tvm/s_tir/dlight/adreno/fallback.py b/python/tvm/s_tir/dlight/adreno/fallback.py index f628583194c8..b453916b460f 100644 --- a/python/tvm/s_tir/dlight/adreno/fallback.py +++ b/python/tvm/s_tir/dlight/adreno/fallback.py @@ -33,7 +33,7 @@ # under the License. """Dlight Adreno Fallback Schedules""" -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.target import Target from .. import analysis @@ -168,13 +168,13 @@ def schedule_fallback(sch): def apply( # pylint: disable=too-many-locals self, - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, _: bool, ) -> None | s_tir.Schedule | list[s_tir.Schedule]: # pylint: disable=invalid-name - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): return None sch = s_tir.Schedule(func) diff --git a/python/tvm/s_tir/dlight/adreno/layout_transform.py b/python/tvm/s_tir/dlight/adreno/layout_transform.py index aebd21aa38e6..7068f827102c 100644 --- a/python/tvm/s_tir/dlight/adreno/layout_transform.py +++ b/python/tvm/s_tir/dlight/adreno/layout_transform.py @@ -37,7 +37,7 @@ "Schedules for Texture Based Layout Transforms" -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.target import Target from .. import analysis @@ -53,17 +53,17 @@ def __init__(self, use_op_name=True): # TODO: Try using Coalesced Writes... def apply( # pylint: disable=too-many-locals self, - func: tir.PrimFunc | s_tir.Schedule, + func: tirx.PrimFunc | s_tir.Schedule, target: Target, _: bool, ) -> None | s_tir.Schedule | list[s_tir.Schedule]: # pylint: disable=invalid-name - if not (isinstance(func, tir.PrimFunc | s_tir.Schedule)) or not self.is_target_available( + if not (isinstance(func, tirx.PrimFunc | s_tir.Schedule)) or not self.is_target_available( target ): return None - if isinstance(func, tir.PrimFunc): + if isinstance(func, tirx.PrimFunc): sch = s_tir.Schedule(func) sch.work_on("main") elif isinstance(func, s_tir.Schedule): diff --git a/python/tvm/s_tir/dlight/adreno/pool.py b/python/tvm/s_tir/dlight/adreno/pool.py index 1805fce65dcf..4bd071858e87 100644 --- a/python/tvm/s_tir/dlight/adreno/pool.py +++ b/python/tvm/s_tir/dlight/adreno/pool.py @@ -18,7 +18,7 @@ # ruff: noqa: F841 """Pool schedule rule for Adreno operators.""" -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.target import Target from .. import analysis @@ -29,7 +29,7 @@ class Pool2D(AdrenoScheduleRule): def apply( # pylint: disable=too-many-locals,missing-docstring self, - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, _: bool, ) -> s_tir.Schedule: diff --git a/python/tvm/s_tir/dlight/analysis/common_analysis.py b/python/tvm/s_tir/dlight/analysis/common_analysis.py index 9a8f6a20dccf..2110f55c8e2a 100644 --- a/python/tvm/s_tir/dlight/analysis/common_analysis.py +++ b/python/tvm/s_tir/dlight/analysis/common_analysis.py @@ -24,7 +24,7 @@ from tvm_ffi import get_global_func -from tvm import ir, s_tir, tir +from tvm import ir, s_tir, tirx from tvm.runtime import DataType from tvm.s_tir import Schedule from tvm.s_tir.schedule import SBlockRV @@ -35,15 +35,15 @@ class IterInfo: """Information about a loop/iter var.""" kind: Literal["S", "R", "O"] - var: tir.Var - _dom: tir.PrimExpr + var: tirx.Var + _dom: tirx.PrimExpr loop_rv: s_tir.schedule.LoopRV def __init__( self, kind: Literal["S", "R", "O"], - var: tir.Var, - dom: tir.PrimExpr, + var: tirx.Var, + dom: tirx.PrimExpr, loop_rv: s_tir.schedule.LoopRV, ): """Construct an IterInfo object.""" @@ -53,9 +53,9 @@ def __init__( self.loop_rv = loop_rv @property - def dom(self) -> int | tir.PrimExpr: + def dom(self) -> int | tirx.PrimExpr: """The iteration domain of the loop.""" - return int(self._dom) if isinstance(self._dom, tir.IntImm) else self._dom + return int(self._dom) if isinstance(self._dom, tirx.IntImm) else self._dom def __str__(self) -> str: return f'Iter("{self.kind}", {self.dom})' @@ -76,16 +76,16 @@ def __repr__(self) -> str: class BufferInfo: "Information about Buffer. Provides useful analysis" - buf_region: tir.BufferRegion + buf_region: tirx.BufferRegion shape: tuple[int] assoc_lps: list[s_tir.schedule.LoopRV | None] - assoc_lps_info: list[tir.For | None] + assoc_lps_info: list[tirx.For | None] def __init__( self, sch: s_tir.Schedule, block_rv: s_tir.schedule.SBlockRV, - buf_region: tir.BufferRegion, + buf_region: tirx.BufferRegion, lps: list[s_tir.schedule.LoopRV] | None, ): block = sch.get(block_rv) @@ -97,34 +97,34 @@ def __init__( lpvar_lp = dict([loop.loop_var, lp] for loop, lp in zip(loops, lps)) var_lp = dict(zip(iter_vars, [lpvar_lp.get(val, None) for val in iter_values])) - def extract_index_types(buf: tir.BufferRegion) -> BufIndex: + def extract_index_types(buf: tirx.BufferRegion) -> BufIndex: buf_index = [] for expr in buf.region: expr = expr.min dim = None - if isinstance(expr, tir.expr.Add) and isinstance(expr.b, tir.expr.Var): + if isinstance(expr, tirx.expr.Add) and isinstance(expr.b, tirx.expr.Var): var_add = expr.b if ( - isinstance(expr, tir.expr.Mul) - and isinstance(expr.a, tir.expr.Var) - and isinstance(expr.b, tir.expr.IntImm) + isinstance(expr, tirx.expr.Mul) + and isinstance(expr.a, tirx.expr.Var) + and isinstance(expr.b, tirx.expr.IntImm) ): mul = expr.b var_mul = expr.a dim = MergeIndex(var_mul, mul, var_add) elif ( - isinstance(expr, tir.expr.FloorMod) - and isinstance(expr.a, tir.expr.Var) - and isinstance(expr.b, tir.expr.IntImm) + isinstance(expr, tirx.expr.FloorMod) + and isinstance(expr.a, tirx.expr.Var) + and isinstance(expr.b, tirx.expr.IntImm) ): dim = RemIndex(expr.a, expr.b) elif ( - isinstance(expr, tir.expr.FloorDiv) - and isinstance(expr.a, tir.expr.Var) - and isinstance(expr.b, tir.expr.IntImm) + isinstance(expr, tirx.expr.FloorDiv) + and isinstance(expr.a, tirx.expr.Var) + and isinstance(expr.b, tirx.expr.IntImm) ): dim = DivIndex(expr.a, expr.b) - elif isinstance(expr, tir.expr.Var): + elif isinstance(expr, tirx.expr.Var): dim = Index(expr) buf_index.append(dim) return buf_index @@ -186,7 +186,7 @@ def __init__( self.iters = iters self._reduction_block = reduction_block - def dom(self) -> list[int | tir.PrimExpr]: + def dom(self) -> list[int | tirx.PrimExpr]: """The iteration domain of the block.""" return [i.dom for i in self.iters] @@ -211,7 +211,7 @@ def is_injective(self) -> bool: def is_elementwise(self, sch: s_tir.Schedule) -> bool: """Whether the SBlock is elementwise, i.e. trivial mapping between read/write region""" - def _check_unit_var_range(dom: ir.Range, var: tir.Var) -> bool: + def _check_unit_var_range(dom: ir.Range, var: tirx.Var) -> bool: return dom.min.same_as(var) and dom.extent == 1 if not self.is_injective(): @@ -293,10 +293,10 @@ def normalize_prim_func(sch: s_tir.Schedule) -> list[SBlockInfo] | None: except Exception: # pylint: disable=broad-except return None - def _iter_kind(i: tir.IterVar) -> str: + def _iter_kind(i: tirx.IterVar) -> str: return { - tir.IterVar.DataPar: "S", - tir.IterVar.CommReduce: "R", + tirx.IterVar.DataPar: "S", + tirx.IterVar.CommReduce: "R", }.get(i.iter_type, "O") blocks: list[SBlockInfo] = [] @@ -321,8 +321,8 @@ def _iter_kind(i: tir.IterVar) -> str: def get_sblock_info(sch: s_tir.Schedule, block: s_tir.schedule.SBlockRV) -> SBlockInfo: - def _iter_kind(loop: tir.IterVar) -> str: - return {tir.IterVar.DataPar: "S", tir.IterVar.CommReduce: "R"}.get(loop.iter_type, "O") + def _iter_kind(loop: tirx.IterVar) -> str: + return {tirx.IterVar.DataPar: "S", tirx.IterVar.CommReduce: "R"}.get(loop.iter_type, "O") def _is_reduction_block(block: s_tir.schedule.SBlockRV): for iter_var in sch.get(block).iter_vars: @@ -384,8 +384,8 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> SBlockRV: def collect_block_iter_vars_used_in_access_region( - block: tir.SBlock, region: list[ir.Range] -) -> set[tir.Var]: + block: tirx.SBlock, region: list[ir.Range] +) -> set[tirx.Var]: """Collect the block iter variables used in the access region of a buffer region.""" tir_vars = set() for expr in region: @@ -395,19 +395,19 @@ def collect_block_iter_vars_used_in_access_region( return tir_vars -def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> set[tir.Var]: +def collect_vars_used_in_prim_expr(expr: tirx.PrimExpr) -> set[tirx.Var]: """Collect the variables used in the PrimExpr.""" tir_vars = set() def _collect_tir_var(expr): - if isinstance(expr, tir.Var): + if isinstance(expr, tirx.Var): tir_vars.add(expr) - tir.stmt_functor.post_order_visit(expr, _collect_tir_var) + tirx.stmt_functor.post_order_visit(expr, _collect_tir_var) return tir_vars -def detect_dominant_read(block: tir.SBlock) -> tir.PrimExpr: +def detect_dominant_read(block: tirx.SBlock) -> tirx.PrimExpr: """Detect the dominant read indices in the block.""" dominant_read = None num_read_iters = -1 diff --git a/python/tvm/s_tir/dlight/analysis/gemv.py b/python/tvm/s_tir/dlight/analysis/gemv.py index 79b1b1b1a5ec..75d5b17dfd27 100644 --- a/python/tvm/s_tir/dlight/analysis/gemv.py +++ b/python/tvm/s_tir/dlight/analysis/gemv.py @@ -16,7 +16,7 @@ # under the License. """Analysis for GEMV.""" -from tvm import arith, ir, s_tir, tir +from tvm import arith, ir, s_tir, tirx from .common_analysis import ( SBlockInfo, @@ -26,7 +26,7 @@ ) -def get_reduction_expr(block: tir.SBlock) -> tir.PrimExpr | None: +def get_reduction_expr(block: tirx.SBlock) -> tirx.PrimExpr | None: """Extracts the reduction expression from a TIR block. This function checks whether the given TIR block follows a reduction pattern @@ -34,30 +34,30 @@ def get_reduction_expr(block: tir.SBlock) -> tir.PrimExpr | None: Parameters: ---------- - block : tir.SBlock + block : tirx.SBlock The TIR block to analyze. Returns: ------- - Optional[tir.PrimExpr] + Optional[tirx.PrimExpr] The reduction expression (`Y`) if detected, otherwise None. """ buffer_store = block.body - if not isinstance(buffer_store, tir.BufferStore): + if not isinstance(buffer_store, tirx.BufferStore): return None - if not isinstance(buffer_store.value, tir.Add): + if not isinstance(buffer_store.value, tirx.Add): return None if not ir.structural_equal( buffer_store.value.a, - tir.BufferLoad(buffer_store.buffer, block.body.indices), + tirx.BufferLoad(buffer_store.buffer, block.body.indices), map_free_vars=True, ): return None return buffer_store.value.b -def is_gemv(sch: s_tir.Schedule, block_info: SBlockInfo) -> list[tir.Buffer] | None: +def is_gemv(sch: s_tir.Schedule, block_info: SBlockInfo) -> list[tirx.Buffer] | None: """Check if the block is a GEMV. Parameters @@ -72,7 +72,7 @@ def is_gemv(sch: s_tir.Schedule, block_info: SBlockInfo) -> list[tir.Buffer] | N Returns ------- - ret : Optional[List[tir.Buffer]] + ret : Optional[List[tirx.Buffer]] The vector buffers used in the GEMV if it is a GEMV, otherwise None. """ block = block_info.block_rv @@ -104,7 +104,7 @@ def normalize( block_info: SBlockInfo, ) -> bool | None: """Normalize the main block.""" - block_stmt: tir.SBlock = sch.get(block_info.block_rv) + block_stmt: tirx.SBlock = sch.get(block_info.block_rv) access = arith.normalize_to_iter_sum( detect_dominant_read(block_stmt), input_iters={i.var: i.dom for i in block_stmt.iter_vars}, diff --git a/python/tvm/s_tir/dlight/base/schedule_rule.py b/python/tvm/s_tir/dlight/base/schedule_rule.py index 8764ceb3eb18..240b10d9d084 100644 --- a/python/tvm/s_tir/dlight/base/schedule_rule.py +++ b/python/tvm/s_tir/dlight/base/schedule_rule.py @@ -18,7 +18,7 @@ from collections.abc import Callable -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.target import Target @@ -35,7 +35,7 @@ class ScheduleRule: # pylint: disable=too-few-public-methods def apply( self, - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, tunable: bool, ) -> None | s_tir.Schedule | list[s_tir.Schedule]: @@ -43,7 +43,7 @@ def apply( Parameters ---------- - func : tir.PrimFunc + func : tirx.PrimFunc The PrimFunc to apply the ScheduleRule to. target : Target The compilation target the schedule is supposed to be built for. @@ -64,7 +64,7 @@ def from_callable( ) -> Callable[ [ Callable[ - [tir.PrimFunc, Target, bool], + [tirx.PrimFunc, Target, bool], None | s_tir.Schedule | list[s_tir.Schedule], ], ], @@ -86,7 +86,7 @@ def from_callable( .. code-block:: python @ScheduleRule.from_callable("MyRule") - def my_rule(func: tir.PrimFunc, target: Target, tunable: bool) -> Union[None, Schedule] + def my_rule(func: tirx.PrimFunc, target: Target, tunable: bool) -> Union[None, Schedule] # Do something with func and target """ @@ -94,7 +94,7 @@ def decorator(f) -> "ScheduleRule": # pylint: disable=invalid-name class _Rule(ScheduleRule): def apply( self, - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, tunable: bool, ) -> None | s_tir.Schedule | list[s_tir.Schedule]: diff --git a/python/tvm/s_tir/dlight/base/transform.py b/python/tvm/s_tir/dlight/base/transform.py index 049a920fefac..18014e40c07d 100644 --- a/python/tvm/s_tir/dlight/base/transform.py +++ b/python/tvm/s_tir/dlight/base/transform.py @@ -19,7 +19,7 @@ or a space for MetaSchedule tuning """ -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.ir import IRModule from tvm.ir.transform import PassContext, module_pass from tvm.target import Target @@ -27,15 +27,15 @@ from .schedule_rule import ScheduleRule -def _is_scheduled(func: tir.PrimFunc) -> bool: - if not isinstance(func, tir.PrimFunc): +def _is_scheduled(func: tirx.PrimFunc) -> bool: + if not isinstance(func, tirx.PrimFunc): return False - if "tir.is_scheduled" not in func.attrs: + if "tirx.is_scheduled" not in func.attrs: return False - return func.attrs["tir.is_scheduled"] == 1 + return func.attrs["tirx.is_scheduled"] == 1 -def _get_target(func: tir.PrimFunc) -> Target: +def _get_target(func: tirx.PrimFunc) -> Target: target = func.attrs.get("target") if target is None: return Target.current(allow_none=False) @@ -64,14 +64,14 @@ def transform_module( # pylint: disable=missing-function-docstring ) -> IRModule: updated_functions = {} for g_var, func in mod.functions_items(): - if isinstance(func, tir.PrimFunc) and not _is_scheduled(func): + if isinstance(func, tirx.PrimFunc) and not _is_scheduled(func): target = _get_target(func) sch = _apply_rules(func, target, self.rules, tunable=False) if sch is not None: assert len(sch) == 1 updated_functions[g_var] = ( - sch[0].mod["main"].with_attr("tir.is_scheduled", True) + sch[0].mod["main"].with_attr("tirx.is_scheduled", True) ) for g_var, func in updated_functions.items(): mod[g_var] = func @@ -79,7 +79,7 @@ def transform_module( # pylint: disable=missing-function-docstring def _apply_rules( - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, rules: list[ScheduleRule], tunable: bool, diff --git a/python/tvm/s_tir/dlight/base/utils.py b/python/tvm/s_tir/dlight/base/utils.py index 41abae75aff5..e16c22b8c5f8 100644 --- a/python/tvm/s_tir/dlight/base/utils.py +++ b/python/tvm/s_tir/dlight/base/utils.py @@ -17,7 +17,7 @@ # pylint: disable=missing-docstring """Utility methods for generic GPU.""" -from tvm import DataType, s_tir, tir +from tvm import DataType, s_tir, tirx from tvm.target import Target @@ -28,8 +28,8 @@ def get_bytes(dtype: DataType | str) -> int: def get_extent(sch: s_tir.Schedule, loop_rv: s_tir.schedule.LoopRV): - loop: tir.For = sch.get(loop_rv) - return loop.extent.value if isinstance(loop.extent, tir.IntImm) else loop.extent + loop: tirx.For = sch.get(loop_rv) + return loop.extent.value if isinstance(loop.extent, tirx.IntImm) else loop.extent def auto_vectorize(sch: s_tir.Schedule, loop: s_tir.schedule.LoopRV, max_vec: int): @@ -65,7 +65,7 @@ def max_threads_per_block(target: Target) -> int: def suggest_threads_per_block( target: Target, - loops: list[tir.For], + loops: list[tirx.For], max_threads_for_dynamic_loop: int = 32, ) -> list[int]: if target.kind.name == "cuda": @@ -82,7 +82,7 @@ def suggest_threads_per_block( dynamic: list[int] = [] for i, loop in enumerate(loops): loop_extent = loop.extent - if isinstance(loop_extent, tir.IntImm): + if isinstance(loop_extent, tirx.IntImm): loop_extent = loop_extent.value extent = 1 while extent <= loop_extent and extent <= threads: diff --git a/python/tvm/s_tir/dlight/benchmark/bench.py b/python/tvm/s_tir/dlight/benchmark/bench.py index 34d361fa752a..aa7aefc02cb2 100644 --- a/python/tvm/s_tir/dlight/benchmark/bench.py +++ b/python/tvm/s_tir/dlight/benchmark/bench.py @@ -24,7 +24,7 @@ from tvm.ir import IRModule from tvm.s_tir.meta_schedule.runner import EvaluatorConfig from tvm.s_tir.meta_schedule.testing.tune_utils import generate_input_data -from tvm.tir import PrimFunc +from tvm.tirx import PrimFunc from .extract import extract_all_func_info_from_relax, extract_func_info_from_prim_func from .utils import ( @@ -121,7 +121,7 @@ def benchmark( # append scalar input tensors for rotary embedding input_tensors.extend(scalar_input_tensors) # build locally - rt_mod = tvm.tir.build(mod, target=target) + rt_mod = tvm.tirx.build(mod, target=target) # set up evaluator config evaluator_config = EvaluatorConfig._normalized( # pylint: disable=protected-access evaluator_config diff --git a/python/tvm/s_tir/dlight/benchmark/extract.py b/python/tvm/s_tir/dlight/benchmark/extract.py index f25d018ed2f2..33d9b4402b0c 100644 --- a/python/tvm/s_tir/dlight/benchmark/extract.py +++ b/python/tvm/s_tir/dlight/benchmark/extract.py @@ -29,7 +29,7 @@ import tvm from tvm import relax -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.s_tir.dlight.benchmark import benchmark_prim_func @@ -128,11 +128,11 @@ def extract_dynamic_var( for arg in flattened_arg_list: if isinstance(arg, relax.TensorStructInfo): for val in arg.shape.values: - if isinstance(val, tvm.tir.Var): + if isinstance(val, tvm.tirx.Var): dym_var_dict[gv][str(val)] = val.dtype elif isinstance(arg, relax.ShapeStructInfo): for val in arg.values: - if isinstance(val, tvm.tir.Var): + if isinstance(val, tvm.tirx.Var): dym_var_dict[gv][str(val)] = val.dtype else: raise NotImplementedError @@ -159,19 +159,19 @@ def update_records( def extract_func_info_from_prim_func( - func: tvm.tir.PrimFunc, -) -> tuple[list[tuple[tuple[tvm.tir.Var | int, ...], str]], dict[str, str]]: + func: tvm.tirx.PrimFunc, +) -> tuple[list[tuple[tuple[tvm.tirx.Var | int, ...], str]], dict[str, str]]: """Extract function input information from a PrimFunc. Parameters ---------- - func : tvm.tir.PrimFunc + func : tvm.tirx.PrimFunc The PrimFunc to be analyzed. Returns ------- result : Tuple[ - List[Tuple[Tuple[Union[tvm.tir.Var, int], ...], str]], + List[Tuple[Tuple[Union[tvm.tirx.Var, int], ...], str]], Dict[str, str], ] The function input information and dynamic shape variable dictionary. @@ -182,9 +182,9 @@ def extract_func_info_from_prim_func( buffer = func.buffer_map[param] shape = [] for dim in buffer.shape: - if isinstance(dim, tvm.tir.IntImm): + if isinstance(dim, tvm.tirx.IntImm): shape.append(dim.value) - elif isinstance(dim, tvm.tir.Var): + elif isinstance(dim, tvm.tirx.Var): dym_var[str(dim)] = str(dim.dtype) shape.append(dim) else: @@ -223,7 +223,7 @@ def extract_all_func_info_from_relax( raw_args = binding.value.args functor = raw_args[0] if isinstance(functor, tvm.ir.GlobalVar) and isinstance( - mod.functions[functor], tvm.tir.PrimFunc + mod.functions[functor], tvm.tirx.PrimFunc ): args = extract_shape(raw_args[1:]) + extract_shape(binding.value) if isinstance(functor, tvm.ir.GlobalVar): @@ -240,7 +240,7 @@ def extract_prim_func( # pylint: disable=too-many-arguments model_name: str, relax_func_name: str, prim_func_name: str, - func: tvm.tir.PrimFunc, + func: tvm.tirx.PrimFunc, *, func_args: list[tuple[tuple[tvm.relax.expr.Call | int, ...], str]] | None = None, dym_var_dict: dict[str, str] | None = None, @@ -258,7 +258,7 @@ def extract_prim_func( # pylint: disable=too-many-arguments The name of the Relax function. prim_func_name: str The name of the prim function. - func: tvm.tir.PrimFunc + func: tvm.tirx.PrimFunc The PrimFunc to be extracted. func_args: Optional[List[Tuple[Tuple[Union[tvm.relax.expr.Call, int], ...], str]]] The arguments of the prim function, including both static and dynamic shape arguments. diff --git a/python/tvm/s_tir/dlight/benchmark/utils.py b/python/tvm/s_tir/dlight/benchmark/utils.py index b91dab290467..e25a5968499c 100644 --- a/python/tvm/s_tir/dlight/benchmark/utils.py +++ b/python/tvm/s_tir/dlight/benchmark/utils.py @@ -93,7 +93,7 @@ def populuate_input_shape( for dim in tensor_shape: if isinstance(dim, int): shape.append(dim) - elif isinstance(dim, tvm.tir.IntImm): + elif isinstance(dim, tvm.tirx.IntImm): shape.append(dim.value) else: shape.append(dym_var_sample[str(dim)]) diff --git a/python/tvm/s_tir/dlight/cpu/gemv.py b/python/tvm/s_tir/dlight/cpu/gemv.py index 3d97864bccbd..c4d86de40629 100644 --- a/python/tvm/s_tir/dlight/cpu/gemv.py +++ b/python/tvm/s_tir/dlight/cpu/gemv.py @@ -16,7 +16,7 @@ # under the License. """A rule for GEMV and DecodeGEMV.""" -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.target import Target from ..analysis import SBlockInfo, normalize_prim_func @@ -30,11 +30,11 @@ class GEMV(CPUScheduleRule): def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements, no-else-return self, - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, _: bool, ) -> None | s_tir.Schedule | list[s_tir.Schedule]: - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): return None sch = s_tir.Schedule(func) block_infos = normalize_prim_func(sch) @@ -77,7 +77,7 @@ def sch_inner_reduction( # pylint: disable=too-many-arguments, too-many-positio sch: s_tir.Schedule, target: Target, block: s_tir.schedule.SBlockRV, - vector_input_buffers: list[tir.Buffer], + vector_input_buffers: list[tirx.Buffer], epilogue_info: SBlockInfo | None, ): """Schedule the inner reduction block.""" diff --git a/python/tvm/s_tir/dlight/gpu/fallback.py b/python/tvm/s_tir/dlight/gpu/fallback.py index 73c525f4a5d9..bfc7b5bb3f93 100644 --- a/python/tvm/s_tir/dlight/gpu/fallback.py +++ b/python/tvm/s_tir/dlight/gpu/fallback.py @@ -17,7 +17,7 @@ # pylint: disable=missing-docstring """A fallback schedule rule for GPU operators.""" -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.target import Target from .. import base @@ -34,11 +34,11 @@ class Fallback(GPUScheduleRule): def apply( # pylint: disable=too-many-locals,missing-docstring self, - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, _: bool, ) -> s_tir.Schedule: - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): return None max_threads_per_block = base.max_threads_per_block(target) diff --git a/python/tvm/s_tir/dlight/gpu/gemv.py b/python/tvm/s_tir/dlight/gpu/gemv.py index 5c4ba19e91d1..7e555ce94625 100644 --- a/python/tvm/s_tir/dlight/gpu/gemv.py +++ b/python/tvm/s_tir/dlight/gpu/gemv.py @@ -19,7 +19,7 @@ from functools import reduce -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.target import Target from ..analysis import ( @@ -38,11 +38,11 @@ class GEMV(GPUScheduleRule): def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements self, - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, _: bool, ) -> None | s_tir.Schedule | list[s_tir.Schedule]: - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): return None sch = s_tir.Schedule(func) block_infos = normalize_prim_func(sch) @@ -89,7 +89,7 @@ def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, un sch: s_tir.Schedule, target: Target, block: s_tir.schedule.SBlockRV, - vector_input_buffers: list[tir.Buffer], + vector_input_buffers: list[tirx.Buffer], epilogue_info: SBlockInfo | None, ): """Schedule the inner reduction block.""" @@ -147,7 +147,7 @@ def apply( for buf in vector_input_buffers: dtype_bytes = get_bytes(buf.dtype) buf_size = ( - reduce(lambda x, y: x * y, buf.shape, tir.IntImm(buf.shape[0].dtype, 1)) + reduce(lambda x, y: x * y, buf.shape, tirx.IntImm(buf.shape[0].dtype, 1)) * dtype_bytes ) shared_mem_usage += buf_size @@ -158,7 +158,7 @@ def apply( LOAD_V_SHARED = ( LOAD_V_SHARED - and isinstance(shared_mem_usage, tir.IntImm) + and isinstance(shared_mem_usage, tirx.IntImm) and shared_mem_usage.value <= int(target.attrs["max_shared_memory_per_block"]) ) @@ -182,8 +182,8 @@ def apply( V_shared = sch.cache_read(rf, read_buffer_index=0, storage_scope="shared") sch.compute_at(V_shared, tr, preserve_unit_loops=True) l = sch.get_loops(block=V_shared)[-1] - loop: tir.For = sch.get(l) - if isinstance(loop.extent, tir.IntImm): + loop: tirx.For = sch.get(l) + if isinstance(loop.extent, tirx.IntImm): # avoid introducing predicates when vector length is too large vec_length = max( min( @@ -429,7 +429,7 @@ def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, un sch: s_tir.Schedule, target: Target, block: s_tir.schedule.SBlockRV, - vector_input_buffers: list[tir.Buffer], + vector_input_buffers: list[tirx.Buffer], epilogue_info: SBlockInfo | None, ): """Schedule the outer reduction block.""" @@ -634,7 +634,7 @@ def sch_outer_reduction_fallback( # pylint: disable=too-many-arguments, invalid sch: s_tir.Schedule, target: Target, block: s_tir.schedule.SBlockRV, - vector_input_buffers: list[tir.Buffer], + vector_input_buffers: list[tirx.Buffer], epilogue_info: SBlockInfo | None, ): """Schedule the outer reduction block.""" diff --git a/python/tvm/s_tir/dlight/gpu/general_reduction.py b/python/tvm/s_tir/dlight/gpu/general_reduction.py index 0005153b3ea7..2befacbf6298 100644 --- a/python/tvm/s_tir/dlight/gpu/general_reduction.py +++ b/python/tvm/s_tir/dlight/gpu/general_reduction.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name """Reduction rule for operators including softmax, layer norm, RMS norm, etc""" -from tvm import arith, s_tir, tir +from tvm import arith, s_tir, tirx from tvm.target import Target from ..analysis import normalize_prim_func @@ -30,11 +30,11 @@ class GeneralReduction(GPUScheduleRule): def apply( # pylint: disable=too-many-locals self, - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, _: bool, ) -> None | s_tir.Schedule | list[s_tir.Schedule]: - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): return None if target.kind.name == "cuda": @@ -92,17 +92,17 @@ def f_layout_mapping(*iters): target_layout_iters.append(iters[num_matched]) num_matched += 1 else: - target_layout_iters.append(tir.const(0, iters[0].dtype)) + target_layout_iters.append(tirx.const(0, iters[0].dtype)) # If all the iters of the last block can match, return the new layout. if num_matched == len(iters): return target_layout_iters # Otherwise, fallback to appending zeros in the beginning. - return [tir.const(0, iters[0].dtype)] * ( + return [tirx.const(0, iters[0].dtype)] * ( len(dom_kind) - num_last_block_iter ) + list(iters) - index_map = tir.IndexMap.from_func(f_layout_mapping, ndim=num_last_block_iter) + index_map = tirx.IndexMap.from_func(f_layout_mapping, ndim=num_last_block_iter) sch.transform_block_layout(block_infos[-1].block_rv, index_map) try: @@ -134,16 +134,16 @@ def f_layout_mapping(*iters): for block_iter, loop_rv in zip(spatial_block.iter_vars, loops): block_var_to_loop_var[block_iter.var] = sch.get(loop_rv).loop_var - def _visit_expr(e: tir.PrimExpr): - if isinstance(e, tir.Var) and e in block_var_to_loop_var: + def _visit_expr(e: tirx.PrimExpr): + if isinstance(e, tirx.Var) and e in block_var_to_loop_var: spatial_loops.add(block_var_to_loop_var[e]) for buffer_read in spatial_block.reads: buffer = buffer_read.buffer if buffer in reduced_buffers: for read_range in buffer_read.region: - tir.stmt_functor.post_order_visit(read_range.min, _visit_expr) - tir.stmt_functor.post_order_visit(read_range.extent, _visit_expr) + tirx.stmt_functor.post_order_visit(read_range.min, _visit_expr) + tirx.stmt_functor.post_order_visit(read_range.extent, _visit_expr) s_loops = [] other_loops = [] diff --git a/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py b/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py index ab08ec23e21d..197e1f897d94 100644 --- a/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py @@ -20,7 +20,7 @@ from functools import reduce from typing import Literal -from tvm import arith, ir, s_tir, tir +from tvm import arith, ir, s_tir, tirx from tvm.target import Target from ..analysis import ( @@ -34,23 +34,23 @@ from .base import GPUScheduleRule -def _get_reduction_expr(block: tir.SBlock) -> tir.PrimExpr | None: +def _get_reduction_expr(block: tirx.SBlock) -> tirx.PrimExpr | None: # Detect and return `Y` in `X[...] = X[...] + Y` buffer_store = block.body - if not isinstance(buffer_store, tir.BufferStore): + if not isinstance(buffer_store, tirx.BufferStore): return None - if not isinstance(buffer_store.value, tir.Add): + if not isinstance(buffer_store.value, tirx.Add): return None if not ir.structural_equal( buffer_store.value.a, - tir.BufferLoad(buffer_store.buffer, block.body.indices), + tirx.BufferLoad(buffer_store.buffer, block.body.indices), map_free_vars=True, ): return None return buffer_store.value.b -def is_gemv(sch: s_tir.Schedule, block_info: SBlockInfo) -> list[tir.Buffer] | None: +def is_gemv(sch: s_tir.Schedule, block_info: SBlockInfo) -> list[tirx.Buffer] | None: """Check if the block is a low batch GEMM. Parameters @@ -65,7 +65,7 @@ def is_gemv(sch: s_tir.Schedule, block_info: SBlockInfo) -> list[tir.Buffer] | N Returns ------- - ret : Optional[List[tir.Buffer]] + ret : Optional[List[tirx.Buffer]] The vector-like buffers used in the low batch GEMM if it is a low batch GEMM, otherwise None. """ @@ -85,16 +85,16 @@ def is_gemv(sch: s_tir.Schedule, block_info: SBlockInfo) -> list[tir.Buffer] | N const_iter_vars = set( iter_var.var for iter_var in block_stmt.iter_vars - if isinstance(iter_var.dom.extent, tir.IntImm) + if isinstance(iter_var.dom.extent, tirx.IntImm) ) if len(block_stmt.iter_vars) - len(const_iter_vars) != 1: return None symbolic_iter_var = next( iter_var for iter_var in block_stmt.iter_vars - if not isinstance(iter_var.dom.extent, tir.IntImm) + if not isinstance(iter_var.dom.extent, tirx.IntImm) ) - if symbolic_iter_var.iter_type != tir.stmt.IterVar.DataPar: + if symbolic_iter_var.iter_type != tirx.stmt.IterVar.DataPar: return None ret = [ read.buffer @@ -111,7 +111,7 @@ def is_gemv(sch: s_tir.Schedule, block_info: SBlockInfo) -> list[tir.Buffer] | N return ret if 0 < len(ret) < len(block_stmt.reads) else None -def detect_dominant_read(block: tir.SBlock, const_iter_vars: set[tir.Var]) -> tir.PrimExpr: +def detect_dominant_read(block: tirx.SBlock, const_iter_vars: set[tirx.Var]) -> tirx.PrimExpr: """Detect the dominant read indices in the block.""" dominant_read = None num_read_iters = -1 @@ -133,11 +133,11 @@ def normalize( block_info: SBlockInfo, ) -> bool | None: """Normalize the main block.""" - block_stmt: tir.SBlock = sch.get(block_info.block_rv) + block_stmt: tirx.SBlock = sch.get(block_info.block_rv) const_iter_vars = set( iter_var.var for iter_var in block_stmt.iter_vars - if isinstance(iter_var.dom.extent, tir.IntImm) + if isinstance(iter_var.dom.extent, tirx.IntImm) ) dynamic_iter_vars = set( iter_var.var for iter_var in block_stmt.iter_vars if iter_var.var not in const_iter_vars @@ -196,11 +196,11 @@ def __init__(self, bucket=4): def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements self, - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, _: bool, ) -> None | s_tir.Schedule | list[s_tir.Schedule]: - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): return None sch = s_tir.Schedule(func) block_infos = normalize_prim_func(sch) @@ -293,7 +293,7 @@ def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, un block: s_tir.schedule.SBlockRV, dequantize_block: s_tir.schedule.SBlockRV | None, pad_input_block: s_tir.schedule.SBlockRV | None, - vector_input_buffers: list[tir.Buffer], + vector_input_buffers: list[tirx.Buffer], epilogue_info: SBlockInfo | None, batch_pad: int, ): @@ -351,12 +351,12 @@ def apply( shared_mem_usage = 0 for buf in vector_input_buffers: buf_size = reduce( - lambda x, y: x * y, buf.shape, tir.IntImm(buf.shape[0].dtype, 1) + lambda x, y: x * y, buf.shape, tirx.IntImm(buf.shape[0].dtype, 1) ) * get_bytes(buf.dtype) shared_mem_usage += buf_size LOAD_V_SHARED = ( LOAD_V_SHARED - and isinstance(shared_mem_usage, tir.IntImm) + and isinstance(shared_mem_usage, tirx.IntImm) and shared_mem_usage.value <= int(target.attrs["max_shared_memory_per_block"]) ) @@ -380,8 +380,8 @@ def apply( V_shared = sch.cache_read(rf, read_buffer_index=0, storage_scope="shared") sch.compute_at(V_shared, tr, preserve_unit_loops=True) l = sch.get_loops(block=V_shared)[-1] - loop: tir.For = sch.get(l) - if isinstance(loop.extent, tir.IntImm): + loop: tirx.For = sch.get(l) + if isinstance(loop.extent, tirx.IntImm): # avoid introducing predicates when vector length is too large vec_length = max( min( @@ -605,7 +605,7 @@ def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, un block: s_tir.schedule.SBlockRV, dequantize_block: s_tir.schedule.SBlockRV | None, pad_input_block: s_tir.schedule.SBlockRV | None, - vector_input_buffers: list[tir.Buffer], + vector_input_buffers: list[tirx.Buffer], epilogue_info: SBlockInfo | None, batch_pad: int, ): diff --git a/python/tvm/s_tir/dlight/gpu/matmul.py b/python/tvm/s_tir/dlight/gpu/matmul.py index 0bacb3e00237..ef9392cc2f3c 100644 --- a/python/tvm/s_tir/dlight/gpu/matmul.py +++ b/python/tvm/s_tir/dlight/gpu/matmul.py @@ -21,13 +21,13 @@ from dataclasses import dataclass from enum import Enum -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.ir import Range from tvm.s_tir.schedule.schedule import SBlockRV -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target -from tvm.tir import IterVar, PrimExpr, Var -from tvm.tir.analysis import undefined_vars +from tvm.tirx import IterVar, PrimExpr, Var +from tvm.tirx.analysis import undefined_vars from ..analysis import IterInfo, SBlockInfo, get_root_block from .base import GPUScheduleRule @@ -138,17 +138,17 @@ class IterTrait: def _is_one(x: PrimExpr) -> bool: - return isinstance(x, tir.IntImm) and x.value == 1 + return isinstance(x, tirx.IntImm) and x.value == 1 def make_iter_fusion_index_map( traits: list[IterTrait], kind_order: list[IterKind], -) -> tir.IndexMap: +) -> tirx.IndexMap: fused_iters: dict[IterKind, PrimExpr] = {} - input_iters: list[tir.Var] = [] + input_iters: list[tirx.Var] = [] for i, trait in enumerate(traits): - v_i = tir.Var(f"i{i}", trait.extent.dtype) + v_i = tirx.Var(f"i{i}", trait.extent.dtype) input_iters.append(v_i) if trait.kind == IterKind.kIter_T: continue @@ -159,19 +159,19 @@ def make_iter_fusion_index_map( else: fused_iters[trait.kind] = v_i - final_indices: list[tir.PrimExpr] = [ - fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order + final_indices: list[tirx.PrimExpr] = [ + fused_iters.get(kind, tirx.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order ] - return tir.IndexMap(input_iters, final_indices, None) + return tirx.IndexMap(input_iters, final_indices, None) -def detect_iter_traits(block: tir.SBlock) -> tuple[list[IterTrait]] | None: +def detect_iter_traits(block: tirx.SBlock) -> tuple[list[IterTrait]] | None: """Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K] Parameters ---------- - block : tir.SBlock + block : tirx.SBlock The block to be analyzed Returns @@ -215,7 +215,7 @@ def get_access_axes(region: list[Range]) -> set[Var]: kind = IterKind.kIter_J else: return None - elif iter_var.iter_type == tir.IterVar.CommReduce: + elif iter_var.iter_type == tirx.IterVar.CommReduce: if var in A_axes and var in B_axes and var not in C_axes: kind = IterKind.kIter_K else: @@ -236,17 +236,17 @@ def get_access_axes(region: list[Range]) -> set[Var]: return A_traits, B_traits, C_traits, block_traits -def get_index_map(block: tir.SBlock) -> tuple[tir.IndexMap, ...] | None: +def get_index_map(block: tirx.SBlock) -> tuple[tirx.IndexMap, ...] | None: """Get index maps for the block Parameters ---------- - block : tir.SBlock + block : tirx.SBlock The block to be analyzed Returns ------- - index_maps : Optional[Tuple[tir.IndexMap]] + index_maps : Optional[Tuple[tirx.IndexMap]] The index maps for the block, or None if the block is not a gemm-liked kernel """ traits = detect_iter_traits(block) @@ -276,8 +276,8 @@ def get_index_map(block: tir.SBlock) -> tuple[tir.IndexMap, ...] | None: def get_sblock_info(sch: s_tir.Schedule, block: s_tir.schedule.SBlockRV) -> SBlockInfo: - def _iter_kind(loop: tir.IterVar) -> str: - return {tir.IterVar.DataPar: "S", tir.IterVar.CommReduce: "R"}.get(loop.iter_type, "O") + def _iter_kind(loop: tirx.IterVar) -> str: + return {tirx.IterVar.DataPar: "S", tirx.IterVar.CommReduce: "R"}.get(loop.iter_type, "O") def _is_reduction_block(block: s_tir.schedule.SBlockRV): for iter_var in sch.get(block).iter_vars: @@ -326,7 +326,7 @@ def is_spatial(block: SBlockRV) -> bool: return reduction_blocks -def get_in_out_dtypes(block: tir.SBlock) -> tuple[str]: +def get_in_out_dtypes(block: tirx.SBlock) -> tuple[str]: """ Detect In/Out data types for the given block based on the analysis if read/write buffers. """ @@ -348,7 +348,7 @@ class MetalMatmul(GPUScheduleRule): def apply( # pylint: disable=too-many-locals,missing-docstring self, - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, _: bool, ) -> s_tir.Schedule | None: @@ -356,7 +356,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring get_simdgroup_intrin_group, ) - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): return None sch = s_tir.Schedule(func) root_block = get_root_block(sch) @@ -489,7 +489,7 @@ class MatmulTensorization(GPUScheduleRule): def apply( # pylint: disable=too-many-locals,missing-docstring self, - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, _: bool, ) -> s_tir.Schedule | None: @@ -497,7 +497,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring get_wmma_intrin_group, ) - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): return None sch = s_tir.Schedule(func) root_block = get_root_block(sch) @@ -605,7 +605,7 @@ def fetch_to_shared(block, idx, ndim): sch.vectorize(f_3) sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8) - sch.annotate(block_read, "tir.manifest_shared_memory_local_stage", 1) + sch.annotate(block_read, "tirx.manifest_shared_memory_local_stage", 1) sch.annotate(block_read, "double_buffer_scope", 0) return block_read @@ -710,7 +710,7 @@ class MatmulInt8Tensorization(GPUScheduleRule): def apply( # pylint: disable=too-many-locals,missing-docstring self, - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, _: bool, ) -> s_tir.Schedule | None: @@ -718,7 +718,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring get_wmma_intrin_group, ) - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): return None sch = s_tir.Schedule(func) root_block = get_root_block(sch) @@ -826,7 +826,7 @@ def fetch_to_shared(block, idx, ndim): sch.vectorize(f_3) sch.storage_align(block_read, 0, axis=-2, factor=32, offset=16) - sch.annotate(block_read, "tir.manifest_shared_memory_local_stage", 1) + sch.annotate(block_read, "tirx.manifest_shared_memory_local_stage", 1) sch.annotate(block_read, "double_buffer_scope", 0) return block_read @@ -964,11 +964,11 @@ def get_configs(self, target: Target) -> Config: def apply( # pylint: disable=too-many-locals,missing-docstring self, - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, _: bool, ) -> s_tir.Schedule | None: - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): return None sch = s_tir.Schedule(func) config = self.get_configs(target) @@ -1026,7 +1026,7 @@ def is_inner_reduction(block_stmt, iter_infos): # the batch dimension is not taken into consideration. for item_var in block_stmt.iter_vars[1:]: extent = item_var.dom.extent - if isinstance(extent, tir.expr.IntImm): + if isinstance(extent, tirx.expr.IntImm): if extent.value <= minimal_tensorize_threshold: apply_tensorization = False if apply_tensorization: @@ -1146,9 +1146,9 @@ def get_max_factor(n, factors): mb, ms, n, k = reduction_loops if not ( - isinstance(sch.get(n).extent, tir.IntImm) - and isinstance(sch.get(mb).extent, tir.IntImm) - and not isinstance(sch.get(ms).extent, tir.IntImm) + isinstance(sch.get(n).extent, tirx.IntImm) + and isinstance(sch.get(mb).extent, tirx.IntImm) + and not isinstance(sch.get(ms).extent, tirx.IntImm) ): return None diff --git a/python/tvm/s_tir/dlight/gpu/reduction.py b/python/tvm/s_tir/dlight/gpu/reduction.py index f0c58ff3c423..af310c25c5ce 100644 --- a/python/tvm/s_tir/dlight/gpu/reduction.py +++ b/python/tvm/s_tir/dlight/gpu/reduction.py @@ -19,7 +19,7 @@ # TODO: combine reduction rule and general reduction rule into one file. from collections.abc import Mapping -from tvm import arith, ir, s_tir, tir +from tvm import arith, ir, s_tir, tirx from tvm.target import Target from ..analysis import ( @@ -32,16 +32,16 @@ from .base import GPUScheduleRule -def _get_reduction_expr(block: tir.SBlock) -> tir.PrimExpr | None: +def _get_reduction_expr(block: tirx.SBlock) -> tirx.PrimExpr | None: # Detect and return `Y` in `X[...] = X[...] + Y` buffer_store = block.body - if not isinstance(buffer_store, tir.BufferStore): + if not isinstance(buffer_store, tirx.BufferStore): return None - if not isinstance(buffer_store.value, tir.Add): + if not isinstance(buffer_store.value, tirx.Add): return None if not ir.structural_equal( buffer_store.value.a, - tir.BufferLoad(buffer_store.buffer, block.body.indices), + tirx.BufferLoad(buffer_store.buffer, block.body.indices), map_free_vars=True, ): return None @@ -57,11 +57,11 @@ class Reduction(GPUScheduleRule): def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements self, - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, _: bool, ) -> None | s_tir.Schedule | list[s_tir.Schedule]: - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): return None sch = s_tir.Schedule(func) block_infos = normalize_prim_func(sch) diff --git a/python/tvm/s_tir/dlight/gpu/rmsnorm.py b/python/tvm/s_tir/dlight/gpu/rmsnorm.py index 9497a1589713..5b053b14d594 100644 --- a/python/tvm/s_tir/dlight/gpu/rmsnorm.py +++ b/python/tvm/s_tir/dlight/gpu/rmsnorm.py @@ -18,10 +18,10 @@ """A RMS norm schedule rule for GPU operators.""" import tvm -from tvm import tir +from tvm import tirx from tvm.target import Target -from tvm.tir import BufferStore, SBlock -from tvm.tir.expr import BufferLoad, Call, Cast +from tvm.tirx import BufferStore, SBlock +from tvm.tirx.expr import BufferLoad, Call, Cast from ..base import ScheduleRule @@ -68,7 +68,7 @@ def identify_rsqrt_block(block: SBlock) -> bool: call = store.value op = call.op - return op == tvm.ir.op.Op.get("tir.rsqrt") + return op == tvm.ir.op.Op.get("tirx.rsqrt") class RMSNorm(ScheduleRule): @@ -76,7 +76,7 @@ class RMSNorm(ScheduleRule): def apply( # pylint: disable=too-many-locals,missing-docstring self, - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, _: bool, ) -> "tvm.s_tir.Schedule": diff --git a/python/tvm/s_tir/dlight/gpu/transpose.py b/python/tvm/s_tir/dlight/gpu/transpose.py index 66f4d613d19e..3e5a28225bce 100644 --- a/python/tvm/s_tir/dlight/gpu/transpose.py +++ b/python/tvm/s_tir/dlight/gpu/transpose.py @@ -16,7 +16,7 @@ # under the License. """Reduction rule for operators including softmax, layer norm, RMS norm, etc""" -from tvm import arith, s_tir, tir +from tvm import arith, s_tir, tirx from tvm.s_tir import Schedule from tvm.s_tir.schedule import SBlockRV from tvm.target import Target @@ -31,9 +31,9 @@ class Transpose(GPUScheduleRule): def is_transpose(self, sch: Schedule, block_rv: SBlockRV): block = sch.get(block_rv) - if isinstance(block.body, tir.BufferStore): + if isinstance(block.body, tirx.BufferStore): rhs = block.body.value - if isinstance(rhs, tir.BufferLoad): + if isinstance(rhs, tirx.BufferLoad): lhs_indices = block.body.indices rhs_indices = rhs.indices if list(lhs_indices) != list(rhs_indices) and set(lhs_indices) == set(rhs_indices): @@ -42,12 +42,12 @@ def is_transpose(self, sch: Schedule, block_rv: SBlockRV): def apply( # pylint: disable=too-many-locals self, - func: tir.PrimFunc, + func: tirx.PrimFunc, target: Target, _: bool, ) -> None | s_tir.Schedule | list[s_tir.Schedule]: # pylint: disable=invalid-name - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + if not isinstance(func, tirx.PrimFunc) or not self.is_target_available(target): return None if target.kind.name == "cuda": len_tx = 16 diff --git a/python/tvm/s_tir/meta_schedule/arg_info.py b/python/tvm/s_tir/meta_schedule/arg_info.py index bd52cc7a9680..e01b742e9972 100644 --- a/python/tvm/s_tir/meta_schedule/arg_info.py +++ b/python/tvm/s_tir/meta_schedule/arg_info.py @@ -22,7 +22,7 @@ from tvm.ir import IRModule from tvm.runtime import DataType, Object, ShapeTuple -from tvm.tir import PrimFunc +from tvm.tirx import PrimFunc from . import _ffi_api from .utils import _json_de_tvm diff --git a/python/tvm/s_tir/meta_schedule/database/json_database.py b/python/tvm/s_tir/meta_schedule/database/json_database.py index 8ca76c7e30c1..0dea9873b34a 100644 --- a/python/tvm/s_tir/meta_schedule/database/json_database.py +++ b/python/tvm/s_tir/meta_schedule/database/json_database.py @@ -43,7 +43,7 @@ class JSONDatabase(Database): - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a given module. The "ignore-tensor" varint is used for the extracted blocks or in case no anchor block is found. - For the definition of the anchor block, see tir/analysis/analysis.py. + For the definition of the anchor block, see tirx/analysis/analysis.py. """ path_workload: str diff --git a/python/tvm/s_tir/meta_schedule/database/memory_database.py b/python/tvm/s_tir/meta_schedule/database/memory_database.py index 3fd8ec0663f7..6fa78c3b9622 100644 --- a/python/tvm/s_tir/meta_schedule/database/memory_database.py +++ b/python/tvm/s_tir/meta_schedule/database/memory_database.py @@ -37,7 +37,7 @@ class MemoryDatabase(Database): - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a given module. The "ignore-tensor" varint is used for the extracted blocks or in case no anchor block is found. - For the definition of the anchor block, see tir/analysis/analysis.py. + For the definition of the anchor block, see tirx/analysis/analysis.py. """ def __init__( diff --git a/python/tvm/s_tir/meta_schedule/database/schedule_fn_database.py b/python/tvm/s_tir/meta_schedule/database/schedule_fn_database.py index 62057705fc4d..c1be2bc0b971 100644 --- a/python/tvm/s_tir/meta_schedule/database/schedule_fn_database.py +++ b/python/tvm/s_tir/meta_schedule/database/schedule_fn_database.py @@ -44,7 +44,7 @@ class ScheduleFnDatabase(Database): - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a given module. The "ignore-tensor" varint is used for the extracted blocks or in case no anchor block is found. - For the definition of the anchor block, see tir/analysis/analysis.py. + For the definition of the anchor block, see tirx/analysis/analysis.py. """ def __init__( diff --git a/python/tvm/s_tir/meta_schedule/relax_integration.py b/python/tvm/s_tir/meta_schedule/relax_integration.py index 161e159a9a42..08fe1a434d1d 100644 --- a/python/tvm/s_tir/meta_schedule/relax_integration.py +++ b/python/tvm/s_tir/meta_schedule/relax_integration.py @@ -30,7 +30,7 @@ from tvm.ir.transform import PassContext from tvm.runtime import Tensor from tvm.target import Target -from tvm.tir.expr import IntImm +from tvm.tirx.expr import IntImm from .builder import Builder from .cost_model import CostModel @@ -80,7 +80,7 @@ def extract_tasks( - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a given module. The "ignore-tensor" varint is used for the extracted blocks or in case no anchor block is found. - For the definition of the anchor block, see tir/analysis/analysis.py. + For the definition of the anchor block, see tirx/analysis/analysis.py. Returns ------- @@ -140,7 +140,7 @@ def extracted_tasks_to_tune_contexts( get_loggers_from_work_dir(work_dir, [t.task_name for t in extracted_tasks]), fork_seed(seed, n=len(extracted_tasks)), ): - if task.mod.attrs.get("tir.is_scheduled", False): + if task.mod.attrs.get("tirx.is_scheduled", False): warnings.warn("The task {task.task_name} is already scheduled, skipping it.") continue tasks.append( @@ -228,7 +228,7 @@ def tune_relax( - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a given module. The "ignore-tensor" variant is used for the extracted blocks or in case no anchor block is found. - For the definition of the anchor block, see tir/analysis/analysis.py. + For the definition of the anchor block, see tirx/analysis/analysis.py. Returns ------- @@ -341,7 +341,7 @@ def _tune_relax( - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a given module. The "ignore-tensor" varint is used for the extracted blocks or in case no anchor block is found. - For the definition of the anchor block, see tir/analysis/analysis.py. + For the definition of the anchor block, see tirx/analysis/analysis.py. Returns ------- diff --git a/python/tvm/s_tir/meta_schedule/schedule/cuda/layout_transform.py b/python/tvm/s_tir/meta_schedule/schedule/cuda/layout_transform.py index 7f34d81c208b..6e1e85797f3e 100644 --- a/python/tvm/s_tir/meta_schedule/schedule/cuda/layout_transform.py +++ b/python/tvm/s_tir/meta_schedule/schedule/cuda/layout_transform.py @@ -430,7 +430,7 @@ def unpack_list(target_list) -> list: block_read = sch.reindex_cache_read( block_write, read_buffer_index=0, - index_map=tvm.tir.IndexMap.from_func( + index_map=tvm.tirx.IndexMap.from_func( lambda *loops: [loops[dst_to_src_map[i]] for i, _ in enumerate(loops)], ndim=len(new_src_layout_str), ), diff --git a/python/tvm/s_tir/meta_schedule/testing/te_workload.py b/python/tvm/s_tir/meta_schedule/testing/te_workload.py index ae40249d2799..11118652b8b6 100644 --- a/python/tvm/s_tir/meta_schedule/testing/te_workload.py +++ b/python/tvm/s_tir/meta_schedule/testing/te_workload.py @@ -19,7 +19,7 @@ # pylint: disable=missing-docstring -from tvm import te, tir, topi +from tvm import te, tirx, topi from tvm.target import Target @@ -298,7 +298,7 @@ def _dilate(*indices): index_tuple.append(indices[i]) if not_zero: not_zero = te.all(*not_zero) - return te.if_then_else(not_zero, padded(*index_tuple), tir.const(0.0, padded.dtype)) + return te.if_then_else(not_zero, padded(*index_tuple), tirx.const(0.0, padded.dtype)) return padded(*index_tuple) # convolution stage @@ -655,7 +655,7 @@ def softmax_mn(m, n) -> tuple[te.Tensor, te.Tensor]: # pylint: disable=invalid- return (a, b) -def create_te_workload(name: str, idx: int) -> tir.PrimFunc: +def create_te_workload(name: str, idx: int) -> tirx.PrimFunc: workload_func, params = CONFIGS[name] return te.create_prim_func(workload_func(*params[idx])) # type: ignore diff --git a/python/tvm/s_tir/meta_schedule/testing/tune_utils.py b/python/tvm/s_tir/meta_schedule/testing/tune_utils.py index b93f381048c9..4ead75bdec7a 100644 --- a/python/tvm/s_tir/meta_schedule/testing/tune_utils.py +++ b/python/tvm/s_tir/meta_schedule/testing/tune_utils.py @@ -72,7 +72,7 @@ def create_calculator(backend: str) -> Callable: Parameters ---------- backend : str - The backend to use, only tir is supported for now. + The backend to use, only tirx is supported for now. Returns ------- @@ -97,7 +97,7 @@ def f_calculator( The input data as a dictionary. """ try: - if backend == "tir": + if backend == "tirx": data = [v for _, v in sorted(input_data.items(), key=lambda x: x[0])] rt_mod(*data) return data diff --git a/python/tvm/s_tir/meta_schedule/tir_integration.py b/python/tvm/s_tir/meta_schedule/tir_integration.py index 3fe6d090a4de..810d52462de3 100644 --- a/python/tvm/s_tir/meta_schedule/tir_integration.py +++ b/python/tvm/s_tir/meta_schedule/tir_integration.py @@ -23,10 +23,10 @@ from tvm_ffi import register_global_func # isort: on -from tvm import ir, tir +from tvm import ir, tirx from tvm.s_tir.schedule import Schedule as _Schedule from tvm.target import Target -from tvm.tir.expr import IntImm +from tvm.tirx.expr import IntImm from .builder import Builder from .cost_model import CostModel @@ -43,7 +43,7 @@ def tune_tir( # pylint: disable=too-many-locals - mod: ir.IRModule | tir.PrimFunc, + mod: ir.IRModule | tirx.PrimFunc, target: str | Target, work_dir: str, max_trials_global: int, @@ -68,7 +68,7 @@ def tune_tir( # pylint: disable=too-many-locals Parameters ---------- - mod : Union[ir.IRModule, tir.PrimFunc] + mod : Union[ir.IRModule, tirx.PrimFunc] The TIR IRModule to tune. target : Union[str, Target] The target to tune for. @@ -110,12 +110,12 @@ def tune_tir( # pylint: disable=too-many-locals database : Database The database with all tuning records """ - if isinstance(mod, tir.PrimFunc): + if isinstance(mod, tirx.PrimFunc): mod = _normalize_mod(mod) - named_tasks: list[tuple[str, tir.PrimFunc]] = [] + named_tasks: list[tuple[str, tirx.PrimFunc]] = [] for gv, func in mod.functions_items(): # pylint: disable=invalid-name - if isinstance(func, tir.PrimFunc): + if isinstance(func, tirx.PrimFunc): named_tasks.append((gv.name_hint, func)) named_tasks.sort(key=lambda x: x[0]) @@ -165,7 +165,7 @@ def tune_tir( # pylint: disable=too-many-locals @register_global_func("tvm.s_tir.meta_schedule.tune_tir") def _tune_tir( - mod: ir.IRModule | tir.PrimFunc, + mod: ir.IRModule | tirx.PrimFunc, target: str | Target, work_dir: str, max_trials_global: int, @@ -186,7 +186,7 @@ def _tune_tir( Parameters ---------- - mod : Union[ir.IRModule, tir.PrimFunc] + mod : Union[ir.IRModule, tirx.PrimFunc] The TIR function to tune. target : Union[str, Target] The target to tune for. @@ -248,7 +248,7 @@ def _tune_tir( def compile_tir( database: Database, - mod: ir.IRModule | tir.PrimFunc, + mod: ir.IRModule | tirx.PrimFunc, target: Target | str, ) -> _Schedule: """Compile a TIR to s_tir.Schedule, according to the records in the database. @@ -257,7 +257,7 @@ def compile_tir( ---------- database : Database The database of tuning records. - mod : Union[ir.IRModule, tir.PrimFunc] + mod : Union[ir.IRModule, tirx.PrimFunc] The TIR function to tune. target : Union[str, Target] The target to tune for. diff --git a/python/tvm/s_tir/meta_schedule/tune.py b/python/tvm/s_tir/meta_schedule/tune.py index e5f4c5d2dd4f..16e5ae484211 100644 --- a/python/tvm/s_tir/meta_schedule/tune.py +++ b/python/tvm/s_tir/meta_schedule/tune.py @@ -81,7 +81,7 @@ def tune_tasks( - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a given module. The "ignore-tensor" varint is used for the extracted blocks or in case no anchor block is found. For the definition of the anchor block, see - tir/analysis/analysis.py. + tirx/analysis/analysis.py. post_optimization : Optional[Bool] Generate post-optimization using Droplet Search as exploitation space. diff --git a/python/tvm/s_tir/meta_schedule/tune_context.py b/python/tvm/s_tir/meta_schedule/tune_context.py index 09de3afc5ebf..3b0a0daba66d 100644 --- a/python/tvm/s_tir/meta_schedule/tune_context.py +++ b/python/tvm/s_tir/meta_schedule/tune_context.py @@ -30,7 +30,7 @@ from tvm.runtime import Object from tvm.s_tir import Schedule from tvm.target import Target -from tvm.tir import PrimFunc +from tvm.tirx import PrimFunc from . import _ffi_api from .logging import Logger, get_logger, get_logging_func @@ -50,7 +50,7 @@ def _normalize_mod(mod: PrimFunc | IRModule) -> IRModule: if isinstance(mod, PrimFunc): if not (mod.attrs and "global_symbol" in mod.attrs): mod = mod.with_attr("global_symbol", "main") - mod = mod.with_attr("tir.noalias", True) + mod = mod.with_attr("tirx.noalias", True) mod = IRModule({"main": mod}) if not isinstance(mod, IRModule): raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") diff --git a/python/tvm/s_tir/meta_schedule/utils.py b/python/tvm/s_tir/meta_schedule/utils.py index 10c8327b2bbe..b4cd2f4009df 100644 --- a/python/tvm/s_tir/meta_schedule/utils.py +++ b/python/tvm/s_tir/meta_schedule/utils.py @@ -30,7 +30,7 @@ from tvm.ir import Array, IRModule, Map from tvm.rpc import RPCSession from tvm.runtime import PackedFunc -from tvm.tir import FloatImm, IntImm +from tvm.tirx import FloatImm, IntImm def derived_object(cls: type) -> type: diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py index ce39b207c94f..f775c0dd1eac 100644 --- a/python/tvm/s_tir/pipeline.py +++ b/python/tvm/s_tir/pipeline.py @@ -19,12 +19,12 @@ """The S-TIR backend compilation pipeline.""" import tvm -from tvm import s_tir, tir -from tvm.tir import pipeline as tir_pipeline +from tvm import s_tir, tirx +from tvm.tirx import pipeline as tir_pipeline def default_s_tir_pipeline(): - """The default tir pipeline used in tvm.tir.build""" + """The default tirx pipeline used in tvm.tirx.build""" @tvm.transform.module_pass(opt_level=0) def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: @@ -43,55 +43,55 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I s_tir.transform.LowerAutoCopy(), s_tir.transform.UnifyThreadBinding(), s_tir.transform.LowerMatchBuffer(), - tir.transform.Simplify(), + tirx.transform.Simplify(), s_tir.transform.InjectPermutedLayout(), s_tir.transform.AnnotateIrregularLoop(), s_tir.transform.InjectSoftwarePipeline(), s_tir.transform.TransformMmaBufferLayout(), s_tir.transform.LowerOpaqueBlock(), - tir.transform.FlattenBuffer(), - tir.transform.BF16ComputeLegalize(), - tir.transform.NarrowDataType(32), + tirx.transform.FlattenBuffer(), + tirx.transform.BF16ComputeLegalize(), + tirx.transform.NarrowDataType(32), s_tir.transform.LoopPartition(), - tir.transform.VectorizeLoop(not bool(config.get("tir.disable_vectorize", False))), + tirx.transform.VectorizeLoop(not bool(config.get("tirx.disable_vectorize", False))), s_tir.transform.InjectVirtualThread(), s_tir.transform.InjectDoubleBuffer(), ] - if not bool(config.get("tir.disable_storage_rewrite", False)): - passes.append(tir.transform.StorageRewrite()) - if config.get("tir.use_async_copy", False): + if not bool(config.get("tirx.disable_storage_rewrite", False)): + passes.append(tirx.transform.StorageRewrite()) + if config.get("tirx.use_async_copy", False): passes.append(s_tir.transform.LowerAsyncDMA()) passes.extend( [ s_tir.transform.HoistIfThenElse(), - tir.transform.UnrollLoop(), + tirx.transform.UnrollLoop(), s_tir.transform.RenormalizeSplitPattern(), - tir.transform.Simplify(), - tir.transform.RemoveNoOp(), + tirx.transform.Simplify(), + tirx.transform.RemoveNoOp(), s_tir.transform.RewriteUnsafeSelect(), ] ) # Additional passes based on configuration. - if bool(config.get("tir.instrument_bound_checkers", False)): + if bool(config.get("tirx.instrument_bound_checkers", False)): passes.append(s_tir.transform.InstrumentBoundCheckers()) - if bool(config.get("tir.ptx_ldg32", False)): + if bool(config.get("tirx.ptx_ldg32", False)): passes.append(s_tir.transform.InjectPTXLDG32(True)) - if not bool(config.get("tir.disable_cse_tir", False)): - passes.append(tir.transform.CommonSubexprElim()) - if bool(config.get("tir.instrument_lwp", False)): + if not bool(config.get("tirx.disable_cse_tir", False)): + passes.append(tirx.transform.CommonSubexprElim()) + if bool(config.get("tirx.instrument_lwp", False)): passes.append(s_tir.transform.InstrumentProfileIntrinsics()) passes.extend( [ # Bind the target first so that target-specific attributes are available. - tir.transform.FP8ComputeLegalize(), + tirx.transform.FP8ComputeLegalize(), # VerifyVTCMLimit must occur before LowerVtcmAlloc. s_tir.transform.VerifyVTCMLimit(), s_tir.transform.LowerVtcmAlloc(), - tir.transform.VerifyMemory(), - tir.transform.AnnotateEntryFunc(), + tirx.transform.VerifyMemory(), + tirx.transform.AnnotateEntryFunc(), ] ) - if bool(config.get("tir.detect_global_barrier", False)): + if bool(config.get("tirx.detect_global_barrier", False)): passes.append(s_tir.transform.ThreadSync("global")) passes.extend( [ @@ -102,20 +102,20 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I s_tir.transform.LowerThreadAllreduce(), ] ) - if bool(config.get("tir.use_async_copy", False)): + if bool(config.get("tirx.use_async_copy", False)): passes.append(s_tir.transform.InjectPTXAsyncCopy()) - if bool(config.get("tir.ptx_ldg32", False)): + if bool(config.get("tirx.ptx_ldg32", False)): passes.append(s_tir.transform.InjectPTXLDG32()) passes.extend( [ - tir.transform.AnnotateDeviceRegions(), - tir.transform.SplitHostDevice(), + tirx.transform.AnnotateDeviceRegions(), + tirx.transform.SplitHostDevice(), # MergeSharedMemoryAllocations must follow SplitHostDevice. s_tir.transform.MergeSharedMemoryAllocations(), - tir.transform.MakePackedAPI(), - tir.transform.FP8StorageLegalize(), - tir.transform.BF16StorageLegalize(), - tir.transform.LowerDeviceKernelLaunch(), + tirx.transform.MakePackedAPI(), + tirx.transform.FP8StorageLegalize(), + tirx.transform.BF16StorageLegalize(), + tirx.transform.LowerDeviceKernelLaunch(), ] ) mod = tvm.ir.transform.Sequential(passes)(mod) @@ -127,9 +127,9 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I def finalize_host_passes(): # pylint: disable=unused-argument """The default finalization passes for TIR backend.""" host_pass_list = [ - tir.transform.LowerTVMBuiltin(), - tir.transform.LowerCustomDatatypes(), - tir.transform.LowerIntrin(), + tirx.transform.LowerTVMBuiltin(), + tirx.transform.LowerCustomDatatypes(), + tirx.transform.LowerIntrin(), ] return tvm.ir.transform.Sequential(host_pass_list) diff --git a/python/tvm/s_tir/sblock_dependence_info.py b/python/tvm/s_tir/sblock_dependence_info.py index 968533a7c999..0376d2ca4560 100644 --- a/python/tvm/s_tir/sblock_dependence_info.py +++ b/python/tvm/s_tir/sblock_dependence_info.py @@ -21,7 +21,7 @@ from tvm.ir.module import IRModule from tvm.runtime import Object -from tvm.tir import PrimFunc, SBlock +from tvm.tirx import PrimFunc, SBlock from . import _ffi_api from .sblock_scope import SBlockScope, StmtSRef diff --git a/python/tvm/s_tir/sblock_scope.py b/python/tvm/s_tir/sblock_scope.py index 4406cd4481c5..bd5bb369faa7 100644 --- a/python/tvm/s_tir/sblock_scope.py +++ b/python/tvm/s_tir/sblock_scope.py @@ -22,7 +22,7 @@ from tvm_ffi import register_object from tvm.runtime import Object -from tvm.tir import For, SBlock +from tvm.tirx import For, SBlock from . import _ffi_api diff --git a/python/tvm/s_tir/schedule/analysis.py b/python/tvm/s_tir/schedule/analysis.py index 0b4da193504d..11d083bd2d97 100644 --- a/python/tvm/s_tir/schedule/analysis.py +++ b/python/tvm/s_tir/schedule/analysis.py @@ -19,10 +19,10 @@ import tvm_ffi from tvm.runtime import Object -from tvm.tir.buffer import Buffer -from tvm.tir.expr import PrimExpr -from tvm.tir.function import IndexMap, PrimFunc -from tvm.tir.stmt import For +from tvm.tirx.buffer import Buffer +from tvm.tirx.expr import PrimExpr +from tvm.tirx.function import IndexMap, PrimFunc +from tvm.tirx.stmt import For from . import _ffi_api from .schedule import SBlockRV, Schedule diff --git a/python/tvm/s_tir/schedule/schedule.py b/python/tvm/s_tir/schedule/schedule.py index 5b687b819abe..0433089d2dca 100644 --- a/python/tvm/s_tir/schedule/schedule.py +++ b/python/tvm/s_tir/schedule/schedule.py @@ -25,8 +25,8 @@ from tvm.error import TVMError, register_error from tvm.ir import GlobalVar, IRModule, PrimExpr from tvm.runtime import Object -from tvm.tir import Buffer, FloatImm, For, IntImm, PrimFunc, SBlock -from tvm.tir.function import IndexMap +from tvm.tirx import Buffer, FloatImm, For, IntImm, PrimFunc, SBlock +from tvm.tirx.function import IndexMap from . import _ffi_api from ._type_checker import type_checked @@ -2434,12 +2434,12 @@ def decompose_reduction(self, block: SBlockRV | str, loop: LoopRV) -> SBlockRV: @T.prim_func def before_decompose(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - for i, j, k in tir.grid(128, 128, 128): - with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: - with tir.init(): + A = tirx.match_buffer(a, [128, 128]) + B = tirx.match_buffer(b, [128, 128]) + C = tirx.match_buffer(c, [128, 128]) + for i, j, k in tirx.grid(128, 128, 128): + with tirx.block([128, 128, tirx.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with tirx.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @@ -2459,15 +2459,15 @@ def before_decompose(a: ty.handle, b: ty.handle, c: ty.handle) -> None: @T.prim_func def after_decompose(a: ty.handle, b: ty.handle, c: ty.handle) -> None: - A = tir.match_buffer(a, [128, 128]) - B = tir.match_buffer(b, [128, 128]) - C = tir.match_buffer(c, [128, 128]) - for i in tir.serial(128): - for j in tir.serial(128): - with tir.block([128, 128]) as [vi, vj]: + A = tirx.match_buffer(a, [128, 128]) + B = tirx.match_buffer(b, [128, 128]) + C = tirx.match_buffer(c, [128, 128]) + for i in tirx.serial(128): + for j in tirx.serial(128): + with tirx.block([128, 128]) as [vi, vj]: C[vi, vj] = 0.0 - for i, j, k in tir.grid(128, 128, 128): - with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + for i, j, k in tirx.grid(128, 128, 128): + with tirx.block([128, 128, tirx.reduce_axis(0, 128)], "C") as [vi, vj, vk]: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] """ @@ -3027,7 +3027,7 @@ def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None: ) ) - tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin) + tirx.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin) Create the schedule and do tensorize: diff --git a/python/tvm/s_tir/schedule/state.py b/python/tvm/s_tir/schedule/state.py index 19d27c06788e..6702930faa1b 100644 --- a/python/tvm/s_tir/schedule/state.py +++ b/python/tvm/s_tir/schedule/state.py @@ -24,7 +24,7 @@ from tvm.ir import IRModule from tvm.runtime import Object -from tvm.tir import For, PrimFunc, SBlock, SBlockRealize +from tvm.tirx import For, PrimFunc, SBlock, SBlockRealize from ..sblock_scope import SBlockScope, StmtSRef from . import _ffi_api diff --git a/python/tvm/s_tir/schedule/testing.py b/python/tvm/s_tir/schedule/testing.py index 7fe0d55087ac..56d4b5bdceb8 100644 --- a/python/tvm/s_tir/schedule/testing.py +++ b/python/tvm/s_tir/schedule/testing.py @@ -23,7 +23,7 @@ import tvm from tvm.ir import IRModule, assert_structural_equal from tvm.s_tir.schedule import Schedule, Trace -from tvm.tir import PrimFunc +from tvm.tirx import PrimFunc def assert_structural_equal_ignore_global_symbol( @@ -72,7 +72,7 @@ def verify_trace_roundtrip( The text format or formats whose round-trip behavior should be validated. If a single string, validate round-trips through """ - from tvm.script import tir as T # pylint: disable=import-outside-toplevel + from tvm.script import tirx as T # pylint: disable=import-outside-toplevel if not isinstance(text_format, str): for opt in text_format: @@ -90,7 +90,7 @@ def verify_trace_roundtrip( elif text_format == "python": py_trace = "\n".join(trace.as_python()) vars_dict = {"T": T} - vars_dict.update(tvm.tir.__dict__) + vars_dict.update(tvm.tirx.__dict__) exec(py_trace, vars_dict, {"sch": new_sch}) # pylint: disable=exec-used else: assert text_format in ("json", "python"), f"Unknown text format: {text_format}" diff --git a/python/tvm/s_tir/schedule/trace.py b/python/tvm/s_tir/schedule/trace.py index f2602eb281d2..b82df6868e8e 100644 --- a/python/tvm/s_tir/schedule/trace.py +++ b/python/tvm/s_tir/schedule/trace.py @@ -23,8 +23,8 @@ from tvm_ffi import register_object as _register_object from tvm.runtime import Object -from tvm.tir.expr import FloatImm, IntImm -from tvm.tir.function import IndexMap +from tvm.tirx.expr import FloatImm, IntImm +from tvm.tirx.function import IndexMap from ...ir import Array, Map, save_json from . import _ffi_api diff --git a/python/tvm/s_tir/tensor_intrin/arm_cpu.py b/python/tvm/s_tir/tensor_intrin/arm_cpu.py index 15c1eed0dfdf..259b75d87f80 100644 --- a/python/tvm/s_tir/tensor_intrin/arm_cpu.py +++ b/python/tvm/s_tir/tensor_intrin/arm_cpu.py @@ -18,10 +18,10 @@ # ruff: noqa: E501, F401 """Intrinsics for ARM tensorization.""" -from tvm import tir -from tvm.script import tir as T +from tvm import tirx +from tvm.script import tirx as T from tvm.script.ir_builder import IRBuilder -from tvm.script.ir_builder.tir import prim_func as build_prim_func +from tvm.script.ir_builder.tirx import prim_func as build_prim_func from tvm.target.codegen import llvm_version_major from .. import TensorIntrin @@ -167,7 +167,7 @@ def _create_ptrue_mask(dtype): """ Creates a mask that enables all lanes of a scalable vector. """ - return T.broadcast(T.bool(True), tir.get_vscale_expr(dtype)) + return T.broadcast(T.bool(True), tirx.get_vscale_expr(dtype)) def _create_active_lane_mask(tensor, relative_offsets, vertical_limit): @@ -176,7 +176,7 @@ def _create_active_lane_mask(tensor, relative_offsets, vertical_limit): Parameters ---------- - tensor : tvm.tir.Buffer + tensor : tvm.tirx.Buffer The tensor the buffer access will be performed on. relative_offsets : Tuple[PrimExpr, PrimExpr] The vertical and horizontal offsets into the accumulator tile. @@ -251,7 +251,7 @@ def get_sme_transpose_interleave_2svlx2svl_fp32_intrin(cols, rows): The SME TensorIntrin that can be used in tensorizing a schedule. """ - SVF = tir.get_vscale_expr("float32") + SVF = tirx.get_vscale_expr("float32") SVF2 = 2 * SVF @T.prim_func @@ -386,7 +386,7 @@ def get_sme_transpose_interleave_block2_2svl_fp16_intrin(): """ # pylint: enable=line-too-long - SVF = tir.get_vscale_expr("float16") + SVF = tirx.get_vscale_expr("float16") SVF2 = 2 * SVF @T.prim_func @@ -485,7 +485,7 @@ def get_transpose_interleave_intrin_name(in_dtype, out_dtype, extent_cols, exten sme_transpose_interleave_intrin_name = ( ARM_SME_2SVLx2SVL_FP32_TRANSPOSE_INTERLEAVE + f"_{extent_cols}_{extent_rows}" ) - tir.TensorIntrin.register( + tirx.TensorIntrin.register( sme_transpose_interleave_intrin_name, *get_sme_transpose_interleave_2svlx2svl_fp32_intrin(extent_cols, extent_rows), override=True, @@ -584,7 +584,7 @@ def get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(M, K, in_dtype): The SME TensorIntrin that can be used in tensorizing a schedule. """ - SVF = tir.get_vscale_expr("float32") + SVF = tirx.get_vscale_expr("float32") SVF2 = 2 * SVF fmopa_intrin = ( "llvm.aarch64.sme.mopa" if in_dtype == "float32" else "llvm.aarch64.sme.mopa.wide" @@ -629,7 +629,7 @@ def impl(): rows_per_iter = 1 if in_dtype == "float32" else 2 with T.serial(T.ceildiv(K, rows_per_iter)) as k: k_row = k * rows_per_iter - in_dtype_svf = tir.get_vscale_expr(in_dtype) + in_dtype_svf = tirx.get_vscale_expr(in_dtype) # Ideally we'd rely on predicating the loads and use the same predicate # for the outer product operation. However, support for predicated diff --git a/python/tvm/s_tir/tensor_intrin/cuda.py b/python/tvm/s_tir/tensor_intrin/cuda.py index 3e0a877f5aa0..4ef7ffe20c12 100644 --- a/python/tvm/s_tir/tensor_intrin/cuda.py +++ b/python/tvm/s_tir/tensor_intrin/cuda.py @@ -23,9 +23,9 @@ from tvm_ffi import register_global_func from tvm.runtime import convert -from tvm.script import tir as T -from tvm.tir import Cast, IntImm, TensorIntrin -from tvm.tir.function import PrimFunc +from tvm.script import tirx as T +from tvm.tirx import Cast, IntImm, TensorIntrin +from tvm.tirx.function import PrimFunc def shared_16x16_to_ldmatrix_32x8_layout(i, j): @@ -49,7 +49,7 @@ def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): return row, col -@register_global_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout") +@register_global_func("tirx.index_map.shared_16x16_to_ldmatrix_32x8_layout") def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind): i, j = ind[0], ind[1] thread_id, local_id = shared_16x16_to_ldmatrix_32x8_layout(i, j) @@ -1749,7 +1749,7 @@ def mma_store_desc(a: T.handle, c: T.handle) -> None: ) -@register_global_func("tir.index_map_m16n8k8.matrixC") +@register_global_func("tirx.index_map_m16n8k8.matrixC") def index_map_m16n8k8_matrixC(ind): i, j = ind[0], ind[1] return convert([(i // 8) // 2, j // 8, (i // 8) % 2, (j % 8) % 2]) diff --git a/python/tvm/s_tir/tensor_intrin/dot_product_common.py b/python/tvm/s_tir/tensor_intrin/dot_product_common.py index 92d4b522e895..1cfae11b6f1f 100644 --- a/python/tvm/s_tir/tensor_intrin/dot_product_common.py +++ b/python/tvm/s_tir/tensor_intrin/dot_product_common.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name,missing-function-docstring """Dot product related intrinsics.""" -from tvm.script import tir as T +from tvm.script import tirx as T from .. import TensorIntrin diff --git a/python/tvm/s_tir/tensor_intrin/hexagon.py b/python/tvm/s_tir/tensor_intrin/hexagon.py index 5df605f1c72c..cbf684ee8aac 100644 --- a/python/tvm/s_tir/tensor_intrin/hexagon.py +++ b/python/tvm/s_tir/tensor_intrin/hexagon.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name,missing-function-docstring """Intrinsics for Hexagon tensorization.""" -from tvm.script import tir as T +from tvm.script import tirx as T from .. import TensorIntrin diff --git a/python/tvm/s_tir/tensor_intrin/metal.py b/python/tvm/s_tir/tensor_intrin/metal.py index 45c482b39e8e..d14fdb3b1540 100644 --- a/python/tvm/s_tir/tensor_intrin/metal.py +++ b/python/tvm/s_tir/tensor_intrin/metal.py @@ -19,8 +19,8 @@ from typing import Literal -from tvm.script import tir as T -from tvm.tir import Buffer, PrimExpr, PrimFunc, TensorIntrin +from tvm.script import tirx as T +from tvm.tirx import Buffer, PrimExpr, PrimFunc, TensorIntrin ######## simdgroup matrix intrinsics ######## diff --git a/python/tvm/s_tir/tensor_intrin/riscv_cpu.py b/python/tvm/s_tir/tensor_intrin/riscv_cpu.py index 01389875b03f..e5590fd63067 100644 --- a/python/tvm/s_tir/tensor_intrin/riscv_cpu.py +++ b/python/tvm/s_tir/tensor_intrin/riscv_cpu.py @@ -23,7 +23,7 @@ import tvm_ffi from tvm.runtime import DataType -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target.codegen import Target, llvm_get_vector_width, target_has_features from .. import TensorIntrin @@ -169,7 +169,7 @@ def rvv_vec_dot_prod_impl( return rvv_vec_dot_prod_desc, rvv_vec_dot_prod_impl -@tvm_ffi.register_global_func("tir.tensor_intrin.register_rvv_isa_intrinsics") +@tvm_ffi.register_global_func("tirx.tensor_intrin.register_rvv_isa_intrinsics") def register_rvv_isa_intrinsics(target: Target, inventory_only=False) -> dict(): """Register RISCV V (vector) intrinsics [x] Implementation follows version 1.0 vector specifications: diff --git a/python/tvm/s_tir/tensor_intrin/rocm.py b/python/tvm/s_tir/tensor_intrin/rocm.py index e64a32d73e52..e8a8bd504696 100644 --- a/python/tvm/s_tir/tensor_intrin/rocm.py +++ b/python/tvm/s_tir/tensor_intrin/rocm.py @@ -18,8 +18,8 @@ """Intrinsics for AMDGPU tensorization.""" from tvm.runtime import convert -from tvm.script import tir as T -from tvm.tir.expr import Cast, IntImm +from tvm.script import tirx as T +from tvm.tirx.expr import Cast, IntImm from .. import TensorIntrin from .dot_product_common import get_dp4a_intrin @@ -363,8 +363,8 @@ def mfma_sync_impl_integer(a: T.handle, b: T.handle, c: T.handle) -> None: C[tx, 0:local_size_out] = T.call_llvm_pure_intrin( T.llvm_lookup_intrinsic_id(mfma_intrin), - T.call_intrin("int32", "tir.reinterpret", A[tx, 0:local_size]), - T.call_intrin("int32", "tir.reinterpret", A[tx, 0:local_size]), + T.call_intrin("int32", "tirx.reinterpret", A[tx, 0:local_size]), + T.call_intrin("int32", "tirx.reinterpret", A[tx, 0:local_size]), C[tx, 0:local_size_out], T.int32(0), T.int32(0), diff --git a/python/tvm/s_tir/tensor_intrin/x86.py b/python/tvm/s_tir/tensor_intrin/x86.py index 3c6af2ad4036..4e8af37e1007 100644 --- a/python/tvm/s_tir/tensor_intrin/x86.py +++ b/python/tvm/s_tir/tensor_intrin/x86.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name,missing-function-docstring """Intrinsics for x86 tensorization.""" -from tvm.script import tir as T +from tvm.script import tirx as T from .. import TensorIntrin diff --git a/python/tvm/s_tir/transform/__init__.py b/python/tvm/s_tir/transform/__init__.py index f11488d29fc9..19bc7281f44b 100644 --- a/python/tvm/s_tir/transform/__init__.py +++ b/python/tvm/s_tir/transform/__init__.py @@ -19,4 +19,4 @@ # pylint: disable=wildcard-import, invalid-name from .transform import * -from ...tir.transform.transform import HoistedConditionals, HoistedLetBindings +from ...tirx.transform.transform import HoistedConditionals, HoistedLetBindings diff --git a/python/tvm/s_tir/transform/_ffi_api.py b/python/tvm/s_tir/transform/_ffi_api.py index b6b6c4b49e07..ed47161debcd 100644 --- a/python/tvm/s_tir/transform/_ffi_api.py +++ b/python/tvm/s_tir/transform/_ffi_api.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI APIs for tvm.tir.transform""" +"""FFI APIs for tvm.tirx.transform""" import tvm_ffi diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index 73000136877c..5ba1b25bf684 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -40,7 +40,7 @@ class IRBuilderFrame(_Object): .. code-block:: python - from tvm.script.ir_builder import tir as T + from tvm.script.ir_builder import tirx as T from tvm.script.ir_builder import IRBuilder with IRBuilder() as builder: @@ -53,7 +53,7 @@ class IRBuilderFrame(_Object): .. code-block:: python - from tvm.script.ir_builder import tir as T + from tvm.script.ir_builder import tirx as T from tvm.script.ir_builder import IRBuilder with IRBuilder() as builder: @@ -98,7 +98,7 @@ class IRBuilder(_Object): .. code-block:: python - from tvm.script.ir_builder import tir as T + from tvm.script.ir_builder import tirx as T from tvm.script.ir_builder import IRBuilder with IRBuilder() as builder: @@ -106,7 +106,7 @@ class IRBuilder(_Object): # to `builder`'s stack of frames buffer = T.match_buffer(...) - return builder.get() # returns the constructed IR, i.e. tir.PrimFunc + return builder.get() # returns the constructed IR, i.e. tirx.PrimFunc """ def __init__(self) -> None: diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index a40517992e83..51aae5350e23 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -584,7 +584,7 @@ def emit_te(func: Callable, *args: Any, **kwargs: Any) -> Call: Returns ------- call : Call - A newly created call that calls into a tir function. + A newly created call that calls into a tirx function. """ primfunc_name_hint = kwargs.pop("primfunc_name_hint", None) tir_func, call_args, out_sinfo, tir_vars = gen_call_tir_inputs(func, *args, **kwargs) diff --git a/python/tvm/script/ir_builder/tir/__init__.py b/python/tvm/script/ir_builder/tirx/__init__.py similarity index 96% rename from python/tvm/script/ir_builder/tir/__init__.py rename to python/tvm/script/ir_builder/tirx/__init__.py index 25978b5434d6..81da83c022af 100644 --- a/python/tvm/script/ir_builder/tir/__init__.py +++ b/python/tvm/script/ir_builder/tirx/__init__.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Package tvm.script.ir_builder.tir""" +"""Package tvm.script.ir_builder.tirx""" from .ir import * # pylint: disable=wildcard-import,redefined-builtin from .ir import boolean as bool # pylint: disable=redefined-builtin diff --git a/python/tvm/script/ir_builder/tir/_ffi_api.py b/python/tvm/script/ir_builder/tirx/_ffi_api.py similarity index 89% rename from python/tvm/script/ir_builder/tir/_ffi_api.py rename to python/tvm/script/ir_builder/tirx/_ffi_api.py index 746cfaf315e0..12353f63c3d6 100644 --- a/python/tvm/script/ir_builder/tir/_ffi_api.py +++ b/python/tvm/script/ir_builder/tirx/_ffi_api.py @@ -18,4 +18,4 @@ import tvm_ffi -tvm_ffi.init_ffi_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access +tvm_ffi.init_ffi_api("script.ir_builder.tirx", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/tir/external_kernel.py b/python/tvm/script/ir_builder/tirx/external_kernel.py similarity index 97% rename from python/tvm/script/ir_builder/tir/external_kernel.py rename to python/tvm/script/ir_builder/tirx/external_kernel.py index 63d30f809e73..ebe432d69a4c 100644 --- a/python/tvm/script/ir_builder/tir/external_kernel.py +++ b/python/tvm/script/ir_builder/tirx/external_kernel.py @@ -25,7 +25,7 @@ from typing import Any from tvm import __version__ as tvm_version -from tvm import tir +from tvm import tirx from tvm.contrib import nvcc from tvm.runtime import Module, const, load_module @@ -113,7 +113,7 @@ def __init__(self, source_code: str): def compile_to_device_module( # pylint: disable=arguments-differ self, - grid: list[list[int | tir.PrimExpr]], + grid: list[list[int | tirx.PrimExpr]], *args: list[Any], **kwargs: dict[str, Any], ) -> tuple[str, Module, list[Any]]: @@ -182,7 +182,7 @@ def compile_to_device_module( # pylint: disable=arguments-differ def call_kernel( kernel, - launch_args: list[int | tir.PrimExpr | list[int | tir.PrimExpr]], + launch_args: list[int | tirx.PrimExpr | list[int | tirx.PrimExpr]], *args: list[Any], **kwargs: dict[str, Any], ): @@ -194,11 +194,11 @@ def call_kernel( kernel : Any The external kernel to call. - launch_args : List[Union[int, tir.PrimExpr, List[Union[int, tir.PrimExpr]]]] + launch_args : List[Union[int, tirx.PrimExpr, List[Union[int, tirx.PrimExpr]]]] The launch arguments. A list of integers for grid size, block size, and shared memory size. The actual requirements depend on the kernel. - args : List[tir.PrimExpr] + args : List[tirx.PrimExpr] The arguments to pass to the kernel. kwargs : Dict[str, Any] diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tirx/frame.py similarity index 69% rename from python/tvm/script/ir_builder/tir/frame.py rename to python/tvm/script/ir_builder/tirx/frame.py index 5c8b4255c47d..aeced570ba94 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tirx/frame.py @@ -18,59 +18,59 @@ from tvm_ffi import register_object as _register_object -from tvm.tir import Var +from tvm.tirx import Var from ..base import IRBuilderFrame -@_register_object("script.ir_builder.tir.TIRFrame") +@_register_object("script.ir_builder.tirx.TIRFrame") class TIRFrame(IRBuilderFrame): ... -@_register_object("script.ir_builder.tir.PrimFuncFrame") +@_register_object("script.ir_builder.tirx.PrimFuncFrame") class PrimFuncFrame(TIRFrame): ... -@_register_object("script.ir_builder.tir.SSBlockFrame") +@_register_object("script.ir_builder.tirx.SSBlockFrame") class SBlockFrame(TIRFrame): ... -@_register_object("script.ir_builder.tir.SBlockInitFrame") +@_register_object("script.ir_builder.tirx.SBlockInitFrame") class BlockInitFrame(TIRFrame): ... -@_register_object("script.ir_builder.tir.ForFrame") +@_register_object("script.ir_builder.tirx.ForFrame") class ForFrame(TIRFrame): def __enter__(self) -> Var | list[Var]: # type: ignore[override] super().__enter__() return self.vars if len(self.vars) > 1 else self.vars[0] -@_register_object("script.ir_builder.tir.AssertFrame") +@_register_object("script.ir_builder.tirx.AssertFrame") class AssertFrame(TIRFrame): ... -@_register_object("script.ir_builder.tir.AttrFrame") +@_register_object("script.ir_builder.tirx.AttrFrame") class AttrFrame(TIRFrame): ... -@_register_object("script.ir_builder.tir.WhileFrame") +@_register_object("script.ir_builder.tirx.WhileFrame") class WhileFrame(TIRFrame): ... -@_register_object("script.ir_builder.tir.IfFrame") +@_register_object("script.ir_builder.tirx.IfFrame") class IfFrame(TIRFrame): ... -@_register_object("script.ir_builder.tir.ThenFrame") +@_register_object("script.ir_builder.tirx.ThenFrame") class ThenFrame(TIRFrame): ... -@_register_object("script.ir_builder.tir.ElseFrame") +@_register_object("script.ir_builder.tirx.ElseFrame") class ElseFrame(TIRFrame): ... -@_register_object("script.ir_builder.tir.LaunchThreadFrame") +@_register_object("script.ir_builder.tirx.LaunchThreadFrame") class LaunchThreadFrame(TIRFrame): def __enter__(self) -> Var: super().__enter__() diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tirx/ir.py similarity index 98% rename from python/tvm/script/ir_builder/tir/ir.py rename to python/tvm/script/ir_builder/tirx/ir.py index 92c7d7d6e95e..ce88c563156f 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tirx/ir.py @@ -30,7 +30,7 @@ # isort: on -from tvm import ir, tir +from tvm import ir, tirx from tvm.ir import Type from tvm.ir.base import deprecated from tvm.runtime import String, convert @@ -38,11 +38,11 @@ # pylint: disable=unused-import from tvm.target.codegen import llvm_lookup_intrinsic_id -from tvm.tir import Buffer, BufferRegion, IndexMap, PrimExpr, type_annotation -from tvm.tir import op as _tir_op +from tvm.tirx import Buffer, BufferRegion, IndexMap, PrimExpr, type_annotation +from tvm.tirx import op as _tir_op -# import tir.expr for direct ir construction to pass structural_equal comparison -from tvm.tir.expr import ( +# import tirx.expr for direct ir construction to pass structural_equal comparison +from tvm.tirx.expr import ( EQ, GE, GT, @@ -79,7 +79,7 @@ Sub, Var, ) -from tvm.tir.generic import cast +from tvm.tirx.generic import cast from . import _ffi_api, frame from .external_kernel import call_kernel @@ -589,7 +589,7 @@ def _as_range(dom: ir.Range | list[PrimExpr]) -> ir.Range: from tvm.arith import Analyzer # pylint: disable=import-outside-toplevel extent = Analyzer().simplify(dom[1] - dom[0]) - if isinstance(extent, tir.IntImm): + if isinstance(extent, tirx.IntImm): return ir.Range.from_min_extent(dom[0], extent) return ir.Range(dom[0], dom[1]) if hasattr(dom, "dtype"): @@ -1021,7 +1021,7 @@ def Let( # pylint: disable=invalid-name """Create a Let expression binding""" assert len(where) == 1, "T.Let only allows `where` to have exactly one element" var, value = next(iter(where.items())) # pylint: disable=redefined-outer-name - return tir.Let(var, value, expr) + return tirx.Let(var, value, expr) def let( @@ -1050,7 +1050,7 @@ def let( @deprecated("T.let", "T.Let") def let_expr(v: Var, value: PrimExpr, body: PrimExpr) -> PrimExpr: - return tir.Let(v, value, body) + return tirx.Let(v, value, body) @deprecated("T.let", "T.bind") def let_stmt(v: Var, value: PrimExpr) -> Var: @@ -1245,7 +1245,7 @@ def launch_thread( .. code-block:: python - from tvm.script.ir_builder import tir as T + from tvm.script.ir_builder import tirx as T brow = T.env_thread("blockIdx.y") T.launch_thread(brow, 1) @@ -1530,7 +1530,7 @@ def func( def boolean(expr: PrimExpr | None = None, is_size_var: bool = False) -> PrimExpr: - """Construct a new tir.Var with type boolean or cast expression to type boolean. + """Construct a new tirx.Var with type boolean or cast expression to type boolean. Parameters ---------- @@ -1543,7 +1543,7 @@ def boolean(expr: PrimExpr | None = None, is_size_var: bool = False) -> PrimExpr Returns ------- res : PrimExpr - The new tir.Var with type boolean or casted expression with type boolean. + The new tirx.Var with type boolean or casted expression with type boolean. """ return _ffi_api.Boolean(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member @@ -1570,7 +1570,7 @@ def handle( Returns ------- res : PrimExpr - The new tir.Var with type handle or casted expression with type handle. + The new tirx.Var with type handle or casted expression with type handle. """ if dtype == "tensormap": return _ffi_api.TensormapHandle() # type: ignore[attr-defined] # pylint: disable=no-member @@ -1586,7 +1586,7 @@ def handle( def void(expr: PrimExpr | None = None, *, is_size_var: bool = False) -> PrimExpr: - """Construct a new tir.Var with type void or cast expression to type void. + """Construct a new tirx.Var with type void or cast expression to type void. Parameters ---------- @@ -1596,14 +1596,14 @@ def void(expr: PrimExpr | None = None, *, is_size_var: bool = False) -> PrimExpr Returns ------- res : PrimExpr - The new tir.Var with type void or casted expression with type void. + The new tirx.Var with type void or casted expression with type void. """ return _ffi_api.Void(expr, is_size_var) # type: ignore[attr-defined] # pylint: disable=no-member @deprecated("T.var", "T.{dtype}") def var(dtype: str, name: str = "") -> Var: - """Construct a new tir.Var. + """Construct a new tirx.Var. Parameters ---------- @@ -1616,7 +1616,7 @@ def var(dtype: str, name: str = "") -> Var: Returns ------- res : Var - The result tir.Var. + The result tirx.Var. """ return Var(name, dtype) # pylint: disable=no-member @@ -2279,7 +2279,7 @@ def wrapped(*args, **kwargs): "broadcast", "ramp", "cast", - # tvm.tir.expr + # tvm.tirx.expr "Var", "SizeVar", "Reduce", diff --git a/python/tvm/script/ir_builder/tir/triton.py b/python/tvm/script/ir_builder/tirx/triton.py similarity index 95% rename from python/tvm/script/ir_builder/tir/triton.py rename to python/tvm/script/ir_builder/tirx/triton.py index 75184d291645..5c5d2b55673a 100644 --- a/python/tvm/script/ir_builder/tir/triton.py +++ b/python/tvm/script/ir_builder/tirx/triton.py @@ -23,7 +23,7 @@ from packaging import version from triton.runtime.jit import type_canonicalisation_dict -from tvm import tir +from tvm import tirx from tvm.runtime import Module from tvm.topi.utils import get_const_int @@ -45,7 +45,7 @@ def __init__(self, func): def compile_to_device_module( self, - launch_args: list[int | tir.PrimExpr], + launch_args: list[int | tirx.PrimExpr], *args: list[Any], **kwargs: dict[str, Any], ) -> tuple[str, Module, list[Any]]: @@ -82,7 +82,7 @@ def compile_to_device_module( kernel_arg_types = [arg.dtype for arg in kernel_args] if triton_kernel.metadata.shared > 0: # Add shared memory size to the launch arguments - launch_param_tags.append("tir.use_dyn_shared_memory") + launch_param_tags.append("tirx.use_dyn_shared_memory") launch_args.append(triton_kernel.metadata.shared) kernel_module = self._create_cuda_module( @@ -93,7 +93,7 @@ def compile_to_device_module( def _generate_triton_kernel( self, func, *args, **kwargs - ) -> tuple["triton.compiler.CompiledKernel", list[tir.PrimExpr]]: + ) -> tuple["triton.compiler.CompiledKernel", list[tirx.PrimExpr]]: """Deduce the kernel signature and generate the Triton kernel""" kernel_params = func.params @@ -112,7 +112,7 @@ def _generate_triton_kernel( kernel_args.append(arg) continue if arg.dtype == "handle": - assert isinstance(arg, tir.Var) + assert isinstance(arg, tirx.Var) elem_type = arg.type_annotation.element_type.dtype pointer_type = "*" + type_canonicalisation_dict[elem_type] signature[kernel_params[i].name] = pointer_type diff --git a/python/tvm/script/ir_builder/tir/utils.py b/python/tvm/script/ir_builder/tirx/utils.py similarity index 94% rename from python/tvm/script/ir_builder/tir/utils.py rename to python/tvm/script/ir_builder/tirx/utils.py index e88fcbebb1d7..006f9a2ecc2b 100644 --- a/python/tvm/script/ir_builder/tir/utils.py +++ b/python/tvm/script/ir_builder/tirx/utils.py @@ -18,8 +18,8 @@ import contextlib -from tvm import tir -from tvm.tir import Buffer +from tvm import tirx +from tvm.tirx import Buffer from . import frame from . import ir as T @@ -108,7 +108,7 @@ def seq_scope(): T.evaluate(j) result = ib.get() """ - return T.attr(tir.const(0, "int32"), "pragma_scope", tir.StringImm("seq")) + return T.attr(tirx.const(0, "int32"), "pragma_scope", tirx.StringImm("seq")) def _unravel_index(index, shape): @@ -153,7 +153,7 @@ class _BufferProxy: -------- .. code-block:: python - buf = tvm.tir.decl_buffer([2, 3], "float32") + buf = tvm.tirx.decl_buffer([2, 3], "float32") ptr = buffer_proxy(buf) # Read with flat index (converted to [0, 1]) @@ -185,7 +185,7 @@ def _normalize_index(self, index): def __getitem__(self, index): index = self._normalize_index(index) - return tir.BufferLoad(self._buffer, index) + return tirx.BufferLoad(self._buffer, index) def __setitem__(self, index, value): index = self._normalize_index(index) @@ -212,9 +212,9 @@ def buffer_proxy(buf: Buffer) -> _BufferProxy: -------- .. code-block:: python - from tvm.script.ir_builder.tir.utils import buffer_proxy + from tvm.script.ir_builder.tirx.utils import buffer_proxy - buf = tvm.tir.decl_buffer([2, 3], "float32") + buf = tvm.tirx.decl_buffer([2, 3], "float32") ptr = buffer_proxy(buf) # Flat indexing (index 1 -> indices [0, 1]) diff --git a/python/tvm/script/parser/core/dispatch.py b/python/tvm/script/parser/core/dispatch.py index a8dcf0fe7dbc..c05110fcdf7f 100644 --- a/python/tvm/script/parser/core/dispatch.py +++ b/python/tvm/script/parser/core/dispatch.py @@ -96,7 +96,7 @@ def register_op(operand_type: type, op_node_type: AST, operand_index: int): Parameters ---------- operand_type : Type - The type of operands, e.g., tir.PrimExpr, tir.IterVar. + The type of operands, e.g., tirx.PrimExpr, tirx.IterVar. op_node_type : AST The doc AST operator node type, e.g., doc.Add, doc.Eq. @@ -135,7 +135,7 @@ def get_op( Parameters ---------- operand_type : Type - The type of operands, e.g., tir.PrimExpr, tir.IterVar. + The type of operands, e.g., tirx.PrimExpr, tirx.IterVar. op_node_type : AST The doc AST operator node type, e.g., doc.Add, doc.Eq. diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index e1d9260c805d..7c09ced3a715 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -41,15 +41,15 @@ def _default_globals() -> dict[str, Any]: from tvm.script.parser import ( ir, # pylint: disable=import-outside-toplevel relax, # pylint: disable=import-outside-toplevel - tir, # pylint: disable=import-outside-toplevel + tirx, # pylint: disable=import-outside-toplevel ) extra_vars = { "tvm": tvm, "I": ir, "ir": ir, - "T": tir, - "tir": tir, + "T": tirx, + "tirx": tirx, "R": relax, "relax": relax, } @@ -126,7 +126,7 @@ def parse( parser.report_error(source_ast, err=WELL_FORMED_ERROR_MESSAGE) try: - tvm.tir.analysis.verify_well_formed(check_ret) + tvm.tirx.analysis.verify_well_formed(check_ret) except Exception as err: # pylint: disable=broad-exception-caught parser.report_error( source_ast, diff --git a/python/tvm/script/parser/core/evaluator.py b/python/tvm/script/parser/core/evaluator.py index a8ce4ce6c1a1..3b940fa6d227 100644 --- a/python/tvm/script/parser/core/evaluator.py +++ b/python/tvm/script/parser/core/evaluator.py @@ -390,8 +390,8 @@ def _eval_if_exp(self, fields: dict[str, Any]) -> Any: orelse = self._eval_expr(fields["orelse"]) if isinstance(test, bool): return body if test else orelse - elif isinstance(test, tvm.tir.PrimExpr) and test.dtype == "bool": - return tvm.tir.op.if_then_else(test, body, orelse) + elif isinstance(test, tvm.tirx.PrimExpr) and test.dtype == "bool": + return tvm.tirx.op.if_then_else(test, body, orelse) else: raise TypeError(f"Expected Python bool or TIR bool, but got {type(test)}") diff --git a/python/tvm/script/parser/relax/dist.py b/python/tvm/script/parser/relax/dist.py index 2aa0397b96a2..cc3da4921546 100644 --- a/python/tvm/script/parser/relax/dist.py +++ b/python/tvm/script/parser/relax/dist.py @@ -32,7 +32,7 @@ redistribute, redistribute_replica_to_shard, ) -from tvm.tir import PrimExpr +from tvm.tirx import PrimExpr from .entry import StructInfoProxy, TensorProxy diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index 324a11daff45..a7e6181412e9 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -35,7 +35,7 @@ ) from tvm.relax.expr import Var from tvm.runtime import ObjectConvertible -from tvm.tir import PrimExpr +from tvm.tirx import PrimExpr from ...ir_builder import relax as R from .._core import doc, parse, utils diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index bb1a867331bf..8800ea156a18 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -20,7 +20,7 @@ import numbers from typing import Any -from tvm import relax, tir +from tvm import relax, tirx from tvm.ir import GlobalVar, structural_equal from tvm.relax import Expr, StructInfo from tvm.relax.utils import convert_to_expr @@ -47,7 +47,7 @@ def bind_assign_value( ) -> Any: var_table = self.var_table.get() - if isinstance(value, tir.Var): + if isinstance(value, tirx.Var): if value.name and var_name != value.name: self.report_error( node, @@ -56,7 +56,7 @@ def bind_assign_value( ) if var_name in var_table: prev_value = var_table[var_name] - if not isinstance(prev_value, tir.Var): + if not isinstance(prev_value, tirx.Var): self.report_error( node, "Cannot redefine a non-TIR-variable object to a TIR variable. Please " @@ -152,8 +152,8 @@ def is_recursive(node: doc.FunctionDef) -> bool: def collect_symbolic_var_from_prelude( - self: Parser, node: doc.FunctionDef, symbolic_vars: dict[str, tir.Var] -) -> dict[str, tir.Var]: + self: Parser, node: doc.FunctionDef, symbolic_vars: dict[str, tirx.Var] +) -> dict[str, tirx.Var]: prelude_vars = {} for stmt in node.body: if isinstance(stmt, doc.Assign) and all( @@ -184,7 +184,7 @@ def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> Non for var_name in param_sinfo_proxy.get_symbolic_vars(): if var_name not in symbolic_vars: - symbolic_vars[var_name] = tir.Var(var_name, "int64") + symbolic_vars[var_name] = tirx.Var(var_name, "int64") # Update symbolic vars based on symbolic_vars = collect_symbolic_var_from_prelude(self, node, symbolic_vars) @@ -240,7 +240,7 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: self.report_error(stmt, "Function must be decorated") dec = self.eval_expr(stmt.decorator_list[-1]) # inline prim_func was found - if dec.dispatch_token == "tir": + if dec.dispatch_token == "tirx": self.report_error(stmt, "inline prim_func is disallowed in Relax IR") self.visit_body(node.body) diff --git a/python/tvm/script/parser/tir/__init__.py b/python/tvm/script/parser/tirx/__init__.py similarity index 90% rename from python/tvm/script/parser/tir/__init__.py rename to python/tvm/script/parser/tirx/__init__.py index 559998970f09..929cf96635dc 100644 --- a/python/tvm/script/parser/tir/__init__.py +++ b/python/tvm/script/parser/tirx/__init__.py @@ -15,12 +15,12 @@ # specific language governing permissions and limitations # under the License. # ruff: noqa: RUF005 -"""The tir parser""" +"""The tirx parser""" from typing import TYPE_CHECKING -from ...ir_builder.tir import * # pylint: disable=redefined-builtin -from ...ir_builder.tir import ir as _tir +from ...ir_builder.tirx import * # pylint: disable=redefined-builtin +from ...ir_builder.tirx import ir as _tir from . import operation as _operation from . import parser as _parser from .entry import Buffer, Ptr diff --git a/python/tvm/script/parser/tir/entry.py b/python/tvm/script/parser/tirx/entry.py similarity index 93% rename from python/tvm/script/parser/tir/entry.py rename to python/tvm/script/parser/tirx/entry.py index d0486b0d9fbc..fdac3b0db454 100644 --- a/python/tvm/script/parser/tir/entry.py +++ b/python/tvm/script/parser/tirx/entry.py @@ -14,15 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""The entry point of TVM parser for tir.""" +"""The entry point of TVM parser for tirx.""" import inspect from collections.abc import Callable from tvm.ir.base import deprecated -from tvm.tir import Buffer, PrimFunc +from tvm.tirx import Buffer, PrimFunc -from ...ir_builder.tir import block_name_suffix_context, buffer, ptr +from ...ir_builder.tirx import block_name_suffix_context, buffer, ptr from .._core import parse, scan_macro, utils from ..core.parser import Parser, ScriptMacro @@ -30,7 +30,7 @@ def prim_func( func: Callable | None = None, private: bool = False, check_well_formed=True ) -> PrimFunc | Callable: - """The parsing method for tir prim func, by using `@prim_func` as decorator. + """The parsing method for tirx prim func, by using `@prim_func` as decorator. Parameters ---------- @@ -49,7 +49,7 @@ def prim_func( Returns ------- res : Union[PrimFunc, Callable] - The parsed tir prim func. + The parsed tirx prim func. """ # pylint: disable=unused-argument # (private will be used in the parser, but not immediately) @@ -75,11 +75,11 @@ def decorator_wrapper(func): else: # if there is an optional arg given, return a new decorator # that will then be invoked - setattr(decorator_wrapper, "dispatch_token", "tir") + setattr(decorator_wrapper, "dispatch_token", "tirx") return decorator_wrapper -setattr(prim_func, "dispatch_token", "tir") +setattr(prim_func, "dispatch_token", "tirx") # Semantics of TIR macros: @@ -128,7 +128,7 @@ def macro(*args, hygienic: bool = True) -> Callable: Example: ``` import tvm - from tvm.script import tir as T + from tvm.script import tirx as T x_value = 128 @@ -173,7 +173,7 @@ def wrapper(*args, **kwargs): class BufferProxy: - """Buffer proxy class for constructing tir buffer.""" + """Buffer proxy class for constructing tirx buffer.""" def __or__(self, other): """Support ``T.Buffer | None`` union syntax in annotations.""" @@ -219,7 +219,7 @@ def __getitem__(self, keys) -> Buffer: class PtrProxy: - """Ptr proxy class for constructing tir pointer.""" + """Ptr proxy class for constructing tirx pointer.""" def __or__(self, other): """Support union syntax in annotations.""" diff --git a/python/tvm/script/parser/tir/operation.py b/python/tvm/script/parser/tirx/operation.py similarity index 86% rename from python/tvm/script/parser/tir/operation.py rename to python/tvm/script/parser/tirx/operation.py index 2bb1e44da72b..fd528ba35349 100644 --- a/python/tvm/script/parser/tir/operation.py +++ b/python/tvm/script/parser/tirx/operation.py @@ -14,12 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""The tir expression operation registration""" +"""The tirx expression operation registration""" -from tvm import tir +from tvm import tirx from tvm.runtime import DataType, DataTypeCode -from tvm.tir import IntImm -from tvm.tir.expr import FloatImm +from tvm.tirx import IntImm +from tvm.tirx.expr import FloatImm from .._core import OpMethod, doc, register_op @@ -35,7 +35,7 @@ def _and(a, b): if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1: return a & b else: - return tir.And(a, b) + return tirx.And(a, b) def _or(a, b): if isinstance(a, bool): @@ -45,7 +45,7 @@ def _or(a, b): if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1: return a | b else: - return tir.Or(a, b) + return tirx.Or(a, b) def _get_type_str(dtype: str): if DataType(dtype).lanes == 1: @@ -74,7 +74,7 @@ def _auto_broadcast(a, b, op): else: a = FloatImm("float32", a) - assert isinstance(a, tir.PrimExpr), "Operand should be a PrimExpr." + assert isinstance(a, tirx.PrimExpr), "Operand should be a PrimExpr." if isinstance(b, int): if ( DataType(a.dtype).type_code == DataTypeCode.INT @@ -90,31 +90,31 @@ def _auto_broadcast(a, b, op): if DataType(a.dtype).lanes == DataType(b.dtype).lanes: return op(a, b) elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes: - broadcast_a = tir.Broadcast(a, DataType(b.dtype).lanes) + broadcast_a = tirx.Broadcast(a, DataType(b.dtype).lanes) return op(broadcast_a, b) elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes: - broadcast_b = tir.Broadcast(b, DataType(a.dtype).lanes) + broadcast_b = tirx.Broadcast(b, DataType(a.dtype).lanes) return op(a, broadcast_b) else: raise TypeError("do not know how to deal with it.") def _eq(a, b): - return _auto_broadcast(a, b, tir.EQ) + return _auto_broadcast(a, b, tirx.EQ) def _ne(a, b): - return _auto_broadcast(a, b, tir.NE) + return _auto_broadcast(a, b, tirx.NE) def _lt(a, b): - return _auto_broadcast(a, b, tir.LT) + return _auto_broadcast(a, b, tirx.LT) def _le(a, b): - return _auto_broadcast(a, b, tir.LE) + return _auto_broadcast(a, b, tirx.LE) def _gt(a, b): - return _auto_broadcast(a, b, tir.GT) + return _auto_broadcast(a, b, tirx.GT) def _ge(a, b): - return _auto_broadcast(a, b, tir.GE) + return _auto_broadcast(a, b, tirx.GE) def r(op: type, i: int, m: OpMethod): # pylint: disable=invalid-name register_op(ty, op, i)(m) @@ -151,10 +151,10 @@ def r(op: type, i: int, m: OpMethod): # pylint: disable=invalid-name for i in [0]: # Case 4. unaryop # doc.Invert <-- is overloaded - r(doc.Not, i, tir.Not) + r(doc.Not, i, tirx.Not) # doc.UAdd <-- is overloaded # doc.USub <-- is overloaded -_register_expr_op(tir.PrimExpr) -_register_expr_op(tir.IterVar) +_register_expr_op(tirx.PrimExpr) +_register_expr_op(tirx.IterVar) diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tirx/parser.py similarity index 89% rename from python/tvm/script/parser/tir/parser.py rename to python/tvm/script/parser/tirx/parser.py index b4d6f88edd00..888113c7d4cb 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tirx/parser.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""The base parser for tir""" +"""The base parser for tirx""" import contextlib from functools import partial @@ -22,10 +22,10 @@ import tvm from tvm.ir import GlobalVar, PrimType -from tvm.tir import Buffer, IterVar, PrimExpr, Var +from tvm.tirx import Buffer, IterVar, PrimExpr, Var from ...ir_builder import ir as I -from ...ir_builder import tir as T +from ...ir_builder import tirx as T from ...ir_builder.base import IRBuilder from ...ir_builder.base import IRBuilderFrame as Frame from .._core import Parser, dispatch, doc @@ -174,7 +174,7 @@ def range_sugar( ) -> T.frame.ForFrame: """The sugar for python range builtin.""" - # Since `tir.For` do not support reversed iteration semantic, + # Since `tirx.For` do not support reversed iteration semantic, # the step must be checked to be positive integer when use range sugar if step is not None: try: @@ -187,9 +187,9 @@ def range_sugar( return T.serial(start, stop, annotations=annotations, step=step) -@dispatch.register(token="tir", type_name="For") +@dispatch.register(token="tirx", type_name="For") def visit_for(self: Parser, node: doc.For) -> None: - """The for visiting method for tir. + """The for visiting method for tirx. Parameters ---------- @@ -212,9 +212,9 @@ def visit_for(self: Parser, node: doc.For) -> None: self.visit_body(node.body) -@dispatch.register(token="tir", type_name="While") +@dispatch.register(token="tirx", type_name="While") def visit_while(self: Parser, node: doc.While) -> None: - """The while visiting method for tir. + """The while visiting method for tirx. Parameters ---------- @@ -230,9 +230,9 @@ def visit_while(self: Parser, node: doc.While) -> None: self.visit_body(node.body) -@dispatch.register(token="tir", type_name="Assign") +@dispatch.register(token="tirx", type_name="Assign") def visit_assign(self: Parser, node: doc.Assign) -> None: - """The assign visiting method for tir. + """The assign visiting method for tirx. Parameters ---------- @@ -278,9 +278,9 @@ def visit_assign(self: Parser, node: doc.Assign) -> None: self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) -@dispatch.register(token="tir", type_name="AugAssign") +@dispatch.register(token="tirx", type_name="AugAssign") def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None: - """The augmented assign visiting method for tir. + """The augmented assign visiting method for tirx. Parameters ---------- @@ -331,9 +331,9 @@ def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None: self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) -@dispatch.register(token="tir", type_name="AnnAssign") +@dispatch.register(token="tirx", type_name="AnnAssign") def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: - """The annotated assign visiting method for tir. + """The annotated assign visiting method for tirx. Parameters ---------- @@ -352,9 +352,9 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: T.bind(rhs, var=ann_var) -@dispatch.register(token="tir", type_name="With") +@dispatch.register(token="tirx", type_name="With") def visit_with(self: Parser, node: doc.With) -> None: - """The with visiting method for tir. + """The with visiting method for tirx. Parameters ---------- @@ -379,9 +379,9 @@ def visit_with(self: Parser, node: doc.With) -> None: self.visit_body(node.body) -@dispatch.register(token="tir", type_name="FunctionDef") +@dispatch.register(token="tirx", type_name="FunctionDef") def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: - """The function definition visiting method for tir. + """The function definition visiting method for tirx. Parameters ---------- @@ -404,7 +404,7 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: if callable(ret_type): ret_type = PrimType(ret_type().dtype) T.func_ret(ret_type) - with self.with_dispatch_token("tir"): + with self.with_dispatch_token("tirx"): # TODO: handle different types of arguments: # - vararg: arg | None # - kwonlyargs: list[arg] @@ -429,9 +429,9 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: self.function_annotations = supplied_annotation -@dispatch.register(token="tir", type_name="tvm_annotation") +@dispatch.register(token="tirx", type_name="tvm_annotation") def visit_tvm_annotation(self: Parser, node: doc.expr): - """The TVM annotation visiting method for tir. + """The TVM annotation visiting method for tirx. Parameters ---------- @@ -447,9 +447,9 @@ def visit_tvm_annotation(self: Parser, node: doc.expr): return annotation -@dispatch.register(token="tir", type_name="Expr") +@dispatch.register(token="tirx", type_name="Expr") def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: - """The expr statement visiting method for tir. + """The expr statement visiting method for tirx. Parameters ---------- @@ -474,24 +474,24 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: elif isinstance(res, PrimExpr): T.evaluate(res) elif isinstance(res, int | bool): - T.evaluate(tvm.tir.const(res)) + T.evaluate(tvm.tirx.const(res)) elif isinstance(res, tvm.relax.Call) and not res.args: # Using GlobalVar.__call__ with no arguments is ambiguous, as # each IR has a different function Call representation. If # this occurs, convert to the TIR representation. - T.evaluate(tvm.tir.call_tir(res.op)) + T.evaluate(tvm.tirx.call_tir(res.op)) elif isinstance(res, str): # Ignore docstrings pass - elif isinstance(res, tvm.tir.stmt.BufferStore): + elif isinstance(res, tvm.tirx.stmt.BufferStore): T.buffer_store(res.buffer, res.value, res.indices, res.predicate) else: self.report_error(node, f"Parsing resulted in unexpected type {type(res)}") -@dispatch.register(token="tir", type_name="If") +@dispatch.register(token="tirx", type_name="If") def visit_if(self: Parser, node: doc.If) -> None: - """The if visiting method for tir. + """The if visiting method for tirx. Parameters ---------- @@ -503,7 +503,7 @@ def visit_if(self: Parser, node: doc.If) -> None: """ with self.var_table.with_frame(): predicate = self.eval_expr(node.test) - if isinstance(predicate, PrimExpr | tvm.tir.expr.ExprOp): + if isinstance(predicate, PrimExpr | tvm.tirx.expr.ExprOp): with T.If(self.eval_expr(node.test)): with T.Then(): with self.var_table.with_frame(): @@ -526,9 +526,9 @@ def visit_if(self: Parser, node: doc.If) -> None: ) -@dispatch.register(token="tir", type_name="Assert") +@dispatch.register(token="tirx", type_name="Assert") def visit_assert(self: Parser, node: doc.Assert) -> None: - """The assert visiting method for tir. + """The assert visiting method for tirx. Parameters ---------- @@ -556,7 +556,7 @@ def visit_assert(self: Parser, node: doc.Assert) -> None: f"got {len(msg)} elements", ) kind_str, parts = msg - if isinstance(kind_str, tvm.tir.StringImm): + if isinstance(kind_str, tvm.tirx.StringImm): kind_str = kind_str.value if not isinstance(kind_str, str): self.report_error( @@ -568,16 +568,16 @@ def visit_assert(self: Parser, node: doc.Assert) -> None: message = parts if isinstance(message, list | tuple): - message = [p.value if isinstance(p, tvm.tir.StringImm) else str(p) for p in message] + message = [p.value if isinstance(p, tvm.tirx.StringImm) else str(p) for p in message] frame = T.Assert(cond, message, error_kind=kind) frame.add_callback(partial(frame.__exit__, None, None, None)) frame.__enter__() -@dispatch.register(token="tir", type_name="Return") +@dispatch.register(token="tirx", type_name="Return") def visit_return(self: Parser, node: doc.Return) -> None: - """The return visiting method for tir. + """The return visiting method for tirx. Parameters ---------- @@ -590,12 +590,12 @@ def visit_return(self: Parser, node: doc.Return) -> None: value = self.eval_expr(node.value) if value is None: self.report_error(node, "Expression to be returned must be a PrimExpr") - T.evaluate(tvm.tir.ret(value)) + T.evaluate(tvm.tirx.ret(value)) -@dispatch.register(token="tir", type_name="Continue") +@dispatch.register(token="tirx", type_name="Continue") def visit_continue(self: Parser, node: doc.Continue) -> None: # pylint:disable=unused-argument - """The continue visiting method for tir. + """The continue visiting method for tirx. Parameters ---------- @@ -605,12 +605,12 @@ def visit_continue(self: Parser, node: doc.Continue) -> None: # pylint:disable= node : doc.Continue The doc AST continue node. """ - T.evaluate(tvm.tir.continue_loop()) + T.evaluate(tvm.tirx.continue_loop()) -@dispatch.register(token="tir", type_name="Break") +@dispatch.register(token="tirx", type_name="Break") def visit_break(self: Parser, node: doc.Break) -> None: # pylint:disable=unused-argument - """The continue visiting method for tir. + """The continue visiting method for tirx. Parameters ---------- @@ -620,12 +620,12 @@ def visit_break(self: Parser, node: doc.Break) -> None: # pylint:disable=unused node : doc.Break The doc AST break node. """ - T.evaluate(tvm.tir.break_loop()) + T.evaluate(tvm.tirx.break_loop()) -@dispatch.register(token="tir", type_name="tvm_declare_function") +@dispatch.register(token="tirx", type_name="tvm_declare_function") def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar: - """The function declaration step for tir + """The function declaration step for tirx Parameters ---------- @@ -662,5 +662,5 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar IRBuilder.name(arg.arg, ann) arg_annotations.append(ann) - func_signature = tvm.tir.PrimFunc(arg_annotations, None, ret_type=ret_type) + func_signature = tvm.tirx.PrimFunc(arg_annotations, None, ret_type=ret_type) return I.decl_function(node.name, func_signature) diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index a5f3a08c210b..2f2b04995704 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -24,7 +24,7 @@ from tvm_ffi.access_path import AccessPath from tvm.runtime import Object -from tvm.tir import FloatImm, IntImm +from tvm.tirx import FloatImm, IntImm from . import _ffi_api diff --git a/python/tvm/script/tir.py b/python/tvm/script/tirx.py similarity index 89% rename from python/tvm/script/tir.py rename to python/tvm/script/tirx.py index 8c4057e6ac9e..21a2c6f42712 100644 --- a/python/tvm/script/tir.py +++ b/python/tvm/script/tirx.py @@ -17,4 +17,4 @@ # ruff: noqa: F403 """TVM Script APIs of TVM Python Package for TIR""" -from .parser.tir import * # pylint: disable=redefined-builtin,unused-wildcard-import,wildcard-import +from .parser.tirx import * # pylint: disable=redefined-builtin,unused-wildcard-import,wildcard-import diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py index 3715aaf070ba..d7a47836b2f5 100644 --- a/python/tvm/target/datatype.py +++ b/python/tvm/target/datatype.py @@ -24,20 +24,20 @@ import tvm from tvm.runtime import DataType, convert -from tvm.tir import call_intrin -from tvm.tir.expr import ( +from tvm.tirx import call_intrin +from tvm.tirx.expr import ( BinaryOpExpr as _BinaryOpExpr, ) -from tvm.tir.expr import ( +from tvm.tirx.expr import ( Call as _Call, ) -from tvm.tir.expr import ( +from tvm.tirx.expr import ( Cast as _Cast, ) -from tvm.tir.expr import ( +from tvm.tirx.expr import ( FloatImm as _FloatImm, ) -from tvm.tir.op import call_pure_extern +from tvm.tirx.op import call_pure_extern def register(type_name, type_code): @@ -338,7 +338,7 @@ def lower_ite(ite_op): ---------- ite_op : Op Takes an if then else op and returns a - call to tir.if_then_else function, passing the op's + call to tirx.if_then_else function, passing the op's arguments. The return type of the call if a uint of the same width as the custom type is returned. """ @@ -350,7 +350,7 @@ def lower_ite(ite_op): dtype += "x" + str(t.lanes) return call_intrin( dtype, - "tir.if_then_else", + "tirx.if_then_else", convert(ite_op.args[0]), convert(ite_op.args[1]), convert(ite_op.args[2]), @@ -366,7 +366,7 @@ def lower_call_pure_extern(op): ---------- ite_op : Op Takes a call_pure_extern op and returns a - call to tir.call_pure_extern function, passing the op's + call to tirx.call_pure_extern function, passing the op's arguments. The return type of the call if a uint of the same width as the custom type is returned. """ @@ -376,4 +376,4 @@ def lower_call_pure_extern(op): dtype = "uint" + str(t.bits) if t.lanes > 1: dtype += "x" + str(t.lanes) - return call_intrin(dtype, "tir.call_pure_extern", *op.args) + return call_intrin(dtype, "tirx.call_pure_extern", *op.args) diff --git a/python/tvm/target/intrin.py b/python/tvm/target/intrin.py index a79f13acc45a..0bb37adf6149 100644 --- a/python/tvm/target/intrin.py +++ b/python/tvm/target/intrin.py @@ -17,7 +17,7 @@ """Target dependent intrinsic registration.""" from tvm.ir import register_intrin_lowering -from tvm.tir import call_pure_extern +from tvm.tirx import call_pure_extern def _rule_float_suffix(op): @@ -41,7 +41,7 @@ def _rule_float_suffix(op): register_intrin_lowering : The registration function for intrinsic lowering rule. """ name = op.op.name - assert name.startswith("tir.") + assert name.startswith("tirx.") prefix = name[4:] if op.dtype == "float32": @@ -77,6 +77,6 @@ def _rule_float_direct(op): # opencl pattern for exp -register_intrin_lowering("tir.exp", target="opencl", f=_rule_float_direct, level=99) +register_intrin_lowering("tirx.exp", target="opencl", f=_rule_float_direct, level=99) # default pattern for exp -register_intrin_lowering("tir.exp", target="default", f=_rule_float_suffix, level=99) +register_intrin_lowering("tirx.exp", target="default", f=_rule_float_suffix, level=99) diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index f94dcd7699a7..8ad8671605ad 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -18,16 +18,16 @@ # pylint: disable=unused-import, redefined-builtin, wildcard-import """Namespace for Tensor Expression Language""" -# expose all operators in tvm tir.op -from tvm.tir import any, all, min_value, max_value, trace -from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, sqrt, rsqrt, floor, ceil -from tvm.tir import sinh, cosh, log2, log10 -from tvm.tir import asin, asinh, acos, acosh, atan, atanh -from tvm.tir import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else -from tvm.tir import isnan, isfinite, isinf -from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, logaddexp -from tvm.tir import comm_reducer, min, max, sum -from tvm.tir import add, subtract, multiply +# expose all operators in tvm tirx.op +from tvm.tirx import any, all, min_value, max_value, trace +from tvm.tirx import exp, erf, tanh, sigmoid, log, tan, cos, sin, sqrt, rsqrt, floor, ceil +from tvm.tirx import sinh, cosh, log2, log10 +from tvm.tirx import asin, asinh, acos, acosh, atan, atanh +from tvm.tirx import trunc, abs, round, nearbyint, power, popcount, fmod, if_then_else +from tvm.tirx import isnan, isfinite, isinf +from tvm.tirx import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod, logaddexp +from tvm.tirx import comm_reducer, min, max, sum +from tvm.tirx import add, subtract, multiply from .tensor import TensorSlice, Tensor from .tag import tag_scope diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index e01ace8da572..fdc6c5f95a46 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -22,8 +22,8 @@ from numbers import Integral as _Integral import tvm.arith._ffi_api -import tvm.tir -import tvm.tir._ffi_api +import tvm.tirx +import tvm.tirx._ffi_api from tvm.ir import Array from tvm.runtime import convert @@ -90,7 +90,7 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None, varargs_names=N if tag != "": raise ValueError("nested tag is not allowed for now") tag = _tag.TagScope.get_current().tag - shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape + shape = (shape,) if isinstance(shape, tvm.tirx.PrimExpr) else shape # for python3 shape = tuple([int(s) if isinstance(s, float) else s for s in shape]) out_ndim = len(shape) @@ -125,7 +125,7 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None, varargs_names=N f"args={len(arg_names)}, dimension={out_ndim}" ) - dim_var = [tvm.tir.IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])] + dim_var = [tvm.tirx.IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])] body = fcompute(*[v.var for v in dim_var]) if not isinstance(body, list | tuple): @@ -199,7 +199,7 @@ def scan(init, update, state_placeholder, inputs=None, name="scan", tag="", attr inputs = [] if len(init) != len(update) or len(init) != len(state_placeholder): raise ValueError("init, update, state_placeholder must have same length") - axis = tvm.tir.IterVar((init[0].shape[0], update[0].shape[0]), f"{name}.idx", 3) + axis = tvm.tirx.IterVar((init[0].shape[0], update[0].shape[0]), f"{name}.idx", 3) op = _ffi_api.ScanOp(name, tag, attrs, axis, init, update, state_placeholder, inputs) res = [op.output(i) for i in range(len(update))] return res[0] if len(res) == 1 else res @@ -233,12 +233,12 @@ def extern( .. note:: **Parameters** - - **ins** (list of :any:`tvm.tir.Buffer`) - Placeholder for each inputs - - **outs** (list of :any:`tvm.tir.Buffer`) - Placeholder for each outputs + - **ins** (list of :any:`tvm.tirx.Buffer`) - Placeholder for each inputs + - **outs** (list of :any:`tvm.tirx.Buffer`) - Placeholder for each outputs **Returns** - - **stmt** (:any:`tvm.tir.Stmt`) - The statement that carries out array computation. + - **stmt** (:any:`tvm.tirx.Stmt`) - The statement that carries out array computation. name: str, optional The name hint of the tensor @@ -247,10 +247,10 @@ def extern( The data types of outputs, by default dtype will be same as inputs. - in_buffers: tvm.tir.Buffer or list of tvm.tir.Buffer, optional + in_buffers: tvm.tirx.Buffer or list of tvm.tirx.Buffer, optional Input buffers. - out_buffers: tvm.tir.Buffer or list of tvm.tir.Buffer, optional + out_buffers: tvm.tirx.Buffer or list of tvm.tirx.Buffer, optional Output buffers. @@ -275,7 +275,7 @@ def extern( A = te.placeholder((n, l), name="A") B = te.placeholder((l, m), name="B") C = te.extern((n, m), [A, B], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], 0, 0), name="C") """ @@ -283,8 +283,8 @@ def extern( if tag != "": raise ValueError("nested tag is not allowed for now") tag = _tag.TagScope.get_current().tag - shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr | _Integral) else shape - if shape == () or isinstance(shape[0], tvm.tir.PrimExpr | _Integral): + shape = (shape,) if isinstance(shape, tvm.tirx.PrimExpr | _Integral) else shape + if shape == () or isinstance(shape[0], tvm.tirx.PrimExpr | _Integral): shape = [shape] if in_buffers is not None: in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers @@ -306,8 +306,8 @@ def extern( raise ValueError("expect inputs to be tensor") if in_buffers is None: input_placeholders.append( - tvm.tir.decl_buffer( - t.shape, t.dtype, t.op.name, elem_offset=tvm.tir.Var("elem_offset", "int32") + tvm.tirx.decl_buffer( + t.shape, t.dtype, t.op.name, elem_offset=tvm.tirx.Var("elem_offset", "int32") ) ) types.add(t.dtype) @@ -323,12 +323,14 @@ def extern( for shp, dt in zip(shape, dtype): output_placeholders.append( - tvm.tir.decl_buffer(shp, dt, name, elem_offset=tvm.tir.Var("elem_offset", "int32")) + tvm.tirx.decl_buffer( + shp, dt, name, elem_offset=tvm.tirx.Var("elem_offset", "int32") + ) ) body = fcompute(input_placeholders, output_placeholders) - if isinstance(body, tvm.tir.PrimExpr): - body = tvm.tir.Evaluate(body) - if not isinstance(body, tvm.tir.Stmt): + if isinstance(body, tvm.tirx.PrimExpr): + body = tvm.tirx.Evaluate(body) + if not isinstance(body, tvm.tirx.Stmt): raise ValueError( f"Function '{fcompute.__name__}' should return PrimExpr or Stmt, but it returned " f"'{type(body)}'" @@ -339,7 +341,7 @@ def extern( return res[0] if len(res) == 1 else res -def extern_primfunc(input_tensors: list[_tensor.Tensor], primfunc: tvm.tir.PrimFunc, **kwargs): +def extern_primfunc(input_tensors: list[_tensor.Tensor], primfunc: tvm.tirx.PrimFunc, **kwargs): """Compute tensors via a schedulable TIR PrimFunc Parameters @@ -439,10 +441,10 @@ def var(name="tindex", dtype="int32", span=None): Returns ------- - var : tir.Var + var : tirx.Var The result symbolic variable. """ - return tvm.tir.Var(name, dtype, span) + return tvm.tirx.Var(name, dtype, span) def const(value, dtype="int32", span=None): @@ -464,7 +466,7 @@ def const(value, dtype="int32", span=None): const : PrimExpr The result constant expr. """ - return tvm.tir.const(value, dtype, span) + return tvm.tirx.const(value, dtype, span) def size_var(name="size", dtype="int32", span=None): @@ -486,7 +488,7 @@ def size_var(name="size", dtype="int32", span=None): var : SizeVar The result symbolic shape variable. """ - return tvm.tir.SizeVar(name, dtype, span) + return tvm.tirx.SizeVar(name, dtype, span) def thread_axis(dom=None, tag="", name="", span=None): @@ -517,7 +519,7 @@ def thread_axis(dom=None, tag="", name="", span=None): if not tag: raise ValueError("tag must be given as Positional or keyword argument") name = name if name else tag - return tvm.tir.IterVar(dom, name, 1, tag, span) + return tvm.tirx.IterVar(dom, name, 1, tag, span) def reduce_axis(dom, name="rv", thread_tag="", span=None): @@ -542,17 +544,17 @@ def reduce_axis(dom, name="rv", thread_tag="", span=None): axis : IterVar An iteration variable representing the value. """ - return tvm.tir.IterVar(dom, name, 2, thread_tag, span) + return tvm.tirx.IterVar(dom, name, 2, thread_tag, span) def create_prim_func( - ops: list[_tensor.Tensor | tvm.tir.Var], index_dtype_override: str | None = None -) -> tvm.tir.PrimFunc: + ops: list[_tensor.Tensor | tvm.tirx.Var], index_dtype_override: str | None = None +) -> tvm.tirx.PrimFunc: """Create a TensorIR PrimFunc from tensor expression Parameters ---------- - ops : List[Union[_tensor.Tensor, tvm.tir.Var]] + ops : List[Union[_tensor.Tensor, tvm.tirx.Var]] The source expression. Example @@ -594,7 +596,7 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: Returns ------- - func : tir.PrimFunc + func : tirx.PrimFunc The created function. """ if not isinstance(ops, list | tuple | Array): @@ -602,4 +604,4 @@ def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: return _ffi_api.CreatePrimFunc(ops, index_dtype_override) -AXIS_SEPARATOR = tvm.tir.IndexMap.AXIS_SEPARATOR +AXIS_SEPARATOR = tvm.tirx.IndexMap.AXIS_SEPARATOR diff --git a/python/tvm/te/tag.py b/python/tvm/te/tag.py index ab1d0b44f20d..13c490c78827 100644 --- a/python/tvm/te/tag.py +++ b/python/tvm/te/tag.py @@ -91,6 +91,6 @@ def tag_scope(tag): # or use tag_scope as decorator @tvm.te.tag_scope(tag="conv") def compute_relu(data): - return te.compute(data.shape, lambda *i: tvm.tir.Select(data(*i) < 0, 0.0, data(*i))) + return te.compute(data.shape, lambda *i: tvm.tirx.Select(data(*i) < 0, 0.0, data(*i))) """ return TagScope(tag) diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 6b21dad13186..531915c6798a 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -20,8 +20,8 @@ import tvm_ffi from tvm.runtime import Object, ObjectConvertible -from tvm.tir import DataProducer -from tvm.tir import expr as _expr +from tvm.tirx import DataProducer +from tvm.tirx import expr as _expr from . import _ffi_api diff --git a/python/tvm/testing/tir.py b/python/tvm/testing/tir.py index 184eb4e98cbb..af5c8839a2ca 100644 --- a/python/tvm/testing/tir.py +++ b/python/tvm/testing/tir.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, import-outside-toplevel, unused-variable -"""Common utility functions in TVM tir""" +"""Common utility functions in TVM tirx""" def mma_schedule( diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index a6744a3b9e01..c5937eec4ec8 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -89,7 +89,7 @@ def test_something(): import tvm.contrib.hexagon._ci_env_check as hexagon import tvm.contrib.utils import tvm.te -import tvm.tir +import tvm.tirx from tvm.contrib import cudnn, nvcc, rocm from tvm.error import TVMError from tvm.target import codegen @@ -266,10 +266,10 @@ def assert_prim_expr_equal(lhs, rhs): Parameters ---------- - lhs : tvm.tir.PrimExpr + lhs : tvm.tirx.PrimExpr The left operand. - rhs : tvm.tir.PrimExpr + rhs : tvm.tirx.PrimExpr The left operand. """ ana = tvm.arith.Analyzer() @@ -292,13 +292,13 @@ def check_bool_expr_is_true(bool_expr, vranges, cond=None): ---------- bool_expr : tvm.ir.PrimExpr Boolean expression to check - vranges: Dict[tvm.tir.expr.Var, tvm.ir.Range] + vranges: Dict[tvm.tirx.expr.Var, tvm.ir.Range] Free variables and their ranges cond: tvm.ir.PrimExpr extra conditions needs to be satisfied. """ if cond is not None: - bool_expr = tvm.te.any(tvm.tir.Not(cond), bool_expr) + bool_expr = tvm.te.any(tvm.tirx.Not(cond), bool_expr) def _run_expr(expr, vranges): """Evaluate expr for every value of free variables @@ -307,7 +307,7 @@ def _run_expr(expr, vranges): def _compute_body(*us): vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} - return tvm.tir.stmt_functor.substitute(expr, vmap) + return tvm.tirx.stmt_functor.substitute(expr, vmap) A = tvm.te.compute([r.extent.value for v, r in vranges.items()], _compute_body) args = [tvm.runtime.empty(A.shape, A.dtype)] @@ -335,7 +335,7 @@ def check_int_constraints_trans_consistency(constraints_trans, vranges=None): ---------- constraints_trans : arith.IntConstraintsTransform Integer constraints transformation - vranges: Dict[tvm.tir.Var, tvm.ir.Range] + vranges: Dict[tvm.tirx.Var, tvm.ir.Range] Free variables and their ranges """ if vranges is None: @@ -347,28 +347,28 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): all_vranges.update({v: r for v, r in constraints1.ranges.items()}) # Check that the transformation is injective - cond_on_vars = tvm.tir.const(1, "bool") + cond_on_vars = tvm.tirx.const(1, "bool") for v in constraints1.variables: if v in varmap: # variable mapping is consistent - v_back = ana.simplify(tvm.tir.stmt_functor.substitute(varmap[v], backvarmap)) + v_back = ana.simplify(tvm.tirx.stmt_functor.substitute(varmap[v], backvarmap)) cond_on_vars = tvm.te.all(cond_on_vars, v == v_back) # Also we have to check that the new relations are true when old relations are true - cond_subst = tvm.tir.stmt_functor.substitute( - tvm.te.all(tvm.tir.const(1, "bool"), *constraints2.relations), backvarmap + cond_subst = tvm.tirx.stmt_functor.substitute( + tvm.te.all(tvm.tirx.const(1, "bool"), *constraints2.relations), backvarmap ) # We have to include relations from vranges too for v in constraints2.variables: if v in constraints2.ranges: r = constraints2.ranges[v] range_cond = tvm.te.all(v >= r.min, v < r.min + r.extent) - range_cond = tvm.tir.stmt_functor.substitute(range_cond, backvarmap) + range_cond = tvm.tirx.stmt_functor.substitute(range_cond, backvarmap) cond_subst = tvm.te.all(cond_subst, range_cond) cond_subst = ana.simplify(cond_subst) check_bool_expr_is_true( tvm.te.all(cond_subst, cond_on_vars), all_vranges, - cond=tvm.te.all(tvm.tir.const(1, "bool"), *constraints1.relations), + cond=tvm.te.all(tvm.tirx.const(1, "bool"), *constraints1.relations), ) _check_forward( diff --git a/python/tvm/tir/__init__.py b/python/tvm/tirx/__init__.py similarity index 100% rename from python/tvm/tir/__init__.py rename to python/tvm/tirx/__init__.py diff --git a/python/tvm/tir/_ffi_api.py b/python/tvm/tirx/_ffi_api.py similarity index 92% rename from python/tvm/tir/_ffi_api.py rename to python/tvm/tirx/_ffi_api.py index 17737a22abb8..96b67a0ef74b 100644 --- a/python/tvm/tir/_ffi_api.py +++ b/python/tvm/tirx/_ffi_api.py @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI APIs for tvm.tir""" +"""FFI APIs for tvm.tirx""" import tvm_ffi -tvm_ffi.init_ffi_api("tir", __name__) +tvm_ffi.init_ffi_api("tirx", __name__) diff --git a/python/tvm/tir/analysis/__init__.py b/python/tvm/tirx/analysis/__init__.py similarity index 100% rename from python/tvm/tir/analysis/__init__.py rename to python/tvm/tirx/analysis/__init__.py diff --git a/python/tvm/tir/transform/_ffi_api.py b/python/tvm/tirx/analysis/_ffi_api.py similarity index 90% rename from python/tvm/tir/transform/_ffi_api.py rename to python/tvm/tirx/analysis/_ffi_api.py index b7077c1f243e..3429e4cfc018 100644 --- a/python/tvm/tir/transform/_ffi_api.py +++ b/python/tvm/tirx/analysis/_ffi_api.py @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI APIs for tvm.tir.transform""" +"""FFI APIs for tvm.tirx.analysis""" import tvm_ffi -tvm_ffi.init_ffi_api("tir.transform", __name__) +tvm_ffi.init_ffi_api("tirx.analysis", __name__) diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tirx/analysis/analysis.py similarity index 95% rename from python/tvm/tir/analysis/analysis.py rename to python/tvm/tirx/analysis/analysis.py index ea78cd87c1b1..e7aa97e99dd7 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tirx/analysis/analysis.py @@ -19,8 +19,8 @@ # pylint: disable=invalid-name from tvm.ir import IRModule -from tvm.tir.expr import Var -from tvm.tir.stmt import PrimExpr +from tvm.tirx.expr import Var +from tvm.tirx.stmt import PrimExpr from .. import Stmt from ..function import PrimFunc @@ -69,7 +69,7 @@ def verify_ssa(func: PrimFunc) -> bool: Parameters ---------- - func: tvm.tir.PrimFunc + func: tvm.tirx.PrimFunc The module to be verified. Returns @@ -85,7 +85,7 @@ def verify_memory(func: PrimFunc) -> bool: Parameters ---------- - func: tvm.tir.PrimFunc + func: tvm.tirx.PrimFunc The module to be verified. Returns @@ -122,7 +122,7 @@ def verify_well_formed(obj: PrimFunc | IRModule, assert_mode: bool = True) -> bo Parameters ---------- - obj: Union[tvm.tir.PrimFunc, tvm.ir.IRModule] + obj: Union[tvm.tirx.PrimFunc, tvm.ir.IRModule] The function or module to be verified. assert_mode: bool diff --git a/python/tvm/tir/backend/__init__.py b/python/tvm/tirx/backend/__init__.py similarity index 100% rename from python/tvm/tir/backend/__init__.py rename to python/tvm/tirx/backend/__init__.py diff --git a/python/tvm/tir/backend/adreno/__init__.py b/python/tvm/tirx/backend/adreno/__init__.py similarity index 100% rename from python/tvm/tir/backend/adreno/__init__.py rename to python/tvm/tirx/backend/adreno/__init__.py diff --git a/python/tvm/tir/buffer.py b/python/tvm/tirx/buffer.py similarity index 97% rename from python/tvm/tir/buffer.py rename to python/tvm/tirx/buffer.py index bb00cfe90bb4..8e54d2234d78 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tirx/buffer.py @@ -27,7 +27,7 @@ from . import _ffi_api -@tvm_ffi.register_object("tir.Buffer") +@tvm_ffi.register_object("tirx.Buffer") class Buffer(Object, Scriptable): """Symbolic data buffer in TVM. @@ -225,8 +225,8 @@ def __getitem__(self, indices): stop = self.shape[i] if index.stop is None else index.stop step = 1 if index.step is None else index.step # We should ensure the dtype of start is the same with that of step. - if isinstance(start, tvm.tir.expr.PrimExpr) and isinstance(step, int): - step = tvm.tir.expr.IntImm(start.dtype, step) + if isinstance(start, tvm.tirx.expr.PrimExpr) and isinstance(step, int): + step = tvm.tirx.expr.IntImm(start.dtype, step) lanes = analyzer.simplify((stop - start + step - 1) // step) if lanes == 1: expr_indices.append(start) @@ -269,7 +269,7 @@ def decl_buffer( name : str, optional The name of the buffer. - data : tir.Var, optional + data : tirx.Var, optional The data pointer in the buffer. strides: array of Expr @@ -308,7 +308,7 @@ def decl_buffer( Returns ------- - buffer : tvm.tir.Buffer + buffer : tvm.tirx.Buffer The created buffer Note @@ -357,6 +357,6 @@ def decl_buffer( ) -@tvm_ffi.register_object("tir.DataProducer") +@tvm_ffi.register_object("tirx.DataProducer") class DataProducer(Object): pass diff --git a/python/tvm/tir/build.py b/python/tvm/tirx/build.py similarity index 90% rename from python/tvm/tir/build.py rename to python/tvm/tirx/build.py index 5f6bf8adbfb5..d310afee7938 100644 --- a/python/tvm/tir/build.py +++ b/python/tvm/tirx/build.py @@ -22,7 +22,7 @@ from tvm import ir from tvm.ir.module import IRModule from tvm.target import Target -from tvm.tir import PrimFunc +from tvm.tirx import PrimFunc def split_host_device_mods(mod: IRModule) -> tuple[IRModule, dict[Target, IRModule]]: @@ -72,14 +72,14 @@ def main_kernel(A: T.handle, B: T.handle, C: T.handle, length: T.int32): T.func_attr({"target": T.target({"arch": "sm_90", "keys": ["cuda", "gpu"], "kind": "cuda"}), "calling_conv": 2, # kDeviceKernelLaunch for device kernels - "tir.is_global_func": True}) + "tirx.is_global_func": True}) # ... kernel implementation @T.prim_func def main(self_handle: T.handle, args: T.handle, num_args: T.int32, result: T.handle): T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "c"}), "calling_conv": 1, # kCPackedFunc for entry functions - "tir.is_entry_func": True}) + "tirx.is_entry_func": True}) # ... main function implementation The function will return: @@ -102,8 +102,8 @@ def is_host_func(f): target = f.attrs.get("target", tvm.target.Target("llvm")) return str(target.kind) in ["llvm", "c"] - host_mod = tvm.tir.transform.Filter(is_host_func)(mod) - device_mod = tvm.tir.transform.Filter(lambda f: not is_host_func(f))(mod) + host_mod = tvm.tirx.transform.Filter(is_host_func)(mod) + device_mod = tvm.tirx.transform.Filter(lambda f: not is_host_func(f))(mod) # TODO(syfeng): Here we use str as key since target hash is not correct target_str2target = {} device_func_dict = {} @@ -121,8 +121,8 @@ def is_host_func(f): def codegen_build(mod: IRModule, target: Target) -> tvm.runtime.Module: """Build a runtime module from an IRModule and a Target.""" - if tvm.ir.transform.PassContext.current().config.get("tir.disable_assert", False): - mod = tvm.tir.transform.SkipAssert()(mod) + if tvm.ir.transform.PassContext.current().config.get("tirx.disable_assert", False): + mod = tvm.tirx.transform.SkipAssert()(mod) build_f_name = "target.build." + target.kind.name bf = tvm.get_global_func(build_f_name) if bf is None: @@ -188,7 +188,7 @@ def build( assert target_to_bind is not None target_to_bind = Target(target_to_bind) - # Step 1: Determine the target to search for tir pipeline + # Step 1: Determine the target to search for tirx pipeline target = Target.current() if target is None else target if target is None: for func in mod.functions.values(): @@ -212,25 +212,25 @@ def build( target_to_bind = target_to_bind.with_host(target_host) # Step 3: Bind the target to the input module - mod = tvm.tir.transform.BindTarget(target_to_bind)(mod) + mod = tvm.tirx.transform.BindTarget(target_to_bind)(mod) - # Step 4: Apply the tir pipeline + # Step 4: Apply the tirx pipeline if pipeline is not None: # custom pipeline if isinstance(pipeline, str): - pipeline = tvm.tir.get_tir_pipeline(pipeline) + pipeline = tvm.tirx.get_tir_pipeline(pipeline) else: # default pipeline depends on the target - pipeline = tvm.tir.get_default_tir_pipeline(target) + pipeline = tvm.tirx.get_default_tir_pipeline(target) mod = pipeline(mod) # Step 5: Get host and device modules host_mod, device_mod_dict = split_host_device_mods(mod) # Step 6: Apply finalization passes - host_mod = tvm.tir.pipeline.finalize_host_passes()(host_mod) + host_mod = tvm.tirx.pipeline.finalize_host_passes()(host_mod) device_mod_dict = { - target: tvm.tir.pipeline.finalize_device_passes()(device_mod) + target: tvm.tirx.pipeline.finalize_device_passes()(device_mod) for target, device_mod in device_mod_dict.items() } @@ -238,4 +238,4 @@ def build( return tir_to_runtime(host_mod, device_mod_dict, target_host) -tvm.register_global_func("tir.build", build) +tvm.register_global_func("tirx.build", build) diff --git a/python/tvm/tir/expr.py b/python/tvm/tirx/expr.py similarity index 95% rename from python/tvm/tir/expr.py rename to python/tvm/tirx/expr.py index 0aa686e3a846..e38026855cc3 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tirx/expr.py @@ -22,9 +22,9 @@ .. code-block:: python - x = tvm.tir.Var("n", "int32") + x = tvm.tirx.Var("n", "int32") y = x + 2 - assert(isinstance(y, tvm.tir.Add)) + assert(isinstance(y, tvm.tirx.Add)) assert(y.a == x) """ @@ -180,7 +180,7 @@ def __ge__(self, other: PrimExpr) -> PrimExpr: def __nonzero__(self): raise ValueError( "Cannot use and / or / not operator to Expr, hint: " - + "use tvm.tir.all / tvm.tir.any instead" + + "use tvm.tirx.all / tvm.tirx.any instead" ) def __bool__(self) -> bool: @@ -345,7 +345,7 @@ class LogicalExpr(PrimExprWithOp): pass -@tvm_ffi.register_object("tir.Var") +@tvm_ffi.register_object("tirx.Var") class Var(PrimExprWithOp): """Symbolic variable. @@ -368,7 +368,7 @@ def __init__(self, name: str, dtype: str | ir.Type, span: Span | None = None) -> self.__init_handle_by_constructor__(_ffi_api.Var, name, dtype, span) # type: ignore -@tvm_ffi.register_object("tir.SizeVar") +@tvm_ffi.register_object("tirx.SizeVar") class SizeVar(Var): """Symbolic variable to represent a tensor index size which is greater or equal to zero. @@ -390,7 +390,7 @@ def __init__(self, name: str, dtype: str | ir.Type, span: Span | None = None) -> self.__init_handle_by_constructor__(_ffi_api.SizeVar, name, dtype, span) # type: ignore -@tvm_ffi.register_object("tir.IterVar") +@tvm_ffi.register_object("tirx.IterVar") class IterVar(ExprOp, Object, Scriptable): """Represent iteration variable. @@ -468,7 +468,7 @@ def __init__( ) -@tvm_ffi.register_object("tir.CommReducer") +@tvm_ffi.register_object("tirx.CommReducer") class CommReducer(Object, Scriptable): """Commutative reduce operator @@ -513,7 +513,7 @@ def __init__( ) -@tvm_ffi.register_object("tir.Reduce") +@tvm_ffi.register_object("tirx.Reduce") class Reduce(PrimExprWithOp): """Reduce node. @@ -646,7 +646,7 @@ def __bool__(self) -> bool: return self.__nonzero__() -@tvm_ffi.register_object("tir.StringImm") # type: ignore +@tvm_ffi.register_object("tirx.StringImm") # type: ignore class StringImm(ConstExpr): """String constant. @@ -678,7 +678,7 @@ def __hash__(self) -> int: return PrimExpr.__hash__(self) -@tvm_ffi.register_object("tir.Cast") +@tvm_ffi.register_object("tirx.Cast") class Cast(PrimExprWithOp): """Cast expression. @@ -700,7 +700,7 @@ def __init__(self, dtype, value, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Cast, dtype, value, span) # type: ignore -@tvm_ffi.register_object("tir.Add") +@tvm_ffi.register_object("tirx.Add") class Add(BinaryOpExpr): """Add node. @@ -720,7 +720,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Add, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.Sub") +@tvm_ffi.register_object("tirx.Sub") class Sub(BinaryOpExpr): """Sub node. @@ -740,7 +740,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Sub, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.Mul") +@tvm_ffi.register_object("tirx.Mul") class Mul(BinaryOpExpr): """Mul node. @@ -760,7 +760,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Mul, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.Div") +@tvm_ffi.register_object("tirx.Div") class Div(BinaryOpExpr): """Div node. @@ -780,7 +780,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Div, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.Mod") +@tvm_ffi.register_object("tirx.Mod") class Mod(BinaryOpExpr): """Mod node. @@ -800,7 +800,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Mod, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.FloorDiv") +@tvm_ffi.register_object("tirx.FloorDiv") class FloorDiv(BinaryOpExpr): """FloorDiv node. @@ -820,7 +820,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.FloorMod") +@tvm_ffi.register_object("tirx.FloorMod") class FloorMod(BinaryOpExpr): """FloorMod node. @@ -840,7 +840,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.FloorMod, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.Min") +@tvm_ffi.register_object("tirx.Min") class Min(BinaryOpExpr): """Min node. @@ -860,7 +860,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Min, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.Max") +@tvm_ffi.register_object("tirx.Max") class Max(BinaryOpExpr): """Max node. @@ -880,7 +880,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Max, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.EQ") +@tvm_ffi.register_object("tirx.EQ") class EQ(CmpExpr): """EQ node. @@ -900,7 +900,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.EQ, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.NE") +@tvm_ffi.register_object("tirx.NE") class NE(CmpExpr): """NE node. @@ -920,7 +920,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.NE, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.LT") +@tvm_ffi.register_object("tirx.LT") class LT(CmpExpr): """LT node. @@ -940,7 +940,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.LT, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.LE") +@tvm_ffi.register_object("tirx.LE") class LE(CmpExpr): """LE node. @@ -960,7 +960,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.GT") +@tvm_ffi.register_object("tirx.GT") class GT(CmpExpr): """GT node. @@ -980,7 +980,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.GT, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.GE") +@tvm_ffi.register_object("tirx.GE") class GE(CmpExpr): """GE node. @@ -1000,7 +1000,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.GE, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.And") +@tvm_ffi.register_object("tirx.And") class And(LogicalExpr): """And node. @@ -1020,7 +1020,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.And, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.Or") +@tvm_ffi.register_object("tirx.Or") class Or(LogicalExpr): """Or node. @@ -1043,7 +1043,7 @@ def __init__(self, a: PrimExpr, b: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Or, a, b, span) # type: ignore -@tvm_ffi.register_object("tir.Not") +@tvm_ffi.register_object("tirx.Not") class Not(LogicalExpr): """Not node. @@ -1062,14 +1062,14 @@ def __init__(self, a: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Not, a, span) # type: ignore -@tvm_ffi.register_object("tir.Select") +@tvm_ffi.register_object("tirx.Select") class Select(PrimExprWithOp): """Select node. Note ---- Select may compute both true_value and false_value. - Use :py:class:`tvm.tir.if_then_else` instead if you want to + Use :py:class:`tvm.tirx.if_then_else` instead if you want to get a conditional expression that only evaluates the correct branch. @@ -1110,7 +1110,7 @@ def __init__( ) -@tvm_ffi.register_object("tir.BufferLoad") +@tvm_ffi.register_object("tirx.BufferLoad") class BufferLoad(PrimExprWithOp): """Buffer load node. @@ -1149,7 +1149,7 @@ def __init__( ) -@tvm_ffi.register_object("tir.ProducerLoad") +@tvm_ffi.register_object("tirx.ProducerLoad") class ProducerLoad(PrimExprWithOp): """Producer load node. @@ -1179,7 +1179,7 @@ def __init__( ) -@tvm_ffi.register_object("tir.Ramp") +@tvm_ffi.register_object("tirx.Ramp") class Ramp(PrimExprWithOp): """Ramp node. @@ -1214,7 +1214,7 @@ def __init__( ) -@tvm_ffi.register_object("tir.Broadcast") +@tvm_ffi.register_object("tirx.Broadcast") class Broadcast(PrimExprWithOp): """Broadcast node. @@ -1237,7 +1237,7 @@ def __init__(self, value: PrimExpr, lanes: PrimExpr, span: Span | None = None) - self.__init_handle_by_constructor__(_ffi_api.Broadcast, value, lanes, span) # type: ignore -@tvm_ffi.register_object("tir.Shuffle") +@tvm_ffi.register_object("tirx.Shuffle") class Shuffle(PrimExprWithOp): """Shuffle node. @@ -1278,7 +1278,7 @@ class CallEffectKind: Opaque = UpdateState -@tvm_ffi.register_object("tir.Call") +@tvm_ffi.register_object("tirx.Call") class Call(PrimExprWithOp): """Call node. @@ -1305,11 +1305,11 @@ def __init__( self, dtype: str, op: Op | str, args: list[PrimExpr], span: Span | None = None ) -> None: if isinstance(op, str): - if not op.startswith("tir."): + if not op.startswith("tirx."): raise ValueError( ( "Cannot handle str op argument %s. This function only handles str " - + "argument with the tir namespace. If you are " + + "argument with the tirx namespace. If you are " + "certain about the intrinsic name, pass in Op.get(name) instead" ) % op @@ -1318,7 +1318,7 @@ def __init__( self.__init_handle_by_constructor__(_ffi_api.Call, dtype, op, args, span) # type: ignore -@tvm_ffi.register_object("tir.Let") +@tvm_ffi.register_object("tirx.Let") class Let(PrimExprWithOp): """Let node. diff --git a/python/tvm/tir/function.py b/python/tvm/tirx/function.py similarity index 93% rename from python/tvm/tir/function.py rename to python/tvm/tirx/function.py index f755fc1b0742..67b7149c4609 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tirx/function.py @@ -35,22 +35,22 @@ from .expr import PrimExpr, Var -@tvm_ffi.register_object("tir.PrimFunc") +@tvm_ffi.register_object("tirx.PrimFunc") class PrimFunc(BaseFunc, Scriptable): """A function declaration expression. Parameters ---------- - params: List[Union[tvm.tir.Var, tvm.tir.Buffer]] + params: List[Union[tvm.tirx.Var, tvm.tirx.Buffer]] List of input parameters to the function. - body: tvm.tir.Stmt + body: tvm.tirx.Stmt The body of the function. ret_type: tvm.ir.Type The return type annotation of the function. - buffer_map : Map[tvm.tir.Var, tvm.tir.Buffer] + buffer_map : Map[tvm.tirx.Var, tvm.tirx.Buffer] The buffer binding map. attrs: Optional[tvm.Attrs] @@ -150,7 +150,7 @@ def mem_copy(a: T.handle, b: T.handle, m: T.int32, n: T.int32) -> None: .. code-block:: python a, _, m, n = mem_copy.params - func = mem_copy.specialize({a: tir.decl_buffer((16, 16))}) + func = mem_copy.specialize({a: tirx.decl_buffer((16, 16))}) # or func = mem_copy.specialize({n: 16, m: 16}) @@ -176,7 +176,7 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None: return _ffi_api.Specialize(self, param_map) # type: ignore -@tvm_ffi.register_object("tir.TensorIntrin") +@tvm_ffi.register_object("tirx.TensorIntrin") class TensorIntrin(Object): """A tensor intrinsic. @@ -230,7 +230,7 @@ def get(name: str, allow_missing: bool = False) -> Optional["TensorIntrin"]: return _ffi_api.TensorIntrinGet(name, allow_missing) # pylint: type: ignore -@tvm_ffi.register_object("tir.IndexMap") +@tvm_ffi.register_object("tirx.IndexMap") class IndexMap(Object): """A mapping from multi-dimensional indices to another set of multi-dimensional indices @@ -278,10 +278,10 @@ def from_func( mapping_function : Callable The function to map from source indices to target indices. - The function should accept `tir.Var` parameters and return - a either a `tir.PrimExpr`, or a list of `tir.PrimExpr`. - Returning a `tir.PrimExpr` is equivalent to returning a - list of length 1 containing that `tir.PrimExpr`. + The function should accept `tirx.Var` parameters and return + a either a `tirx.PrimExpr`, or a list of `tirx.PrimExpr`. + Returning a `tirx.PrimExpr` is equivalent to returning a + list of length 1 containing that `tirx.PrimExpr`. ndim: Optional[int] @@ -333,12 +333,12 @@ def from_func_with_separators( mapping_function : Callable The function to map from source indices to target indices. - The function should accept tir.Var parameters and return - either a `tir.PrimExpr` or a list. Each element of the - returned list should be either a `tir.PrimExpr` or the + The function should accept tirx.Var parameters and return + either a `tirx.PrimExpr` or a list. Each element of the + returned list should be either a `tirx.PrimExpr` or the object `IndexMap.AXIS_SEPARATOR`. Returning a - `tir.PrimExpr` is equivalent to returning a list of length - 1 containing that `tir.PrimExpr`. + `tirx.PrimExpr` is equivalent to returning a list of length + 1 containing that `tirx.PrimExpr`. ndim: Optional[int] @@ -379,13 +379,13 @@ def from_func_with_separators( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, ]: - args.append(tvm.tir.Var(name, index_dtype)) + args.append(tvm.tirx.Var(name, index_dtype)) elif param.kind == inspect.Parameter.VAR_POSITIONAL: var_arg_name = name elif param.kind == inspect.Parameter.KEYWORD_ONLY: - kwargs[name] = tvm.tir.Var(name, index_dtype) + kwargs[name] = tvm.tirx.Var(name, index_dtype) else: raise ValueError("transform_layout mapping may not have *args") @@ -397,7 +397,7 @@ def from_func_with_separators( assert ndim is not None, "ndim must be specified when *args is used" num_var_args = ndim - len(args) - len(kwargs) for i in range(num_var_args): - args.append(tvm.tir.Var(f"{var_arg_name}_{i}", index_dtype)) + args.append(tvm.tirx.Var(f"{var_arg_name}_{i}", index_dtype)) mapping = mapping_function(*args, **kwargs) diff --git a/python/tvm/tir/functor.py b/python/tvm/tirx/functor.py similarity index 99% rename from python/tvm/tir/functor.py rename to python/tvm/tirx/functor.py index 9c684d8d990b..4619c0b51fbb 100644 --- a/python/tvm/tir/functor.py +++ b/python/tvm/tirx/functor.py @@ -96,7 +96,7 @@ ------- .. code-block:: python - @tir.functor.stmt_expr_visitor + @tirx.functor.stmt_expr_visitor class MyStmtExprVisitor(PyStmtExprVisitor): # customize visit function def visit_call_(self, op: Call) -> None: @@ -128,7 +128,7 @@ def visit_call_(self, op: Call) -> None: ------- .. code-block:: python - @tir.functor.stmt_expr_mutator + @tirx.functor.stmt_expr_mutator class MyStmtExprMutator(PyStmtExprMutator): # customize rewrite function def visit_add_(self, op: Add) -> PrimExpr: @@ -144,7 +144,7 @@ def visit_add_(self, op: Add) -> PrimExpr: """ -@tvm_ffi.register_object("tir.PyStmtExprVisitor") +@tvm_ffi.register_object("tirx.PyStmtExprVisitor") class _PyStmtExprVisitor(tvm_ffi.core.Object): """ An internal wrapper to interface between C++ and Python StmtExprVisitor. @@ -945,7 +945,7 @@ def visit_string_imm_(self, op: StringImm) -> None: _ffi_api.PyStmtExprVisitorDefaultVisitExpr(self._outer(), op) # type: ignore -@tvm_ffi.register_object("tir.PyStmtExprMutator") +@tvm_ffi.register_object("tirx.PyStmtExprMutator") class _PyStmtExprMutator(tvm_ffi.core.Object): """ A TVM object to support customization of StmtExprMutator on the python side. diff --git a/python/tvm/tir/generic.py b/python/tvm/tirx/generic.py similarity index 98% rename from python/tvm/tir/generic.py rename to python/tvm/tirx/generic.py index c9ee7b618071..132a09e302a3 100644 --- a/python/tvm/tir/generic.py +++ b/python/tvm/tirx/generic.py @@ -16,7 +16,7 @@ # under the License. """Generic opertors in TVM. We follow the numpy naming convention for this interface -(e.g., tvm.tir.generic.multitply ~ numpy.multiply). +(e.g., tvm.tirx.generic.multitply ~ numpy.multiply). The default implementation is used by tvm.ExprOp. """ diff --git a/python/tvm/tir/op.py b/python/tvm/tirx/op.py similarity index 89% rename from python/tvm/tir/op.py rename to python/tvm/tirx/op.py index da9a9aecd121..6b4a636f3061 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tirx/op.py @@ -22,7 +22,7 @@ import tvm_ffi import tvm -from tvm import tir +from tvm import tirx from tvm.ir import Array, Op, PrimExpr from tvm.ir.base import Span from tvm.runtime import const @@ -34,8 +34,8 @@ def _pack_buffer(buf, span=None): """Build intrinsics that packs the buffer.""" - shape = Call("handle", "tir.tvm_stack_make_shape", buf.shape, span) - strides = Call("handle", "tir.tvm_stack_make_shape", buf.strides, span) if buf.strides else 0 + shape = Call("handle", "tirx.tvm_stack_make_shape", buf.shape, span) + strides = Call("handle", "tirx.tvm_stack_make_shape", buf.strides, span) if buf.strides else 0 pack_args = [ buf.data, shape, @@ -44,7 +44,7 @@ def _pack_buffer(buf, span=None): const(0, dtype=buf.dtype), buf.elem_offset, ] - return Call("handle", Op.get("tir.tvm_stack_make_array"), pack_args, span) + return Call("handle", Op.get("tirx.tvm_stack_make_array"), pack_args, span) def call_packed_lowered(*args, span=None): @@ -73,7 +73,7 @@ def call_packed_lowered(*args, span=None): te.extern : Create tensor with extern function call. """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - return Call("int32", Op.get("tir.tvm_call_packed_lowered"), call_args, span) + return Call("int32", Op.get("tirx.tvm_call_packed_lowered"), call_args, span) def call_cpacked_lowered(*args, span=None): @@ -99,7 +99,7 @@ def call_cpacked_lowered(*args, span=None): te.extern : Create tensor with extern function call. """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - return Call("int32", Op.get("tir.tvm_call_cpacked_lowered"), call_args, span) + return Call("int32", Op.get("tirx.tvm_call_cpacked_lowered"), call_args, span) def call_packed(*args, span=None): @@ -130,7 +130,7 @@ def call_packed(*args, span=None): te.extern : Create tensor with extern function call. """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - return Call("int32", Op.get("tir.tvm_call_packed"), call_args, span) + return Call("int32", Op.get("tirx.tvm_call_packed"), call_args, span) def call_cpacked(*args, span=None): @@ -157,7 +157,7 @@ def call_cpacked(*args, span=None): te.extern : Create tensor with extern function call. """ call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] - return Call("int32", Op.get("tir.tvm_call_cpacked"), call_args, span) + return Call("int32", Op.get("tirx.tvm_call_cpacked"), call_args, span) def call_intrin(dtype, func_name, *args, span=None): @@ -210,7 +210,7 @@ def call_pure_extern(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call(dtype, Op.get("tir.call_pure_extern"), [func_name, *args], span) + return Call(dtype, Op.get("tirx.call_pure_extern"), [func_name, *args], span) def call_extern(dtype, func_name, *args, span=None): @@ -235,13 +235,13 @@ def call_extern(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call(dtype, Op.get("tir.call_extern"), [func_name, *args], span=span) + return Call(dtype, Op.get("tirx.call_extern"), [func_name, *args], span=span) def _require_float_arg(op_name, x): - x = tir.convert(x) + x = tirx.convert(x) if "float" not in x.dtype and "bfloat" not in x.dtype: - raise TypeError(f"tir.{op_name} only supports floating-point inputs, but got {x.dtype}") + raise TypeError(f"tirx.{op_name} only supports floating-point inputs, but got {x.dtype}") return x @@ -280,8 +280,8 @@ def call_llvm_intrin(dtype, name, *args, span=None): raise ValueError(f"Unknown llvm intrinsic function {name}") return call_intrin( dtype, - Op.get("tir.call_llvm_intrin"), - tvm.tir.const(llvm_id, "uint32"), + Op.get("tirx.call_llvm_intrin"), + tvm.tirx.const(llvm_id, "uint32"), *args, span=span, ) @@ -322,8 +322,8 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None): raise ValueError(f"Unknown llvm intrinsic function {name}") return call_intrin( dtype, - Op.get("tir.call_llvm_pure_intrin"), - tvm.tir.const(llvm_id, "uint32"), + Op.get("tirx.call_llvm_pure_intrin"), + tvm.tirx.const(llvm_id, "uint32"), *args, span=span, ) @@ -345,7 +345,7 @@ def tvm_stack_alloca(dtype_str, num): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.tvm_stack_alloca", dtype_str, num) + return call_intrin("handle", "tirx.tvm_stack_alloca", dtype_str, num) def tvm_stack_make_shape(*args): @@ -361,7 +361,7 @@ def tvm_stack_make_shape(*args): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.tvm_stack_make_shape", *args) + return call_intrin("handle", "tirx.tvm_stack_make_shape", *args) def tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset): @@ -394,7 +394,7 @@ def tvm_stack_make_array(data, shape, strides, ndim, arr_dtype, elem_offset): """ return call_intrin( "handle", - "tir.tvm_stack_make_array", + "tirx.tvm_stack_make_array", data, shape, strides, @@ -417,7 +417,7 @@ def assume(cond=None): call : PrimExpr The call expression. """ - return call_intrin("bool", "tir.assume", cond) + return call_intrin("bool", "tirx.assume", cond) def undef(): @@ -428,7 +428,7 @@ def undef(): call : PrimExpr The call expression. """ - return call_intrin("int32", "tir.undef") + return call_intrin("int32", "tirx.undef") def call_tir(global_var: tvm.ir.GlobalVar, *args): @@ -461,7 +461,7 @@ def start_profile_intrinsic(id): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.start_profile_intrinsic", id) + return call_intrin("handle", "tirx.start_profile_intrinsic", id) def end_profile_intrinsic(id): @@ -475,7 +475,7 @@ def end_profile_intrinsic(id): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.end_profile_intrinsic", id) + return call_intrin("handle", "tirx.end_profile_intrinsic", id) def tvm_tuple(*value): @@ -491,7 +491,7 @@ def tvm_tuple(*value): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.tvm_tuple", *value) + return call_intrin("handle", "tirx.tvm_tuple", *value) def handle_add_byte_offset(handle, offset): @@ -510,7 +510,7 @@ def handle_add_byte_offset(handle, offset): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.handle_add_byte_offset", handle, offset) + return call_intrin("handle", "tirx.handle_add_byte_offset", handle, offset) def tvm_struct_get(arr, index, field, dtype): @@ -535,7 +535,7 @@ def tvm_struct_get(arr, index, field, dtype): call : PrimExpr The call expression. """ - return call_intrin(dtype, "tir.tvm_struct_get", arr, index, field) + return call_intrin(dtype, "tirx.tvm_struct_get", arr, index, field) def tvm_struct_set(arr, index, field, value): @@ -560,7 +560,7 @@ def tvm_struct_set(arr, index, field, value): call : PrimExpr The call expression. """ - return call_intrin("int32", "tir.tvm_struct_set", arr, index, field, value) + return call_intrin("int32", "tirx.tvm_struct_set", arr, index, field, value) def address_of(obj: Buffer | BufferLoad, span: Span | None = None) -> PrimExpr: @@ -582,9 +582,9 @@ def address_of(obj: Buffer | BufferLoad, span: Span | None = None) -> PrimExpr: if isinstance(obj, Buffer): n_dim = len(obj.shape) buffer_load = BufferLoad(obj, [0] * n_dim) - return call_intrin("handle", "tir.address_of", buffer_load, span=span) + return call_intrin("handle", "tirx.address_of", buffer_load, span=span) elif isinstance(obj, BufferLoad): - return call_intrin("handle", "tir.address_of", obj, span=span) + return call_intrin("handle", "tirx.address_of", obj, span=span) else: raise ValueError(f"Invalid object type: {type(obj)}") @@ -605,7 +605,7 @@ def lookup_param(param_name, span=None): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.lookup_param", param_name, span=span) + return call_intrin("handle", "tirx.lookup_param", param_name, span=span) def tvm_thread_allreduce(*freduce_args): @@ -621,7 +621,7 @@ def tvm_thread_allreduce(*freduce_args): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.tvm_thread_allreduce", *freduce_args) + return call_intrin("handle", "tirx.tvm_thread_allreduce", *freduce_args) def tvm_thread_invariant(cond): @@ -638,7 +638,7 @@ def tvm_thread_invariant(cond): The call expression. """ assert isinstance(cond, PrimExpr) - return call_intrin(cond.dtype, "tir.tvm_thread_invariant", cond) + return call_intrin(cond.dtype, "tirx.tvm_thread_invariant", cond) def tvm_storage_sync(storage_scope): @@ -654,7 +654,7 @@ def tvm_storage_sync(storage_scope): call : PrimExpr The call expression. """ - return call_intrin("int32", "tir.tvm_storage_sync", storage_scope) + return call_intrin("int32", "tirx.tvm_storage_sync", storage_scope) def tvm_warp_shuffle(mask, value, warp_id, width, warp_size): @@ -678,7 +678,7 @@ def tvm_warp_shuffle(mask, value, warp_id, width, warp_size): call : PrimExpr The call expression. """ - return call_intrin(value.dtype, "tir.tvm_warp_shuffle", mask, value, warp_id, width, warp_size) + return call_intrin(value.dtype, "tirx.tvm_warp_shuffle", mask, value, warp_id, width, warp_size) def tvm_warp_shuffle_up(mask, value, offset, width, warp_size): @@ -704,7 +704,7 @@ def tvm_warp_shuffle_up(mask, value, offset, width, warp_size): The call expression. """ return call_intrin( - value.dtype, "tir.tvm_warp_shuffle_up", mask, value, offset, width, warp_size + value.dtype, "tirx.tvm_warp_shuffle_up", mask, value, offset, width, warp_size ) @@ -731,7 +731,7 @@ def tvm_warp_shuffle_down(mask, value, offset, width, warp_size): The call expression. """ return call_intrin( - value.dtype, "tir.tvm_warp_shuffle_down", mask, value, offset, width, warp_size + value.dtype, "tirx.tvm_warp_shuffle_down", mask, value, offset, width, warp_size ) @@ -743,7 +743,7 @@ def tvm_warp_activemask(): call : PrimExpr The call expression. """ - return call_intrin("uint32", "tir.tvm_warp_activemask") + return call_intrin("uint32", "tirx.tvm_warp_activemask") def type_annotation(dtype): @@ -759,7 +759,7 @@ def type_annotation(dtype): call : PrimExpr The call expression. """ - return call_intrin(dtype, "tir.type_annotation") + return call_intrin(dtype, "tirx.type_annotation") def tvm_access_ptr(ptype, data, offset, extent, rw_mask): @@ -787,7 +787,7 @@ def tvm_access_ptr(ptype, data, offset, extent, rw_mask): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.tvm_access_ptr", ptype, data, offset, extent, rw_mask) + return call_intrin("handle", "tirx.tvm_access_ptr", ptype, data, offset, extent, rw_mask) def tvm_throw_last_error(): @@ -798,7 +798,7 @@ def tvm_throw_last_error(): ret : PrimExpr The return expression """ - return call_intrin("handle", "tir.tvm_throw_last_error") + return call_intrin("handle", "tirx.tvm_throw_last_error") def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): @@ -837,7 +837,7 @@ def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): """ return call_intrin( "handle", - "tir.tvm_load_matrix_sync", + "tirx.tvm_load_matrix_sync", fragment, m, n, @@ -887,7 +887,7 @@ def tvm_mma_sync( """ return call_intrin( "handle", - "tir.tvm_mma_sync", + "tirx.tvm_mma_sync", fragment_d, index_d, fragment_a, @@ -937,7 +937,7 @@ def tvm_bmma_sync( """ return call_intrin( "handle", - "tir.tvm_bmma_sync", + "tirx.tvm_bmma_sync", fragment_d, index_d, fragment_a, @@ -979,7 +979,7 @@ def tvm_fill_fragment(fragment, m, n, k, index, value): """ return call_intrin( "handle", - "tir.tvm_fill_fragment", + "tirx.tvm_fill_fragment", fragment, m, n, @@ -1025,7 +1025,7 @@ def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): """ return call_intrin( "handle", - "tir.tvm_store_matrix_sync", + "tirx.tvm_store_matrix_sync", fragment, m, n, @@ -1112,7 +1112,7 @@ def ptx_mma( if operator is None: return call_intrin( dtype, - "tir.ptx_mma", + "tirx.ptx_mma", shape, A_layout, B_layout, @@ -1129,7 +1129,7 @@ def ptx_mma( ) return call_intrin( dtype, - "tir.ptx_mma", + "tirx.ptx_mma", shape, A_layout, B_layout, @@ -1229,7 +1229,7 @@ def ptx_mma_sp( """ return call_intrin( dtype, - "tir.ptx_mma_sp", + "tirx.ptx_mma_sp", shape, A_layout, B_layout, @@ -1282,7 +1282,7 @@ def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride): """ return call_intrin( dtype, - "tir.mma_store", + "tirx.mma_store", m, n, dst_ptr, @@ -1316,7 +1316,7 @@ def mma_fill(dtype, local_size, local_ptr, offset): """ return call_intrin( dtype, - "tir.mma_fill", + "tirx.mma_fill", local_size, local_ptr, offset, @@ -1360,7 +1360,7 @@ def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, sme """ return call_intrin( dtype, - "tir.ptx_ldmatrix", + "tirx.ptx_ldmatrix", trans, num, type, @@ -1402,7 +1402,7 @@ def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, by """ return call_intrin( dtype, - "tir.ptx_cp_async", + "tirx.ptx_cp_async", shared_ptr, shared_offset, global_ptr, @@ -1447,7 +1447,7 @@ def ptx_cp_async_bulk( """ return call_intrin( dtype, - "tir.ptx_cp_async_bulk", + "tirx.ptx_cp_async_bulk", shared_ptr, shared_offset, global_ptr, @@ -1466,7 +1466,7 @@ def ptx_commit_group(): call : PrimExpr The call expression. """ - return call_intrin("", "tir.ptx_commit_group") + return call_intrin("", "tirx.ptx_commit_group") def ptx_wait_group(num): @@ -1483,7 +1483,7 @@ def ptx_wait_group(num): call : PrimExpr The call expression. """ - return call_intrin("", "tir.ptx_wait_group", num) + return call_intrin("", "tirx.ptx_wait_group", num) def ptx_cp_async_barrier(barrier_id): @@ -1500,7 +1500,7 @@ def ptx_cp_async_barrier(barrier_id): call : PrimExpr The call expression. """ - return call_intrin("", "tir.ptx_cp_async_barrier", barrier_id) + return call_intrin("", "tirx.ptx_cp_async_barrier", barrier_id) def ptx_init_barrier_thread_count(barrier_id, thread_count): @@ -1520,7 +1520,7 @@ def ptx_init_barrier_thread_count(barrier_id, thread_count): call : PrimExpr The call expression. """ - return call_intrin("", "tir.ptx_init_barrier_thread_count", barrier_id, thread_count) + return call_intrin("", "tirx.ptx_init_barrier_thread_count", barrier_id, thread_count) def ptx_arrive_barrier(barrier_id): @@ -1537,7 +1537,7 @@ def ptx_arrive_barrier(barrier_id): call : PrimExpr The call expression. """ - return call_intrin("", "tir.ptx_arrive_barrier", barrier_id) + return call_intrin("", "tirx.ptx_arrive_barrier", barrier_id) def ptx_arrive_barrier_expect_tx(barrier_id, byte_count): @@ -1559,7 +1559,7 @@ def ptx_arrive_barrier_expect_tx(barrier_id, byte_count): call : PrimExpr The call expression. """ - return call_intrin("", "tir.ptx_arrive_barrier_expect_tx", barrier_id, byte_count) + return call_intrin("", "tirx.ptx_arrive_barrier_expect_tx", barrier_id, byte_count) def ptx_wait_barrier(barrier_id): @@ -1576,7 +1576,7 @@ def ptx_wait_barrier(barrier_id): call : PrimExpr The call expression. """ - return call_intrin("", "tir.ptx_wait_barrier", barrier_id) + return call_intrin("", "tirx.ptx_wait_barrier", barrier_id) def create_barriers(barrier_count): @@ -1592,7 +1592,7 @@ def create_barriers(barrier_count): call : PrimExpr The call expression. """ - return call_intrin("", "tir.create_barriers", barrier_count) + return call_intrin("", "tirx.create_barriers", barrier_count) def make_filled_simdgroup_matrix( @@ -1626,7 +1626,7 @@ def make_filled_simdgroup_matrix( call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.make_filled_simdgroup_matrix", d, index, value, col, row) + return call_intrin("handle", "tirx.make_filled_simdgroup_matrix", d, index, value, col, row) def simdgroup_load( @@ -1670,7 +1670,7 @@ def simdgroup_load( """ return call_intrin( "handle", - "tir.simdgroup_load", + "tirx.simdgroup_load", d, index, ptr, @@ -1723,7 +1723,7 @@ def simdgroup_store( """ return call_intrin( "handle", - "tir.simdgroup_store", + "tirx.simdgroup_store", d, index, ptr, @@ -1780,7 +1780,7 @@ def simdgroup_multiply_accumulate( """ return call_intrin( "handle", - "tir.simdgroup_multiply_accumulate", + "tirx.simdgroup_multiply_accumulate", d, index_d, a, @@ -1808,7 +1808,7 @@ def vectorlow(dtype, vec): call : PrimExpr The call expression. """ - return call_intrin(dtype, "tir.vectorlow", vec) + return call_intrin(dtype, "tirx.vectorlow", vec) def vectorhigh(dtype, vec): @@ -1827,7 +1827,7 @@ def vectorhigh(dtype, vec): call : PrimExpr The call expression. """ - return call_intrin(dtype, "tir.vectorhigh", vec) + return call_intrin(dtype, "tirx.vectorhigh", vec) def vectorcombine(dtype, vec1, vec2): @@ -1846,7 +1846,7 @@ def vectorcombine(dtype, vec1, vec2): call : PrimExpr The call expression. """ - return call_intrin(dtype, "tir.vectorcombine", vec1, vec2) + return call_intrin(dtype, "tirx.vectorcombine", vec1, vec2) def dp4a(vec1, vec2, acc=0): @@ -1868,16 +1868,16 @@ def dp4a(vec1, vec2, acc=0): call : PrimExpr The call expression. """ - return call_intrin("int32", "tir.dp4a", vec1, vec2, acc) + return call_intrin("int32", "tirx.dp4a", vec1, vec2, acc) def ret(val, span=None): - """Create a tir return expression + """Create a tirx return expression Parameters ---------- val : Expr - The returned tir expression, whose data type is int, float or void pointer. + The returned tirx expression, whose data type is int, float or void pointer. span : Optional[Span] The location of this operator in the source code. @@ -1908,7 +1908,7 @@ def thread_return(span=None): def continue_loop(span=None): - """Create a tir intrinsic call to represent continue expression + """Create a tirx intrinsic call to represent continue expression Parameters ---------- @@ -1925,7 +1925,7 @@ def continue_loop(span=None): def break_loop(span=None): - """Create a tir intrinsic call to represent break expression + """Create a tirx intrinsic call to represent break expression Parameters ---------- @@ -2022,13 +2022,13 @@ def trace(args, trace_action="tvm.default_trace_action"): See Also -------- - tvm.tir.call_packed : Creates packed function. + tvm.tirx.call_packed : Creates packed function. """ if not isinstance(args, list): - raise Exception("tvm.tir.trace consumes the args as list type") + raise Exception("tvm.tirx.trace consumes the args as list type") call_args = [_pack_buffer(x) if isinstance(x, Buffer) else x for x in args] call_args.insert(0, trace_action) - return tvm.tir.Call(args[-1].dtype, Op.get("tir.tvm_call_trace_packed"), call_args) + return tvm.tirx.Call(args[-1].dtype, Op.get("tirx.tvm_call_trace_packed"), call_args) def min_value(dtype, span=None): @@ -2123,10 +2123,10 @@ def exp(x): y : PrimExpr The result. """ - x = tir.convert(x) + x = tirx.convert(x) if "int" in x.dtype: - x = tir.Cast("float32", x) - return call_intrin(x.dtype, "tir.exp", x) + x = tirx.Cast("float32", x) + return call_intrin(x.dtype, "tirx.exp", x) def exp2(x): @@ -2142,8 +2142,8 @@ def exp2(x): y : PrimExpr The result. """ - x = tir.convert(x) - return call_intrin(x.dtype, "tir.exp2", x) + x = tirx.convert(x) + return call_intrin(x.dtype, "tirx.exp2", x) def exp10(x): @@ -2159,8 +2159,8 @@ def exp10(x): y : PrimExpr The result. """ - x = tir.convert(x) - return call_intrin(x.dtype, "tir.exp10", x) + x = tirx.convert(x) + return call_intrin(x.dtype, "tirx.exp10", x) def erf(x): @@ -2176,8 +2176,8 @@ def erf(x): y : PrimExpr The result. """ - x = tir.convert(x) - return call_intrin(x.dtype, "tir.erf", x) + x = tirx.convert(x) + return call_intrin(x.dtype, "tirx.erf", x) def tanh(x): @@ -2194,7 +2194,7 @@ def tanh(x): The result. """ x = _require_float_arg("tanh", x) - return call_intrin(x.dtype, "tir.tanh", x) + return call_intrin(x.dtype, "tirx.tanh", x) def sigmoid(x): @@ -2210,8 +2210,8 @@ def sigmoid(x): y : PrimExpr The result. """ - x = tir.convert(x) - return call_intrin(x.dtype, "tir.sigmoid", x) + x = tirx.convert(x) + return call_intrin(x.dtype, "tirx.sigmoid", x) def log(x): @@ -2227,8 +2227,8 @@ def log(x): y : PrimExpr The result. """ - x = tir.convert(x) - return call_intrin(x.dtype, "tir.log", x) + x = tirx.convert(x) + return call_intrin(x.dtype, "tirx.log", x) def log2(x): @@ -2244,8 +2244,8 @@ def log2(x): y : PrimExpr The result. """ - x = tir.convert(x) - return call_intrin(x.dtype, "tir.log2", x) + x = tirx.convert(x) + return call_intrin(x.dtype, "tirx.log2", x) def log10(x): @@ -2261,8 +2261,8 @@ def log10(x): y : PrimExpr The result. """ - x = tir.convert(x) - return call_intrin(x.dtype, "tir.log10", x) + x = tirx.convert(x) + return call_intrin(x.dtype, "tirx.log10", x) def log1p(x): @@ -2278,8 +2278,8 @@ def log1p(x): y : PrimExpr The result. """ - x = tir.convert(x) - return call_intrin(x.dtype, "tir.log1p", x) + x = tirx.convert(x) + return call_intrin(x.dtype, "tirx.log1p", x) def tan(x): @@ -2296,7 +2296,7 @@ def tan(x): The result. """ x = _require_float_arg("tan", x) - return call_intrin(x.dtype, "tir.tan", x) + return call_intrin(x.dtype, "tirx.tan", x) def cos(x): @@ -2313,7 +2313,7 @@ def cos(x): The result. """ x = _require_float_arg("cos", x) - return call_intrin(x.dtype, "tir.cos", x) + return call_intrin(x.dtype, "tirx.cos", x) def cosh(x): @@ -2330,7 +2330,7 @@ def cosh(x): The result. """ x = _require_float_arg("cosh", x) - return call_intrin(x.dtype, "tir.cosh", x) + return call_intrin(x.dtype, "tirx.cosh", x) def acos(x): @@ -2347,7 +2347,7 @@ def acos(x): The result. """ x = _require_float_arg("acos", x) - return call_intrin(x.dtype, "tir.acos", x) + return call_intrin(x.dtype, "tirx.acos", x) def acosh(x): @@ -2364,7 +2364,7 @@ def acosh(x): The result. """ x = _require_float_arg("acosh", x) - return call_intrin(x.dtype, "tir.acosh", x) + return call_intrin(x.dtype, "tirx.acosh", x) def sin(x): @@ -2381,7 +2381,7 @@ def sin(x): The result. """ x = _require_float_arg("sin", x) - return call_intrin(x.dtype, "tir.sin", x) + return call_intrin(x.dtype, "tirx.sin", x) def sinh(x): @@ -2398,7 +2398,7 @@ def sinh(x): The result. """ x = _require_float_arg("sinh", x) - return call_intrin(x.dtype, "tir.sinh", x) + return call_intrin(x.dtype, "tirx.sinh", x) def asin(x): @@ -2415,7 +2415,7 @@ def asin(x): The result. """ x = _require_float_arg("asin", x) - return call_intrin(x.dtype, "tir.asin", x) + return call_intrin(x.dtype, "tirx.asin", x) def asinh(x): @@ -2432,7 +2432,7 @@ def asinh(x): The result. """ x = _require_float_arg("asinh", x) - return call_intrin(x.dtype, "tir.asinh", x) + return call_intrin(x.dtype, "tirx.asinh", x) def atan(x): @@ -2449,7 +2449,7 @@ def atan(x): The result. """ x = _require_float_arg("atan", x) - return call_intrin(x.dtype, "tir.atan", x) + return call_intrin(x.dtype, "tirx.atan", x) def atanh(x): @@ -2466,7 +2466,7 @@ def atanh(x): The result. """ x = _require_float_arg("atanh", x) - return call_intrin(x.dtype, "tir.atanh", x) + return call_intrin(x.dtype, "tirx.atanh", x) def atan2(x1, x2): @@ -2485,9 +2485,9 @@ def atan2(x1, x2): y : PrimExpr The result. """ - x1 = tir.convert(x1) - x2 = tir.convert(x2) - return call_intrin(x1.dtype, "tir.atan2", x1, x2) + x1 = tirx.convert(x1) + x2 = tirx.convert(x2) + return call_intrin(x1.dtype, "tirx.atan2", x1, x2) def sqrt(x): @@ -2503,8 +2503,8 @@ def sqrt(x): y : PrimExpr The result. """ - x = tir.convert(x) - return call_intrin(x.dtype, "tir.sqrt", x) + x = tirx.convert(x) + return call_intrin(x.dtype, "tirx.sqrt", x) def rsqrt(x): @@ -2520,8 +2520,8 @@ def rsqrt(x): y : PrimExpr The result. """ - x = tir.convert(x) - return call_intrin(x.dtype, "tir.rsqrt", x) + x = tirx.convert(x) + return call_intrin(x.dtype, "tirx.rsqrt", x) def clz(x): @@ -2538,7 +2538,7 @@ def clz(x): y : PrimExpr The result. """ - return call_intrin("int32", "tir.clz", x) + return call_intrin("int32", "tirx.clz", x) def floor(x: PrimExprWithOp, span=None): @@ -2766,9 +2766,9 @@ def nextafter(x1, x2): y : PrimExpr The result. """ - x1 = tir.convert(x1) - x2 = tir.convert(x2) - return call_intrin(x1.dtype, "tir.nextafter", x1, x2) # type: ignore + x1 = tirx.convert(x1) + x2 = tirx.convert(x2) + return call_intrin(x1.dtype, "tirx.nextafter", x1, x2) # type: ignore def hypot(x1, x2): @@ -2787,9 +2787,9 @@ def hypot(x1, x2): y : PrimExpr The result. """ - x1 = tir.convert(x1) - x2 = tir.convert(x2) - return call_intrin(x1.dtype, "tir.hypot", x1, x2) # type: ignore + x1 = tirx.convert(x1) + x2 = tirx.convert(x2) + return call_intrin(x1.dtype, "tirx.hypot", x1, x2) # type: ignore def copysign(x1, x2): @@ -2808,9 +2808,9 @@ def copysign(x1, x2): y : PrimExpr The result. """ - x1 = tir.convert(x1) - x2 = tir.convert(x2) - return call_intrin(x1.dtype, "tir.copysign", x1, x2) # type: ignore + x1 = tirx.convert(x1) + x2 = tirx.convert(x2) + return call_intrin(x1.dtype, "tirx.copysign", x1, x2) # type: ignore def ldexp(x1, x2): @@ -2829,9 +2829,9 @@ def ldexp(x1, x2): y : PrimExpr The result. """ - x1 = tir.convert(x1) - x2 = tir.convert(x2) - return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore + x1 = tirx.convert(x1) + x2 = tirx.convert(x2) + return call_intrin(x1.dtype, "tirx.ldexp", x1, x2) # type: ignore def likely(cond, span=None): @@ -2889,7 +2889,7 @@ def isnullptr(x, span=None): y : PrimExpr The result. """ - return call_intrin("bool", "tir.isnullptr", x, span=span) # type: ignore + return call_intrin("bool", "tirx.isnullptr", x, span=span) # type: ignore def isfinite(x, span=None): @@ -2987,8 +2987,8 @@ def popcount(x): y : PrimExpr The result. """ - x = tir.convert(x) - return call_intrin(x.dtype, "tir.popcount", x) + x = tirx.convert(x) + return call_intrin(x.dtype, "tirx.popcount", x) def q_multiply_shift(x, y, q, s): @@ -3017,7 +3017,7 @@ def q_multiply_shift(x, y, q, s): y : PrimExpr The result. """ - return call_intrin("int32", "tir.q_multiply_shift", x, y, q, s) + return call_intrin("int32", "tirx.q_multiply_shift", x, y, q, s) def q_multiply_shift_per_axis( @@ -3055,7 +3055,7 @@ def q_multiply_shift_per_axis( """ return call_intrin( "int32", - "tir.q_multiply_shift_per_axis", + "tirx.q_multiply_shift_per_axis", x, y, ls, @@ -3119,9 +3119,9 @@ def fmod(x, y): z : PrimExpr The result. """ - x = tir.convert(x) - y = tir.convert(y) - return call_intrin(x.dtype, "tir.fmod", x, y) + x = tirx.convert(x) + y = tirx.convert(y) + return call_intrin(x.dtype, "tirx.fmod", x, y) def if_then_else(cond, t, f, span=None): @@ -3404,7 +3404,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"): n = te.var("n") m = te.var("m") mysum = te.comm_reducer(lambda x, y: x+y, - lambda t: tvm.tir.const(0, dtype=t), name="mysum") + lambda t: tvm.tirx.const(0, dtype=t), name="mysum") A = te.placeholder((n, m), name="A") k = te.reduce_axis((0, m), name="k") B = te.compute((n,), lambda i: mysum(A[i, k], axis=k), name="B") @@ -3423,9 +3423,9 @@ def _reduce_directly(*args): def _make_reduce(expr, axis, where=None, init=None): code = fcombine.__code__ assert fcombine.__code__.co_argcount == 2 - expr = tir.convert(expr) + expr = tirx.convert(expr) if init is not None: - init = tir.convert(init) + init = tirx.convert(init) if isinstance(expr, Array): size = len(expr) lhs = [] @@ -3459,18 +3459,20 @@ def _make_reduce(expr, axis, where=None, init=None): if not isinstance(axis, list | tuple | tvm.ir.Array): axis = [axis] if where is None: - where = tir.convert(True) + where = tirx.convert(True) if init is None: - outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i, []) for i in range(size)) + outputs = tuple( + tvm.tirx.Reduce(combiner, expr, axis, where, i, []) for i in range(size) + ) else: outputs = tuple( - tvm.tir.Reduce(combiner, expr, axis, where, i, init) for i in range(size) + tvm.tirx.Reduce(combiner, expr, axis, where, i, init) for i in range(size) ) return outputs[0] if size == 1 else outputs # pylint: disable=keyword-arg-before-vararg def reducer(expr, axis, where=None, init=None, *args): - if isinstance(axis, tvm.tir.IterVar | list | tuple): + if isinstance(axis, tvm.tirx.IterVar | list | tuple): assert not args return _make_reduce(expr, axis, where, init) @@ -3510,7 +3512,7 @@ def reducer(expr, axis, where=None, init=None, *args): # there are two way to use this {0} reducer: # mode 1, accept (expr, axis, where) to produce an Reduce Expr - # tvm.{0} represents tvm.te.{0} or tvm.tir.{0}. + # tvm.{0} represents tvm.te.{0} or tvm.tirx.{0}. B = te.compute((m,), lambda i: tvm.{0}(A[i, k], axis=k), name="B") # mode 2, simply use it with multiple Exprs: @@ -3547,7 +3549,7 @@ def TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, dt """ return call_intrin( "handle", - "tir.TVMBackendAllocWorkspace", + "tirx.TVMBackendAllocWorkspace", device_type, device_id, nbytes, @@ -3575,7 +3577,7 @@ def TVMBackendFreeWorkspace(device_type, device_id, ptr): call : PrimExpr The call expression. """ - return call_intrin("int32", "tir.TVMBackendFreeWorkspace", device_type, device_id, ptr) + return call_intrin("int32", "tirx.TVMBackendFreeWorkspace", device_type, device_id, ptr) def anylist_getitem(list_handle, index): @@ -3589,7 +3591,7 @@ def anylist_getitem(list_handle, index): call : PrimExpr The call expression. """ - return call_intrin("handle", "tir.anylist_getitem", list_handle, index) + return call_intrin("handle", "tirx.anylist_getitem", list_handle, index) def anylist_resetitem(list_handle, index): @@ -3603,7 +3605,7 @@ def anylist_resetitem(list_handle, index): call : PrimExpr The call expression. """ - return call_intrin("int", "tir.anylist_resetitem", list_handle, index) + return call_intrin("int", "tirx.anylist_resetitem", list_handle, index) def anylist_setitem_call_packed(list_handle, index, func_name, *args): @@ -3622,7 +3624,7 @@ def anylist_setitem_call_packed(list_handle, index, func_name, *args): The call expression. """ return call_intrin( - "int", "tir.anylist_setitem_call_packed", list_handle, index, func_name, *args + "int", "tirx.anylist_setitem_call_packed", list_handle, index, func_name, *args ) @@ -3642,7 +3644,7 @@ def anylist_setitem_call_cpacked(list_handle, index, func_name, *args): The call expression. """ return call_intrin( - "int", "tir.anylist_setitem_call_cpacked", list_handle, index, func_name, *args + "int", "tirx.anylist_setitem_call_cpacked", list_handle, index, func_name, *args ) @@ -3654,7 +3656,7 @@ def vscale(): call : PrimExpr Call to the vscale intrinsic """ - return call_intrin("int32", "tir.vscale") + return call_intrin("int32", "tirx.vscale") def get_active_lane_mask(dtype, base, limit): @@ -3675,7 +3677,7 @@ def get_active_lane_mask(dtype, base, limit): limit : PrimExpr An expression representing the limit. """ - return call_intrin(dtype, "tir.get_active_lane_mask", base, limit) + return call_intrin(dtype, "tirx.get_active_lane_mask", base, limit) def get_vscale_expr(dtype: str | tvm_ffi.dtype, min_size: int = 128) -> PrimExpr: @@ -3703,7 +3705,7 @@ def ignore_loop_partition(predicate) -> PrimExpr: predicate : PrimExpr The annotated predicate expression. """ - return call_intrin("bool", "tir.ignore_loop_partition", predicate) + return call_intrin("bool", "tirx.ignore_loop_partition", predicate) # pylint: disable=unnecessary-lambda diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tirx/pipeline.py similarity index 87% rename from python/tvm/tir/pipeline.py rename to python/tvm/tirx/pipeline.py index 3cd633c99a4a..24a6625d1a0f 100644 --- a/python/tvm/tir/pipeline.py +++ b/python/tvm/tirx/pipeline.py @@ -19,15 +19,15 @@ """The TIR backend compilation pipeline.""" import tvm -from tvm import tir +from tvm import tirx def finalize_host_passes(): # pylint: disable=unused-argument """The default finalization passes for TIR backend.""" host_pass_list = [ - tir.transform.LowerTVMBuiltin(), - tir.transform.LowerCustomDatatypes(), - tir.transform.LowerIntrin(), + tirx.transform.LowerTVMBuiltin(), + tirx.transform.LowerCustomDatatypes(), + tirx.transform.LowerIntrin(), ] return tvm.ir.transform.Sequential(host_pass_list) @@ -35,10 +35,10 @@ def finalize_host_passes(): # pylint: disable=unused-argument def finalize_device_passes(): # pylint: disable=unused-argument """The default finalization passes for TIR backend.""" device_pass_list = [ - tir.transform.LowerWarpMemory(), - tir.transform.Simplify(), - tir.transform.LowerCustomDatatypes(), - tir.transform.LowerIntrin(), + tirx.transform.LowerWarpMemory(), + tirx.transform.Simplify(), + tirx.transform.LowerCustomDatatypes(), + tirx.transform.LowerIntrin(), ] return tvm.ir.transform.Sequential(device_pass_list) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tirx/stmt.py similarity index 94% rename from python/tvm/tir/stmt.py rename to python/tvm/tirx/stmt.py index 0ffe31b4f796..8539ea819dab 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tirx/stmt.py @@ -20,10 +20,10 @@ .. code-block:: python - x = tvm.tir.Var("n", "int32") - buffer = tvm.tir.decl_buffer((16,), "float32") - st = tvm.tir.stmt.BufferStore(buffer, 1, (x,)) - assert isinstance(st, tvm.tir.stmt.BufferStore) + x = tvm.tirx.Var("n", "int32") + buffer = tvm.tirx.decl_buffer((16,), "float32") + st = tvm.tirx.stmt.BufferStore(buffer, 1, (x,)) + assert isinstance(st, tvm.tirx.stmt.BufferStore) assert(st.buffer == buffer) """ @@ -44,7 +44,7 @@ class Stmt(Object, Scriptable): """Base class of all the statements.""" -@tvm_ffi.register_object("tir.Bind") +@tvm_ffi.register_object("tirx.Bind") class Bind(Stmt): """Bind node. @@ -78,7 +78,7 @@ def __init__(self, var: Var, value: PrimExpr, span: Span | None = None) -> None: ) -@tvm_ffi.register_object("tir.AssertStmt") +@tvm_ffi.register_object("tirx.AssertStmt") class AssertStmt(Stmt): """AssertStmt node. @@ -136,7 +136,7 @@ class ForKind(IntEnum): THREAD_BINDING = 4 # pylint: disable=invalid-name -@tvm_ffi.register_object("tir.For") +@tvm_ffi.register_object("tirx.For") class For(Stmt): """For node. @@ -157,7 +157,7 @@ class For(Stmt): body : Stmt The body statement. - thread_binding: Optional[tir.IterVar] + thread_binding: Optional[tirx.IterVar] The thread this loop binds to. Only valid if kind is ThreadBinding @@ -208,7 +208,7 @@ def __init__( ) -@tvm_ffi.register_object("tir.While") +@tvm_ffi.register_object("tirx.While") class While(Stmt): """While node. @@ -232,7 +232,7 @@ def __init__(self, condition: PrimExpr, body: Stmt, span: Span | None = None) -> self.__init_handle_by_constructor__(_ffi_api.While, condition, body, span) # type: ignore -@tvm_ffi.register_object("tir.BufferStore") +@tvm_ffi.register_object("tirx.BufferStore") class BufferStore(Stmt): """Buffer store node. @@ -280,7 +280,7 @@ def __init__( ) -@tvm_ffi.register_object("tir.AllocBuffer") +@tvm_ffi.register_object("tirx.AllocBuffer") class AllocBuffer(Stmt): """AllocBuffer node. @@ -310,7 +310,7 @@ def __init__( self.__init_handle_by_constructor__(_ffi_api.AllocBuffer, buffer, annotations, span) -@tvm_ffi.register_object("tir.DeclBuffer") +@tvm_ffi.register_object("tirx.DeclBuffer") class DeclBuffer(Stmt): """DeclBuffer node. @@ -330,7 +330,7 @@ def __init__(self, buffer: Buffer, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.DeclBuffer, buffer, span) -@tvm_ffi.register_object("tir.AttrStmt") +@tvm_ffi.register_object("tirx.AttrStmt") class AttrStmt(Stmt): """AttrStmt node. @@ -376,7 +376,7 @@ def __init__( ) -@tvm_ffi.register_object("tir.SeqStmt") +@tvm_ffi.register_object("tirx.SeqStmt") class SeqStmt(Stmt): """Sequence of statements. @@ -402,7 +402,7 @@ def __len__(self): return len(self.seq) -@tvm_ffi.register_object("tir.IfThenElse") +@tvm_ffi.register_object("tirx.IfThenElse") class IfThenElse(Stmt): """IfThenElse node. @@ -441,7 +441,7 @@ def __init__( ) -@tvm_ffi.register_object("tir.Evaluate") +@tvm_ffi.register_object("tirx.Evaluate") class Evaluate(Stmt): """Evaluate node. @@ -461,7 +461,7 @@ def __init__(self, value: PrimExpr, span: Span | None = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Evaluate, value, span) # type: ignore -@tvm_ffi.register_object("tir.BufferRegion") +@tvm_ffi.register_object("tirx.BufferRegion") class BufferRegion(Object, Scriptable): """BufferRegion node. @@ -481,7 +481,7 @@ def __init__(self, buffer: Buffer, region: list[Range]) -> None: self.__init_handle_by_constructor__(_ffi_api.BufferRegion, buffer, region) # type: ignore -@tvm_ffi.register_object("tir.MatchBufferRegion") +@tvm_ffi.register_object("tirx.MatchBufferRegion") class MatchBufferRegion(Object, Scriptable): """MatchBufferRegion node. @@ -505,7 +505,7 @@ def __init__(self, buffer: Buffer, source: BufferRegion) -> None: ) -@tvm_ffi.register_object("tir.SBlock") +@tvm_ffi.register_object("tirx.SBlock") class SBlock(Stmt): """SBlock node. @@ -587,7 +587,7 @@ def __init__( ) # type: ignore -@tvm_ffi.register_object("tir.SBlockRealize") +@tvm_ffi.register_object("tirx.SBlockRealize") class SBlockRealize(Stmt): """SBlockRealize node. diff --git a/python/tvm/tir/stmt_functor.py b/python/tvm/tirx/stmt_functor.py similarity index 93% rename from python/tvm/tir/stmt_functor.py rename to python/tvm/tirx/stmt_functor.py index 1d85935b996b..e058378a8de9 100644 --- a/python/tvm/tir/stmt_functor.py +++ b/python/tvm/tirx/stmt_functor.py @@ -24,13 +24,13 @@ def ir_transform(stmt, preorder, postorder, only_enable=None): Parameters ---------- - stmt : tvm.tir.Stmt + stmt : tvm.tirx.Stmt The input to be transformed. preorder: function The function called in before recursive mutation If preorder returns None, then the transform will proceed to recursive call. - If preorder returns a not None tvm.tir.Stmt/Expr, the transformer will simply return it and + If preorder returns a not None tvm.tirx.Stmt/Expr, the transformer will simply return it and won't do further recursion. postorder : function @@ -41,7 +41,7 @@ def ir_transform(stmt, preorder, postorder, only_enable=None): Returns ------- - result : tvm.tir.Stmt + result : tvm.tirx.Stmt The result. """ return _ffi_api.IRTransform(stmt, preorder, postorder, only_enable) # type: ignore @@ -84,7 +84,7 @@ def substitute(node, vmap): Returns ------- - result : tvm.tir.Stmt + result : tvm.tirx.Stmt The result. """ return _ffi_api.Substitute(node, vmap) # type: ignore diff --git a/python/tvm/tir/transform/__init__.py b/python/tvm/tirx/transform/__init__.py similarity index 100% rename from python/tvm/tir/transform/__init__.py rename to python/tvm/tirx/transform/__init__.py diff --git a/python/tvm/tir/analysis/_ffi_api.py b/python/tvm/tirx/transform/_ffi_api.py similarity index 90% rename from python/tvm/tir/analysis/_ffi_api.py rename to python/tvm/tirx/transform/_ffi_api.py index 32d36a63c12d..f932f38be492 100644 --- a/python/tvm/tir/analysis/_ffi_api.py +++ b/python/tvm/tirx/transform/_ffi_api.py @@ -14,8 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI APIs for tvm.tir.analysis""" +"""FFI APIs for tvm.tirx.transform""" import tvm_ffi -tvm_ffi.init_ffi_api("tir.analysis", __name__) +tvm_ffi.init_ffi_api("tirx.transform", __name__) diff --git a/python/tvm/tir/transform/function_pass.py b/python/tvm/tirx/transform/function_pass.py similarity index 92% rename from python/tvm/tir/transform/function_pass.py rename to python/tvm/tirx/transform/function_pass.py index bb79914c0fd0..e54e915796f7 100644 --- a/python/tvm/tir/transform/function_pass.py +++ b/python/tvm/tirx/transform/function_pass.py @@ -27,10 +27,10 @@ from . import _ffi_api -@tvm_ffi.register_object("tir.PrimFuncPass") +@tvm_ffi.register_object("tirx.PrimFuncPass") class PrimFuncPass(Pass): - """A pass that works on each :py:func:`tvm.tir.PrimFunc` in a module. A function - pass class should be created through py:func:`tvm.tir.transform.function_pass`. + """A pass that works on each :py:func:`tvm.tirx.PrimFunc` in a module. A function + pass class should be created through py:func:`tvm.tirx.transform.function_pass`. """ @@ -82,7 +82,7 @@ def prim_func_pass( Parameters ---------- - pass_func : Optional[Callable[(tvm.tir.PrimFunc, IRModule, PassContext) -> tvm.tir.PrimFunc]] + pass_func : Optional[Callable[(tvm.tirx.PrimFunc, IRModule, PassContext) -> tvm.tirx.PrimFunc]] The transformation function or class. opt_level : int @@ -111,7 +111,7 @@ def prim_func_pass( .. code-block:: python - @tvm.tir.transform.prim_func_pass(opt_level=1) + @tvm.tirx.transform.prim_func_pass(opt_level=1) class TestReplaceFunc: def __init__(self, new_func): self.new_func = new_func @@ -126,7 +126,7 @@ def transform_function(self, func, mod, ctx): .. code-block:: python - @tvm.tir.transform.prim_func_pass(opt_level=2) + @tvm.tirx.transform.prim_func_pass(opt_level=2) def transform(func, mod, ctx): # my transformations here. return func diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tirx/transform/transform.py similarity index 95% rename from python/tvm/tir/transform/transform.py rename to python/tvm/tirx/transform/transform.py index 7ba5cd5f3a1c..6e18558b0ecd 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tirx/transform/transform.py @@ -29,11 +29,11 @@ def Apply(ftransform): """Apply ftransform to each function in the Module. - This function is a thin wrapper around tvm.tir.transform.prim_func_pass + This function is a thin wrapper around tvm.tirx.transform.prim_func_pass Parameters ---------- - ftransform: tvm.tir.PrimFunc -> tvm.tir.PrimFunc + ftransform: tvm.tirx.PrimFunc -> tvm.tirx.PrimFunc The transformation pass. Returns @@ -106,7 +106,7 @@ def PointerValueTypeRewrite(): return _ffi_api.PointerValueTypeRewrite() # type: ignore -@_ffi.register_object("tir.transform.UnrollLoopConfig") +@_ffi.register_object("tirx.transform.UnrollLoopConfig") class UnrollLoopConfig(_ir.Attrs): """Config for unroll loop pass""" @@ -124,7 +124,7 @@ def UnrollLoop(): return _ffi_api.UnrollLoop() # type: ignore -@_ffi.register_object("tir.transform.RemoveNoOpConfig") +@_ffi.register_object("tirx.transform.RemoveNoOpConfig") class RemoveNoOpConfig(_ir.Attrs): """Config for remove no op pass""" @@ -211,7 +211,7 @@ def CommonSubexprElim(): return _ffi_api.CommonSubexprElim() # type: ignore -@_ffi.register_object("tir.transform.SimplifyConfig") +@_ffi.register_object("tirx.transform.SimplifyConfig") class SimplifyConfig(_ir.Attrs): """Config for simplify pass""" @@ -230,10 +230,10 @@ def Simplify(): def ConvertSSA(): """Convert an IRModule to be SSA form. - This pass handles cases where the same `tir.Var` appears in + This pass handles cases where the same `tirx.Var` appears in multiple functions within the same module. For example, after extracting a fragment from one function into another, where the - same `tir.Var` may be defined both as within the body of the + same `tirx.Var` may be defined both as within the body of the original function, and as a parameter within the hoisted function. Returns @@ -274,10 +274,10 @@ def MakePackedAPI(): the `DLTensor::shape` array is `[16,32]`.) For dynamic Buffers, in which one or more of these `BufferNode` member - variables use `tir.Var` that are not defined by other PrimFunc + variables use `tirx.Var` that are not defined by other PrimFunc parameters, these are instead used to define the variables based on the corresponding `DLTensor` members. (e.g. A PrimFunc that accepts a - buffer of shape `[tir.Var("n"), tir.Var("m")]`, when passed a + buffer of shape `[tirx.Var("n"), tirx.Var("m")]`, when passed a `DLTensor` of shape `[16,32]`, will define `n = 16` and `n=32`, based on the argument's shape. @@ -447,7 +447,7 @@ class HoistedConditionals(enum.Flag): """ If set, look for hoist candidates in IfElseStmt """ IfElseExpr = 2 - """ If set, look for hoist candidates in tir.if_then_else """ + """ If set, look for hoist candidates in tirx.if_then_else """ BooleanExpression = 4 """ If set, look for hoist candidates in all boolean expressions """ diff --git a/python/tvm/topi/generic_op_impl.py b/python/tvm/topi/generic_op_impl.py index 9968f2ebc096..a26e435df9fe 100644 --- a/python/tvm/topi/generic_op_impl.py +++ b/python/tvm/topi/generic_op_impl.py @@ -66,7 +66,7 @@ def _tensor_bop_impl(lhs, rhs): it performs tensor-scalar {op} operation on an element-wise basis. Otherwise, it performs default generic.{op} operation, as defined - in tvm.tir.generic module. + in tvm.tirx.generic module. Parameters ---------- @@ -93,13 +93,13 @@ def _bind_generic_ops(): """Bind generic operators for Tensor.""" # Check __op_priority__ to make sure the binding happens only once. __op_priority__ = 1 - if __op_priority__ > tvm.tir.generic.__op_priority__: - tvm.tir.generic.__op_priority__ = __op_priority__ - tvm.tir.generic.add = _make_bop(_broadcast.add, tvm.tir.generic.add) - tvm.tir.generic.subtract = _make_bop(_broadcast.subtract, tvm.tir.generic.subtract) - tvm.tir.generic.multiply = _make_bop(_broadcast.multiply, tvm.tir.generic.multiply) - tvm.tir.generic.divide = _make_bop(_broadcast.divide, tvm.tir.generic.divide) - tvm.tir.generic.cast = _math.cast + if __op_priority__ > tvm.tirx.generic.__op_priority__: + tvm.tirx.generic.__op_priority__ = __op_priority__ + tvm.tirx.generic.add = _make_bop(_broadcast.add, tvm.tirx.generic.add) + tvm.tirx.generic.subtract = _make_bop(_broadcast.subtract, tvm.tirx.generic.subtract) + tvm.tirx.generic.multiply = _make_bop(_broadcast.multiply, tvm.tirx.generic.multiply) + tvm.tirx.generic.divide = _make_bop(_broadcast.divide, tvm.tirx.generic.divide) + tvm.tirx.generic.cast = _math.cast _bind_generic_ops() diff --git a/python/tvm/topi/gpu/scan.py b/python/tvm/topi/gpu/scan.py index 7f83535be6fe..91fdadea9ee7 100644 --- a/python/tvm/topi/gpu/scan.py +++ b/python/tvm/topi/gpu/scan.py @@ -23,7 +23,7 @@ from tvm import te from tvm.contrib.thrust import can_use_rocthrust, can_use_thrust from tvm.script.ir_builder import IRBuilder -from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import tirx as T from ..math import cast, ceil_log2 from ..transform import expand_dims, reshape, squeeze, transpose @@ -31,7 +31,7 @@ def _get_thrust_func_name(tvmop): - tvmop_to_thrust_func_name = {tvm.tir.generic.add: "tvm.contrib.thrust.sum_scan"} + tvmop_to_thrust_func_name = {tvm.tirx.generic.add: "tvm.contrib.thrust.sum_scan"} assert tvmop in tvmop_to_thrust_func_name, f"{tvmop} not supported by thrust" return tvmop_to_thrust_func_name[tvmop] @@ -44,7 +44,7 @@ def _can_use_scan_thrust(binop): if target is None: return False # pylint: disable=comparison-with-callable - return binop == tvm.tir.generic.add and any( + return binop == tvm.tirx.generic.add and any( [ can_use_thrust(target, "tvm.contrib.thrust.sum_scan"), can_use_rocthrust(target, "tvm.contrib.thrust.sum_scan"), @@ -52,7 +52,7 @@ def _can_use_scan_thrust(binop): ) -def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, identity_value=0): +def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tirx.generic.add, identity_value=0): """Low level IR to do exclusive sum scan along rows of 2D input. Parameters @@ -68,7 +68,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i binop: function, optional A binary associative op to use for scan. The function takes two TIR expressions - and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute + and produce a new TIR expression. By default it uses tvm.tirx.generic.add to compute prefix sum. identity_value: int or float @@ -140,7 +140,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i T.attr( bx, "thread_extent", - tvm.tir.generic.cast( + tvm.tirx.generic.cast( ceil_div(scan_axis_size, max_threads * width), "int32" ), ), @@ -154,7 +154,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i start[0] = width * tid with T.If(start[0] < scan_axis_size): with T.Then(): - middle[0] = start[0] + tvm.tir.indexdiv(width, 2) + middle[0] = start[0] + tvm.tirx.indexdiv(width, 2) end[0] = tvm.te.min(start[0] + width, scan_axis_size) with T.If(middle[0] < scan_axis_size): with T.Then(): @@ -188,7 +188,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i T.attr( bx, "thread_extent", - tvm.tir.generic.cast( + tvm.tirx.generic.cast( ceil_div(scan_axis_size, max_threads * width), "int32" ), ), @@ -201,10 +201,10 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i end = T.buffer_proxy(end_buf) tmp = T.buffer_proxy(tmp_buf) start[0] = width * tid - with T.If(tvm.tir.all(start[0] < scan_axis_size)): + with T.If(tvm.tirx.all(start[0] < scan_axis_size)): with T.Then(): - middle[0] = start[0] + tvm.tir.indexdiv(width, 2) - end[0] = tvm.tir.min(start[0] + width, scan_axis_size) + middle[0] = start[0] + tvm.tirx.indexdiv(width, 2) + end[0] = tvm.tirx.min(start[0] + width, scan_axis_size) with T.If(middle[0] < scan_axis_size): with T.Then(): tmp[0] = output[by * scan_axis_size + middle[0] - 1] @@ -218,7 +218,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i return ib.get() -def get_reduction_from_exclusive_scan(data, ex_scan_output, binop=tvm.tir.generic.add): +def get_reduction_from_exclusive_scan(data, ex_scan_output, binop=tvm.tirx.generic.add): """Return the sum of the last element of data and the exclusive scan output. The is the reduction of data along each row (for 2-D case). @@ -232,7 +232,7 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output, binop=tvm.tir.generi binop: function, optional A binary associative op to use for scan. The function takes two TIR expressions - and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute + and produce a new TIR expression. By default it uses tvm.tirx.generic.add to compute prefix sum. Returns @@ -280,8 +280,8 @@ def ir(data_buf, data_ex_scan_buf, reduction_buf): return ib.get() - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "valid_indices_buf", data_alignment=8) - ex_scan_output_buf = tvm.tir.decl_buffer( + data_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "valid_indices_buf", data_alignment=8) + ex_scan_output_buf = tvm.tirx.decl_buffer( ex_scan_output.shape, ex_scan_output.dtype, "ex_scan_output_buf", data_alignment=8 ) @@ -306,7 +306,7 @@ def scan_thrust( output_dtype, exclusive=True, return_reduction=False, - binop=tvm.tir.generic.add, + binop=tvm.tirx.generic.add, workspace=None, ): """Do exclusive or inclusive scan on 1D or multidimensional input, using thrust. @@ -330,7 +330,7 @@ def scan_thrust( binop: function, optional A binary associative op to use for scan. Since we need to lookup the corresponding thrust function, arbitrariy callables are not supported. Currently only - tvm.tir.generic.add can be passed in. + tvm.tirx.generic.add can be passed in. workspace: Optional[tvm.te.Tensor] A buffer to store intermediate results. The size of the workspace should be sufficiently @@ -346,11 +346,11 @@ def scan_thrust( (N-1)-D tensor storing the reduction of each scan axis. Returned if return_reduction is True. """ - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8) + data_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + output_buf = tvm.tirx.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8) workspace_buf = ( - tvm.tir.decl_buffer(workspace.shape, workspace.dtype, "workspace_buf", data_alignment=8) + tvm.tirx.decl_buffer(workspace.shape, workspace.dtype, "workspace_buf", data_alignment=8) if workspace is not None else None ) @@ -359,7 +359,7 @@ def f_compute(ins, outs): args = [_get_thrust_func_name(binop), ins[0], outs[0], exclusive] if workspace is not None: args.append(ins[1]) - return tvm.tir.call_packed(*args) + return tvm.tirx.call_packed(*args) output = te.extern( [data.shape], @@ -385,7 +385,7 @@ def exclusive_scan( axis=-1, return_reduction=False, output_dtype=None, - binop=tvm.tir.generic.add, + binop=tvm.tirx.generic.add, identity_value=0, workspace=None, ): @@ -410,7 +410,7 @@ def exclusive_scan( binop: function, optional A binary associative op to use for scan. The function takes two TIR expressions - and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute + and produce a new TIR expression. By default it uses tvm.tirx.generic.add to compute prefix sum. identity_value: int or float @@ -449,8 +449,8 @@ def do_scan(data, output_dtype): # TIR exclusive scan accepts only 2D or higher-rank inputs. data = expand_dims(data, axis=0) - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - output_buf = tvm.tir.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8) + data_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + output_buf = tvm.tirx.decl_buffer(data.shape, output_dtype, "output_buf", data_alignment=8) if return_reduction: output, reduction = te.extern( @@ -518,7 +518,7 @@ def do_scan(data, output_dtype): def inclusive_scan( - data, axis=-1, output_dtype=None, binop=tvm.tir.generic.add, identity_value=0, workspace=None + data, axis=-1, output_dtype=None, binop=tvm.tirx.generic.add, identity_value=0, workspace=None ): """Do inclusive scan on 1D or multidimensional input. @@ -535,7 +535,7 @@ def inclusive_scan( binop: function, optional A binary associative op to use for scan. The function takes two TIR expressions - and produce a new TIR expression. By default it uses tvm.tir.generic.add to compute + and produce a new TIR expression. By default it uses tvm.tirx.generic.add to compute prefix sum. identity_value: int or float @@ -701,7 +701,7 @@ def cumsum( """ return scanop( data=data, - binop=tvm.tir.generic.add, + binop=tvm.tirx.generic.add, identity_value=0, axis=axis, dtype=dtype, @@ -751,7 +751,7 @@ def cumprod( """ return scanop( data=data, - binop=tvm.tir.generic.multiply, + binop=tvm.tirx.generic.multiply, identity_value=1, axis=axis, dtype=dtype, diff --git a/python/tvm/topi/gpu/sort.py b/python/tvm/topi/gpu/sort.py index fed493e40008..8f0e76b0aaff 100644 --- a/python/tvm/topi/gpu/sort.py +++ b/python/tvm/topi/gpu/sort.py @@ -21,7 +21,7 @@ import tvm from tvm import te from tvm.script.ir_builder import IRBuilder -from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import tirx as T from ..math import cast, ceil_log2 from ..searchsorted import binary_search @@ -99,7 +99,7 @@ def _odd_even_sort( tx, bx, by, ntx, nbx, nby = _get_threads(nthread_tx, nthread_bx, nthread_by) with T.frame_scope( [ - T.attr(tvm.tir.const(0), "hand_threaded", 0), + T.attr(tvm.tirx.const(0), "hand_threaded", 0), T.attr(tx, "thread_extent", ntx), T.attr(bx, "thread_extent", nbx), T.attr(by, "thread_extent", nby), @@ -136,11 +136,11 @@ def _odd_even_sort( [tid + n], ) - T.evaluate(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + T.evaluate(tvm.tirx.Call(None, "tirx.tvm_storage_sync", tvm.runtime.convert(["shared"]))) - idxm = tvm.tir.indexmod + idxm = tvm.tirx.indexmod # OddEvenTransposeSort - current_sort_num = tvm.tir.min(block_size, size - start) + current_sort_num = tvm.tirx.min(block_size, size - start) with T.serial(0, current_sort_num) as k: n = idxm(tid + k, 2) with T.If(tid + n < current_sort_num - 1): @@ -164,7 +164,9 @@ def _odd_even_sort( [tid + n], ) T.buffer_store(tmp_values_swap, temp_values[0], [tid + n + 1]) - T.evaluate(tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"]))) + T.evaluate( + tvm.tirx.Call(None, "tirx.tvm_storage_sync", tvm.runtime.convert(["shared"])) + ) ## Copy sorted data to output with T.serial(0, 2) as n: @@ -284,7 +286,7 @@ def serial_merge( j_idx = base_idx + j_buf[0] k_idx = base_idx + (kStart + diag + count) - with T.If(tvm.tir.all(i_buf[0] < aStart + aCount, j_buf[0] < bStart + bCount)): + with T.If(tvm.tirx.all(i_buf[0] < aStart + aCount, j_buf[0] < bStart + bCount)): with T.Then(): with T.If(compare(source[i_idx], source[j_idx])): with T.Then(): @@ -490,12 +492,12 @@ def dual_mergepath( target = tvm.target.Target.current() if "vulkan" in str(target): ntx = max_threads - nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32") - nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32") + nbx = tvm.tirx.generic.cast(ceil_div(width, max_threads * thread_work), "int32") + nbz = tvm.tirx.generic.cast(ceil_div(size, width), "int32") else: - ntx = tvm.tir.generic.cast(tvm.te.min(max_threads, width), "int32") - nbx = tvm.tir.generic.cast(ceil_div(width, max_threads * thread_work), "int32") - nbz = tvm.tir.generic.cast(ceil_div(size, width), "int32") + ntx = tvm.tirx.generic.cast(tvm.te.min(max_threads, width), "int32") + nbx = tvm.tirx.generic.cast(ceil_div(width, max_threads * thread_work), "int32") + nbz = tvm.tirx.generic.cast(ceil_div(size, width), "int32") tx, bx, by, _, _, _ = _get_threads(ntx, nbx, nthread_by * nbz) with T.frame_scope( @@ -511,12 +513,12 @@ def dual_mergepath( # calculate the start, mid, and end points of this section start_pos = width * bz - middle = cast(tvm.te.min(start_pos + tvm.tir.indexdiv(width, 2), size), target_dtype) + middle = cast(tvm.te.min(start_pos + tvm.tirx.indexdiv(width, 2), size), target_dtype) end = cast(tvm.te.min(start_pos + width, size), target_dtype) with T.If(start_pos < size): with T.Then(): - even = tvm.tir.indexmod(l2_width, 2) == 0 + even = tvm.tirx.indexmod(l2_width, 2) == 0 with T.If(nbx == 1): with T.Then(): ## merge the start->middle and middle->end arrays @@ -555,7 +557,9 @@ def dual_mergepath( ## if the final sorted data ended up in the swap, copy it to the real output nthread_bx = ceil_div(size, nthread_tx) - with T.If(tvm.tir.all(upper_lim > lower_lim, tvm.tir.indexmod(upper_lim - lower_lim, 2) == 1)): + with T.If( + tvm.tirx.all(upper_lim > lower_lim, tvm.tirx.indexmod(upper_lim - lower_lim, 2) == 1) + ): with T.Then(): tx2, bx2, by2, _, _, _ = _get_threads(nthread_tx, nthread_bx, nthread_by) with T.frame_scope( @@ -629,7 +633,7 @@ def sort_ir( indices_out, value_init_func=( lambda _, tid: ( - tvm.tir.generic.cast(tid, indices_out_orig.dtype) + tvm.tirx.generic.cast(tid, indices_out_orig.dtype) if indices_out is not None else None ) @@ -677,8 +681,10 @@ def sort(data, axis=-1, is_ascend=1): axes = swap(list(range(ndim)), axis) data = transpose(data, axes) - value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) - value_buf_swap = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf_swap", data_alignment=8) + value_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) + value_buf_swap = tvm.tirx.decl_buffer( + data.shape, data.dtype, "value_buf_swap", data_alignment=8 + ) out = te.extern( [data.shape, data.shape], @@ -731,14 +737,14 @@ def sort_thrust(data, axis=-1, is_ascend=1, workspace=None): axes = swap(list(range(ndim)), axis) data = transpose(data, axes) - value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) - indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) + value_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) + indices_buf = tvm.tirx.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) def f_compute(ins, outs): args = ["tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend] if workspace is not None: args.append(ins[1]) - return tvm.tir.call_packed(*args) + return tvm.tirx.call_packed(*args) out = te.extern( [data.shape, data.shape], @@ -793,10 +799,12 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32", ret_type="indices"): axes = swap(list(range(ndim)), axis) data = transpose(data, axes) - value_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) - value_swap_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "value_swap_buf", data_alignment=8) - indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) - indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8) + value_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) + value_swap_buf = tvm.tirx.decl_buffer( + data.shape, data.dtype, "value_swap_buf", data_alignment=8 + ) + indices_buf = tvm.tirx.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) + indices_swap_buf = tvm.tirx.decl_buffer(data.shape, dtype, "out_swap_buf", data_alignment=8) outs = te.extern( [data.shape, data.shape, data.shape, data.shape], @@ -901,12 +909,12 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): axes = swap(list(range(ndim)), axis) data = transpose(data, axes) - values_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8) - values_swap_buf = tvm.tir.decl_buffer( + values_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8) + values_swap_buf = tvm.tirx.decl_buffer( data.shape, data.dtype, "values_swap_buf", data_alignment=8 ) - indices_buf = tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8) - indices_swap_buf = tvm.tir.decl_buffer(data.shape, dtype, "indies_swap_buf", data_alignment=8) + indices_buf = tvm.tirx.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8) + indices_swap_buf = tvm.tirx.decl_buffer(data.shape, dtype, "indies_swap_buf", data_alignment=8) if ret_type == "values": output = te.extern( @@ -1006,23 +1014,23 @@ def topk_thrust( axes = swap(list(range(ndim)), axis) data = transpose(data, axes) - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + data_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) if workspace is not None: - workspace_buf = tvm.tir.decl_buffer( + workspace_buf = tvm.tirx.decl_buffer( workspace.shape, workspace.dtype, "workspace_buf", data_alignment=8 ) else: workspace_buf = None out_bufs = [ - tvm.tir.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8), - tvm.tir.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8), + tvm.tirx.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8), + tvm.tirx.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8), ] def f_compute(ins, outs): args = ["tvm.contrib.thrust.sort", ins[0], outs[0], outs[1], is_ascend] if workspace is not None: args.append(ins[1]) - return tvm.tir.call_packed(*args) + return tvm.tirx.call_packed(*args) is_ascend = 1 if is_ascend else 0 @@ -1036,7 +1044,7 @@ def f_compute(ins, outs): tag="topk_gpu", ) - if isinstance(k, tvm.tir.IntImm): + if isinstance(k, tvm.tirx.IntImm): k = k.value if not isinstance(k, int) or k > 0: diff --git a/python/tvm/topi/image/grid_sample.py b/python/tvm/topi/image/grid_sample.py index 3fd0d3699f0e..79032f41b3ae 100644 --- a/python/tvm/topi/image/grid_sample.py +++ b/python/tvm/topi/image/grid_sample.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name """affine_grid and grid_sample operator""" -from tvm import te, tir +from tvm import te, tirx def affine_grid(data, target_shape): @@ -47,9 +47,9 @@ def affine_grid(data, target_shape): ) dtype = data.dtype - y_step = tir.const((2.0 - 1e-7) / (target_shape[0] - 1), dtype=dtype) - x_step = tir.const((2.0 - 1e-7) / (target_shape[1] - 1), dtype=dtype) - start = tir.const(-1.0, dtype=dtype) + y_step = tirx.const((2.0 - 1e-7) / (target_shape[0] - 1), dtype=dtype) + x_step = tirx.const((2.0 - 1e-7) / (target_shape[1] - 1), dtype=dtype) + start = tirx.const(-1.0, dtype=dtype) def _compute(n, dim, i, j): y = start + i * y_step @@ -131,7 +131,7 @@ def _get_pixel_value(n, c, h, w): return te.if_then_else( te.all(h >= 0, w >= 0, h < in_height, w < in_width), data[n, c, h, w], - tir.const(0.0, dtype=data.dtype), + tirx.const(0.0, dtype=data.dtype), ) def _unnormalize(h, w): @@ -167,14 +167,14 @@ def __refelection(x, size, corner_start): def __reflect(index, size, corner_start): index_align_corner = te.abs(corner_start - index) size_times = te.truncdiv(index_align_corner.astype("int32"), size).astype("int32") - t = tir.Mod(size_times, 2) + t = tirx.Mod(size_times, 2) extra = index_align_corner - size_times * size - return tir.if_then_else( - tir.EQ(t, 0), extra + corner_start, size - extra + corner_start + return tirx.if_then_else( + tirx.EQ(t, 0), extra + corner_start, size - extra + corner_start ) - return tir.if_then_else( - tir.all(x >= corner_start, x <= size + corner_start), + return tirx.if_then_else( + tirx.all(x >= corner_start, x <= size + corner_start), x, __reflect(x, size, corner_start), ) @@ -189,8 +189,8 @@ def _bilinear_sample(n, c, h, w): y, x = _compute_source_index(n, h, w) y0 = te.floor(y).astype("int32") x0 = te.floor(x).astype("int32") - y1 = y0 + tir.const(1, "int32") - x1 = x0 + tir.const(1, "int32") + y1 = y0 + tirx.const(1, "int32") + x1 = x0 + tirx.const(1, "int32") return ( _get_pixel_value(n, c, y0, x0) * (1.0 - (y - y0)) * (1.0 - (x - x0)) @@ -361,7 +361,7 @@ def _get_pixel_value(n, c, d, h, w): return te.if_then_else( te.all(d >= 0, h >= 0, w >= 0, d < in_depth, h < in_height, w < in_width), data[n, c, d, h, w], - tir.const(0.0, dtype=data.dtype), + tirx.const(0.0, dtype=data.dtype), ) def _compute_source_index(n, d, h, w): @@ -400,14 +400,14 @@ def __refelection(x, size, corner_start): def __reflect(index, size, corner_start): index_align_corner = te.abs(corner_start - index) size_times = te.truncdiv(index_align_corner.astype("int32"), size).astype("int32") - t = tir.Mod(size_times, 2) + t = tirx.Mod(size_times, 2) extra = index_align_corner - size_times * size - return tir.if_then_else( - tir.EQ(t, 0), extra + corner_start, size - extra + corner_start + return tirx.if_then_else( + tirx.EQ(t, 0), extra + corner_start, size - extra + corner_start ) - return tir.if_then_else( - tir.all(x >= corner_start, x <= size + corner_start), + return tirx.if_then_else( + tirx.all(x >= corner_start, x <= size + corner_start), x, __reflect(x, size, corner_start), ) @@ -421,9 +421,9 @@ def _trilinear_sample(n, c, d, h, w): z0 = te.floor(z).astype("int32") y0 = te.floor(y).astype("int32") x0 = te.floor(x).astype("int32") - z1 = z0 + tir.const(1, "int32") - y1 = y0 + tir.const(1, "int32") - x1 = x0 + tir.const(1, "int32") + z1 = z0 + tirx.const(1, "int32") + y1 = y0 + tirx.const(1, "int32") + x1 = x0 + tirx.const(1, "int32") return ( _get_pixel_value(n, c, z0, y0, x0) * (1 - (x - x0)) * (1 - (y - y0)) * (1 - (z - z0)) diff --git a/python/tvm/topi/image/resize.py b/python/tvm/topi/image/resize.py index e8303daa0256..a0ff38f0f3ca 100644 --- a/python/tvm/topi/image/resize.py +++ b/python/tvm/topi/image/resize.py @@ -28,7 +28,7 @@ def can_convert_multiply_to_intdiv(origin_size, scaled_size): """Check whether can convert multiplication to division""" # Only support IntImm type - if not isinstance(scaled_size, tvm.tir.expr.IntImm): + if not isinstance(scaled_size, tvm.tirx.expr.IntImm): return False div = scaled_size / origin_size.astype("float") @@ -362,10 +362,10 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): if coordinate_transformation_mode == "tf_crop_and_resize": # use extrapolation_value if in_x is out of boundary - value = tvm.tir.if_then_else( + value = tvm.tirx.if_then_else( in_x < 0, extrapolation_value, - tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, value), + tvm.tirx.if_then_else(in_x > image_width - 1, extrapolation_value, value), ) return _cast_output(value, data.dtype, out_dtype=out_dtype) @@ -434,7 +434,7 @@ def resize1d( out_dtype: string, optional Type to return. If left None will be same as input type. - output_shape: tvm.tir.container.Array, optional + output_shape: tvm.tirx.container.Array, optional Shape to return. If left None will be inferred (If shape is determined dynamically, pass out_dtype.shape as output_shape) @@ -470,7 +470,7 @@ def resize1d( for i in range(1): if isinstance(size[i], int): - size[i] = tvm.tir.IntImm("int32", size[i]) + size[i] = tvm.tirx.IntImm("int32", size[i]) def compute_func(*indices): return _resize_1d( @@ -729,16 +729,16 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): raise ValueError("Unknown resize method:", method) if coordinate_transformation_mode == "tf_crop_and_resize": - out = tvm.tir.if_then_else( + out = tvm.tirx.if_then_else( in_y < 0, extrapolation_value, - tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value), + tvm.tirx.if_then_else(in_y > image_height - 1, extrapolation_value, value), ) # use extrapolation_value if in_x is out of boundary - value = tvm.tir.if_then_else( + value = tvm.tirx.if_then_else( in_x < 0, extrapolation_value, - tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out), + tvm.tirx.if_then_else(in_x > image_width - 1, extrapolation_value, out), ) return _cast_output(value, data.dtype, out_dtype=out_dtype) @@ -801,7 +801,7 @@ def resize2d( out_dtype: string, optional Type to return. If left None will be same as input type. - output_shape: tvm.tir.container.Array, optional + output_shape: tvm.tirx.container.Array, optional Shape to return. If left None will be inferred (If shape is determined dynamically, pass out_dtype.shape as output_shape) @@ -837,7 +837,7 @@ def resize2d( for i in range(2): if isinstance(size[i], int): - size[i] = tvm.tir.IntImm("int32", size[i]) + size[i] = tvm.tirx.IntImm("int32", size[i]) def compute_func(*indices): return _resize_2d( @@ -1193,21 +1193,21 @@ def _cast_output(value, data_dtype="float32", out_dtype=None): raise ValueError("Unknown resize method:", method) if coordinate_transformation_mode == "tf_crop_and_resize": - out = tvm.tir.if_then_else( + out = tvm.tirx.if_then_else( in_z < 0, extrapolation_value, - tvm.tir.if_then_else(in_z > image_depth - 1, extrapolation_value, value), + tvm.tirx.if_then_else(in_z > image_depth - 1, extrapolation_value, value), ) - out = tvm.tir.if_then_else( + out = tvm.tirx.if_then_else( in_y < 0, extrapolation_value, - tvm.tir.if_then_else(in_y > image_height - 1, extrapolation_value, value), + tvm.tirx.if_then_else(in_y > image_height - 1, extrapolation_value, value), ) # use extrapolation_value if in_x is out of boundary - value = tvm.tir.if_then_else( + value = tvm.tirx.if_then_else( in_x < 0, extrapolation_value, - tvm.tir.if_then_else(in_x > image_width - 1, extrapolation_value, out), + tvm.tirx.if_then_else(in_x > image_width - 1, extrapolation_value, out), ) return _cast_output(value, data.dtype, out_dtype=out_dtype) @@ -1270,7 +1270,7 @@ def resize3d( out_dtype: string, optional Type to return. If left None will be same as input type. - output_shape: tvm.tir.container.Array, optional + output_shape: tvm.tirx.container.Array, optional Shape to return. If left None will be inferred (If shape is determined dynamically, pass out_dtype.shape as output_shape) @@ -1300,7 +1300,7 @@ def resize3d( for i in range(3): if isinstance(size[i], int): - size[i] = tvm.tir.IntImm("int32", size[i]) + size[i] = tvm.tirx.IntImm("int32", size[i]) def compute_func(*indices): return _resize_3d( diff --git a/python/tvm/topi/index_put.py b/python/tvm/topi/index_put.py index d3e5ba42be12..b4e509fb4aa6 100644 --- a/python/tvm/topi/index_put.py +++ b/python/tvm/topi/index_put.py @@ -17,9 +17,9 @@ # pylint: disable=invalid-name """IndexPut operator""" -from tvm import te, tir +from tvm import te, tirx from tvm.script.ir_builder import IRBuilder -from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import tirx as T from . import utils @@ -153,7 +153,7 @@ def add_func(dst_ptr, dst_index, update): in_buffers.extend(indices) in_buffers.append(values) - out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") + out_buf = tirx.decl_buffer(data.shape, data.dtype, "out_buf") return te.extern( [data.shape], in_buffers, diff --git a/python/tvm/topi/math.py b/python/tvm/topi/math.py index d086141371c6..146009e3ba12 100644 --- a/python/tvm/topi/math.py +++ b/python/tvm/topi/math.py @@ -19,7 +19,7 @@ # pylint: disable=redefined-builtin,unused-argument import tvm from tvm import DataType, DataTypeCode, te -from tvm.tir import PrimExpr +from tvm.tirx import PrimExpr from . import cpp, tag from .utils import get_const_tuple @@ -618,9 +618,9 @@ def clip(x, a_min, a_max): ---------- x : tvm.te.Tensor Input argument. - a_min : tvm.tir.PrimExpr + a_min : tvm.tirx.PrimExpr Minimum value. - a_max : tvm.tir.PrimExpr + a_max : tvm.tirx.PrimExpr Maximum value. Returns @@ -632,14 +632,14 @@ def clip(x, a_min, a_max): def _compute(*indices): value = x(*indices) const_min = ( - tvm.tir.Cast(value.dtype, a_min) + tvm.tirx.Cast(value.dtype, a_min) if isinstance(a_min, PrimExpr) - else tvm.tir.const(a_min, value.dtype) + else tvm.tirx.const(a_min, value.dtype) ) const_max = ( - tvm.tir.Cast(value.dtype, a_max) + tvm.tirx.Cast(value.dtype, a_max) if isinstance(a_max, PrimExpr) - else tvm.tir.const(a_max, value.dtype) + else tvm.tirx.const(a_max, value.dtype) ) return tvm.te.max(tvm.te.min(value, const_max), const_min) @@ -669,11 +669,11 @@ def fixed_point_multiply(x, multiplier, shift): def _compute(*indices): value = x(*indices) - return tvm.tir.q_multiply_shift( + return tvm.tirx.q_multiply_shift( value, - tvm.tir.const(multiplier, "int32"), - tvm.tir.const(31, "int32"), - tvm.tir.const(shift, "int32"), + tvm.tirx.const(multiplier, "int32"), + tvm.tirx.const(31, "int32"), + tvm.tirx.const(shift, "int32"), ) return te.compute(x.shape, _compute) @@ -723,14 +723,14 @@ def _compute(*indices): m = y(*param_indices) l_shift = lshift(*param_indices) r_shift = rshift(*param_indices) - return tvm.tir.q_multiply_shift_per_axis( + return tvm.tirx.q_multiply_shift_per_axis( value, m, l_shift, r_shift, - tvm.tir.const(31, "int32"), - tvm.tir.const(is_lshift_required, "bool"), - tvm.tir.const(is_rshift_required, "bool"), + tvm.tirx.const(31, "int32"), + tvm.tirx.const(is_lshift_required, "bool"), + tvm.tirx.const(is_rshift_required, "bool"), ) return te.compute(x.shape, _compute) @@ -758,7 +758,7 @@ def cast(x, dtype, span=None): if isinstance(x, te.tensor.Tensor): return te.compute(x.shape, lambda *i: x(*i).astype(dtype), tag=tag.ELEMWISE) # pylint: disable=import-outside-toplevel - from tvm.tir import _ffi_api + from tvm.tirx import _ffi_api return _ffi_api._cast(dtype, x, span) @@ -849,18 +849,18 @@ def ceil_log2(x): y : tvm.te.Tensor The result. """ - if not isinstance(x, tvm.tir.PrimExpr): - x = tvm.tir.const(x) + if not isinstance(x, tvm.tirx.PrimExpr): + x = tvm.tirx.const(x) if "float" in x.dtype: - return tvm.tir.ceil(tvm.tir.log2(x)) + return tvm.tirx.ceil(tvm.tirx.log2(x)) target = tvm.target.Target.current() if "vulkan" in target.kind.name: - clz = tvm.tir.clz(x) + clz = tvm.tirx.clz(x) bits = int(x.dtype[-2:]) - res = tvm.tir.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits - clz) + res = tvm.tirx.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits - clz) if res.dtype != x.dtype: return cast(res, x.dtype) return res @@ -870,6 +870,6 @@ def ceil_log2(x): "rocm", "webgpu", ]: - return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, "float32"))), x.dtype) + return cast(tvm.tirx.ceil(tvm.tirx.log2(cast(x, "float32"))), x.dtype) - return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, "float64"))), x.dtype) + return cast(tvm.tirx.ceil(tvm.tirx.log2(cast(x, "float64"))), x.dtype) diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index eff3060cb0a8..5143238d7799 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -91,13 +91,13 @@ def batch_matmul( else: YB, YK, YJ = get_const_tuple(tensor_b.shape) - assert XK == YK or isinstance(YK, tvm.tir.expr.Var), "shapes of x and y are inconsistent" + assert XK == YK or isinstance(YK, tvm.tirx.expr.Var), "shapes of x and y are inconsistent" k = te.reduce_axis((0, XK), name="k") if oshape is None: assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match" batch = ( - tvm.tir.expr.SizeVar("batch", "int32") - if isinstance(XB, tvm.tir.expr.Var) or isinstance(YB, tvm.tir.expr.Var) + tvm.tirx.expr.SizeVar("batch", "int32") + if isinstance(XB, tvm.tirx.expr.Var) or isinstance(YB, tvm.tirx.expr.Var) else te.max(XB, YB) ) oshape = (batch, XI, YJ) diff --git a/python/tvm/topi/nn/bitserial_conv2d.py b/python/tvm/topi/nn/bitserial_conv2d.py index 76a1940412d9..374e380dd784 100644 --- a/python/tvm/topi/nn/bitserial_conv2d.py +++ b/python/tvm/topi/nn/bitserial_conv2d.py @@ -113,11 +113,11 @@ def _conv(nn, ff, yy, xx): return te.sum( ( ( - tvm.tir.popcount( + tvm.tirx.popcount( PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] & Filter_q[ff, rc, ry, rx, b2] ) - - tvm.tir.popcount( + - tvm.tirx.popcount( PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] & ~Filter_q[ff, rc, ry, rx, b2] ) @@ -133,7 +133,7 @@ def _conv(nn, ff, yy, xx): b1b2 = (b1 + b2).astype(out_dtype) return te.sum( ( - tvm.tir.popcount( + tvm.tirx.popcount( PadInput_q[nn, rc, b1, yy * stride_h + ry, xx * stride_w + rx] & Filter_q[ff, rc, ry, rx, b2] ) @@ -237,11 +237,11 @@ def _conv(nn, yy, xx, ff): return te.sum( ( ( - tvm.tir.popcount( + tvm.tirx.popcount( PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] & Filter_q[ry, rx, rc, ff, b2] ) - - tvm.tir.popcount( + - tvm.tirx.popcount( PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] & ~Filter_q[ry, rx, rc, ff, b2] ) @@ -257,7 +257,7 @@ def _conv(nn, yy, xx, ff): b1b2 = (b1 + b2).astype(out_dtype) return te.sum( ( - tvm.tir.popcount( + tvm.tirx.popcount( PadInput_q[nn, yy * stride_h + ry, xx * stride_w + rx, rc, b1] & Filter_q[ry, rx, rc, ff, b2] ) diff --git a/python/tvm/topi/nn/bitserial_dense.py b/python/tvm/topi/nn/bitserial_dense.py index 145b2ede5530..7fb48d00211e 100644 --- a/python/tvm/topi/nn/bitserial_dense.py +++ b/python/tvm/topi/nn/bitserial_dense.py @@ -58,8 +58,8 @@ def bitserial_dense( oshape, lambda i, j: te.sum( ( - tvm.tir.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]) - - tvm.tir.popcount(~weight_packed[j, wb, k] & data_packed[i, db, k]) + tvm.tirx.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]) + - tvm.tirx.popcount(~weight_packed[j, wb, k] & data_packed[i, db, k]) ).astype(out_dtype) << (db + wb).astype(out_dtype), axis=[wb, db, k], @@ -70,7 +70,7 @@ def bitserial_dense( matmul = te.compute( oshape, lambda i, j: te.sum( - tvm.tir.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]).astype(out_dtype) + tvm.tirx.popcount(weight_packed[j, wb, k] & data_packed[i, db, k]).astype(out_dtype) << (db + wb).astype(out_dtype), axis=[wb, db, k], ), diff --git a/python/tvm/topi/nn/bitserial_util.py b/python/tvm/topi/nn/bitserial_util.py index 33bb1d53c04c..b6c3013fc863 100644 --- a/python/tvm/topi/nn/bitserial_util.py +++ b/python/tvm/topi/nn/bitserial_util.py @@ -69,7 +69,7 @@ def bitpack(data, bits, pack_axis, bit_axis, pack_type, name="QuantizeInput"): pack_axis += 1 def _bitpack(*indices): - packed_data = [tvm.tir.const(0, pack_type)] * bits + packed_data = [tvm.tirx.const(0, pack_type)] * bits for k in range(data_width): # Translate indices for packed data back to original idx = [0] * n @@ -85,7 +85,7 @@ def _bitpack(*indices): element = data(*idx) for b in range(bits): - extracted_bit = ((element & tvm.tir.const(masks[b], "int32")) >> b).astype( + extracted_bit = ((element & tvm.tirx.const(masks[b], "int32")) >> b).astype( pack_type ) packed_data[b] = packed_data[b] | extracted_bit diff --git a/python/tvm/topi/nn/bnn.py b/python/tvm/topi/nn/bnn.py index 2e79d5118b12..4d76ac5a6b5d 100644 --- a/python/tvm/topi/nn/bnn.py +++ b/python/tvm/topi/nn/bnn.py @@ -52,7 +52,7 @@ def binarize_pack(data, axis=None, name="PackedInput"): def _binarize_pack(*indices): start_idx = [indices[i] * 32 if i == axis else indices[i] for i in range(n)] - packed = tvm.tir.const(0, "uint32") + packed = tvm.tirx.const(0, "uint32") for j in range(32): idx = [start_idx[i] + j if i == axis else start_idx[i] for i in range(n)] sign = (data(*idx) >= 0).astype("uint32") @@ -90,7 +90,7 @@ def binary_dense(data, weight): k = te.reduce_axis((0, in_dim), name="k") matmul = te.compute( (batch, out_dim), - lambda i, j: te.sum(tvm.tir.popcount(data[i, k] ^ weight[j, k]), axis=k), + lambda i, j: te.sum(tvm.tirx.popcount(data[i, k] ^ weight[j, k]), axis=k), tag="binary_dense", ) diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 7e46ea553925..330fbd6c1c0e 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -353,8 +353,8 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou kh = te.reduce_axis((0, kernel_height), name="kh") kw = te.reduce_axis((0, kernel_width), name="kw") - idxdiv = tvm.tir.indexdiv - idxmod = tvm.tir.indexmod + idxdiv = tvm.tirx.indexdiv + idxmod = tvm.tirx.indexmod if groups == 1: ic = te.reduce_axis((0, in_channel), name="ic") @@ -480,8 +480,8 @@ def conv2d_NCHWc_OIHWo( kh = te.reduce_axis((0, kernel_height), name="kh") kw = te.reduce_axis((0, kernel_width), name="kw") - idxdiv = tvm.tir.indexdiv - idxmod = tvm.tir.indexmod + idxdiv = tvm.tirx.indexdiv + idxmod = tvm.tirx.indexmod def compute_conv2d(*args): n, occ, oh, ow, ocb = args @@ -985,8 +985,8 @@ def unpack_NCHWc_to_nchw(packed_out, out_dtype): """ n, oc_chunk, oh, ow, oc_bn = get_const_tuple(packed_out.shape) - idxmod = tvm.tir.indexmod - idxdiv = tvm.tir.indexdiv + idxmod = tvm.tirx.indexmod + idxdiv = tvm.tirx.indexdiv oshape = (n, oc_chunk * oc_bn, oh, ow) unpacked_out = te.compute( diff --git a/python/tvm/topi/nn/deformable_conv2d.py b/python/tvm/topi/nn/deformable_conv2d.py index e48557fee4f9..64b19e7e4f5e 100644 --- a/python/tvm/topi/nn/deformable_conv2d.py +++ b/python/tvm/topi/nn/deformable_conv2d.py @@ -92,12 +92,12 @@ def deformable_conv2d_nchw( ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") - zero = tvm.tir.const(0.0, data.dtype) + zero = tvm.tirx.const(0.0, data.dtype) def _bilinear(n, c, h, w): - outside = tvm.tir.any(h < 0, w < 0, h >= in_height, w >= in_width) + outside = tvm.tirx.any(h < 0, w < 0, h >= in_height, w >= in_width) val = bilinear_sample_nchw(data, (n, c, h, w), in_height - 1, in_width - 1) - return tvm.tir.if_then_else(outside, zero, val) + return tvm.tirx.if_then_else(outside, zero, val) data_deform = te.compute( (batch, in_channel, kernel_h, kernel_w, out_height, out_width), @@ -200,12 +200,12 @@ def deformable_conv2d_nhwc( ry = te.reduce_axis((0, kernel_h), name="ry") rx = te.reduce_axis((0, kernel_w), name="rx") - zero = tvm.tir.const(0.0, data.dtype) + zero = tvm.tirx.const(0.0, data.dtype) def _bilinear(n, h, w, c): - outside = tvm.tir.any(h < 0, w < 0, h >= in_height, w >= in_width) + outside = tvm.tirx.any(h < 0, w < 0, h >= in_height, w >= in_width) val = bilinear_sample_nhwc(data, (n, h, w, c), in_height - 1, in_width - 1) - return tvm.tir.if_then_else(outside, zero, val) + return tvm.tirx.if_then_else(outside, zero, val) data_deform = te.compute( (batch, kernel_h, kernel_w, in_channel, out_height, out_width), diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py index d3700e569d14..c3415541853b 100644 --- a/python/tvm/topi/nn/dense.py +++ b/python/tvm/topi/nn/dense.py @@ -95,7 +95,7 @@ def matmul( reduce_dim_b, out_dim = tensor_b.shape[-2:] batch_dims_b = tensor_b.shape[:-2] - if not isinstance(reduce_dim_a, tvm.tir.Var) and not isinstance(reduce_dim_b, tvm.tir.Var): + if not isinstance(reduce_dim_a, tvm.tirx.Var) and not isinstance(reduce_dim_b, tvm.tirx.Var): assert int(reduce_dim_a) == int(reduce_dim_b), ( f"Reduction dimensions of dense do not match. {reduce_dim_a} vs {reduce_dim_b}." ) @@ -106,8 +106,8 @@ def matmul( for idx, (l, r) in enumerate(zip(batch_dims_a, batch_dims_b)): if ( - not isinstance(l, tvm.tir.Var) - and not isinstance(r, tvm.tir.Var) + not isinstance(l, tvm.tirx.Var) + and not isinstance(r, tvm.tirx.Var) and int(l) != 1 and int(r) != 1 ): @@ -115,7 +115,7 @@ def matmul( "Batch dimensions of dense do not match: " f"{tensor_a.shape[:-2]} vs {tensor_b.shape[:-2]}." ) - if not isinstance(l, tvm.tir.Var) and int(l) == 1: + if not isinstance(l, tvm.tirx.Var) and int(l) == 1: batch_dims_a[idx] = batch_dims_b[idx] k = te.reduce_axis((0, reduce_dim_a), name="k") @@ -123,12 +123,12 @@ def matmul( def compute(*indices): batch_indices_a = indices[-len(tensor_a.shape) : -2] batch_indices_a = [ - i if isinstance(dim, tvm.tir.Var) or int(dim) != 1 else 0 + i if isinstance(dim, tvm.tirx.Var) or int(dim) != 1 else 0 for i, dim in zip(batch_indices_a, tensor_a.shape[:-2]) ] batch_indices_b = indices[-len(tensor_b.shape) : -2] batch_indices_b = [ - i if isinstance(dim, tvm.tir.Var) or int(dim) != 1 else 0 + i if isinstance(dim, tvm.tirx.Var) or int(dim) != 1 else 0 for i, dim in zip(batch_indices_b, tensor_b.shape[:-2]) ] i, j = indices[-2:] @@ -243,8 +243,8 @@ def dense_pack(data, weight, bias=None, out_dtype=None): N, _, packw_bn = get_const_tuple(weight.shape) # out_dim N = N * packw_bn - idxdiv = tvm.tir.indexdiv - idxmod = tvm.tir.indexmod + idxdiv = tvm.tirx.indexdiv + idxmod = tvm.tirx.indexmod k = te.reduce_axis((0, K), name="k") C = te.compute( (M, N), diff --git a/python/tvm/topi/nn/depth_to_space.py b/python/tvm/topi/nn/depth_to_space.py index abfbb606ccf8..00c4e9a91936 100644 --- a/python/tvm/topi/nn/depth_to_space.py +++ b/python/tvm/topi/nn/depth_to_space.py @@ -49,11 +49,11 @@ def depth_to_space(data, block_size, layout="NCHW", mode="DCR"): """ if layout == "NCHW": in_n, in_c, in_h, in_w = data.shape - channel_factor = tvm.tir.truncdiv(in_c, (block_size * block_size)) + channel_factor = tvm.tirx.truncdiv(in_c, (block_size * block_size)) output_shape = [in_n, channel_factor, in_h * block_size, in_w * block_size] elif layout == "NHWC": in_n, in_h, in_w, in_c = data.shape - channel_factor = tvm.tir.truncdiv(in_c, (block_size * block_size)) + channel_factor = tvm.tirx.truncdiv(in_c, (block_size * block_size)) output_shape = [in_n, in_h * block_size, in_w * block_size, channel_factor] else: raise ValueError("Only NCHW and NHWC layouts are currently supported.") @@ -66,10 +66,10 @@ def _get_indices(*indices): return n, c, y, x def _get_pixel(n, c, y, x): - block_x = tvm.tir.truncdiv(x, block_size) - block_y = tvm.tir.truncdiv(y, block_size) - idx_x = tvm.tir.truncmod(x, block_size) - idx_y = tvm.tir.truncmod(y, block_size) + block_x = tvm.tirx.truncdiv(x, block_size) + block_y = tvm.tirx.truncdiv(y, block_size) + idx_x = tvm.tirx.truncmod(x, block_size) + idx_y = tvm.tirx.truncmod(y, block_size) if mode == "DCR": channel_idx = channel_factor * ((block_size * idx_y) + idx_x) + c else: diff --git a/python/tvm/topi/nn/depthwise_conv2d.py b/python/tvm/topi/nn/depthwise_conv2d.py index 7b2f3bedb613..cedda45ac071 100644 --- a/python/tvm/topi/nn/depthwise_conv2d.py +++ b/python/tvm/topi/nn/depthwise_conv2d.py @@ -179,8 +179,8 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No pad_after = [0, 0, pad_down, pad_right] PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") # depthconv stage - idxdiv = tvm.tir.indexdiv - idxmod = tvm.tir.indexmod + idxdiv = tvm.tirx.indexdiv + idxmod = tvm.tirx.indexmod di = te.reduce_axis((0, filter_height), name="di") dj = te.reduce_axis((0, filter_width), name="dj") Output = te.compute( @@ -271,8 +271,8 @@ def depthwise_conv2d_nhwc( pad_after = [0, pad_down, pad_right, 0] PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") # depthconv stage - idxdiv = tvm.tir.indexdiv - idxmod = tvm.tir.indexmod + idxdiv = tvm.tirx.indexdiv + idxmod = tvm.tirx.indexmod di = te.reduce_axis((0, filter_height), name="di") dj = te.reduce_axis((0, filter_width), name="dj") @@ -405,8 +405,8 @@ def depthwise_conv2d_backward_weight_nhwc(Input, Out_grad, oshape, fshape, strid dh = te.reduce_axis((0, Out_grad.shape[1].value), name="dh") dw = te.reduce_axis((0, Out_grad.shape[2].value), name="dw") db = te.reduce_axis((0, batch), name="db") - idxdiv = tvm.tir.indexdiv - idxmod = tvm.tir.indexmod + idxdiv = tvm.tirx.indexdiv + idxmod = tvm.tirx.indexmod Weight_grad = te.compute( (filter_h, filter_w, in_c, channel_multiplier), diff --git a/python/tvm/topi/nn/dilate.py b/python/tvm/topi/nn/dilate.py index 68c5c9b76ce3..536654a43f1e 100644 --- a/python/tvm/topi/nn/dilate.py +++ b/python/tvm/topi/nn/dilate.py @@ -55,8 +55,8 @@ def dilate(data, strides, dilation_value=0.0, name="DilatedInput"): def _dilate(*indices): not_zero = [] index_tuple = [] - idxdiv = tvm.tir.indexdiv - idxmod = tvm.tir.indexmod + idxdiv = tvm.tirx.indexdiv + idxmod = tvm.tirx.indexmod for i in range(n): if not utils.equal_const_int(strides[i], 1): index_tuple.append(idxdiv(indices[i], strides[i])) @@ -64,9 +64,9 @@ def _dilate(*indices): else: index_tuple.append(indices[i]) if not_zero: - not_zero = tvm.tir.all(*not_zero) - return tvm.tir.if_then_else( - not_zero, data(*index_tuple), tvm.tir.const(dilation_value, data.dtype) + not_zero = tvm.tirx.all(*not_zero) + return tvm.tirx.if_then_else( + not_zero, data(*index_tuple), tvm.tirx.const(dilation_value, data.dtype) ) return data(*index_tuple) diff --git a/python/tvm/topi/nn/elemwise.py b/python/tvm/topi/nn/elemwise.py index 1c9d5b693f90..53980b04b96c 100644 --- a/python/tvm/topi/nn/elemwise.py +++ b/python/tvm/topi/nn/elemwise.py @@ -37,7 +37,7 @@ def relu(x): y : tvm.te.Tensor The result. """ - return te.compute(x.shape, lambda *i: tvm.te.max(x(*i), tvm.tir.const(0, x.dtype))) + return te.compute(x.shape, lambda *i: tvm.te.max(x(*i), tvm.tirx.const(0, x.dtype))) @tvm.te.tag_scope(tag=tag.ELEMWISE) @@ -60,8 +60,8 @@ def leaky_relu(x, alpha): def _compute(*indices): value = x(*indices) - calpha = tvm.tir.const(alpha, value.dtype) - return tvm.tir.Select(value > 0, value, value * calpha) + calpha = tvm.tirx.const(alpha, value.dtype) + return tvm.tirx.Select(value > 0, value, value * calpha) return te.compute(x.shape, _compute) @@ -89,11 +89,11 @@ def softplus(x, beta=1.0, threshold=20.0): def _compute(*indices): value = x(*indices) - b = tvm.tir.const(beta, value.dtype) - t = tvm.tir.const(threshold, value.dtype) + b = tvm.tirx.const(beta, value.dtype) + t = tvm.tirx.const(threshold, value.dtype) - return tvm.tir.Select( - b * value > t, value, (1 / b) * tvm.tir.log(1 + tvm.tir.exp(b * value)) + return tvm.tirx.Select( + b * value > t, value, (1 / b) * tvm.tirx.log(1 + tvm.tirx.exp(b * value)) ) return te.compute(x.shape, _compute) @@ -138,6 +138,6 @@ def prelu(x, slope, axis=1): def _compute_channelwise(*indices): xval = x(*indices) - return tvm.tir.Select(xval > 0, xval, xval * slope(indices[axis])) + return tvm.tirx.Select(xval > 0, xval, xval * slope(indices[axis])) return te.compute(x.shape, _compute_channelwise) diff --git a/python/tvm/topi/nn/fifo_buffer.py b/python/tvm/topi/nn/fifo_buffer.py index a07d0e64e1a7..20be54ad1870 100644 --- a/python/tvm/topi/nn/fifo_buffer.py +++ b/python/tvm/topi/nn/fifo_buffer.py @@ -77,7 +77,7 @@ def fifo_buffer(data, buffer, axis): if len(buffer.shape) == 1: return te.compute( buffer.shape, - lambda i: tvm.tir.if_then_else( + lambda i: tvm.tirx.if_then_else( i < buflen - data_size, buffer[i + data_size], data[i - buflen + data_size] ), name="new_buffer", @@ -86,7 +86,7 @@ def fifo_buffer(data, buffer, axis): if axis == 0: return te.compute( buffer.shape, - lambda i, j: tvm.tir.if_then_else( + lambda i, j: tvm.tirx.if_then_else( i < buflen - data_size, buffer[i + data_size, j], data[i - buflen + data_size, j], @@ -96,7 +96,7 @@ def fifo_buffer(data, buffer, axis): if axis == 1: return te.compute( buffer.shape, - lambda i, j: tvm.tir.if_then_else( + lambda i, j: tvm.tirx.if_then_else( j < buflen - data_size, buffer[i, j + data_size], data[i, j - buflen + data_size], @@ -108,7 +108,7 @@ def fifo_buffer(data, buffer, axis): if axis == 0: return te.compute( buffer.shape, - lambda i, j, k: tvm.tir.if_then_else( + lambda i, j, k: tvm.tirx.if_then_else( i < buflen - data_size, buffer[i + data_size, j, k], data[i - buflen + data_size, j, k], @@ -118,7 +118,7 @@ def fifo_buffer(data, buffer, axis): if axis == 1: return te.compute( buffer.shape, - lambda i, j, k: tvm.tir.if_then_else( + lambda i, j, k: tvm.tirx.if_then_else( j < buflen - data_size, buffer[i, j + data_size, k], data[i, j - buflen + data_size, k], @@ -128,7 +128,7 @@ def fifo_buffer(data, buffer, axis): if axis == 2: return te.compute( buffer.shape, - lambda i, j, k: tvm.tir.if_then_else( + lambda i, j, k: tvm.tirx.if_then_else( k < buflen - data_size, buffer[i, j, k + data_size], data[i, j, k - buflen + data_size], @@ -140,7 +140,7 @@ def fifo_buffer(data, buffer, axis): if axis == 0: return te.compute( buffer.shape, - lambda i, j, k, l: tvm.tir.if_then_else( + lambda i, j, k, l: tvm.tirx.if_then_else( i < buflen - data_size, buffer[i + data_size, j, k, l], data[i - buflen + data_size, j, k, l], @@ -150,7 +150,7 @@ def fifo_buffer(data, buffer, axis): if axis == 1: return te.compute( buffer.shape, - lambda i, j, k, l: tvm.tir.if_then_else( + lambda i, j, k, l: tvm.tirx.if_then_else( j < buflen - data_size, buffer[i, j + data_size, k, l], data[i, j - buflen + data_size, k, l], @@ -160,7 +160,7 @@ def fifo_buffer(data, buffer, axis): if axis == 2: return te.compute( buffer.shape, - lambda i, j, k, l: tvm.tir.if_then_else( + lambda i, j, k, l: tvm.tirx.if_then_else( k < buflen - data_size, buffer[i, j, k + data_size, l], data[i, j, k - buflen + data_size, l], @@ -170,7 +170,7 @@ def fifo_buffer(data, buffer, axis): if axis == 3: return te.compute( buffer.shape, - lambda i, j, k, l: tvm.tir.if_then_else( + lambda i, j, k, l: tvm.tirx.if_then_else( l < buflen - data_size, buffer[i, j, k, l + data_size], data[i, j, k, l - buflen + data_size], diff --git a/python/tvm/topi/nn/flatten.py b/python/tvm/topi/nn/flatten.py index 7a007cdb55ad..aac22942053d 100644 --- a/python/tvm/topi/nn/flatten.py +++ b/python/tvm/topi/nn/flatten.py @@ -41,8 +41,8 @@ def flatten(data): for i in range(1, len(ishape)): dim = dim * ishape[i] oshape = [ishape[0], dim] - idxdiv = tvm.tir.indexdiv - idxmod = tvm.tir.indexmod + idxdiv = tvm.tirx.indexdiv + idxmod = tvm.tirx.indexmod def unwrap(idx, shape): index = [] diff --git a/python/tvm/topi/nn/lstm.py b/python/tvm/topi/nn/lstm.py index cceeb7a84434..8f4292bdabc4 100644 --- a/python/tvm/topi/nn/lstm.py +++ b/python/tvm/topi/nn/lstm.py @@ -18,7 +18,7 @@ # ruff: noqa: E731 """General LSTM implementation using TE scan.""" -from tvm import te, tir +from tvm import te, tirx from tvm.topi import tag @@ -34,9 +34,9 @@ def lstm( p_i=None, p_f=None, p_o=None, - f_act=tir.sigmoid, - g_act=tir.tanh, - h_act=tir.tanh, + f_act=tirx.sigmoid, + g_act=tirx.tanh, + h_act=tirx.tanh, reverse=False, weight_layout: str = "IFGO", ): diff --git a/python/tvm/topi/nn/pad.py b/python/tvm/topi/nn/pad.py index 813c4421ab55..feed68a854ad 100644 --- a/python/tvm/topi/nn/pad.py +++ b/python/tvm/topi/nn/pad.py @@ -18,7 +18,7 @@ import tvm from tvm import te -from tvm.tir import if_then_else +from tvm.tirx import if_then_else from .. import tag from ..utils import equal_const_int @@ -94,8 +94,8 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput", attrs= out_shape = tuple(ana.simplify(dshape[i] + pad_before[i] + pad_after[i]) for i in range(n)) pad_value = ( pad_value - if isinstance(pad_value, tvm.tir.PrimExpr) - else tvm.tir.const(pad_value, data.dtype) + if isinstance(pad_value, tvm.tirx.PrimExpr) + else tvm.tirx.const(pad_value, data.dtype) ) def _pad(*indices): @@ -109,8 +109,8 @@ def _pad(*indices): not_zero.append(indices[i] >= pad_before[i]) not_zero.append(indices[i] < data.shape[i] + pad_before[i]) if not_zero: - not_zero = tvm.tir.all(*not_zero) - return tvm.tir.if_then_else(not_zero, data(*index_tuple), pad_value) + not_zero = tvm.tirx.all(*not_zero) + return tvm.tirx.if_then_else(not_zero, data(*index_tuple), pad_value) return data(*index_tuple) return te.compute(out_shape, _pad, name=name, attrs=attrs) @@ -168,8 +168,8 @@ def _pad(*indices): below.append(indices[i] < pad_before[i]) mapped_tuple = [] for i, axis in enumerate(index_tuple): - mapped_axis = tvm.tir.if_then_else(below[i], -axis - mode, axis) - mapped_axis = tvm.tir.if_then_else( + mapped_axis = tvm.tirx.if_then_else(below[i], -axis - mode, axis) + mapped_axis = tvm.tirx.if_then_else( above[i], (2 * (data.shape[i] - 1)) - axis + mode, mapped_axis ) mapped_tuple.append(mapped_axis) @@ -264,7 +264,7 @@ def _pad(*indices): orig_idx = idx - before clamped_idx = if_then_else( orig_idx < 0, - tvm.tir.const(0, "int32"), # replicate first element + tvm.tirx.const(0, "int32"), # replicate first element if_then_else( orig_idx >= size, size - 1, # replicate last element @@ -311,7 +311,7 @@ def _pad(*indices): before = pad_before[i] orig_idx = idx - before - wrapped_idx = tvm.tir.indexmod(orig_idx + size, size) + wrapped_idx = tvm.tirx.indexmod(orig_idx + size, size) index_tuple.append(wrapped_idx) return data(*index_tuple) diff --git a/python/tvm/topi/nn/pixel_shuffle.py b/python/tvm/topi/nn/pixel_shuffle.py index ccfa13355774..c9fd8dfcc295 100644 --- a/python/tvm/topi/nn/pixel_shuffle.py +++ b/python/tvm/topi/nn/pixel_shuffle.py @@ -44,10 +44,10 @@ def pixel_shuffle(data, upscale_factor, name="PixelShuffle"): ndim = len(data.shape) assert ndim >= 3, "Input must be at least 3D" - upscale_factor_const = tvm.tir.const(upscale_factor, "int32") + upscale_factor_const = tvm.tirx.const(upscale_factor, "int32") c_in, h_in, w_in = data.shape[-3], data.shape[-2], data.shape[-1] - c_out = tvm.tir.floordiv(c_in, upscale_factor_const * upscale_factor_const) + c_out = tvm.tirx.floordiv(c_in, upscale_factor_const * upscale_factor_const) h_out = h_in * upscale_factor_const w_out = w_in * upscale_factor_const @@ -57,10 +57,10 @@ def _compute(*indices): batch_indices = indices[:-3] c_out_idx, h_out_idx, w_out_idx = indices[-3], indices[-2], indices[-1] - h_idx = tvm.tir.floordiv(h_out_idx, upscale_factor_const) + h_idx = tvm.tirx.floordiv(h_out_idx, upscale_factor_const) h_offset = h_out_idx % upscale_factor_const - w_idx = tvm.tir.floordiv(w_out_idx, upscale_factor_const) + w_idx = tvm.tirx.floordiv(w_out_idx, upscale_factor_const) w_offset = w_out_idx % upscale_factor_const c_in_idx = ( diff --git a/python/tvm/topi/nn/qnn.py b/python/tvm/topi/nn/qnn.py index 520f25e2bf23..c8602fdafcef 100644 --- a/python/tvm/topi/nn/qnn.py +++ b/python/tvm/topi/nn/qnn.py @@ -17,7 +17,7 @@ """Quantized Neural Network (QNN) Operators""" import tvm -from tvm import te, tir, topi +from tvm import te, tirx, topi SQNN_DISABLE = 0 SQNN_INT8 = 1 @@ -74,11 +74,11 @@ def _compute_pass_through(value, *indices): # out_dtype::max) def _compute_intn(dtype, value, *indices): assert output_scale is not None and output_zero_point is not None - const_min = tvm.tir.min_value(dtype) - const_max = tvm.tir.max_value(dtype) + const_min = tvm.tirx.min_value(dtype) + const_max = tvm.tirx.max_value(dtype) # Use indexmod to handle both scalar and per-channel QNN parameters. - scale_idx = tir.indexmod(indices[axis], topi.shape(output_scale)[0]) - zp_idx = tir.indexmod(indices[axis], topi.shape(output_zero_point)[0]) + scale_idx = tirx.indexmod(indices[axis], topi.shape(output_scale)[0]) + zp_idx = tirx.indexmod(indices[axis], topi.shape(output_zero_point)[0]) return te.max( te.min( te.round(value[indices] / output_scale[scale_idx]) + output_zero_point[zp_idx], @@ -96,7 +96,7 @@ def _dispatch_sim_quantize(value): ) int8_value = te.compute( data.shape, - lambda *indices: tir.if_then_else( + lambda *indices: tirx.if_then_else( out_dtype.equal(SQNN_DTYPE_TO_CODE["int8"]), _compute_intn("int8", value, *indices), pass_through_value[indices], @@ -104,7 +104,7 @@ def _dispatch_sim_quantize(value): ) uint8_value = te.compute( data.shape, - lambda *indices: tir.if_then_else( + lambda *indices: tirx.if_then_else( out_dtype.equal(SQNN_DTYPE_TO_CODE["uint8"]), _compute_intn("uint8", value, *indices), int8_value[indices], @@ -112,7 +112,7 @@ def _dispatch_sim_quantize(value): ) int32_value = te.compute( data.shape, - lambda *indices: tir.if_then_else( + lambda *indices: tirx.if_then_else( out_dtype.equal(SQNN_DTYPE_TO_CODE["int32"]), _compute_intn("int32", value, *indices), uint8_value[indices], @@ -163,8 +163,8 @@ def _compute_pass_through(value, *indices): def _compute_intn(value, *indices): assert input_scale is not None and input_zero_point is not None # Use indexmod to handle both scalar and per-channel QNN parameters. - scale_idx = tir.indexmod(indices[axis], topi.shape(input_scale)[0]) - zp_idx = tir.indexmod(indices[axis], topi.shape(input_zero_point)[0]) + scale_idx = tirx.indexmod(indices[axis], topi.shape(input_scale)[0]) + zp_idx = tirx.indexmod(indices[axis], topi.shape(input_zero_point)[0]) return (value[indices] - input_zero_point[zp_idx]) * input_scale[scale_idx] # Use an if chain to dynamically return the proper dequantization based on the input datatype. @@ -181,7 +181,7 @@ def _dispatch_sim_dequantize(value): ) intn_value = te.compute( data.shape, - lambda *indices: tir.if_then_else( + lambda *indices: tirx.if_then_else( intn_condition, _compute_intn(value, *indices), pass_through_value[indices], diff --git a/python/tvm/topi/nn/space_to_depth.py b/python/tvm/topi/nn/space_to_depth.py index 1f0dbcacdf54..8a041b7ae410 100644 --- a/python/tvm/topi/nn/space_to_depth.py +++ b/python/tvm/topi/nn/space_to_depth.py @@ -48,15 +48,15 @@ def space_to_depth(data, block_size, layout="NCHW"): output_shape = [ in_n, in_c * block_size * block_size, - tvm.tir.truncdiv(in_h, block_size), - tvm.tir.truncdiv(in_w, block_size), + tvm.tirx.truncdiv(in_h, block_size), + tvm.tirx.truncdiv(in_w, block_size), ] elif layout == "NHWC": in_n, in_h, in_w, in_c = data.shape output_shape = [ in_n, - tvm.tir.truncdiv(in_h, block_size), - tvm.tir.truncdiv(in_w, block_size), + tvm.tirx.truncdiv(in_h, block_size), + tvm.tirx.truncdiv(in_w, block_size), in_c * block_size * block_size, ] else: @@ -70,10 +70,10 @@ def _get_indices(*indices): return n, c, y, x def _get_pixel(n, c, y, x): - block_offset = tvm.tir.truncdiv(c, in_c) - channel_idx = tvm.tir.truncmod(c, in_c) - x_idx = tvm.tir.truncmod(block_offset, block_size) - y_idx = tvm.tir.truncdiv(block_offset, block_size) + block_offset = tvm.tirx.truncdiv(c, in_c) + channel_idx = tvm.tirx.truncmod(c, in_c) + x_idx = tvm.tirx.truncmod(block_offset, block_size) + y_idx = tvm.tirx.truncdiv(block_offset, block_size) if layout == "NCHW": output = data(n, channel_idx, y_idx + (y * block_size), x_idx + (x * block_size)) diff --git a/python/tvm/topi/nn/upsampling.py b/python/tvm/topi/nn/upsampling.py index ee1cfbce106f..67b1fa5dac6e 100644 --- a/python/tvm/topi/nn/upsampling.py +++ b/python/tvm/topi/nn/upsampling.py @@ -52,7 +52,7 @@ def upsampling( method : {"bilinear", "nearest_neighbor", "bicubic"} Method to be used for upsampling. - output_shape: tvm.tir.container.Array, optional + output_shape: tvm.tirx.container.Array, optional Shape to return. If left None will be inferred (If shape is determined dynamically, pass out_dtype.shape as output_shape) @@ -147,7 +147,7 @@ def upsampling3d( Refer to the ONNX Resize operator specification for details. Available options are "half_pixel", "align_corners" and "asymmetric". - output_shape: tvm.tir.container.Array, optional + output_shape: tvm.tirx.container.Array, optional Shape to return. If left None will be inferred (If shape is determined dynamically, pass out_dtype.shape as output_shape) diff --git a/python/tvm/topi/nn/utils.py b/python/tvm/topi/nn/utils.py index 0198654c5a13..a17a34451549 100644 --- a/python/tvm/topi/nn/utils.py +++ b/python/tvm/topi/nn/utils.py @@ -110,8 +110,8 @@ def infer_stride(data, kernel, out): _, _, IH, IW = data.shape _, _, KH, KW = kernel.shape _, _, OH, OW = out.shape - hstride = (IH - KH) // tvm.te.max(OH - 1, 1) + tvm.tir.Select(OH == 1, 1, 0) - wstride = (IW - KW) // tvm.te.max(OW - 1, 1) + tvm.tir.Select(OW == 1, 1, 0) + hstride = (IH - KH) // tvm.te.max(OH - 1, 1) + tvm.tirx.Select(OH == 1, 1, 0) + wstride = (IW - KW) // tvm.te.max(OW - 1, 1) + tvm.tirx.Select(OW == 1, 1, 0) return get_const_int(hstride), get_const_int(wstride) diff --git a/python/tvm/topi/scan.py b/python/tvm/topi/scan.py index c18376ef63c2..07dc2cddd5dc 100644 --- a/python/tvm/topi/scan.py +++ b/python/tvm/topi/scan.py @@ -21,10 +21,10 @@ import tvm from tvm.script.ir_builder import IRBuilder -from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import tirx as T from ..te import extern -from ..tir import decl_buffer, generic +from ..tirx import decl_buffer, generic from . import utils from .math import cast diff --git a/python/tvm/topi/scatter.py b/python/tvm/topi/scatter.py index ab108690706a..75a5d1cdbfeb 100644 --- a/python/tvm/topi/scatter.py +++ b/python/tvm/topi/scatter.py @@ -18,11 +18,11 @@ # ruff: noqa: E741 """ScatterND operator""" -from tvm import te, tir # hide redefinition of min and max +from tvm import te, tirx # hide redefinition of min and max from tvm.arith.analyzer import Analyzer from tvm.script.ir_builder import IRBuilder -from tvm.script.ir_builder import tir as T -from tvm.tir import expr +from tvm.script.ir_builder import tirx as T +from tvm.tirx import expr def _verify_scatter_nd_inputs(data, indices, updates): @@ -139,11 +139,11 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): elif mode == "mul": out[index] *= updates[i * fused_updates_dimension + j] elif mode == "min": - out[index] = tir.min( + out[index] = tirx.min( out[index], updates[i * fused_updates_dimension + j] ) elif mode == "max": - out[index] = tir.max( + out[index] = tirx.max( out[index], updates[i * fused_updates_dimension + j] ) else: @@ -153,7 +153,7 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): return ib.get() - out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") + out_buf = tirx.decl_buffer(data.shape, data.dtype, "out_buf") return te.extern( [data.shape], [data, indices, updates], diff --git a/python/tvm/topi/scatter_elements.py b/python/tvm/topi/scatter_elements.py index 99fd111bd88d..047a882b7900 100644 --- a/python/tvm/topi/scatter_elements.py +++ b/python/tvm/topi/scatter_elements.py @@ -16,9 +16,9 @@ # under the License. """ScatterElements operator""" -from tvm import te, tir +from tvm import te, tirx from tvm.script.ir_builder import IRBuilder -from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import tirx as T from . import utils from .math import cast @@ -139,10 +139,10 @@ def mean_func(dst_ptr, dst_index, update): dst_ptr[dst_index] = (dst_ptr[dst_index] + update) / 2 def min_func(dst_ptr, dst_index, update): - dst_ptr[dst_index] = tir.min(dst_ptr[dst_index], update) + dst_ptr[dst_index] = tirx.min(dst_ptr[dst_index], update) def max_func(dst_ptr, dst_index, update): - dst_ptr[dst_index] = tir.max(dst_ptr[dst_index], update) + dst_ptr[dst_index] = tirx.max(dst_ptr[dst_index], update) reduce_func = None if reduction == "update": @@ -162,7 +162,7 @@ def max_func(dst_ptr, dst_index, update): "scatter_elements reduction not in [update, add, mul, mean, min, max]:", reduction ) - out_buf = tir.decl_buffer(data.shape, data.dtype, "out_buf") + out_buf = tirx.decl_buffer(data.shape, data.dtype, "out_buf") return te.extern( [data.shape], [data, indices, updates], diff --git a/python/tvm/topi/searchsorted.py b/python/tvm/topi/searchsorted.py index c7bed8e38998..36db4676f961 100644 --- a/python/tvm/topi/searchsorted.py +++ b/python/tvm/topi/searchsorted.py @@ -18,7 +18,7 @@ """searchsorted operator""" from tvm.script.ir_builder import IRBuilder -from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import tirx as T from . import te, utils from .math import cast diff --git a/python/tvm/topi/signal.py b/python/tvm/topi/signal.py index 6f007f5cabc2..982b2c6532a5 100644 --- a/python/tvm/topi/signal.py +++ b/python/tvm/topi/signal.py @@ -19,9 +19,9 @@ from math import pi -from tvm import te, tir +from tvm import te, tirx from tvm.script.ir_builder import IRBuilder -from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import tirx as T def stft( @@ -88,31 +88,31 @@ def gen_ir( # https://librosa.org/doc/0.7.2/_modules/librosa/core/spectrum.html#stft with T.parallel(0, output_ptr.shape[0] * output_ptr.shape[1]) as batch_row: with col_loop(0, output_ptr.shape[2]) as col: - batch = tir.floordiv(batch_row, output_ptr.shape[1]) - row = tir.floormod(batch_row, output_ptr.shape[1]) - output[batch, row, col, 0] = tir.Cast(data_ptr.dtype, 0) - output[batch, row, col, 1] = tir.Cast(data_ptr.dtype, 0) + batch = tirx.floordiv(batch_row, output_ptr.shape[1]) + row = tirx.floormod(batch_row, output_ptr.shape[1]) + output[batch, row, col, 0] = tirx.Cast(data_ptr.dtype, 0) + output[batch, row, col, 1] = tirx.Cast(data_ptr.dtype, 0) with T.serial(0, win_length) as wlen: output[batch, row, col, 0] += ( window[wlen] * data[batch, col * hop_length + wlen] - * tir.cos(2 * pi * row * wlen / win_length) + * tirx.cos(2 * pi * row * wlen / win_length) ) output[batch, row, col, 1] -= ( window[wlen] * data[batch, col * hop_length + wlen] - * tir.sin(2 * pi * row * wlen / win_length) + * tirx.sin(2 * pi * row * wlen / win_length) ) with T.If(normalized): with T.Then(): - output[batch, row, col, 0] /= tir.sqrt(tir.const(n_fft, "float32")) - output[batch, row, col, 1] /= tir.sqrt(tir.const(n_fft, "float32")) + output[batch, row, col, 0] /= tirx.sqrt(tirx.const(n_fft, "float32")) + output[batch, row, col, 1] /= tirx.sqrt(tirx.const(n_fft, "float32")) return ib.get() - output_buf = tir.decl_buffer(output_shape, data.dtype, "output_buf") + output_buf = tirx.decl_buffer(output_shape, data.dtype, "output_buf") loop_kind = "vectorize" - if isinstance(output_shape[2], tir.expr.SizeVar): # any_dim + if isinstance(output_shape[2], tirx.expr.SizeVar): # any_dim loop_kind = "serial" return te.extern( @@ -131,7 +131,7 @@ def gen_ir( def dft( re_data: te.Tensor, im_data: te.Tensor, - inverse: tir.IntImm, + inverse: tirx.IntImm, ): """ Computes the discrete Fourier transform of input (calculation along the last axis). @@ -182,14 +182,14 @@ def gen_ir( base_idx = i * n_fft with T.serial(0, n_fft) as n: n_idx = base_idx + n - re_output_ptr[n_idx] = tir.Cast(re_output_ptr.dtype, 0) - im_output_ptr[n_idx] = tir.Cast(im_output_ptr.dtype, 0) + re_output_ptr[n_idx] = tirx.Cast(re_output_ptr.dtype, 0) + im_output_ptr[n_idx] = tirx.Cast(im_output_ptr.dtype, 0) _w = sign * -2 * pi * n / n_fft with T.serial(0, n_fft) as k: k_idx = base_idx + k w = _w * k - cos_w = tir.Cast(re_output_ptr.dtype, tir.cos(w)) - sin_w = tir.Cast(re_output_ptr.dtype, tir.sin(w)) + cos_w = tirx.Cast(re_output_ptr.dtype, tirx.cos(w)) + sin_w = tirx.Cast(re_output_ptr.dtype, tirx.sin(w)) re_output_ptr[n_idx] += ( re_data_ptr[k_idx] * cos_w - im_data_ptr[k_idx] * sin_w ) @@ -197,8 +197,8 @@ def gen_ir( re_data_ptr[k_idx] * sin_w + im_data_ptr[k_idx] * cos_w ) - re_output_ptr[n_idx] *= tir.Cast(re_output_ptr.dtype, factor) - im_output_ptr[n_idx] *= tir.Cast(im_output_ptr.dtype, factor) + re_output_ptr[n_idx] *= tirx.Cast(re_output_ptr.dtype, factor) + im_output_ptr[n_idx] *= tirx.Cast(im_output_ptr.dtype, factor) return ib.get() diff --git a/python/tvm/topi/sort.py b/python/tvm/topi/sort.py index c937e0305c81..b11f960983bf 100644 --- a/python/tvm/topi/sort.py +++ b/python/tvm/topi/sort.py @@ -48,12 +48,12 @@ def sort(data, axis=-1, is_ascend=1): Sorted index tensor. """ - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) - out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8) + data_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + out_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "out_buf", data_alignment=8) out = te.extern( data.shape, [data], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.sort.sort", ins[0], outs[0], axis, is_ascend ), dtype=data.dtype, @@ -111,16 +111,16 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): tvm_out = tvm.runtime.tensor(np.zeros(dshape, dtype=data.dtype), dev) f(tvm_data, tvm_out) """ - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + data_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) if valid_count is not None: - valid_count_buf = tvm.tir.decl_buffer( + valid_count_buf = tvm.tirx.decl_buffer( valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4 ) - out_buf = tvm.tir.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8) + out_buf = tvm.tirx.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8) out = te.extern( data.shape, [data, valid_count], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.sort.argsort_nms", ins[0], ins[1], outs[0], axis, is_ascend ), dtype="int32", @@ -130,11 +130,11 @@ def argsort(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): tag="argsort_nms_cpu", ) else: - out_buf = tvm.tir.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) + out_buf = tvm.tirx.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) out = te.extern( data.shape, [data], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.sort.argsort", ins[0], outs[0], axis, is_ascend ), dtype=dtype, @@ -178,7 +178,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): The computed result. """ assert ret_type in ["both", "values", "indices"] - data_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + data_buf = tvm.tirx.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) out_shape = list(get_const_tuple(data.shape)) kvar = tvm.te.size_var("k") if not isinstance(k, int): @@ -187,16 +187,16 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): out_shape[axis] = k out_bufs = [] if ret_type in ["both", "values"]: - out_bufs.append(tvm.tir.decl_buffer(out_shape, data.dtype, "value_buf", data_alignment=8)) + out_bufs.append(tvm.tirx.decl_buffer(out_shape, data.dtype, "value_buf", data_alignment=8)) if ret_type in ["both", "indices"]: - out_bufs.append(tvm.tir.decl_buffer(out_shape, dtype, "indices_buf", data_alignment=8)) + out_bufs.append(tvm.tirx.decl_buffer(out_shape, dtype, "indices_buf", data_alignment=8)) out_shapes = [out_shape] * len(out_bufs) kv = kvar if not isinstance(k, int) else k out = te.extern( out_shapes, [data], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.sort.topk", ins[0], *outs, kv, axis, ret_type, is_ascend ), in_buffers=[data_buf], diff --git a/python/tvm/topi/sparse_reshape.py b/python/tvm/topi/sparse_reshape.py index 55ee1c9c1b5d..2017ced6cb5c 100644 --- a/python/tvm/topi/sparse_reshape.py +++ b/python/tvm/topi/sparse_reshape.py @@ -18,9 +18,9 @@ """Sparse_Reshape operator""" from tvm.script.ir_builder import IRBuilder -from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import tirx as T from tvm.te import div, extern, floordiv, floormod -from tvm.tir import Cast, decl_buffer +from tvm.tirx import Cast, decl_buffer def sparse_reshape( diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index a3e736644607..fba3eb4cfa7e 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -99,7 +99,7 @@ def _compute(*idxs): axis_index = 0 for i in range(0, len(idxs)): if i not in real_axis: - dim = tvm.tir.if_then_else(a.shape[len(indices)] != 1, idxs[i], 0) + dim = tvm.tirx.if_then_else(a.shape[len(indices)] != 1, idxs[i], 0) indices.append(dim) axis_index += 1 return a(*indices) @@ -311,37 +311,37 @@ def strided_set(a, v, begin, end, strides=None): raise TypeError("strides should be int32") def _max(a, b): - return tvm.tir.Select(a > b, a, b) + return tvm.tirx.Select(a > b, a, b) if strides is None: - strides = [tvm.tir.const(1, "int32")] * n + strides = [tvm.tirx.const(1, "int32")] * n else: strides = [ - tvm.tir.if_then_else(strides.shape[0] > i, strides[i], tvm.tir.const(1, "int32")) + tvm.tirx.if_then_else(strides.shape[0] > i, strides[i], tvm.tirx.const(1, "int32")) for i in range(n) ] begin = [ - tvm.tir.if_then_else( + tvm.tirx.if_then_else( begin.shape[0] > i, begin[i], - tvm.tir.Select(strides[i] > 0, tvm.tir.const(0, "int32"), a.shape[i]), + tvm.tirx.Select(strides[i] > 0, tvm.tirx.const(0, "int32"), a.shape[i]), ) for i in range(n) ] end = [ - tvm.tir.if_then_else( + tvm.tirx.if_then_else( end.shape[0] > i, end[i], - tvm.tir.Select(strides[i] > 0, a.shape[i] + 1, -(a.shape[i] + 1)), + tvm.tirx.Select(strides[i] > 0, a.shape[i] + 1, -(a.shape[i] + 1)), ) for i in range(n) ] # Convert negative indexes for i in range(n): - begin[i] = tvm.tir.if_then_else(begin[i] < 0, begin[i] + a.shape[i], begin[i]) - end[i] = tvm.tir.if_then_else(end[i] < 0, end[i] + a.shape[i], end[i]) + begin[i] = tvm.tirx.if_then_else(begin[i] < 0, begin[i] + a.shape[i], begin[i]) + end[i] = tvm.tirx.if_then_else(end[i] < 0, end[i] + a.shape[i], end[i]) def _select(*indices): from_val = [] @@ -349,7 +349,7 @@ def _select(*indices): for i in range(n): from_val.append(within_index(begin[i], end[i], strides[i], indices[i])) index_tuple.append(make_idx(begin[i], end[i], strides[i], a.shape[i], indices[i])) - return tvm.tir.if_then_else(tvm.tir.all(*from_val), v(*index_tuple), a(*indices)) + return tvm.tirx.if_then_else(tvm.tirx.all(*from_val), v(*index_tuple), a(*indices)) return te.compute(a.shape, _select, name="strided_set") @@ -1058,12 +1058,12 @@ def trilu(data, k, upper): """ # Make sure datatype is consistent. if k.dtype != "int32": - k = tvm.tir.Cast("int32", k) + k = tvm.tirx.Cast("int32", k) # Check either above or below diagonal depending on upper. - check_op = tvm.tir.GE + check_op = tvm.tirx.GE if upper: - check_op = tvm.tir.LE + check_op = tvm.tirx.LE def _apply_trilu(*indices): row_index = indices[-2] @@ -1072,13 +1072,13 @@ def _apply_trilu(*indices): if row_index.dtype != col_index.dtype: target_type = (col_index + row_index).dtype if row_index.dtype != target_type: - row_index = tvm.tir.Cast(target_type, row_index) + row_index = tvm.tirx.Cast(target_type, row_index) else: - col_index = tvm.tir.Cast(target_type, col_index) + col_index = tvm.tirx.Cast(target_type, col_index) other_indices = indices[:-2] check_position = check_op(row_index, col_index - k) value = data(*other_indices, row_index, col_index) - return tvm.tir.Select(check_position, value, tvm.tir.const(0, data.dtype)) + return tvm.tirx.Select(check_position, value, tvm.tirx.const(0, data.dtype)) return te.compute(data.shape, _apply_trilu, name="trilu", tag=topi.tag.ELEMWISE) diff --git a/python/tvm/topi/unique.py b/python/tvm/topi/unique.py index 1549c3071633..6135f6c8f626 100644 --- a/python/tvm/topi/unique.py +++ b/python/tvm/topi/unique.py @@ -17,12 +17,12 @@ # pylint: disable=invalid-name """Unique operator""" -from tvm import te, tir +from tvm import te, tirx from tvm.script.ir_builder import IRBuilder -from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import tirx as T -def _calc_adjacent_diff_ir(data, output, binop=tir.Sub): +def _calc_adjacent_diff_ir(data, output, binop=tirx.Sub): """Low level IR to calculate adjacent difference in an 1-D array. Parameters @@ -37,7 +37,7 @@ def _calc_adjacent_diff_ir(data, output, binop=tir.Sub): binop: function, optional A binary associative op to use for calculating adjacent difference. The function takes two - TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to + TIR expressions and produce a new TIR expression. By default it uses tvm.tirx.Sub to compute the adjacent difference. """ with IRBuilder() as ib: @@ -48,11 +48,11 @@ def _calc_adjacent_diff_ir(data, output, binop=tir.Sub): with T.Then(): output_ptr[0] = 0 with T.Else(): - output_ptr[i] = tir.Cast(output.dtype, binop(data_ptr[i], data_ptr[i - 1])) + output_ptr[i] = tirx.Cast(output.dtype, binop(data_ptr[i], data_ptr[i - 1])) return ib.get() -def _calc_adjacent_diff(data, out_dtype="int32", binop=tir.Sub): +def _calc_adjacent_diff(data, out_dtype="int32", binop=tirx.Sub): """Function calculate adjacent difference in an 1-D array. Parameters @@ -65,7 +65,7 @@ def _calc_adjacent_diff(data, out_dtype="int32", binop=tir.Sub): binop: function, optional A binary associative op to use for calculating difference. The function takes two - TIR expressions and produce a new TIR expression. By default it uses tvm.tir.Sub to + TIR expressions and produce a new TIR expression. By default it uses tvm.tirx.Sub to compute the adjacent difference. Returns diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 2e2b1b01ad16..c660fa65ce6b 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -25,7 +25,7 @@ import tvm from tvm import te from tvm.s_tir import bijective_layout, layout -from tvm.tir import SizeVar +from tvm.tirx import SizeVar from . import cpp, tag @@ -97,7 +97,7 @@ def prod(x): The result value """ if not x: - return tvm.tir.const(1, "int32") + return tvm.tirx.const(1, "int32") res = x[0] for i in range(1, len(x)): res = res * x[i] @@ -119,10 +119,10 @@ def get_const_int(expr): """ if isinstance(expr, Integral): return expr - if not isinstance(expr, tvm.tir.IntImm): + if not isinstance(expr, tvm.tirx.IntImm): ana = tvm.arith.Analyzer() expr = ana.simplify(expr) - if not isinstance(expr, tvm.tir.IntImm): + if not isinstance(expr, tvm.tirx.IntImm): raise ValueError("Expect value to be constant int") return int(expr.value) @@ -142,10 +142,10 @@ def get_const_float(expr): """ if isinstance(expr, float): return float(expr) - if not isinstance(expr, tvm.tir.FloatImm): + if not isinstance(expr, tvm.tirx.FloatImm): ana = tvm.arith.Analyzer() expr = ana.simplify(expr) - if not isinstance(expr, tvm.tir.FloatImm): + if not isinstance(expr, tvm.tirx.FloatImm): raise ValueError("Expect value to be constant float") return float(expr.value) @@ -165,10 +165,10 @@ def equal_const_int(expr, value): """ if isinstance(expr, Integral): return expr == value - if not isinstance(expr, tvm.tir.IntImm): + if not isinstance(expr, tvm.tirx.IntImm): ana = tvm.arith.Analyzer() expr = ana.simplify(expr) - if not isinstance(expr, tvm.tir.IntImm): + if not isinstance(expr, tvm.tirx.IntImm): return False return expr.value == value @@ -189,12 +189,12 @@ def get_const_tuple(in_tuple): ret = [] ana = None for elem in in_tuple: - if isinstance(elem, tvm.tir.Var): + if isinstance(elem, tvm.tirx.Var): ret.append(elem) - elif not isinstance(elem, tvm.tir.IntImm | int): + elif not isinstance(elem, tvm.tirx.IntImm | int): ana = tvm.arith.Analyzer() if ana is None else ana elem = ana.simplify(elem) - if not isinstance(elem, tvm.tir.IntImm): + if not isinstance(elem, tvm.tirx.IntImm): ret.append(elem) else: ret.append(get_const_int(elem)) @@ -222,13 +222,13 @@ def const_vector(vector, name="const_vector"): vector = np.array(vector) row = vector.shape[0] dtype = str(vector.dtype) - idxm = tvm.tir.indexmod + idxm = tvm.tirx.indexmod def select_array(i): - now = tvm.tir.const(0.0, dtype) + now = tvm.tirx.const(0.0, dtype) for ii in range(row): - now = tvm.tir.Select( - tvm.tir.all(idxm(i, row) == ii), tvm.tir.const(vector[ii], dtype), now + now = tvm.tirx.Select( + tvm.tirx.all(idxm(i, row) == ii), tvm.tirx.const(vector[ii], dtype), now ) return now @@ -271,7 +271,7 @@ def simplify(expr): name="simplify_output", tag="simplify", ) - elif isinstance(expr, tvm.tir.PrimExpr): + elif isinstance(expr, tvm.tirx.PrimExpr): return tvm.arith.Analyzer().simplify(expr) else: return expr @@ -282,7 +282,7 @@ def ravel_index(indices, shape): Parameters ---------- - indices : tuple of int or tvm.tir.IntImm + indices : tuple of int or tvm.tirx.IntImm The input coordinates shape : tuple of int @@ -307,7 +307,7 @@ def unravel_index(idx, shape): Parameters ---------- - idx : int or tvm.tir.IntImm + idx : int or tvm.tirx.IntImm The 1D index shape : tuple of int @@ -315,11 +315,11 @@ def unravel_index(idx, shape): Returns ------- - indices : tuple of int or tvm.tir.IntImm + indices : tuple of int or tvm.tirx.IntImm Corresponding coordinate of the 1D index """ - idxd = tvm.tir.indexdiv - idxm = tvm.tir.indexmod + idxd = tvm.tirx.indexdiv + idxm = tvm.tirx.indexmod indices = [] for i, dim in enumerate(reversed(shape)): if dim == 0: @@ -353,15 +353,15 @@ def const_matrix(matrix, name="const_matrix", attrs=None): """ row, col = matrix.shape dtype = str(matrix.dtype) - idxm = tvm.tir.indexmod + idxm = tvm.tirx.indexmod def select_array(i, j): - now = tvm.tir.const(0.0, dtype) + now = tvm.tirx.const(0.0, dtype) for ii in range(row): for jj in range(col): - now = tvm.tir.Select( - tvm.tir.all(idxm(i, row) == ii, idxm(j, col) == jj), - tvm.tir.const(matrix[ii][jj], dtype), + now = tvm.tirx.Select( + tvm.tirx.all(idxm(i, row) == ii, idxm(j, col) == jj), + tvm.tirx.const(matrix[ii][jj], dtype), now, ) return now @@ -457,10 +457,10 @@ def within_index(b, e, s, i): bool expression that is True is the array position would be selected by the index and False otherwise """ - bc = tvm.tir.Select(s < 0, i <= e, i < b) - ec = tvm.tir.Select(s < 0, i > b, i >= e) + bc = tvm.tirx.Select(s < 0, i <= e, i < b) + ec = tvm.tirx.Select(s < 0, i > b, i >= e) ss = te.if_then_else(s < 0, ((i - e) + (e % te.abs(s)) + 1) % te.abs(s), (i - b) % s) - return tvm.tir.Select(tvm.tir.Or(bc, ec), tvm.tir.const(False), ss.equal(0)) + return tvm.tirx.Select(tvm.tirx.Or(bc, ec), tvm.tirx.const(False), ss.equal(0)) def make_idx(b, e, s, z, i): @@ -492,14 +492,14 @@ def make_idx(b, e, s, z, i): position: Expr int expression that corresponds to an array position in the selection. """ - bc = tvm.tir.Select(s < 0, i <= e, i < b) - ec = tvm.tir.Select(s < 0, i > b, i >= e) + bc = tvm.tirx.Select(s < 0, i <= e, i < b) + ec = tvm.tirx.Select(s < 0, i > b, i >= e) # Clamp to array size - b = tvm.tir.Select(z < b, z - 1, b) + b = tvm.tirx.Select(z < b, z - 1, b) - ss = tvm.tir.if_then_else(s < 0, (b - i) // te.abs(s), (i - b) // s) - return tvm.tir.if_then_else(tvm.tir.Or(bc, ec), 88, ss) + ss = tvm.tirx.if_then_else(s < 0, (b - i) // te.abs(s), (i - b) // s) + return tvm.tirx.if_then_else(tvm.tirx.Or(bc, ec), 88, ss) def is_empty_shape(shape): @@ -520,7 +520,7 @@ def is_empty_shape(shape): def ceil_div(a, b): """Return ceil division of a by b""" - return tvm.tir.indexdiv(a + (b - 1), b) + return tvm.tirx.indexdiv(a + (b - 1), b) def swap(arr, axis): diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 6edabfcf0785..9bdedc35352c 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -20,8 +20,8 @@ import tvm from tvm import te from tvm.script.ir_builder import IRBuilder -from tvm.script.ir_builder import tir as T -from tvm.tir import if_then_else +from tvm.script.ir_builder import tirx as T +from tvm.tirx import if_then_else from .. import reduction from ..math import cast @@ -60,9 +60,9 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): # pyl Related index in input data. """ if isinstance(score_threshold, float | int): - score_threshold = tvm.tir.const(score_threshold, dtype=data.dtype) - # id_index_const = tvm.tir.const(id_index, "int32") # Unused - # score_index_const = tvm.tir.const(score_index, "int32") # Unused + score_threshold = tvm.tirx.const(score_threshold, dtype=data.dtype) + # id_index_const = tvm.tirx.const(id_index, "int32") # Unused + # score_index_const = tvm.tirx.const(score_index, "int32") # Unused return ( te.compute((data.shape[0],), lambda i: data.shape[1], name="valid_count"), data, @@ -98,7 +98,7 @@ def nms_inner_loop(i, j, nkeep, num_valid_boxes_local): k = j + 1 + _k with T.If( - tvm.tir.all( + tvm.tirx.all( k < nkeep, out_scores[i, k] > 0, # is the box k still valid? needs_bbox_check_func(i, j, k), @@ -113,9 +113,9 @@ def nms_inner_loop(i, j, nkeep, num_valid_boxes_local): on_new_invalidated_box_func(i, k) with T.serial(0, batch_size) as i: - nkeep = if_then_else(tvm.tir.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) + nkeep = if_then_else(tvm.tirx.all(top_k > 0, top_k < valid_count[i]), top_k, valid_count[i]) - with T.If(tvm.tir.all(iou_threshold > te.const(0), valid_count[i] > te.const(0))): + with T.If(tvm.tirx.all(iou_threshold > te.const(0), valid_count[i] > te.const(0))): with T.Then(): num_valid_boxes_local_buf = T.alloc_buffer((1,), "int32", scope="local") num_valid_boxes_local = T.buffer_proxy(num_valid_boxes_local_buf) @@ -123,7 +123,7 @@ def nms_inner_loop(i, j, nkeep, num_valid_boxes_local): with T.serial(0, nkeep) as j: with T.If( - tvm.tir.all( + tvm.tirx.all( out_scores[i, j] > -1.0, # box is still valid num_valid_boxes_local[0] < max_output_size, # haven't reached max limit ) @@ -154,20 +154,20 @@ def searchsorted_ir(scores_buf, score_thresh_buf, valid_count_buf): elif len(score_threshold.shape) == 1 and score_threshold.shape[0] > 0: score_thresh_scalar = score_thresh_buf[0] else: - score_thresh_scalar = tvm.tir.FloatImm("float32", 0.0) + score_thresh_scalar = tvm.tirx.FloatImm("float32", 0.0) else: score_thresh_scalar = score_threshold binary_search(i, num_boxes, scores_buf, score_thresh_scalar, valid_count_buf) return ib.get() - scores_buf = tvm.tir.decl_buffer(scores.shape, scores.dtype, "scores_buf", data_alignment=8) - searchsorted_buf = tvm.tir.decl_buffer( + scores_buf = tvm.tirx.decl_buffer(scores.shape, scores.dtype, "scores_buf", data_alignment=8) + searchsorted_buf = tvm.tirx.decl_buffer( (batch_classes,), "int32", "searchsorted", data_alignment=8 ) if hasattr(score_threshold, "shape"): - score_thresh_buf = tvm.tir.decl_buffer( + score_thresh_buf = tvm.tirx.decl_buffer( score_threshold.shape, score_threshold.dtype, "score_thresh_buf", data_alignment=8 ) return te.extern( @@ -191,9 +191,9 @@ def searchsorted_ir_scalar(scores_buf, valid_count_buf): elif len(score_threshold.shape) == 1 and score_threshold.shape[0] == 1: score_thresh_tir = score_threshold[0] else: - score_thresh_tir = tvm.tir.FloatImm("float32", 0.0) + score_thresh_tir = tvm.tirx.FloatImm("float32", 0.0) else: - score_thresh_tir = tvm.tir.FloatImm("float32", float(score_threshold)) + score_thresh_tir = tvm.tirx.FloatImm("float32", float(score_threshold)) binary_search(i, num_boxes, scores_buf, score_thresh_tir, valid_count_buf) return ib.get() @@ -236,15 +236,15 @@ def _collect_selected_indices_ir( class_id = i_64 % num_class if isinstance(max_output_boxes_per_class, int): - limit = tvm.tir.min( - num_detections[i], tvm.tir.IntImm("int32", max_output_boxes_per_class) + limit = tvm.tirx.min( + num_detections[i], tvm.tirx.IntImm("int32", max_output_boxes_per_class) ) elif isinstance(max_output_boxes_per_class, te.Tensor): if len(max_output_boxes_per_class.shape) == 0: max_boxes_val = max_output_boxes_per_class[()] else: max_boxes_val = max_output_boxes_per_class[0] - limit = tvm.tir.min(num_detections[i], max_boxes_val) + limit = tvm.tirx.min(num_detections[i], max_boxes_val) else: limit = num_detections[i] @@ -392,18 +392,18 @@ def all_class_non_max_suppression( def _sum_clamped_total(): if isinstance(max_output_boxes_per_class, int): - k_expr = tvm.tir.IntImm("int32", int(max_output_boxes_per_class)) + k_expr = tvm.tirx.IntImm("int32", int(max_output_boxes_per_class)) clamped = te.compute( num_detections.shape, - lambda i: tvm.tir.min(num_detections[i], k_expr), + lambda i: tvm.tirx.min(num_detections[i], k_expr), name="clamped_num", ) return reduction.sum(cast(clamped, "int64"), axis=0) - if isinstance(max_output_boxes_per_class, tvm.tir.IntImm): - k_expr = tvm.tir.Cast("int32", max_output_boxes_per_class) + if isinstance(max_output_boxes_per_class, tvm.tirx.IntImm): + k_expr = tvm.tirx.Cast("int32", max_output_boxes_per_class) clamped = te.compute( num_detections.shape, - lambda i: tvm.tir.min(num_detections[i], k_expr), + lambda i: tvm.tirx.min(num_detections[i], k_expr), name="clamped_num", ) return reduction.sum(cast(clamped, "int64"), axis=0) @@ -428,7 +428,7 @@ def _sum_clamped_total(): clamped = te.compute( num_detections.shape, - lambda i: tvm.tir.min(num_detections[i], kb[i]), + lambda i: tvm.tirx.min(num_detections[i], kb[i]), name="clamped_num", ) return reduction.sum(cast(clamped, "int64"), axis=0) diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index 0fb63dece697..ae1716897069 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -21,7 +21,7 @@ import tvm from tvm import te from tvm.script.ir_builder import IRBuilder -from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import tirx as T def _get_boundaries(output, box_idx): @@ -59,7 +59,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): # total area of the figure formed by box a and box b # except for overlapping area u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area - return tvm.tir.Select(u <= 0.0, 0.0, area / u) + return tvm.tirx.Select(u <= 0.0, 0.0, area / u) def binary_search(y, num_boxes, scores, score_threshold, out): @@ -73,7 +73,7 @@ def binary_search(y, num_boxes, scores, score_threshold, out): lo = T.buffer_proxy(lo_buf) hi = T.buffer_proxy(hi_buf) lo[0] = T.int32(0) - hi[0] = tvm.tir.Cast("int32", num_boxes) + hi[0] = tvm.tirx.Cast("int32", num_boxes) with T.While(lo[0] < hi[0]): mid = (hi[0] + lo[0]) >> 1 with T.If(scores[y, mid] > score_threshold): @@ -304,17 +304,17 @@ def _all_class_nms_ir( selected_scores = T.buffer_proxy(selected_scores) if isinstance(iou_threshold, float): - iou_threshold = tvm.tir.FloatImm("float32", iou_threshold) + iou_threshold = tvm.tirx.FloatImm("float32", iou_threshold) elif isinstance(iou_threshold, te.Tensor): if len(iou_threshold.shape) == 0: iou_threshold = iou_threshold() elif len(iou_threshold.shape) == 1 and iou_threshold.shape[0] == 1: iou_threshold = iou_threshold[0] else: - iou_threshold = tvm.tir.FloatImm("float32", 0.5) + iou_threshold = tvm.tirx.FloatImm("float32", 0.5) if isinstance(max_output_size_per_class, int): - max_output_size_per_class = tvm.tir.const(max_output_size_per_class) + max_output_size_per_class = tvm.tirx.const(max_output_size_per_class) elif isinstance(max_output_size_per_class, te.Tensor): if len(max_output_size_per_class.shape) == 0: max_output_size_per_class = max_output_size_per_class() @@ -324,7 +324,7 @@ def _all_class_nms_ir( ): max_output_size_per_class = max_output_size_per_class[0] else: - max_output_size_per_class = tvm.tir.const(1000) + max_output_size_per_class = tvm.tirx.const(1000) def calc_overlap(i, j, k): offset_j = sorted_indices[i, j] * 4 @@ -348,11 +348,11 @@ def on_new_invalidated_box(*_): pass def needs_bbox_check(*_): - return tvm.tir.const(True) + return tvm.tirx.const(True) nms_loop( batch_class, - tvm.tir.IntImm("int32", -1), # top_k + tvm.tirx.IntImm("int32", -1), # top_k iou_threshold, max_output_size_per_class, valid_count, @@ -415,10 +415,10 @@ def run_all_class_nms( num_class = batch_class // batch if return_scores is False: - all_class_num0_buf = tvm.tir.decl_buffer( + all_class_num0_buf = tvm.tirx.decl_buffer( (batch_class, num_boxes), "int32", "all_class_nms0", data_alignment=8 ) - all_class_num1_buf = tvm.tir.decl_buffer( + all_class_num1_buf = tvm.tirx.decl_buffer( (batch_class,), "int32", "all_class_nms1", data_alignment=8 ) extern_inputs = [boxes, sorted_scores, sorted_indices, valid_count] diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index cae6b2d7c6b5..f823e9efca95 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -23,8 +23,8 @@ #include #include #include -#include -#include +#include +#include #include "./scalable_expression.h" #include "const_fold.h" @@ -55,7 +55,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { TVM_FFI_ICHECK(range.defined()); - if (tir::is_one(range->extent)) { + if (tirx::is_one(range->extent)) { this->Bind(var, range->min, allow_override); } else { this->const_int_bound.Bind(var, range, allow_override); @@ -69,7 +69,7 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) { void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) { // decompose value as symbol * scale + offset int64_t offset = 0; - PrimExpr symbol_scale = tir::make_const(value.dtype(), 0); + PrimExpr symbol_scale = tirx::make_const(value.dtype(), 0); auto fcollect_sum = [&](PrimExpr val, int sign) { if (const auto* intimm = val.as()) { @@ -86,7 +86,7 @@ void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) { // split out the symbol and non-symbolic part int64_t cscale = 1; - PrimExpr symbol = tir::make_const(value.dtype(), 1); + PrimExpr symbol = tirx::make_const(value.dtype(), 1); auto fcollect_prod = [&](PrimExpr val) { if (const auto* intimm = val.as()) { cscale *= intimm->value; @@ -94,7 +94,7 @@ void Analyzer::MarkGlobalNonNegValue(const PrimExpr& value) { symbol = symbol * val; } }; - UnpackReduction(symbol_scale, fcollect_prod); + UnpackReduction(symbol_scale, fcollect_prod); if (cscale <= 0) return; // override the constant int bound by marking it as non-negative // NOTE: there might be future opportunities of more bound hint @@ -143,7 +143,7 @@ void ConstraintContext::ExitWithScope() { } bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { - if (const auto* ptr = expr.as()) { + if (const auto* ptr = expr.as()) { return ptr->value >= lower_bound; } auto bd = this->const_int_bound(this->rewrite_simplify(expr)); @@ -152,7 +152,7 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) { } bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) { - if (const auto* ptr = expr.as()) { + if (const auto* ptr = expr.as()) { return ptr->value < upper_bound; } auto bd = this->const_int_bound(this->rewrite_simplify(expr)); @@ -173,7 +173,7 @@ bool Analyzer::CanProveEqual(const PrimExpr& lhs, const PrimExpr& rhs) { bool Analyzer::CanProveLessEqualThanSymbolicShapeValue(const PrimExpr& lhs, const PrimExpr& shape) { if (this->CanProve(lhs <= shape, ProofStrength::kSymbolicBound)) return true; // no need to do further attempt if shape is already a constant. - if (tir::is_const_int(shape)) return false; + if (tirx::is_const_int(shape)) return false; // collect constant scale and ignore symbolic part // so 32 * n => cscale = 32 int64_t cscale = 1; @@ -182,7 +182,7 @@ bool Analyzer::CanProveLessEqualThanSymbolicShapeValue(const PrimExpr& lhs, cons cscale *= ptr->value; } }; - UnpackReduction(shape, fcollect); + UnpackReduction(shape, fcollect); PrimExpr const_shape_bound = IntImm(shape.dtype(), std::abs(cscale)); if (this->CanProve(lhs <= const_shape_bound, ProofStrength::kSymbolicBound)) return true; return false; @@ -194,7 +194,7 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { return ptr->value != 0; } PrimExpr simplified = Simplify(expr); - const int64_t* as_int = tir::as_const_int(simplified); + const int64_t* as_int = tirx::as_const_int(simplified); if (as_int && *as_int) return true; if (strength >= ProofStrength::kSymbolicBound) { // NOTE: we intentionally only pattern match common bound predicate i < bound @@ -204,19 +204,19 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { // This strategy can only be called from top-level and not from sub-analyzers. ffi::Optional pos_diff; int lower_bound = 0; - if (const auto* ptr_lt = expr.as()) { + if (const auto* ptr_lt = expr.as()) { pos_diff = ptr_lt->b - ptr_lt->a; lower_bound = 1; } - if (const auto* ptr_le = expr.as()) { + if (const auto* ptr_le = expr.as()) { pos_diff = ptr_le->b - ptr_le->a; lower_bound = 0; } - if (const auto* ptr_gt = expr.as()) { + if (const auto* ptr_gt = expr.as()) { pos_diff = ptr_gt->a - ptr_gt->b; lower_bound = 1; } - if (const auto* ptr_ge = expr.as()) { + if (const auto* ptr_ge = expr.as()) { pos_diff = ptr_ge->a - ptr_ge->b; lower_bound = 0; } @@ -257,7 +257,7 @@ PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) { res = this->canonical_simplify(res); for (int i = 0; i < steps; ++i) { - if (tir::is_const_int(res)) { + if (tirx::is_const_int(res)) { return res; } if (i % 2 == 0) { diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index e23219465376..4bb566f27fd3 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -24,8 +24,8 @@ #include #include #include -#include -#include +#include +#include #include #include @@ -35,7 +35,7 @@ namespace tvm { namespace arith { -using namespace tir; +using namespace tirx; // a visitor to find the path to the target variable // from a expression. diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 5192bc1ad179..68696b9de10b 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -22,8 +22,8 @@ * \brief Canonical form based simplification. */ #include -#include -#include +#include +#include #include "const_fold.h" #include "pattern_match.h" @@ -33,7 +33,7 @@ namespace tvm { namespace arith { -using namespace tir; +using namespace tirx; class SumExpr; class SplitExpr; @@ -216,7 +216,7 @@ class SplitExpr : public PrimExpr { inline bool SplitExprNode::IndexEqual(const SplitExpr& other) const { if (index.same_as(other->index)) return true; - return tir::ExprDeepEqual()(index, other->index); + return tirx::ExprDeepEqual()(index, other->index); } inline bool SplitExprNode::DivModeCompatibleTo(DivMode mode) const { @@ -900,18 +900,18 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, // collect lhs product and constant scale. auto fcollect_lhs = [&](PrimExpr value) { - if (auto* intimm = value.as()) { + if (auto* intimm = value.as()) { lhs_cscale *= intimm->value; } else { lhs_prods.push_back(value); } }; - UnpackReduction(*plhs, fcollect_lhs); + UnpackReduction(*plhs, fcollect_lhs); // collect rhs product and try to eliminate when possible PEqualChecker deep_equal; auto fcollect_rhs = [&](PrimExpr value) { - if (auto* intimm = value.as()) { + if (auto* intimm = value.as()) { rhs_cscale *= intimm->value; } else { // try eliminate from lhs @@ -927,7 +927,7 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, new_rhs = new_rhs * value; } }; - UnpackReduction(*prhs, fcollect_rhs); + UnpackReduction(*prhs, fcollect_rhs); // find gcd of const scales. int64_t cscale_gcd = ZeroAwareGCD(lhs_cscale, rhs_cscale); lhs_cscale /= cscale_gcd; diff --git a/src/arith/conjunctive_normal_form.cc b/src/arith/conjunctive_normal_form.cc index 47b6156cfa06..92afb242313a 100644 --- a/src/arith/conjunctive_normal_form.cc +++ b/src/arith/conjunctive_normal_form.cc @@ -24,7 +24,7 @@ #include "conjunctive_normal_form.h" #include -#include +#include #include #include @@ -138,10 +138,10 @@ class AndOfOrs { /*! \brief Mapping from PrimExpr to internal Key */ std::unordered_map expr_to_key_; - /*! \brief Cached key representing tir::Bool(true) */ + /*! \brief Cached key representing tirx::Bool(true) */ Key key_true_; - /*! \brief Cached key representing tir::Bool(false) */ + /*! \brief Cached key representing tirx::Bool(false) */ Key key_false_; }; diff --git a/src/arith/conjunctive_normal_form.h b/src/arith/conjunctive_normal_form.h index 84ee972d030e..a173ca587cdb 100644 --- a/src/arith/conjunctive_normal_form.h +++ b/src/arith/conjunctive_normal_form.h @@ -26,7 +26,7 @@ #ifndef TVM_ARITH_CONJUNCTIVE_NORMAL_FORM_H_ #define TVM_ARITH_CONJUNCTIVE_NORMAL_FORM_H_ -#include +#include namespace tvm { namespace arith { diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 4128e43e6e25..8464443118f9 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -25,8 +25,8 @@ #define TVM_ARITH_CONST_FOLD_H_ #include -#include -#include +#include +#include #include #include @@ -110,7 +110,7 @@ inline double GetFoldResultDoubleRepr(float x) { } #define TVM_ARITH_CONST_PROPAGATION(BODY) \ - using tir::FloatImmNode; \ + using tirx::FloatImmNode; \ const IntImmNode* pa = a.as(); \ const IntImmNode* pb = b.as(); \ const FloatImmNode* fa = a.as(); \ @@ -128,7 +128,7 @@ inline double GetFoldResultDoubleRepr(float x) { // specialization of constant folders. template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -152,7 +152,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ TVM_FFI_ICHECK(!((pa && pa->dtype.is_uint() && pa->value == 0U) && (pb && pb->dtype.is_uint() && pb->value > 0U))) @@ -178,7 +178,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -214,7 +214,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -250,7 +250,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -262,7 +262,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { if (pa->value == 0) return a; } if (pb) { - if (pb->value == 1) return tir::make_zero(rtype); + if (pb->value == 1) return tirx::make_zero(rtype); TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; } }); @@ -270,7 +270,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -305,7 +305,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr } template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_INDEX_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) { @@ -317,7 +317,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr if (pa->value == 0) return a; } if (pb) { - if (pb->value == 1) return tir::make_zero(rtype); + if (pb->value == 1) return tirx::make_zero(rtype); TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; } }); @@ -325,7 +325,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr } template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value)); @@ -336,7 +336,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value)); @@ -347,7 +347,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::Bool(), pa->value > pb->value); if (fa && fb) return IntImm(DataType::Bool(), fa->value > fb->value); @@ -356,7 +356,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::Bool(), pa->value >= pb->value); if (fa && fb) return IntImm(DataType::Bool(), fa->value >= fb->value); @@ -365,7 +365,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::Bool(), pa->value < pb->value); if (fa && fb) return IntImm(DataType::Bool(), fa->value < fb->value); @@ -374,7 +374,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::Bool(), pa->value <= pb->value); if (fa && fb) return IntImm(DataType::Bool(), fa->value <= fb->value); @@ -383,7 +383,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::Bool(), pa->value == pb->value); if (fa && fb) return IntImm(DataType::Bool(), fa->value == fb->value); @@ -392,7 +392,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return IntImm(DataType::Bool(), pa->value != pb->value); if (fa && fb) return IntImm(DataType::Bool(), fa->value != fb->value); @@ -401,7 +401,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); if (pa && pa->value) return b; @@ -412,7 +412,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { +inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { const IntImmNode* pa = a.as(); const IntImmNode* pb = b.as(); if (pa && pa->value) return a; @@ -423,7 +423,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { } template <> -inline ffi::Optional TryConstFold(PrimExpr a) { +inline ffi::Optional TryConstFold(PrimExpr a) { const IntImmNode* pa = a.as(); if (pa) { return IntImm(DataType::Bool(), !(pa->value)); diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 78a456784c8d..8e306ba96650 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -23,8 +23,8 @@ #include #include #include -#include -#include +#include +#include #include #include @@ -37,7 +37,7 @@ namespace tvm { namespace arith { -using namespace tir; +using namespace tirx; TVM_FFI_STATIC_INIT_BLOCK() { ConstIntBoundNode::RegisterReflection(); } @@ -163,7 +163,7 @@ class ConstIntBoundAnalyzer::Impl Entry VisitExpr(const PrimExpr& expr) final { Entry res = ExprFunctor::VisitExpr(expr); - tir::ExprDeepEqual equal; + tirx::ExprDeepEqual equal; // a linear search over additional info // assume we won't have a lot of conditions for (const BoundInfo& info : additional_info_) { @@ -425,13 +425,13 @@ class ConstIntBoundAnalyzer::Impl // used for index calculation. auto curr_target = Target::Current(); - if (op->op.same_as(tir::builtin::shift_right())) { + if (op->op.same_as(tirx::builtin::shift_right())) { return VisitRightShift(op); - } else if (op->op.same_as(tir::builtin::shift_left())) { + } else if (op->op.same_as(tirx::builtin::shift_left())) { return VisitLeftShift(op); - } else if (op->op.same_as(tir::builtin::bitwise_and())) { + } else if (op->op.same_as(tirx::builtin::bitwise_and())) { return VisitBitwiseAnd(op); - } else if (op->op.same_as(tir::builtin::vscale()) && TargetHasVLA(curr_target)) { + } else if (op->op.same_as(tirx::builtin::vscale()) && TargetHasVLA(curr_target)) { auto kVScaleValues = GetVScaleValues(curr_target); unsigned int max_val = *std::max_element(kVScaleValues.begin(), kVScaleValues.end()); return MakeBound(1, max_val); @@ -807,10 +807,10 @@ class ConstIntBoundAnalyzer::Impl static ffi::Optional FindCeilLog2Arg(const CastNode* op) { if (op->dtype.is_int()) { if (auto as_call = op->value.as()) { - if (as_call->op.same_as(Op::Get("tir.ceil"))) { + if (as_call->op.same_as(Op::Get("tirx.ceil"))) { PrimExpr ceil_arg = as_call->args[0]; if (auto arg_call = ceil_arg.as()) { - if (arg_call->op.same_as(Op::Get("tir.log2"))) { + if (arg_call->op.same_as(Op::Get("tirx.log2"))) { PrimExpr log_arg = arg_call->args[0]; return log_arg; } diff --git a/src/arith/constraint_extract.cc b/src/arith/constraint_extract.cc index b873adcb5ca4..ba735c814e97 100644 --- a/src/arith/constraint_extract.cc +++ b/src/arith/constraint_extract.cc @@ -24,7 +24,7 @@ #include "constraint_extract.h" #include -#include +#include #include "pattern_match.h" diff --git a/src/arith/constraint_extract.h b/src/arith/constraint_extract.h index 815eafeebd62..7430ef960db1 100644 --- a/src/arith/constraint_extract.h +++ b/src/arith/constraint_extract.h @@ -26,7 +26,7 @@ #ifndef TVM_ARITH_CONSTRAINT_EXTRACT_H_ #define TVM_ARITH_CONSTRAINT_EXTRACT_H_ -#include +#include #include diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index 4a0b5f9cf0c3..3fb74940b4e0 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -24,16 +24,16 @@ #include #include #include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include namespace tvm { namespace arith { -using namespace tir; +using namespace tirx; // Linear equation, the components can be undefined. struct LinearEqEntry { @@ -268,7 +268,7 @@ void SplitCommExpr(const PrimExpr& e, std::vector* ret) { ffi::Array DetectClipBound(const PrimExpr& e, const ffi::Array& vars) { std::vector splits; Analyzer analyzer; - SplitCommExpr(analyzer.Simplify(e), &splits); + SplitCommExpr(analyzer.Simplify(e), &splits); std::unordered_map rmap; for (Var v : vars) { rmap[v.get()] = IntervalEntry(); diff --git a/src/arith/domain_touched.cc b/src/arith/domain_touched.cc index dbfa334107ec..1efcedc21850 100644 --- a/src/arith/domain_touched.cc +++ b/src/arith/domain_touched.cc @@ -24,8 +24,8 @@ #include #include #include -#include -#include +#include +#include #include #include @@ -36,7 +36,7 @@ namespace tvm { namespace arith { -using namespace tir; +using namespace tirx; namespace { diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 3f4048bfd191..6d8e539357a2 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -25,16 +25,16 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include #include -#include "../tir/transform/ir_utils.h" +#include "../tirx/transform/ir_utils.h" namespace tvm { namespace arith { @@ -86,11 +86,11 @@ IntGroupBounds::IntGroupBounds(PrimExpr coef, ffi::Array lower, IntGroupBounds IntGroupBounds::FromRange(const Range& r) { Analyzer analyzer; - PrimExpr coef = tir::make_const(r->min.dtype(), 1); + PrimExpr coef = tirx::make_const(r->min.dtype(), 1); ffi::Array equal; ffi::Array lower; ffi::Array upper; - if (tir::is_one(r->extent)) { + if (tirx::is_one(r->extent)) { equal.push_back(r->min); } else { lower.push_back(r->min); @@ -105,7 +105,7 @@ IntGroupBounds IntGroupBounds::operator+(const Range& r) { ffi::Array lower; ffi::Array upper; const PrimExpr& coef = operator->()->coef; - if (tir::is_one(r->extent)) { + if (tirx::is_one(r->extent)) { equal.push_back(analyzer.Simplify(r->min * coef)); } else { lower.push_back(analyzer.Simplify(r->min * coef)); @@ -118,11 +118,11 @@ IntGroupBounds IntGroupBounds::operator+(const Range& r) { } IntGroupBounds IntGroupBounds::Substitute(const ffi::Map& subst) const { - auto apply_fun = [&subst](const PrimExpr& e) { return tir::Substitute(e, subst); }; - return IntGroupBounds(tir::Substitute(operator->()->coef, subst), - tir::UpdateArray(operator->()->lower, apply_fun), - tir::UpdateArray(operator->()->equal, apply_fun), - tir::UpdateArray(operator->()->upper, apply_fun)); + auto apply_fun = [&subst](const PrimExpr& e) { return tirx::Substitute(e, subst); }; + return IntGroupBounds(tirx::Substitute(operator->()->coef, subst), + tirx::UpdateArray(operator->()->lower, apply_fun), + tirx::UpdateArray(operator->()->equal, apply_fun), + tirx::UpdateArray(operator->()->upper, apply_fun)); } Range IntGroupBounds::FindBestRange(const ffi::Map& vranges_addl) const { @@ -146,7 +146,7 @@ Range IntGroupBounds::FindBestRange(const ffi::Map& vranges_addl) co uppers.push_back(expr); } - if (lowers.size() == 1 && uppers.size() == 1 && tir::is_one(coef)) { + if (lowers.size() == 1 && uppers.size() == 1 && tirx::is_one(coef)) { return Range(analyzer.Simplify(lowers[0]), analyzer.Simplify(uppers[0] + 1)); } diff --git a/src/arith/int_operator.h b/src/arith/int_operator.h index 6dec9a5502e1..661ea188c468 100644 --- a/src/arith/int_operator.h +++ b/src/arith/int_operator.h @@ -45,21 +45,24 @@ inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_va } template <> -inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, + int64_t max_value) { if ((y > 0) && (x > max_value - y)) return true; if ((y < 0) && (x < min_value - y)) return true; return false; } template <> -inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, + int64_t max_value) { if ((y > 0) && (x < min_value + y)) return true; if ((y < 0) && (x > max_value + y)) return true; return false; } template <> -inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, + int64_t max_value) { if (y == 0) return false; if (y > 0) { if (x < min_value / y) return true; @@ -73,7 +76,8 @@ inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, } template <> -inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, int64_t max_value) { +inline bool WillOverflow(int64_t x, int64_t y, int64_t min_value, + int64_t max_value) { return y == 0; } diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 25c825cbbc7b..c8e1d73771c0 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -25,9 +25,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include @@ -41,10 +41,10 @@ namespace tvm { namespace arith { -using tir::is_one; -using tir::is_zero; -using tir::make_const; -using tir::make_zero; +using tirx::is_one; +using tirx::is_zero; +using tirx::make_const; +using tirx::make_zero; TVM_FFI_STATIC_INIT_BLOCK() { IntervalSetNode::RegisterReflection(); } @@ -95,7 +95,7 @@ struct is_logical_op { #define TVM_DECLARE_LOGICAL_OP(OP) \ template <> \ - struct is_logical_op { \ + struct is_logical_op { \ static const bool value = true; \ }; @@ -140,8 +140,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, con } template <> -inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b, - const tir::AddNode* /* op */) { +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b, + const tirx::AddNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value + b->min_value); } @@ -155,8 +155,8 @@ inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalS } template <> -inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b, - const tir::SubNode* /* op */) { +inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b, + const tirx::SubNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value - b->min_value); } @@ -170,8 +170,8 @@ inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalS } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - const tir::MulNode* /* op */) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, + const tirx::MulNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value * b->min_value); } @@ -192,7 +192,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval PrimExpr max_value = a->HasLowerBound() ? a->min_value * b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using tir::Select; + using tirx::Select; PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); PrimExpr e1 = a->min_value * b->min_value; PrimExpr e2 = a->max_value * b->min_value; @@ -204,8 +204,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - const tir::DivNode* /* op */) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, + const tirx::DivNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value / b->min_value); } @@ -226,7 +226,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval PrimExpr max_value = a->HasLowerBound() ? a->min_value / b->min_value : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using tir::Select; + using tirx::Select; PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); PrimExpr e1 = a->min_value / b->min_value; PrimExpr e2 = a->max_value / b->min_value; @@ -238,8 +238,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - const tir::ModNode* op) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, + const tirx::ModNode* op) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value)); } @@ -267,8 +267,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - const tir::FloorDivNode* /* op */) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, + const tirx::FloorDivNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value)); } @@ -289,7 +289,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int PrimExpr max_value = a->HasLowerBound() ? floordiv(a->min_value, b->min_value) : pos_inf(); return IntervalSet(min_value, max_value); } else if (a->HasUpperBound() && a->HasLowerBound()) { - using tir::Select; + using tirx::Select; PrimExpr sign = b->min_value >= make_zero(b->min_value.dtype().element_of()); PrimExpr e1 = floordiv(a->min_value, b->min_value); PrimExpr e2 = floordiv(a->max_value, b->min_value); @@ -301,8 +301,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int } template <> -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - const tir::FloorModNode* op) { +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, + const tirx::FloorModNode* op) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value)); } @@ -315,7 +315,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int TVM_FFI_THROW(InternalError) << "Modular by zero in CombineInterval Mod"; } if (analyzer->CanProveGreaterEqual(divisor, 0)) { - if (divisor.as()) { + if (divisor.as()) { // a mod b = a - (a / b) * b if a_max / b == a_min / b auto qmax = a->HasUpperBound() ? floordiv(a->max_value, divisor) : pos_inf(); auto qmin = a->HasLowerBound() ? floordiv(a->min_value, divisor) : neg_inf(); @@ -329,7 +329,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int } } // Enhanced: Use ModularSet analysis for better bounds - if (auto* div_imm = divisor.as()) { + if (auto* div_imm = divisor.as()) { int64_t div_val = div_imm->value; // Analyze the modular properties of the dividend @@ -362,8 +362,8 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int } template <> -inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b, - const tir::MaxNode* /* op */) { +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b, + const tirx::MaxNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); } @@ -373,8 +373,8 @@ inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, Interval } template <> -inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b, - const tir::MinNode* /* op */) { +inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b, + const tirx::MinNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(min(a->min_value, b->min_value)); } @@ -392,7 +392,7 @@ IntervalSet ToIntervalSet(IntSet set) { return IntervalSet::Everything(); } -using namespace tir; +using namespace tirx; // Simplified version of int set evaluator that operates on IntervalSet // We might use better set analysis in the future to replace the intervalset. @@ -506,23 +506,23 @@ class IntervalSetEvaluator : public ExprFunctor { int lanes = static_cast(Downcast(op->lanes)->value); if (vstride > 0) { PrimExpr stride_expr = make_const(t, vstride * (lanes - 1)); - auto add_op = tir::Add(op->base, stride_expr); - auto add_node = add_op.as(); + auto add_op = tirx::Add(op->base, stride_expr); + auto add_node = add_op.as(); return Combine(analyzer_, base, IntervalSet(make_zero(t), stride_expr), add_node); } else { PrimExpr stride_expr = make_const(t, vstride * (lanes - 1)); - auto add_op = tir::Add(op->base, stride_expr); - auto add_node = add_op.as(); + auto add_op = tirx::Add(op->base, stride_expr); + auto add_node = add_op.as(); return Combine(analyzer_, base, IntervalSet(stride_expr, make_zero(t)), add_node); } } else { /* Scalable vector */ if (vstride > 0) { - auto add_op = tir::Add(op->base, make_zero(t)); - auto add_node = add_op.as(); + auto add_op = tirx::Add(op->base, make_zero(t)); + auto add_node = add_op.as(); return Combine(analyzer_, base, IntervalSet(make_zero(t), pos_inf()), add_node); } else { - auto add_op = tir::Add(op->base, make_zero(t)); - auto add_node = add_op.as(); + auto add_op = tirx::Add(op->base, make_zero(t)); + auto add_node = add_op.as(); return Combine(analyzer_, base, IntervalSet(neg_inf(), make_zero(t)), add_node); } } @@ -575,7 +575,7 @@ class IntervalSetEvaluator : public ExprFunctor { } IntervalSet VisitExpr_(const CallNode* op) final { - if (op->op.same_as(tir::builtin::vscale())) + if (op->op.same_as(tirx::builtin::vscale())) return IntervalSet(ffi::GetRef(op), ffi::GetRef(op)); return IntervalSet::Everything(); } diff --git a/src/arith/interval_set.h b/src/arith/interval_set.h index b8597db7aa90..471cc7f69274 100644 --- a/src/arith/interval_set.h +++ b/src/arith/interval_set.h @@ -26,7 +26,7 @@ #include #include -#include +#include #include diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 11caef56850b..7e37dc9d489c 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -24,15 +24,15 @@ #include #include -#include -#include +#include +#include namespace tvm { namespace arith { -using namespace tir; +using namespace tirx; -void IRMutatorWithAnalyzer::MarkBufferMapShapes(const tir::PrimFunc& func) { +void IRMutatorWithAnalyzer::MarkBufferMapShapes(const tirx::PrimFunc& func) { // Mark the all the symbolic buffer shape values in the buffer map as positive value. for (auto kv : func->buffer_map) { for (PrimExpr shape : kv.second->shape) { @@ -98,7 +98,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { return constraint_scope_.WithNewScope([&]() -> Stmt { PrimExpr condition = this->VisitExpr(op->condition); PrimExpr real_condition = condition; - static auto op_likely = Op::Get("tir.likely"); + static auto op_likely = Op::Get("tirx.likely"); if (auto call = condition.as()) { if (call->op.same_as(op_likely)) { @@ -139,7 +139,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { return constraint_scope_.WithNewScope([&]() -> Stmt { - if (op->attr_key == tir::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { + if (op->attr_key == tirx::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { IterVar iv = Downcast(op->node); TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); Range dom = Range::FromMinExtent(make_zero(op->value.dtype()), op->value); @@ -170,7 +170,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const SeqStmtNode* op) { PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { // add condition context to if_then_else - static auto op_if_then_else = Op::Get("tir.if_then_else"); + static auto op_if_then_else = Op::Get("tirx.if_then_else"); if (op->op.same_as(op_if_then_else)) { PrimExpr cond = this->VisitExpr(op->args[0]); PrimExpr true_value, false_value; diff --git a/src/arith/ir_mutator_with_analyzer.h b/src/arith/ir_mutator_with_analyzer.h index 0f03fef7d25e..178b2a6d34f5 100644 --- a/src/arith/ir_mutator_with_analyzer.h +++ b/src/arith/ir_mutator_with_analyzer.h @@ -27,8 +27,8 @@ #include #include #include -#include -#include +#include +#include #include @@ -44,7 +44,7 @@ namespace arith { * * \sa src/arithmetic/ir_mutator_with_analyzer.cc */ -class IRMutatorWithAnalyzer : public tir::StmtExprMutator { +class IRMutatorWithAnalyzer : public tirx::StmtExprMutator { public: explicit IRMutatorWithAnalyzer(Analyzer* analyzer) : analyzer_(analyzer) {} @@ -52,17 +52,17 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { using StmtExprMutator::VisitStmt_; // override functions that need to populate the context information. - tir::Stmt VisitStmt_(const tir::ForNode* op) override; - tir::Stmt VisitStmt_(const tir::SBlockNode* op) override; - tir::Stmt VisitStmt_(const tir::BindNode* op) override; - tir::Stmt VisitStmt_(const tir::IfThenElseNode* op) override; - tir::Stmt VisitStmt_(const tir::AttrStmtNode* op) override; - tir::Stmt VisitStmt_(const tir::AssertStmtNode* op) override; - tir::Stmt VisitStmt_(const tir::SeqStmtNode* op) override; - PrimExpr VisitExpr_(const tir::LetNode* op) override; - PrimExpr VisitExpr_(const tir::SelectNode* op) override; - PrimExpr VisitExpr_(const tir::CallNode* op) override; - PrimExpr VisitExpr_(const tir::ReduceNode* op) override; + tirx::Stmt VisitStmt_(const tirx::ForNode* op) override; + tirx::Stmt VisitStmt_(const tirx::SBlockNode* op) override; + tirx::Stmt VisitStmt_(const tirx::BindNode* op) override; + tirx::Stmt VisitStmt_(const tirx::IfThenElseNode* op) override; + tirx::Stmt VisitStmt_(const tirx::AttrStmtNode* op) override; + tirx::Stmt VisitStmt_(const tirx::AssertStmtNode* op) override; + tirx::Stmt VisitStmt_(const tirx::SeqStmtNode* op) override; + PrimExpr VisitExpr_(const tirx::LetNode* op) override; + PrimExpr VisitExpr_(const tirx::SelectNode* op) override; + PrimExpr VisitExpr_(const tirx::CallNode* op) override; + PrimExpr VisitExpr_(const tirx::ReduceNode* op) override; protected: /*! @@ -71,7 +71,7 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { * \note call this function before Visit function's body to maximize * simplification efficiency */ - void MarkBufferMapShapes(const tir::PrimFunc& func); + void MarkBufferMapShapes(const tirx::PrimFunc& func); /*! * \brief Use internal bound information to perform inter map simplification of indices. @@ -99,11 +99,11 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator { */ template void WithRecordIterPredicate(PrimExpr condition, FLambda callback) { - auto f_use_itervar = [this](const tir::VarNode* v) { - return iter_vars_.count(ffi::GetRef(v)); + auto f_use_itervar = [this](const tirx::VarNode* v) { + return iter_vars_.count(ffi::GetRef(v)); }; // simple heuristics for detecting predicate - if (tir::UsesVar(condition, f_use_itervar)) { + if (tirx::UsesVar(condition, f_use_itervar)) { iter_predicates_.push_back(condition); callback(); iter_predicates_.pop_back(); diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc index e5041b159f8d..35f2bec52919 100644 --- a/src/arith/ir_visitor_with_analyzer.cc +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -23,14 +23,14 @@ #include "ir_visitor_with_analyzer.h" #include -#include -#include -#include +#include +#include +#include namespace tvm { namespace arith { -using namespace tir; +using namespace tirx; void IRVisitorWithAnalyzer::VisitStmt_(const ForNode* op) { constraint_scope_.WithNewScope([&]() { @@ -75,7 +75,7 @@ void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { void IRVisitorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) { constraint_scope_.WithNewScope([&]() { - if (op->attr_key == tir::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { + if (op->attr_key == tirx::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { IterVar iv = Downcast(op->node); TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); analyzer_.Bind(iv->var, Range::FromMinExtent(IntImm(op->value->dtype, 0), op->value)); @@ -96,7 +96,7 @@ void IRVisitorWithAnalyzer::VisitStmt_(const SeqStmtNode* op) { void IRVisitorWithAnalyzer::VisitExpr_(const CallNode* op) { // add condition context to if_then_else - static auto op_if_then_else = Op::Get("tir.if_then_else"); + static auto op_if_then_else = Op::Get("tirx.if_then_else"); if (op->op.same_as(op_if_then_else)) { PrimExpr cond = op->args[0]; this->VisitExpr(op->args[0]); diff --git a/src/arith/ir_visitor_with_analyzer.h b/src/arith/ir_visitor_with_analyzer.h index a5455659d0fe..404a14cf20a8 100644 --- a/src/arith/ir_visitor_with_analyzer.h +++ b/src/arith/ir_visitor_with_analyzer.h @@ -28,29 +28,29 @@ #include #include #include -#include -#include +#include +#include namespace tvm { namespace arith { -class IRVisitorWithAnalyzer : public tir::StmtExprVisitor { +class IRVisitorWithAnalyzer : public tirx::StmtExprVisitor { public: PrimExpr Simplify(const PrimExpr& expr) { return analyzer_.Simplify(expr); } using StmtExprVisitor::VisitExpr_; using StmtExprVisitor::VisitStmt_; - void VisitStmt_(const tir::ForNode* op); - void VisitStmt_(const tir::SBlockNode* op); - void VisitStmt_(const tir::BindNode* op); - void VisitStmt_(const tir::IfThenElseNode* op); - void VisitStmt_(const tir::AttrStmtNode* op); - void VisitStmt_(const tir::AssertStmtNode* op); - void VisitStmt_(const tir::SeqStmtNode* op); - void VisitExpr_(const tir::CallNode* op); - void VisitExpr_(const tir::LetNode* op); - void VisitExpr_(const tir::ReduceNode* op); + void VisitStmt_(const tirx::ForNode* op); + void VisitStmt_(const tirx::SBlockNode* op); + void VisitStmt_(const tirx::BindNode* op); + void VisitStmt_(const tirx::IfThenElseNode* op); + void VisitStmt_(const tirx::AttrStmtNode* op); + void VisitStmt_(const tirx::AssertStmtNode* op); + void VisitStmt_(const tirx::SeqStmtNode* op); + void VisitExpr_(const tirx::CallNode* op); + void VisitExpr_(const tirx::LetNode* op); + void VisitExpr_(const tirx::ReduceNode* op); // IRVisitorWithAnalyzer deliberately does not handle Select nodes, // because both sides of a Select node are visited regardless of the diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index d522f6b61749..0e996485a414 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -23,11 +23,11 @@ #include #include #include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include @@ -39,7 +39,7 @@ namespace tvm { namespace arith { -using namespace tir; +using namespace tirx; TVM_FFI_STATIC_INIT_BLOCK() { IterMarkNode::RegisterReflection(); @@ -422,7 +422,7 @@ class IterMapRewriter : public ExprMutator { static bool IterSplitEqual(const IterSplitExpr& lhs, const IterSplitExpr& rhs, bool check_scale = true) { - tir::ExprDeepEqual equal; + tirx::ExprDeepEqual equal; if (!lhs->source.same_as(rhs->source)) return false; if (!equal(lhs->lower_factor, rhs->lower_factor)) return false; if (check_scale && !equal(lhs->scale, rhs->scale)) return false; @@ -432,7 +432,7 @@ class IterMapRewriter : public ExprMutator { struct IterSumEqual { bool operator()(const IterSumExpr& lhs, const IterSumExpr& rhs) const { - tir::ExprDeepEqual equal; + tirx::ExprDeepEqual equal; if (lhs->args.size() != rhs->args.size()) return false; if (!equal(lhs->base, rhs->base)) return false; for (size_t i = 0; i < lhs->args.size(); ++i) { @@ -800,7 +800,7 @@ class IterMapRewriter : public ExprMutator { for (IterSplitExpr split : expr->args) { int64_t symbol_prod_count = 0; int64_t cscale = 1; - PrimExpr res = tir::make_const(split.dtype(), 1); + PrimExpr res = tirx::make_const(split.dtype(), 1); auto fcollect = [&](PrimExpr val) { if (const auto* intimm = val.as()) { cscale *= intimm->value; @@ -809,9 +809,9 @@ class IterMapRewriter : public ExprMutator { ++symbol_prod_count; } }; - UnpackReduction(split->scale, fcollect); + UnpackReduction(split->scale, fcollect); if (cscale != 1) { - res = res * tir::make_const(res.dtype(), cscale); + res = res * tirx::make_const(res.dtype(), cscale); } split.CopyOnWrite()->scale = res; items.emplace_back(Item{cscale, symbol_prod_count, split}); @@ -894,7 +894,7 @@ class IterMapRewriter : public ExprMutator { if (match_source.defined() && !match_source.same_as(expr->args[i]->source)) continue; int reduce_size = 0; auto fcollect = [&](const PrimExpr&) { ++reduce_size; }; - UnpackReduction(expr->args[i]->scale, fcollect); + UnpackReduction(expr->args[i]->scale, fcollect); if (base_index == -1 || reduce_size < min_reduce_size) { min_reduce_size = reduce_size; base_index = static_cast(i); @@ -1240,7 +1240,7 @@ class IterMapRewriter : public ExprMutator { PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs); static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) { - tir::ExprDeepEqual equal; + tirx::ExprDeepEqual equal; for (size_t i = 0; i < lhs->args.size(); ++i) { IterSplitExpr lvalue = lhs->args[i]; if (lvalue->source.same_as(rhs->source) && equal(lvalue->lower_factor, rhs->lower_factor) && diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index 1c3233959da0..9aaa81b1bacb 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -24,9 +24,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include @@ -37,7 +37,7 @@ namespace tvm { namespace arith { -using namespace tir; +using namespace tirx; TVM_FFI_STATIC_INIT_BLOCK() { ModularSetNode::RegisterReflection(); } @@ -263,9 +263,9 @@ class ModularSetAnalyzer::Impl : public ExprFunctor> which can be // used for index calculation. - if (op->op.same_as(tir::builtin::shift_right())) { + if (op->op.same_as(tirx::builtin::shift_right())) { return VisitRightShift(op); - } else if (op->op.same_as(tir::builtin::bitwise_and())) { + } else if (op->op.same_as(tirx::builtin::bitwise_and())) { return VisitBitwiseAnd(op); } else { return Everything(); diff --git a/src/arith/narrow_predicate_expression.cc b/src/arith/narrow_predicate_expression.cc index 07337ee1e151..da2b9da442ca 100644 --- a/src/arith/narrow_predicate_expression.cc +++ b/src/arith/narrow_predicate_expression.cc @@ -24,15 +24,15 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include namespace tvm { namespace arith { -using namespace tir; +using namespace tirx; /* \brief Given a true expression that includes free parameter, * generate a true expression without the free parameters. @@ -48,7 +48,7 @@ using namespace tir; */ // Utility for generating a known true expression from an expression // with free parameters, and the range of those parameters. -class ExpressionNarrower : public tir::ExprMutator { +class ExpressionNarrower : public tirx::ExprMutator { public: static PrimExpr Apply(PrimExpr expr, ffi::Map free_parameters) { TVM_FFI_ICHECK(expr.dtype().is_bool()) << "Expected boolean expression, but received " << expr; @@ -60,7 +60,7 @@ class ExpressionNarrower : public tir::ExprMutator { explicit ExpressionNarrower(ffi::Map free_parameters) : free_parameters_(free_parameters) {} - using Parent = tir::ExprMutator; + using Parent = tirx::ExprMutator; using Parent::VisitExpr_; enum class Context { diff --git a/src/arith/narrow_predicate_expression.h b/src/arith/narrow_predicate_expression.h index 42a7c2cf038f..8262646caa2d 100644 --- a/src/arith/narrow_predicate_expression.h +++ b/src/arith/narrow_predicate_expression.h @@ -23,7 +23,7 @@ */ #include -#include +#include #ifndef TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_ #define TVM_ARITH_NARROW_PREDICATE_EXPRESSION_H_ @@ -50,7 +50,7 @@ namespace arith { * \returns An expression that, if true, implies that the original * expression is also true. */ -PrimExpr NarrowPredicateExpression(PrimExpr expr, ffi::Map free_parameters); +PrimExpr NarrowPredicateExpression(PrimExpr expr, ffi::Map free_parameters); } // namespace arith } // namespace tvm diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 626d0b9cbab5..e34430343e4b 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -44,7 +44,7 @@ * return (max(x, y) + z).Eval(); * } * - * tvm::tir::Var tx, ty; + * tvm::tirx::Var tx, ty; * arith::PVar c; * arith::PVar v; * // We can match integer and Var, both of which are @@ -65,9 +65,9 @@ #ifndef TVM_ARITH_PATTERN_MATCH_H_ #define TVM_ARITH_PATTERN_MATCH_H_ -#include -#include -#include +#include +#include +#include #include #include @@ -158,7 +158,7 @@ class PEqualChecker { public: bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { if (lhs.same_as(rhs)) return true; - return tir::ExprDeepEqual()(lhs, rhs); + return tirx::ExprDeepEqual()(lhs, rhs); } }; @@ -177,9 +177,9 @@ class PEqualChecker { }; template <> -class PEqualChecker { +class PEqualChecker { public: - bool operator()(const tir::Var& lhs, const tir::Var& rhs) const { return lhs.same_as(rhs); } + bool operator()(const tirx::Var& lhs, const tirx::Var& rhs) const { return lhs.same_as(rhs); } }; /*! @@ -369,14 +369,14 @@ class PConstWithTypeLike : public Pattern> { void InitMatch_() const {} bool Match_(const ObjectRef& node) const { - if (const tir::IntImmNode* ptr = node.as()) { + if (const tirx::IntImmNode* ptr = node.as()) { return ptr->value == value_; } else { return false; } } - PrimExpr Eval() const { return tir::make_const(ref_.Eval().dtype(), value_); } + PrimExpr Eval() const { return tirx::make_const(ref_.Eval().dtype(), value_); } private: typename TA::Nested ref_; @@ -405,30 +405,30 @@ class PConstWithTypeLike : public Pattern> { #define TVM_PATTERN_BINARY_OP(FuncName, NodeName) TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, ) // raise ambiguity error for operator overload of / and % -TVM_PATTERN_BINARY_OP_EX(operator/, tir::Div, DivAmbiguityError(a)); -TVM_PATTERN_BINARY_OP_EX(operator%, tir::Mod, DivAmbiguityError(a)); +TVM_PATTERN_BINARY_OP_EX(operator/, tirx::Div, DivAmbiguityError(a)); +TVM_PATTERN_BINARY_OP_EX(operator%, tirx::Mod, DivAmbiguityError(a)); // arithmetic expressions -TVM_PATTERN_BINARY_OP(operator+, tir::Add); -TVM_PATTERN_BINARY_OP(operator-, tir::Sub); -TVM_PATTERN_BINARY_OP(operator*, tir::Mul); -TVM_PATTERN_BINARY_OP(min, tir::Min); -TVM_PATTERN_BINARY_OP(max, tir::Max); -TVM_PATTERN_BINARY_OP(div, tir::Div); -TVM_PATTERN_BINARY_OP(truncdiv, tir::Div); -TVM_PATTERN_BINARY_OP(truncmod, tir::Mod); -TVM_PATTERN_BINARY_OP(floordiv, tir::FloorDiv); -TVM_PATTERN_BINARY_OP(floormod, tir::FloorMod); +TVM_PATTERN_BINARY_OP(operator+, tirx::Add); +TVM_PATTERN_BINARY_OP(operator-, tirx::Sub); +TVM_PATTERN_BINARY_OP(operator*, tirx::Mul); +TVM_PATTERN_BINARY_OP(min, tirx::Min); +TVM_PATTERN_BINARY_OP(max, tirx::Max); +TVM_PATTERN_BINARY_OP(div, tirx::Div); +TVM_PATTERN_BINARY_OP(truncdiv, tirx::Div); +TVM_PATTERN_BINARY_OP(truncmod, tirx::Mod); +TVM_PATTERN_BINARY_OP(floordiv, tirx::FloorDiv); +TVM_PATTERN_BINARY_OP(floormod, tirx::FloorMod); // logical expressions -TVM_PATTERN_BINARY_OP(operator>, tir::GT); -TVM_PATTERN_BINARY_OP(operator>=, tir::GE); -TVM_PATTERN_BINARY_OP(operator<, tir::LT); -TVM_PATTERN_BINARY_OP(operator<=, tir::LE); -TVM_PATTERN_BINARY_OP(operator==, tir::EQ); -TVM_PATTERN_BINARY_OP(operator!=, tir::NE); -TVM_PATTERN_BINARY_OP(operator&&, tir::And); -TVM_PATTERN_BINARY_OP(operator||, tir::Or); +TVM_PATTERN_BINARY_OP(operator>, tirx::GT); +TVM_PATTERN_BINARY_OP(operator>=, tirx::GE); +TVM_PATTERN_BINARY_OP(operator<, tirx::LT); +TVM_PATTERN_BINARY_OP(operator<=, tirx::LE); +TVM_PATTERN_BINARY_OP(operator==, tirx::EQ); +TVM_PATTERN_BINARY_OP(operator!=, tirx::NE); +TVM_PATTERN_BINARY_OP(operator&&, tirx::And); +TVM_PATTERN_BINARY_OP(operator||, tirx::Or); /*! * \brief Pattern not expression. @@ -442,7 +442,7 @@ class PNotExpr : public Pattern> { void InitMatch_() const { value_.InitMatch_(); } bool Match_(const ObjectRef& node) const { - if (const tir::NotNode* ptr = node.as()) { + if (const tirx::NotNode* ptr = node.as()) { if (!value_.Match_(ptr->a)) return false; return true; } else { @@ -450,7 +450,7 @@ class PNotExpr : public Pattern> { } } - PrimExpr Eval() const { return tir::Not(value_.Eval()); } + PrimExpr Eval() const { return tirx::Not(value_.Eval()); } private: typename TA::Nested value_; @@ -481,7 +481,7 @@ class PSelectExpr : public Pattern> { } bool Match_(const ObjectRef& node) const { - if (const tir::SelectNode* ptr = node.as()) { + if (const tirx::SelectNode* ptr = node.as()) { if (!condition_.Match_(ptr->condition)) return false; if (!true_value_.Match_(ptr->true_value)) return false; if (!false_value_.Match_(ptr->false_value)) return false; @@ -492,7 +492,7 @@ class PSelectExpr : public Pattern> { } PrimExpr Eval() const { - return tir::Select(condition_.Eval(), true_value_.Eval(), false_value_.Eval()); + return tirx::Select(condition_.Eval(), true_value_.Eval(), false_value_.Eval()); } private: @@ -538,7 +538,7 @@ class PCastExpr : public Pattern> { } bool Match_(const ObjectRef& node) const { - if (const tir::CastNode* ptr = node.as()) { + if (const tirx::CastNode* ptr = node.as()) { if (!dtype_.Match_(ptr->dtype)) return false; if (!value_.Match_(ptr->value)) return false; return true; @@ -547,7 +547,7 @@ class PCastExpr : public Pattern> { } } - PrimExpr Eval() const { return tir::Cast(dtype_.Eval(), value_.Eval()); } + PrimExpr Eval() const { return tirx::Cast(dtype_.Eval(), value_.Eval()); } private: typename DType::Nested dtype_; @@ -589,7 +589,7 @@ class PRampExpr : public Pattern> { } bool Match_(const ObjectRef& node) const { - if (const tir::RampNode* ptr = node.as()) { + if (const tirx::RampNode* ptr = node.as()) { if (!base_.Match_(ptr->base)) return false; if (!stride_.Match_(ptr->stride)) return false; if (!lanes_.Match_(ptr->lanes)) return false; @@ -599,7 +599,7 @@ class PRampExpr : public Pattern> { } } - PrimExpr Eval() const { return tir::Ramp(base_.Eval(), stride_.Eval(), lanes_.Eval()); } + PrimExpr Eval() const { return tirx::Ramp(base_.Eval(), stride_.Eval(), lanes_.Eval()); } private: typename TBase::Nested base_; @@ -651,7 +651,7 @@ class PBroadcastExpr : public Pattern> { } bool Match_(const ObjectRef& node) const { - if (const tir::BroadcastNode* ptr = node.as()) { + if (const tirx::BroadcastNode* ptr = node.as()) { if (!value_.Match_(ptr->value)) return false; if (!lanes_.Match_(ptr->lanes)) return false; return true; @@ -660,7 +660,7 @@ class PBroadcastExpr : public Pattern> { } } - PrimExpr Eval() const { return tir::Broadcast(value_.Eval(), lanes_.Eval()); } + PrimExpr Eval() const { return tirx::Broadcast(value_.Eval(), lanes_.Eval()); } private: typename TA::Nested value_; @@ -715,10 +715,10 @@ struct PCallExprInitMatchFunctor { }; struct PCallExprMatchFunctor { - const tir::CallNode* call_; + const tirx::CallNode* call_; bool matched_{true}; - explicit PCallExprMatchFunctor(const tir::CallNode* call) : call_(call) {} + explicit PCallExprMatchFunctor(const tirx::CallNode* call) : call_(call) {} template void operator()(size_t i, const T& pattern) { @@ -754,7 +754,7 @@ class PCallExpr : public Pattern> { } bool Match_(const ObjectRef& node) const { - if (const tir::CallNode* ptr = node.as()) { + if (const tirx::CallNode* ptr = node.as()) { if (ptr->args.size() != sizeof...(TArgs)) return false; if (!ptr->op.same_as(Op::GetOp())) return false; detail::PCallExprMatchFunctor fmatch(ptr); @@ -779,9 +779,9 @@ class PCallExpr : public Pattern> { #define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName) \ struct OpName { \ static PrimExpr Eval(ffi::Array args) { \ - return tir::Call(args[0].dtype(), GetOp(), args); \ + return tirx::Call(args[0].dtype(), GetOp(), args); \ } \ - static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ + static const Op& GetOp() { return tirx::builtin::IntrinOpName(); } \ }; \ template \ inline PCallExpr FuncName(const Pattern& a, const Pattern& b) { \ @@ -795,16 +795,16 @@ TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, bitwise_or); TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor); // unary intrinsics -#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \ - struct OpName { \ - static PrimExpr Eval(ffi::Array args) { \ - return tir::Call(args[0].dtype(), GetOp(), args); \ - } \ - static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \ - }; \ - template \ - inline PCallExpr FuncName(const Pattern& a) { \ - return PCallExpr(a.derived()); \ +#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \ + struct OpName { \ + static PrimExpr Eval(ffi::Array args) { \ + return tirx::Call(args[0].dtype(), GetOp(), args); \ + } \ + static const Op& GetOp() { return tirx::builtin::IntrinOpName(); } \ + }; \ + template \ + inline PCallExpr FuncName(const Pattern& a) { \ + return PCallExpr(a.derived()); \ } TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not); @@ -812,9 +812,9 @@ TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not); // if_then_else struct PIfThenElseOp { static PrimExpr Eval(ffi::Array args) { - return tir::Call(args[1].dtype(), GetOp(), args); + return tirx::Call(args[1].dtype(), GetOp(), args); } - static const Op& GetOp() { return tir::builtin::if_then_else(); } + static const Op& GetOp() { return tirx::builtin::if_then_else(); } }; /*! @@ -840,8 +840,8 @@ inline PCallExpr if_then_else(const Pattern // vscale struct PVscaleOp { - static PrimExpr Eval() { return tir::Call(DataType::Int(32), GetOp(), {}); } - static const Op& GetOp() { return tir::builtin::vscale(); } + static PrimExpr Eval() { return tirx::Call(DataType::Int(32), GetOp(), {}); } + static const Op& GetOp() { return tirx::builtin::vscale(); } }; template diff --git a/src/arith/presburger_set.cc b/src/arith/presburger_set.cc index 4be7f8442a55..3c7bd25d860f 100644 --- a/src/arith/presburger_set.cc +++ b/src/arith/presburger_set.cc @@ -28,9 +28,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include @@ -46,7 +46,7 @@ namespace arith { #if defined(TVM_MLIR_VERSION) && TVM_MLIR_VERSION >= 150 TVM_FFI_STATIC_INIT_BLOCK() { PresburgerSetNode::RegisterReflection(); } -using namespace tir; +using namespace tirx; static void Update(const PrimExpr& constraint, PresburgerSetNode* intset) { auto& space = intset->space; diff --git a/src/arith/presburger_set.h b/src/arith/presburger_set.h index cf624f757e5f..9e4735004157 100644 --- a/src/arith/presburger_set.h +++ b/src/arith/presburger_set.h @@ -34,7 +34,7 @@ #include #include -#include +#include #include #include diff --git a/src/arith/product_normal_form.h b/src/arith/product_normal_form.h index ed258f207d6d..d3308c07bb2a 100644 --- a/src/arith/product_normal_form.h +++ b/src/arith/product_normal_form.h @@ -24,8 +24,8 @@ #ifndef TVM_ARITH_PRODUCT_NORMAL_FORM_H_ #define TVM_ARITH_PRODUCT_NORMAL_FORM_H_ -#include -#include +#include +#include namespace tvm { namespace arith { @@ -54,10 +54,10 @@ inline void UnpackReduction(const PrimExpr& value, FLeaf fleaf) { */ template inline void UnpackSum(const PrimExpr& value, FLeaf fleaf, int sign = 1) { - if (const tir::AddNode* node = value.as()) { + if (const tirx::AddNode* node = value.as()) { UnpackSum(node->a, fleaf, sign); UnpackSum(node->b, fleaf, sign); - } else if (const tir::SubNode* node = value.as()) { + } else if (const tirx::SubNode* node = value.as()) { UnpackSum(node->a, fleaf, sign); UnpackSum(node->b, fleaf, -sign); } else { @@ -79,7 +79,7 @@ inline void UnpackSum(const PrimExpr& value, FLeaf fleaf, int sign = 1) { */ inline PrimExpr MulAndNormalize(const PrimExpr& lhs, const PrimExpr& rhs) { int64_t cscale = 1; - PrimExpr res = tir::make_const(lhs.dtype(), 1); + PrimExpr res = tirx::make_const(lhs.dtype(), 1); auto fcollect = [&](PrimExpr val) { if (const auto* intimm = val.as()) { cscale *= intimm->value; @@ -87,10 +87,10 @@ inline PrimExpr MulAndNormalize(const PrimExpr& lhs, const PrimExpr& rhs) { res = res * val; } }; - UnpackReduction(lhs, fcollect); - UnpackReduction(rhs, fcollect); + UnpackReduction(lhs, fcollect); + UnpackReduction(rhs, fcollect); if (cscale != 1) { - res = res * tir::make_const(res.dtype(), cscale); + res = res * tirx::make_const(res.dtype(), cscale); } return res; } diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 7ae2f09e3990..cb96bb07f66c 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -25,8 +25,8 @@ #include "rewrite_simplify.h" #include -#include -#include +#include +#include #include #include @@ -42,7 +42,7 @@ namespace tvm { namespace arith { -using namespace tir; +using namespace tirx; TVM_FFI_STATIC_INIT_BLOCK() { RewriteSimplifierStatsNode::RegisterReflection(); } @@ -2270,19 +2270,19 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { op = ret.as(); if (op == nullptr) return ret; - if (op->op.same_as(tir::builtin::likely()) && is_const_int(op->args[0])) { + if (op->op.same_as(tirx::builtin::likely()) && is_const_int(op->args[0])) { return op->args[0]; - } else if (op->op.same_as(tir::builtin::shift_right())) { + } else if (op->op.same_as(tirx::builtin::shift_right())) { if (op->args[0].as() && op->args[1].as()) { // the operator overload will eagerly constant fold. return op->args[0] >> op->args[1]; } - } else if (op->op.same_as(tir::builtin::shift_left())) { + } else if (op->op.same_as(tirx::builtin::shift_left())) { if (op->args[0].as() && op->args[1].as()) { // the operator overload will eagerly constant fold. return op->args[0] << op->args[1]; } - } else if (op->op.same_as(Op::Get("tir.ceil"))) { + } else if (op->op.same_as(Op::Get("tirx.ceil"))) { PrimExpr ceil_arg = op->args[0]; if (auto arg_int = op->args[0].as()) { return cast(op->dtype, IntImm(arg_int->dtype, arg_int->value)); @@ -2291,7 +2291,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { } else if (auto arg_call = ceil_arg.as()) { // ceil(log2(cast(n,"float64"))) is used as the implementation of // topi.math.ceil_log2, and appears in iteration bounds. - if (arg_call->op.same_as(Op::Get("tir.log2"))) { + if (arg_call->op.same_as(Op::Get("tirx.log2"))) { PrimExpr log_arg = arg_call->args[0]; if (auto as_float = log_arg.as()) { // ceil(log2(n)) can be simplified, and should produce the @@ -2301,7 +2301,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { } } } - } else if (op->op.same_as(Op::Get("tir.clz"))) { + } else if (op->op.same_as(Op::Get("tirx.clz"))) { if (const auto* arg_int = op->args[0].as()) { int bits = arg_int->dtype.bits(); if (arg_int->value == 0) return make_const(op->dtype, bits); @@ -2314,14 +2314,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { } } - if (op->op.same_as(tir::builtin::likely())) { + if (op->op.same_as(tirx::builtin::likely())) { // Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } } if (auto match = TryMatchLiteralConstraint(op->args[0])) { return match.value(); } } - if (op->op.same_as(tir::builtin::if_then_else())) { + if (op->op.same_as(tirx::builtin::if_then_else())) { // Simplify nested if_then_else // if (cond) { if (inner_cond) { inner_then_expr } else { inner_else_expr } } else { else_expr } // => if (cond && inner_cond) { inner_then_expr } else { else_expr } @@ -2329,7 +2329,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { const PrimExpr& then_expr = op->args[1]; const PrimExpr& else_expr = op->args[2]; const CallNode* inner_call = then_expr.as(); - if (inner_call != nullptr && inner_call->op.same_as(tir::builtin::if_then_else())) { + if (inner_call != nullptr && inner_call->op.same_as(tirx::builtin::if_then_else())) { const PrimExpr& inner_cond = inner_call->args[0]; const PrimExpr& inner_then_expr = inner_call->args[1]; const PrimExpr& inner_else_expr = inner_call->args[2]; diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 976e490ddfd5..1aed7102f13e 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -26,7 +26,7 @@ #include #include -#include +#include #include #include @@ -39,7 +39,7 @@ namespace tvm { namespace arith { -using namespace tir; +using namespace tirx; /* \brief Usage counters for RewriteSimplifier * diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc index b827dff80105..b0b91b01ec5a 100644 --- a/src/arith/scalable_expression.cc +++ b/src/arith/scalable_expression.cc @@ -24,27 +24,27 @@ #include "scalable_expression.h" -#include -#include +#include +#include #include -#include "../tir/analysis/check_contains.h" -#include "../tir/transform/replace_selected_expr.h" +#include "../tirx/analysis/check_contains.h" +#include "../tirx/transform/replace_selected_expr.h" #include "./pattern_match.h" namespace tvm { namespace arith { bool IsVScaleCall(const PrimExpr& expr) { - if (auto call = expr.as()) { - return call->op.same_as(tir::builtin::vscale()); + if (auto call = expr.as()) { + return call->op.same_as(tirx::builtin::vscale()); } return false; } bool ContainsVscaleCall(const PrimExpr& expr) { - return tir::CheckContains::ExprContains(expr, IsVScaleCall); + return tirx::CheckContains::ExprContains(expr, IsVScaleCall); } PrimExpr SubstituteVScaleWithKnownValue(const PrimExpr& expr, unsigned int vscale_value) { @@ -55,8 +55,8 @@ PrimExpr SubstituteVScaleWithKnownValue(const PrimExpr& expr, unsigned int vscal return true; }; - return tir::ReplaceSelectedExpr::ReplaceSelectedExprInExpr( - expr, predicate_selector, tir::MakeConstScalar(DataType::Int(32), vscale_value), + return tirx::ReplaceSelectedExpr::ReplaceSelectedExprInExpr( + expr, predicate_selector, tirx::MakeConstScalar(DataType::Int(32), vscale_value), can_replace_inside); } @@ -77,7 +77,7 @@ bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const Pr for (const unsigned int vscale_value : vscale_values) { PrimExpr result = SubstituteVScaleWithKnownValue(expr, vscale_value); result = analyzer->Simplify(result); - const int64_t* as_int = tir::as_const_int(result); + const int64_t* as_int = tirx::as_const_int(result); if (!as_int || *as_int == 0) { can_prove_expr = false; break; diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 88c55576206d..8d6b58351359 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -27,9 +27,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include @@ -132,10 +132,10 @@ void SmithNormalFormDiag(std::vector>* S, std::vector>* S, std::vectorrelations) { - if (const tir::EQNode* eq = equation.as()) { + if (const tirx::EQNode* eq = equation.as()) { // a-b = sum_{i=0}^{n-1} variables[i] * coeff[i] + coeff[n] ffi::Array coeffs = arith::DetectLinearEquation( analyzer_problem.Simplify(eq->a - eq->b), system_to_solve->variables); @@ -362,15 +362,15 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol new_relation = (floormod(Uy[j], std::abs(S[j][j])) == 0); } new_relation = analyzer_problem.Simplify(new_relation); - if (tir::is_const_int(new_relation, 0)) { + if (tirx::is_const_int(new_relation, 0)) { // unable to solve the system. return IntConstraintsTransform(system_to_solve, IntConstraints( /*variables=*/{}, /*ranges=*/{}, - /*relations=*/{tir::make_zero(DataType::Bool())}), + /*relations=*/{tirx::make_zero(DataType::Bool())}), {}, {}); - } else if (!tir::is_const_int(new_relation, 1)) { + } else if (!tirx::is_const_int(new_relation, 1)) { new_relations.push_back(new_relation); } } @@ -402,12 +402,12 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol // The j-th variable is just a single value, don't create a tvm variable // S^{-1}_{nxm} Uy_{mxn} if (S[j][j] >= 0) { - PrimExpr a = tir::make_const(Uy[j].dtype(), S[j][j]); + PrimExpr a = tirx::make_const(Uy[j].dtype(), S[j][j]); solution_for_V_inv_x.push_back(analyzer_problem.Simplify(floordiv(Uy[j], a))); } else { // This is required because some simplifiers // have problems with dividing by negative numbers - PrimExpr a = tir::make_const(Uy[j].dtype(), -S[j][j]); + PrimExpr a = tirx::make_const(Uy[j].dtype(), -S[j][j]); solution_for_V_inv_x.push_back(analyzer_problem.Simplify(floordiv(-Uy[j], a))); } } @@ -415,9 +415,9 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol // V V^{-1} x = x for (size_t i = 0; i < num_vars; ++i) { - PrimExpr e = tir::make_zero(system_to_solve->variables[i].dtype()); + PrimExpr e = tirx::make_zero(system_to_solve->variables[i].dtype()); for (size_t j = 0; j < num_vars; ++j) { - e = e + tir::make_const(e.dtype(), V[i][j]) * solution_for_V_inv_x[j]; + e = e + tirx::make_const(e.dtype(), V[i][j]) * solution_for_V_inv_x[j]; } e = analyzer_problem.Simplify(e); old_to_new_map.Set(system_to_solve->variables[i], e); @@ -438,10 +438,10 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol PrimExpr lower_cond = analyzer_solution.Simplify(old_range->min <= express_by_new_vars); PrimExpr upper_cond = analyzer_solution.Simplify(express_by_new_vars < old_range->min + old_range->extent); - if (!tir::is_const_int(lower_cond, 1)) { + if (!tirx::is_const_int(lower_cond, 1)) { new_relations.push_back(lower_cond); } - if (!tir::is_const_int(upper_cond, 1)) { + if (!tirx::is_const_int(upper_cond, 1)) { new_relations.push_back(upper_cond); } } diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 6c932ea5221b..64a85d04d70b 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -27,10 +27,10 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include "int_operator.h" @@ -38,7 +38,7 @@ namespace tvm { namespace arith { using namespace tvm::runtime; -using namespace tvm::tir; +using namespace tvm::tirx; struct ExprLess { bool operator()(const PrimExpr& l, const PrimExpr& r) const { @@ -411,7 +411,7 @@ IntConstraints SolveInequalitiesToRange(const IntConstraints& inequalities) { if (analyzer.CanProveGreaterEqual(-best_range->extent, 0)) { // range.extent <= 0 implies the input inequality system is unsolvable return IntConstraints(/*variables=*/{}, /*ranges=*/{}, - /*relations=*/{tir::make_zero(DataType::Bool())}); + /*relations=*/{tirx::make_zero(DataType::Bool())}); } res_ranges.Set(var, best_range); vranges.Set(var, best_range); @@ -496,7 +496,7 @@ IntConstraintsTransform SolveInequalitiesDeskewRange(const IntConstraints& inequ IntConstraints( /*variables=*/{}, /*ranges=*/{}, - /*relations=*/{tir::make_zero(DataType::Bool())}), + /*relations=*/{tirx::make_zero(DataType::Bool())}), {}, {}); } else { // created new_var starts from 0 diff --git a/src/arith/transitive_comparison_analyzer.cc b/src/arith/transitive_comparison_analyzer.cc index 23aaf2140c33..e7deea4cfd56 100644 --- a/src/arith/transitive_comparison_analyzer.cc +++ b/src/arith/transitive_comparison_analyzer.cc @@ -21,8 +21,8 @@ */ #include -#include -#include +#include +#include #include #include @@ -33,7 +33,7 @@ namespace tvm { namespace arith { -using namespace tir; +using namespace tirx; class TransitiveComparisonAnalyzer::Impl { public: @@ -63,7 +63,7 @@ class TransitiveComparisonAnalyzer::Impl { * \param expr The bound expression * \param allow_override Whether to allow override of existing information. */ - void Bind(const tir::Var& var, const PrimExpr& expr, bool allow_override = false); + void Bind(const tirx::Var& var, const PrimExpr& expr, bool allow_override = false); /*! \brief Bind a variable as being within a specified range * @@ -71,7 +71,7 @@ class TransitiveComparisonAnalyzer::Impl { * \param range The known range * \param allow_override Whether to allow override of existing information. */ - void Bind(const tir::Var& var, const Range& expr, bool allow_override = false); + void Bind(const tirx::Var& var, const Range& expr, bool allow_override = false); /*! * \brief Update the internal state to enter constraint. @@ -547,7 +547,7 @@ std::function TransitiveComparisonAnalyzer::EnterConstraint(const PrimEx void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr, std::vector* vec) { for (const auto& subexpr : ExtractConstraints(expr, false)) { - if (tir::SideEffect(expr) <= tir::CallEffectKind::kPure) { + if (tirx::SideEffect(expr) <= tirx::CallEffectKind::kPure) { if (auto cmp = FromExpr(subexpr)) { vec->push_back(cmp.value()); } @@ -555,7 +555,7 @@ void TransitiveComparisonAnalyzer::Impl::AddKnown(const PrimExpr& expr, } } -void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& range, +void TransitiveComparisonAnalyzer::Impl::Bind(const tirx::Var& var, const Range& range, bool allow_override) { auto it = prev_bindings_.find(var); if (it != prev_bindings_.end()) { @@ -583,7 +583,7 @@ void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const Range& } } -void TransitiveComparisonAnalyzer::Impl::Bind(const tir::Var& var, const PrimExpr& expr, +void TransitiveComparisonAnalyzer::Impl::Bind(const tirx::Var& var, const PrimExpr& expr, bool allow_override) { Bind(var, Range::FromMinExtent(expr, 1), allow_override); } diff --git a/src/arith/unwrap_vector_expr.cc b/src/arith/unwrap_vector_expr.cc index a73cf89f3671..ee56aa17541c 100644 --- a/src/arith/unwrap_vector_expr.cc +++ b/src/arith/unwrap_vector_expr.cc @@ -25,18 +25,18 @@ #include "unwrap_vector_expr.h" #include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include namespace tvm { namespace arith { -using namespace tir; +using namespace tirx; class Scalarizer : public ExprMutator { public: diff --git a/src/arith/unwrap_vector_expr.h b/src/arith/unwrap_vector_expr.h index 9f18964043ff..d17db0366c4b 100644 --- a/src/arith/unwrap_vector_expr.h +++ b/src/arith/unwrap_vector_expr.h @@ -26,7 +26,7 @@ #ifndef TVM_ARITH_UNWRAP_VECTOR_EXPR_H_ #define TVM_ARITH_UNWRAP_VECTOR_EXPR_H_ -#include +#include #include diff --git a/src/ir/apply_pass_to_function.cc b/src/ir/apply_pass_to_function.cc index a3173e990ccc..1524ea9fc249 100644 --- a/src/ir/apply_pass_to_function.cc +++ b/src/ir/apply_pass_to_function.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include #include @@ -36,8 +36,8 @@ namespace transform { namespace { BaseFunc BaseFuncWithAttr(BaseFunc func, const std::string& attr_key, Any attr_value) { - if (auto tir = func.as()) { - return WithAttr(tir.value(), attr_key, attr_value); + if (auto tirx = func.as()) { + return WithAttr(tirx.value(), attr_key, attr_value); } else if (auto relax = func.as()) { return WithAttr(relax.value(), attr_key, attr_value); } else { @@ -46,8 +46,8 @@ BaseFunc BaseFuncWithAttr(BaseFunc func, const std::string& attr_key, Any attr_v } BaseFunc BaseFuncWithoutAttr(BaseFunc func, const std::string& attr_key) { - if (auto tir = func.as()) { - return WithoutAttr(tir.value(), attr_key); + if (auto tirx = func.as()) { + return WithoutAttr(tirx.value(), attr_key); } else if (auto relax = func.as()) { return WithoutAttr(relax.value(), attr_key); } else { diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h index a4f1c462b275..c2a26cb13678 100644 --- a/src/ir/attr_functor.h +++ b/src/ir/attr_functor.h @@ -31,7 +31,7 @@ #define TVM_IR_ATTR_FUNCTOR_H_ #include -#include +#include #include @@ -78,40 +78,40 @@ class AttrFunctor { } virtual R VisitAttrDefault_(const Object* node, Args... args) = 0; virtual R VisitAttr_(const ffi::ArrayObj* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::IntImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::FloatImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::StringImmNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; // deep comparison of symbolic integer expressions. - virtual R VisitAttr_(const tir::VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::SizeVarNode* op, Args... args) { - return VisitAttr_(static_cast(op), std::forward(args)...); + virtual R VisitAttr_(const tirx::VarNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::SizeVarNode* op, Args... args) { + return VisitAttr_(static_cast(op), std::forward(args)...); } - virtual R VisitAttr_(const tir::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::DivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::ModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::FloorDivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::FloorModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::MinNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::MaxNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::GENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::GTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::LTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::LENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::EQNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::NENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::AndNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::OrNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::NotNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::CastNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::CallNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; - virtual R VisitAttr_(const tir::SelectNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::AddNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::SubNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::MulNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::DivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::ModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::FloorDivNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::FloorModNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::MinNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::MaxNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::GENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::GTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::LTNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::LENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::EQNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::NENode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::AndNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::OrNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::NotNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::CastNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::CallNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const tirx::SelectNode* op, Args... args) ATTR_FUNCTOR_DEFAULT; private: // initialize the vtable. static FType InitVTable() { - using namespace tir; + using namespace tirx; FType vtable; // Set dispatch ATTR_FUNCTOR_DISPATCH(ffi::ArrayObj); diff --git a/src/ir/env_func.cc b/src/ir/env_func.cc index 774d81fb32cf..e90ea07a0b38 100644 --- a/src/ir/env_func.cc +++ b/src/ir/env_func.cc @@ -23,7 +23,7 @@ #include #include #include -#include +#include namespace tvm { diff --git a/src/ir/expr.cc b/src/ir/expr.cc index cd1ac04abb47..4acb0507343e 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include "../support/scalars.h" @@ -48,7 +48,7 @@ PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) { PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} -PrimExpr PrimExpr::ConvertFallbackValue(ffi::String value) { return tir::StringImm(value); } +PrimExpr PrimExpr::ConvertFallbackValue(ffi::String value) { return tirx::StringImm(value); } IntImm::IntImm(DataType dtype, int64_t value, Span span) { TVM_FFI_CHECK(dtype.is_scalar(), ValueError) @@ -190,7 +190,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } Range::Range(PrimExpr begin, PrimExpr end, Span span) - : Range(ffi::make_object(begin, tir::is_zero(begin) ? end : (end - begin), span)) {} + : Range(ffi::make_object(begin, tirx::is_zero(begin) ? end : (end - begin), span)) {} Range Range::FromMinExtent(PrimExpr min, PrimExpr extent, Span span) { return Range(ffi::make_object(min, extent, span)); diff --git a/src/ir/function.cc b/src/ir/function.cc index 8a5da7dbefa9..801aab314132 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include namespace tvm { @@ -38,8 +38,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("ir.BaseFuncWithAttr", [](ffi::RValueRef func_ref, ffi::String key, Any value) -> BaseFunc { BaseFunc func = *std::move(func_ref); - if (func->IsInstance()) { - return WithAttr(Downcast(std::move(func)), key, value); + if (func->IsInstance()) { + return WithAttr(Downcast(std::move(func)), key, value); } else if (func->IsInstance()) { return WithAttr(Downcast(std::move(func)), key, value); } else if (func->IsInstance()) { @@ -53,8 +53,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](ffi::RValueRef func_ref, ffi::Map attr_map) -> BaseFunc { BaseFunc func = *std::move(func_ref); - if (func->IsInstance()) { - return WithAttrs(Downcast(std::move(func)), attr_map); + if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); } if (const auto f = tvm::ffi::Function::GetGlobal("relax.FuncWithAttrs")) { if (auto ret = (*f)(func, attr_map).cast>()) { @@ -70,8 +70,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("ir.BaseFuncWithoutAttr", [](ffi::RValueRef func_ref, ffi::String key) -> BaseFunc { BaseFunc func = *std::move(func_ref); - if (func->IsInstance()) { - return WithoutAttr(Downcast(std::move(func)), key); + if (func->IsInstance()) { + return WithoutAttr(Downcast(std::move(func)), key); } else if (func->IsInstance()) { return WithoutAttr(Downcast(std::move(func)), key); } else { diff --git a/src/ir/op.cc b/src/ir/op.cc index 508d4eafcdf1..0f8ff2ea7d96 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include @@ -39,7 +39,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { OpNode::RegisterReflection(); } using ffi::Any; using ffi::Function; using ffi::PackedArgs; -using tir::FLowerIntrinsic; +using tirx::FLowerIntrinsic; using OpRegistry = AttrRegistry; diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc index 69f3dad32022..2b77e6aa0f3e 100644 --- a/src/relax/analysis/analysis.cc +++ b/src/relax/analysis/analysis.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include namespace tvm { namespace relax { diff --git a/src/relax/analysis/collect_call_map.cc b/src/relax/analysis/collect_call_map.cc index 85099d88ff57..9575e4f2cee9 100644 --- a/src/relax/analysis/collect_call_map.cc +++ b/src/relax/analysis/collect_call_map.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include namespace tvm { namespace relax { diff --git a/src/relax/analysis/computable_at_compile_time.cc b/src/relax/analysis/computable_at_compile_time.cc index 954240c19189..43128b0ab9e0 100644 --- a/src/relax/analysis/computable_at_compile_time.cc +++ b/src/relax/analysis/computable_at_compile_time.cc @@ -85,7 +85,7 @@ class CompileTimeCollector : ExprVisitor { } support::OrderedSet known_relax_vars_; - std::unordered_set known_tir_vars_; + std::unordered_set known_tir_vars_; }; } // namespace diff --git a/src/relax/analysis/detect_recursion.cc b/src/relax/analysis/detect_recursion.cc index 7b2a5f516e92..da978804de4b 100644 --- a/src/relax/analysis/detect_recursion.cc +++ b/src/relax/analysis/detect_recursion.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include namespace tvm { namespace relax { diff --git a/src/relax/analysis/layout_transformation.cc b/src/relax/analysis/layout_transformation.cc index 38752128b87f..307ce09d0fb0 100644 --- a/src/relax/analysis/layout_transformation.cc +++ b/src/relax/analysis/layout_transformation.cc @@ -26,22 +26,22 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include "../../support/array.h" namespace tvm { namespace relax { -using namespace tir; +using namespace tirx; /********** Helper Functions **********/ /*! \brief Checks if a transformation is bijective affine over the given ranges */ static bool IsBijectiveAffine(const IndexMap& m, const ffi::Array& ranges) { - ffi::Map input_iters; + ffi::Map input_iters; TVM_FFI_ICHECK_EQ(m->initial_indices.size(), ranges.size()); for (size_t i = 0; i < ranges.size(); i++) { input_iters.Set(m->initial_indices[i], ranges[i]); @@ -61,7 +61,7 @@ static bool IsBijectiveAffine(const IndexMap& m, const ffi::Array& ranges */ class IndexAnalyzer : public ExprVisitor { public: - ffi::Array Analyze(const arith::IterSumExpr& expr) { + ffi::Array Analyze(const arith::IterSumExpr& expr) { VisitExpr(expr); return iterators_; } @@ -85,15 +85,15 @@ class IndexAnalyzer : public ExprVisitor { } void VisitIterMark(const arith::IterMark& op) { - if (const auto* var = op->source.as()) - iterators_.push_back(ffi::GetRef(var)); + if (const auto* var = op->source.as()) + iterators_.push_back(ffi::GetRef(var)); else VisitExpr(op->source); VisitExpr(op->extent); } private: - ffi::Array iterators_; + ffi::Array iterators_; }; /*! @@ -111,13 +111,13 @@ class IndexAnalyzer : public ExprVisitor { * SpatialLayout(A[s0, constant, r0, s1]) = {s0, null, null, s1} * SpatialLayout(A[s0 * c + s1]) = undefined */ -using SpatialLayout = ffi::Array>; +using SpatialLayout = ffi::Array>; static SpatialLayout GetSpatialLayout(const arith::IterMapResult& iter_map_result) { TVM_FFI_ICHECK(!iter_map_result->indices.empty()); SpatialLayout result; for (const arith::IterSumExpr& index : iter_map_result->indices) { IndexAnalyzer index_analyzer; - ffi::Array iter_vars = index_analyzer.Analyze(index); + ffi::Array iter_vars = index_analyzer.Analyze(index); if (iter_vars.size() >= 2) { LOG(WARNING) << "[LayoutInference] Unable to get spatial layout of access: " << arith::NormalizeIterMapToExpr(index); @@ -152,7 +152,7 @@ static bool AreIdenticalSpatialAccess(const SpatialLayout& s0, const SpatialLayo * (ignoring reduction dimensions). It checks that the order of spatial iter vars in spatial layout * of a buffer access is same as the order of spatial iter vars in block domain. */ -using VarToBlockIndexMap = std::unordered_map; +using VarToBlockIndexMap = std::unordered_map; static bool IsSequentialAccess(const SpatialLayout& iterators, const VarToBlockIndexMap& iter_to_block_index) { int last_value = -1; @@ -174,7 +174,7 @@ static bool AreIdenticalTransforms(const IndexMap& t0, const IndexMap& t1) { // Create a new shape expression. ffi::Array t1_initial_indices = - t1->initial_indices.Map([](tir::Var i) -> PrimExpr { return i; }); + t1->initial_indices.Map([](tirx::Var i) -> PrimExpr { return i; }); arith::Analyzer analyzer; auto t0_output = t0->MapIndices(t1_initial_indices, &analyzer); for (size_t i = 0; i < t0_output.size(); ++i) { @@ -212,7 +212,7 @@ static bool AreIdenticalTransforms(const IndexMap& t0, const IndexMap& t1) { * source spatial layout. * target transformation = lambda dim, C, H, W -> (dim, H, W, C // 4, C %4) */ -using VarSet = std::unordered_set; +using VarSet = std::unordered_set; static ffi::Optional InferLayoutTransformation(const SpatialLayout& src_spatial_layout, const IndexMap& src_transformation, const SpatialLayout& tgt_spatial_layout) { @@ -244,13 +244,13 @@ static ffi::Optional InferLayoutTransformation(const SpatialLayout& sr auto final_indices_it = final_indices.begin(); while (final_indices_it != final_indices.end()) { // Collect all the vars used in this final index. - ffi::Array used_vars = tir::UndefinedVars(*final_indices_it); + ffi::Array used_vars = tirx::UndefinedVars(*final_indices_it); TVM_FFI_ICHECK(!used_vars.empty()) - << "IndexMap expression must always contain tir::Var nodes but found none in: " + << "IndexMap expression must always contain tirx::Var nodes but found none in: " << *final_indices_it; bool has_undefined_vars = std::any_of(used_vars.begin(), used_vars.end(), - [&initial_indices_var_set](const tir::Var& v) { + [&initial_indices_var_set](const tirx::Var& v) { return initial_indices_var_set.count(v) == 0; }); @@ -266,7 +266,7 @@ static ffi::Optional InferLayoutTransformation(const SpatialLayout& sr // "H4h" -> "H*4+h" ) and the buffer we are trying to infer the transformation of has 'h' // dimension, but not 'H'. So, it is dependent on undefined var 'H' and defined var 'h'. bool depends_on_initial_indices = std::any_of(used_vars.begin(), used_vars.end(), - [&initial_indices_var_set](const tir::Var& v) { + [&initial_indices_var_set](const tirx::Var& v) { return initial_indices_var_set.count(v) != 0; }); if (depends_on_initial_indices) { @@ -296,7 +296,7 @@ static ffi::Optional InferLayoutTransformation(const SpatialLayout& sr continue; } - auto new_dim = tir::Var("d"); + auto new_dim = tirx::Var("d"); initial_indices.insert(initial_indices_it, new_dim); final_indices.insert(final_indices_it, new_dim); } @@ -457,7 +457,7 @@ class BlockAnalyzer : public StmtExprVisitor { spatial_dom_.Set(v->var, v->dom); continue; } - if (v->iter_type == tir::kCommReduce) continue; + if (v->iter_type == tirx::kCommReduce) continue; LOG(WARNING) << "[LayoutInference] Cannot compute block spatial domain in presence of " "unknown block iter_type : " << v->iter_type; @@ -522,7 +522,7 @@ class BlockAnalyzer : public StmtExprVisitor { private: bool can_transform_block_; IndexMap write_transformation_; - ffi::Map spatial_dom_; + ffi::Map spatial_dom_; arith::Analyzer arith_analyzer_; SBlock block_; @@ -608,7 +608,7 @@ class PrimFuncAnalyzer : public StmtExprVisitor { std::unordered_map, ObjectPtrHash, ObjectPtrEqual> block_to_buffer_; }; -ffi::Map> SuggestLayoutTransforms( +ffi::Map> SuggestLayoutTransforms( const PrimFunc& prim_func, ffi::Array write_buffer_transformations) { // No changes to the PrimFunc are required if no transformations on output buffers. if (write_buffer_transformations.empty()) return {}; @@ -620,7 +620,7 @@ ffi::Map> SuggestLayoutTransform TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("relax.analysis.suggest_layout_transforms", - [](PrimFunc fn, ffi::Array write_buffer_transformations) { + [](PrimFunc fn, ffi::Array write_buffer_transformations) { return SuggestLayoutTransforms(fn, write_buffer_transformations); }); } diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 101e8e8b7410..1152e7670c34 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -27,8 +27,8 @@ #include #include #include -#include -#include +#include +#include namespace tvm { namespace relax { @@ -115,9 +115,9 @@ StructInfo StructInfoFromType(const Type& type) { //-------------------------- class WellDefinedEraser : public StructInfoMutator, public ExprMutatorBase, - public tir::ExprMutator { + public tirx::ExprMutator { public: - WellDefinedEraser(std::function(const tir::Var& var)> f_shape_var_map, + WellDefinedEraser(std::function(const tirx::Var& var)> f_shape_var_map, std::function(const Var& var)> f_var_map, arith::Analyzer* ana) : f_shape_var_map_(f_shape_var_map), f_var_map_(f_var_map), ana_(ana) {} @@ -202,12 +202,12 @@ class WellDefinedEraser : public StructInfoMutator, } using relax::ExprMutatorBase::VisitExpr_; - using tir::ExprMutator::VisitExpr_; + using tirx::ExprMutator::VisitExpr_; // connect things up PrimExpr VisitPrimExpr(const PrimExpr& expr) { // apply eager simplification - PrimExpr val = tir::ExprMutator::VisitExpr(expr); + PrimExpr val = tirx::ExprMutator::VisitExpr(expr); if (!val.same_as(expr)) { return ana_->Simplify(val); } else { @@ -228,10 +228,10 @@ class WellDefinedEraser : public StructInfoMutator, return ret.value_or(ffi::GetRef(var)); } - PrimExpr VisitExpr_(const tir::VarNode* var) final { + PrimExpr VisitExpr_(const tirx::VarNode* var) final { ffi::Optional ret; if (f_shape_var_map_ != nullptr) { - ret = f_shape_var_map_(ffi::GetRef(var)); + ret = f_shape_var_map_(ffi::GetRef(var)); } has_undefined_ = has_undefined_ || !ret.defined(); @@ -250,14 +250,14 @@ class WellDefinedEraser : public StructInfoMutator, private: bool has_undefined_ = false; - std::function(const tir::Var& var)> f_shape_var_map_; + std::function(const tirx::Var& var)> f_shape_var_map_; std::function(const Var& var)> f_var_map_; arith::Analyzer* ana_; }; StructInfo EraseToWellDefined( const StructInfo& info, - std::function(const tir::Var& var)> f_shape_var_map, + std::function(const tirx::Var& var)> f_shape_var_map, std::function(const Var& var)> f_var_map, arith::Analyzer* ana) { if (ana == nullptr) { arith::Analyzer inst; @@ -267,13 +267,13 @@ StructInfo EraseToWellDefined( } } -StructInfo EraseToWellDefined(const StructInfo& info, ffi::Map shape_var_map, +StructInfo EraseToWellDefined(const StructInfo& info, ffi::Map shape_var_map, ffi::Map var_map, arith::Analyzer* ana) { - std::function(const tir::Var& var)> f_shape_var_map = nullptr; + std::function(const tirx::Var& var)> f_shape_var_map = nullptr; std::function(const Var& var)> f_var_map = nullptr; if (!shape_var_map.empty()) { - f_shape_var_map = [&](const tir::Var& var) -> ffi::Optional { + f_shape_var_map = [&](const tirx::Var& var) -> ffi::Optional { auto it = shape_var_map.find(var); if (it != shape_var_map.end()) return (*it).second; return std::nullopt; @@ -295,7 +295,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( "relax.analysis.EraseToWellDefined", - [](const StructInfo& info, ffi::Map shape_var_map, + [](const StructInfo& info, ffi::Map shape_var_map, ffi::Map var_map) { return EraseToWellDefined(info, shape_var_map, var_map); }); } @@ -880,7 +880,7 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { // Whether to populate map in params. bool populate_mapping_{true}; // for simplicity, we make these fields public so the user can access them. - ffi::Map shape_var_map_; + ffi::Map shape_var_map_; ffi::Map var_map_; using StructInfoBaseChecker::ShapeMatchCheck; @@ -891,8 +891,8 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { return StructInfoBaseChecker::PrimValueMatchCheck(param, arg); } - if (auto* ptr = param.as()) { - auto var = ffi::GetRef(ptr); + if (auto* ptr = param.as()) { + auto var = ffi::GetRef(ptr); auto it = shape_var_map_.find(var); // not populated if (it == shape_var_map_.end()) { @@ -1194,16 +1194,16 @@ class TIRVarsDetector : public StructInfoVisitor { }; explicit TIRVarsDetector(VarType collection_type) : collection_type(collection_type) {} - ffi::Array GetTIRVars() const { return tir_vars_; } + ffi::Array GetTIRVars() const { return tir_vars_; } private: void VisitPrimExpr(PrimExpr expr) { if (collection_type == VarType::Definition) { - if (auto opt = expr.as()) { + if (auto opt = expr.as()) { RecordTIRVar(opt.value()); } } else if (collection_type == VarType::Usage) { - for (const tir::Var& tir_var : tir::UndefinedVars(expr)) { + for (const tirx::Var& tir_var : tirx::UndefinedVars(expr)) { RecordTIRVar(tir_var); } } else { @@ -1236,26 +1236,26 @@ class TIRVarsDetector : public StructInfoVisitor { } } - void RecordTIRVar(const tir::Var& tir_var) { + void RecordTIRVar(const tirx::Var& tir_var) { auto insert_res = used_tir_vars_dedup_.insert(tir_var.get()); if (insert_res.second) { tir_vars_.push_back(tir_var); } } - ffi::Array tir_vars_; - std::unordered_set used_tir_vars_dedup_; + ffi::Array tir_vars_; + std::unordered_set used_tir_vars_dedup_; VarType collection_type; }; -ffi::Array TIRVarsInStructInfo(const StructInfo& sinfo) { +ffi::Array TIRVarsInStructInfo(const StructInfo& sinfo) { TIRVarsDetector detector(TIRVarsDetector::VarType::Usage); detector(sinfo); return detector.GetTIRVars(); } -ffi::Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo) { +ffi::Array DefinableTIRVarsInStructInfo(const StructInfo& sinfo) { TIRVarsDetector detector(TIRVarsDetector::VarType::Definition); detector(sinfo); return detector.GetTIRVars(); @@ -1318,29 +1318,29 @@ TVM_FFI_STATIC_INIT_BLOCK() { class SymbolicVarCollector : public relax::ExprVisitor, public relax::StructInfoVisitor, - public tir::ExprVisitor { + public tirx::ExprVisitor { public: - static ffi::Array Free(const Expr& expr) { + static ffi::Array Free(const Expr& expr) { SymbolicVarCollector collector; collector.VisitExpr(expr); - ffi::Array ret{collector.free_symbolic_var_.begin(), - collector.free_symbolic_var_.end()}; + ffi::Array ret{collector.free_symbolic_var_.begin(), + collector.free_symbolic_var_.end()}; return ret; } - static ffi::Array Defined(const Expr& expr) { + static ffi::Array Defined(const Expr& expr) { SymbolicVarCollector collector; collector.VisitExpr(expr); - ffi::Array ret{collector.defined_symbolic_var_.begin(), - collector.defined_symbolic_var_.end()}; + ffi::Array ret{collector.defined_symbolic_var_.begin(), + collector.defined_symbolic_var_.end()}; return ret; } private: using relax::ExprVisitor::VisitExpr; using relax::ExprVisitor::VisitExpr_; - using tir::ExprVisitor::VisitExpr; - using tir::ExprVisitor::VisitExpr_; + using tirx::ExprVisitor::VisitExpr; + using tirx::ExprVisitor::VisitExpr_; // Possible mode of visitor, used as bit-flags enum VisitMode { @@ -1424,17 +1424,17 @@ class SymbolicVarCollector : public relax::ExprVisitor, void VisitStructInfoExprField(const PrimExpr& expr) final { if (mode_ & VisitMode::kProvideDefinition) { - if (auto var = expr.as()) { + if (auto var = expr.as()) { defined_symbolic_var_.insert(var.value()); } } if (mode_ & VisitMode::kRequireDefinition) { - tir::ExprVisitor::VisitExpr(expr); + tirx::ExprVisitor::VisitExpr(expr); } } - void VisitExpr_(const tir::VarNode* op) final { - tir::Var var = ffi::GetRef(op); + void VisitExpr_(const tirx::VarNode* op) final { + tirx::Var var = ffi::GetRef(op); // default mode, check defined. if (defined_symbolic_var_.count(var) == 0) { free_symbolic_var_.insert(var); @@ -1452,15 +1452,17 @@ class SymbolicVarCollector : public relax::ExprVisitor, /*! \brief The current visit mode. */ VisitMode mode_ = VisitMode::kRequireDefinition; /*! \brief The set of defined symbolic vars. */ - std::unordered_set defined_symbolic_var_; + std::unordered_set defined_symbolic_var_; /*! \brief The set of free/undefined symbolic vars. */ - std::unordered_set free_symbolic_var_; + std::unordered_set free_symbolic_var_; }; -ffi::Array DefinedSymbolicVars(const Expr& expr) { +ffi::Array DefinedSymbolicVars(const Expr& expr) { return SymbolicVarCollector::Defined(expr); } -ffi::Array FreeSymbolicVars(const Expr& expr) { return SymbolicVarCollector::Free(expr); } +ffi::Array FreeSymbolicVars(const Expr& expr) { + return SymbolicVarCollector::Free(expr); +} TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc index 19ecb4ca68da..276126ce406e 100644 --- a/src/relax/analysis/tir_op_pattern_kind.cc +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -21,20 +21,20 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include namespace tvm { namespace relax { -using namespace tir; +using namespace tirx; class PatternKindAnalyzer : public StmtExprVisitor { public: - explicit PatternKindAnalyzer(const tir::PrimFunc& func) { - for (const tir::Var& param : func->params) { + explicit PatternKindAnalyzer(const tirx::PrimFunc& func) { + for (const tirx::Var& param : func->params) { ffi::Optional param_buf = func->buffer_map.Get(param); if (param_buf.defined()) { param_buffers_.insert(param_buf.value()); @@ -130,9 +130,9 @@ class PatternKindAnalyzer : public StmtExprVisitor { // Step 4. Checking if the block contains reduce axis by looking into block iterators. bool has_reduction = false; - ffi::Array reduce_vars; + ffi::Array reduce_vars; for (const IterVar& it : op->iter_vars) { - if (it->iter_type == tir::IterVarType::kCommReduce) { + if (it->iter_type == tirx::IterVarType::kCommReduce) { has_reduction = true; reduce_vars.push_back(it->var); } @@ -222,9 +222,9 @@ class PatternKindAnalyzer : public StmtExprVisitor { * A[i, j] = B[i - j] is injective since the load index vars are only i, j */ static bool IsInjectivePattern(const BufferStore& store, const BufferLoad& load) { - std::unordered_set vars; + std::unordered_set vars; for (const PrimExpr& store_index : store->indices) { - if (const auto* v = store_index.as()) { + if (const auto* v = store_index.as()) { vars.insert(v); } else { return false; @@ -232,7 +232,8 @@ class PatternKindAnalyzer : public StmtExprVisitor { } for (const PrimExpr& load_index : load->indices) { // return false if there are vars used in load indices but not in store indices. - if (tir::UsesVar(load_index, [&vars](const tir::VarNode* var) { return !vars.count(var); })) { + if (tirx::UsesVar(load_index, + [&vars](const tirx::VarNode* var) { return !vars.count(var); })) { return false; } } @@ -246,9 +247,9 @@ class PatternKindAnalyzer : public StmtExprVisitor { * Store = A[i, j] and Load = B[i, j + k] allow data reuse. */ static bool IsAllowReusePattern(const BufferStore& store, const BufferLoad& load) { - std::unordered_set vars; + std::unordered_set vars; for (const PrimExpr& index : store->indices) { - if (const auto* v = index.as()) { + if (const auto* v = index.as()) { vars.insert(v); } else { return false; @@ -256,7 +257,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { } for (const PrimExpr& index : load->indices) { PreOrderVisit(index, [&](const ObjectRef& node) { - if (const auto* v = node.as()) { + if (const auto* v = node.as()) { if (vars.count(v)) { vars.erase(v); } @@ -269,7 +270,7 @@ class PatternKindAnalyzer : public StmtExprVisitor { static PrimExpr RemoveCast(PrimExpr e) { for (;;) { - if (const auto* cast = e.as()) { + if (const auto* cast = e.as()) { e = cast->value; } else { break; @@ -281,15 +282,15 @@ class PatternKindAnalyzer : public StmtExprVisitor { /*! \brief Checking if the stmt is multiply add. E.g. C[i, j] += A[i, k] * B[j, k] */ static bool IsFMA(const Stmt& body) { if (const auto* store = body.as()) { - if (const auto* add = RemoveCast(store->value).as()) { - if (const auto* mul = RemoveCast(add->b).as()) { - const auto* store_lhs = RemoveCast(add->a).as(); + if (const auto* add = RemoveCast(store->value).as()) { + if (const auto* mul = RemoveCast(add->b).as()) { + const auto* store_lhs = RemoveCast(add->a).as(); if (!store_lhs || !store->buffer.same_as(store_lhs->buffer) || !IsSameArray(store->indices, store_lhs->indices)) { return false; } - const auto* lhs = RemoveCast(mul->a).as(); - const auto* rhs = RemoveCast(mul->b).as(); + const auto* lhs = RemoveCast(mul->a).as(); + const auto* rhs = RemoveCast(mul->b).as(); if (!lhs || !rhs) { return false; } @@ -309,10 +310,11 @@ class PatternKindAnalyzer : public StmtExprVisitor { * A[i] = sum(B[i, j + k]) is not pure reduce * pooling is not pure reduce */ - static bool IsPureReducePattern(ffi::Array reduce_loops, ffi::Array indices) { + static bool IsPureReducePattern(ffi::Array reduce_loops, + ffi::Array indices) { for (const PrimExpr& e : indices) { int id = -1; - if (UsesVar(e, [&](const tir::VarNode* var) { + if (UsesVar(e, [&](const tirx::VarNode* var) { for (size_t i = 0; i < reduce_loops.size(); ++i) { if (reduce_loops[i].get() == var) { id = i; @@ -386,7 +388,7 @@ bool HasReshapePattern(const PrimFunc& func) { int n_iter = block_iter.size(); for (int i = 0; i < n_iter; ++i) { // To detect the reshape pattern, we require each block iter to be data-parallel. - if (block_iter[i]->iter_type != tir::IterVarType::kDataPar) { + if (block_iter[i]->iter_type != tirx::IterVarType::kDataPar) { return; } } @@ -402,7 +404,7 @@ bool HasReshapePattern(const PrimFunc& func) { return; } - ffi::Map var_range; + ffi::Map var_range; for (const IterVar& v : block->iter_vars) { ana_.Bind(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent)); var_range.Set(v->var, Range::FromMinExtent(v->dom->min, v->dom->extent)); @@ -476,8 +478,8 @@ bool HasReshapePattern(const PrimFunc& func) { if (nontrivial_indices.defined()) { DataType dtype = !block->iter_vars.empty() ? block->iter_vars[0]->var->dtype : DataType::Int(64); - tir::Var fused_var("fused", dtype); - ffi::Map inverse_indices_map; + tirx::Var fused_var("fused", dtype); + ffi::Map inverse_indices_map; PrimExpr stride = IntImm(dtype, /*value=*/1); for (int i = static_cast(block->iter_vars.size()) - 1; i >= 0; --i) { inverse_indices_map.Set( diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 2c1a42fbe843..a88a27c4bdc9 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -71,7 +71,7 @@ #include #include #include -#include +#include #include #include @@ -86,7 +86,7 @@ namespace relax { /*! \brief Helper to implement well formed check.*/ class WellFormedChecker : public relax::ExprVisitor, public relax::StructInfoVisitor, - public tir::ExprVisitor { + public tirx::ExprVisitor { public: static bool Check(ffi::Variant obj, bool check_struct_info) { WellFormedChecker well_formed_checker = @@ -116,8 +116,8 @@ class WellFormedChecker : public relax::ExprVisitor, : mod_(std::move(mod)), check_struct_info_(check_struct_info), cur_visited_func_(nullptr) {} using relax::ExprVisitor::VisitExpr_; - using tir::ExprVisitor::VisitExpr; - using tir::ExprVisitor::VisitExpr_; + using tirx::ExprVisitor::VisitExpr; + using tirx::ExprVisitor::VisitExpr_; // Possible mode of visitor enum class VisitMode { @@ -435,7 +435,7 @@ class WellFormedChecker : public relax::ExprVisitor, } std::unordered_set previous_var_set = var_set_; - std::unordered_set previous_symbolic_var_set = symbolic_var_set_; + std::unordered_set previous_symbolic_var_set = symbolic_var_set_; this->VisitSeqExpr(op->true_branch.get()); var_set_ = previous_var_set; symbolic_var_set_ = previous_symbolic_var_set; @@ -449,7 +449,7 @@ class WellFormedChecker : public relax::ExprVisitor, void VisitExpr_(const ShapeExprNode* op) final { for (PrimExpr expr : op->values) { // check if the symbolic vars in the expr are defined, e.g, 2 * m - tir::ExprVisitor::VisitExpr(expr); + tirx::ExprVisitor::VisitExpr(expr); if (!expr.dtype().is_int()) { Malformed(Diagnostic::Error(expr) << "Shape expressions must be of integer type, but got " << expr.dtype()); @@ -482,7 +482,7 @@ class WellFormedChecker : public relax::ExprVisitor, is_lambda = true; recur_vars_.insert(binding->var); } - if (binding->value->IsInstance()) { + if (binding->value->IsInstance()) { Malformed(Diagnostic::Error(binding->value) << "Inline PrimFunc is disallowed in Relax IR."); } else { this->VisitExpr(binding->value); @@ -549,8 +549,8 @@ class WellFormedChecker : public relax::ExprVisitor, CheckStructInfo(var); } - void VisitExpr_(const tir::VarNode* op) final { - tir::Var var = ffi::GetRef(op); + void VisitExpr_(const tirx::VarNode* op) final { + tirx::Var var = ffi::GetRef(op); // default mode, check defined. if (symbolic_var_set_.count(var) == 0) { this->Malformed(Diagnostic::Error(var) << "Symbolic Var " << var << " is not defined."); @@ -605,14 +605,14 @@ class WellFormedChecker : public relax::ExprVisitor, void VisitStructInfoExprField(const PrimExpr& expr) final { if (mode_ == VisitMode::kMatchVarDef) { // populate symbolic var in first occurrence - if (auto* op = expr.as()) { - auto var = ffi::GetRef(op); + if (auto* op = expr.as()) { + auto var = ffi::GetRef(op); if (symbolic_var_set_.count(var) == 0) { symbolic_var_set_.insert(var); } } } else { - tir::ExprVisitor::VisitExpr(expr); + tirx::ExprVisitor::VisitExpr(expr); } } @@ -652,9 +652,9 @@ class WellFormedChecker : public relax::ExprVisitor, std::unordered_set var_set_; std::unordered_set recur_vars_; std::unordered_set dataflow_var_set_; - std::unordered_set symbolic_var_set_; + std::unordered_set symbolic_var_set_; std::unordered_map param_var_func_map_; - std::unordered_map symbolic_var_func_map_; + std::unordered_map symbolic_var_func_map_; tvm::OpAttrMap op_map_normalize_ = Op::GetAttrMap("FNormalize"); tvm::OpAttrMap op_map_validate_ = Op::GetAttrMap("FValidate"); diff --git a/src/relax/backend/adreno/annotate_custom_storage.cc b/src/relax/backend/adreno/annotate_custom_storage.cc index eda0148779a0..0741ed5e818d 100644 --- a/src/relax/backend/adreno/annotate_custom_storage.cc +++ b/src/relax/backend/adreno/annotate_custom_storage.cc @@ -120,7 +120,7 @@ * 2: CollectProducerScopeInfo: Visitor does finalizes the scope for each input and output based * on consumer scope information. It does evaluating mutiple consumer cases and conflicts. * 3: DefineVDevice: Pass does injects hint_on_device for each argument. It also tries to update - * out StructInfo containing VDevice information. This update for tir calls is straight forward + * out StructInfo containing VDevice information. This update for tirx calls is straight forward * as sinfo_args in CallNode is meant for this purpose. This sinfo_args for other calls by * design is invalid as we do this by "FInferStructInfo". * Another issue we have with "FInferStructInfo" per op is they can't decide this @@ -242,7 +242,7 @@ #include #include #include -#include +#include #include @@ -255,7 +255,7 @@ namespace relax { namespace backend { namespace adreno { -using tvm::tir::Buffer; +using tvm::tirx::Buffer; static ffi::Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { auto shape = tensor_sinfo->GetShape(); @@ -342,9 +342,9 @@ class CollectConsumerScopeInfo : public ExprVisitor { if (call->op == call_tir_op) { gv = Downcast(call->args[0]); - tir::PrimFunc pfunc = Downcast(mod_->Lookup(gv)); - op_attrs = ExtractAttrs(pfunc); - op_pattern = ExtractPattern(pfunc); + tirx::PrimFunc pfunc = Downcast(mod_->Lookup(gv)); + op_attrs = ExtractAttrs(pfunc); + op_pattern = ExtractPattern(pfunc); func_args = Downcast(call->args[1]); } else { op_attrs = {call->attrs}; @@ -424,7 +424,7 @@ class CollectConsumerScopeInfo : public ExprVisitor { std::string Scope(ffi::Array shape) { // currently we support only textures been made from 5d tensors // 5d requirement is not limitation of textures in general, it is limitation how - // we are representing memory scopes/layout and flattening of textures in tir + // we are representing memory scopes/layout and flattening of textures in tirx if (shape.size() == 5 && shape[4].as()->value == 4) { for (auto ind : shape) { if (!ind.as()) { diff --git a/src/relax/backend/adreno/fold_vdevice_scope_change.cc b/src/relax/backend/adreno/fold_vdevice_scope_change.cc index f5898021de6d..2516b41e242e 100644 --- a/src/relax/backend/adreno/fold_vdevice_scope_change.cc +++ b/src/relax/backend/adreno/fold_vdevice_scope_change.cc @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include diff --git a/src/relax/backend/contrib/codegen_json/codegen_json.h b/src/relax/backend/contrib/codegen_json/codegen_json.h index 6f63350dd06a..64aebf1b34f0 100644 --- a/src/relax/backend/contrib/codegen_json/codegen_json.h +++ b/src/relax/backend/contrib/codegen_json/codegen_json.h @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/relax/backend/contrib/utils.h b/src/relax/backend/contrib/utils.h index e13ab6c65e8d..989e485f8d04 100644 --- a/src/relax/backend/contrib/utils.h +++ b/src/relax/backend/contrib/utils.h @@ -46,7 +46,7 @@ namespace backend { inline std::vector GetIntShape(const ffi::Array& shape) { std::vector ret; for (const auto& dim : shape) { - const int64_t* pval = tir::as_const_int(dim); + const int64_t* pval = tirx::as_const_int(dim); ret.push_back(pval ? *pval : -1); } return ret; diff --git a/src/relax/backend/task_extraction.cc b/src/relax/backend/task_extraction.cc index 879d49a22355..e110e0930f5e 100644 --- a/src/relax/backend/task_extraction.cc +++ b/src/relax/backend/task_extraction.cc @@ -22,8 +22,8 @@ #include #include #include -#include -#include +#include +#include #include "../../s_tir/meta_schedule/module_equality.h" @@ -49,16 +49,16 @@ using s_tir::meta_schedule::ModuleHash; * Then we will have a ExtractedTask for all three functions, whose weight * is 5 + 3 + 2 = 10. */ -class BlockCounter : public tir::StmtVisitor { +class BlockCounter : public tirx::StmtVisitor { public: - static size_t GetSBlockCount(const tir::PrimFunc& func) { + static size_t GetSBlockCount(const tirx::PrimFunc& func) { BlockCounter counter; counter(func->body); return counter.count; } private: - void VisitStmt_(const tir::SBlockNode* op) final { + void VisitStmt_(const tirx::SBlockNode* op) final { ++count; StmtVisitor::VisitStmt_(op); } @@ -96,7 +96,7 @@ class TaskExtractor : public ExprVisitor { void VisitExpr_(const CallNode* call) final { static const Op& call_tir_op = Op::Get("relax.call_tir"); - // TODO(@tvm-team): When we differentiate the call for tir function and packed function, + // TODO(@tvm-team): When we differentiate the call for tirx function and packed function, // this logic should be changed accordingly. if (!call->op.same_as(call_tir_op)) { // Since the Relax function is of A-normal form, the arguments of this call cannot be another @@ -105,13 +105,13 @@ class TaskExtractor : public ExprVisitor { } const GlobalVar& global_var = Downcast(call->args[0]); - const tir::PrimFunc& func = Downcast(mod_->Lookup(global_var)); + const tirx::PrimFunc& func = Downcast(mod_->Lookup(global_var)); IRModule mod = (*normalize_mod_func_)(func).cast(); size_t weight = 1; auto it = func2task_.find(mod); if (it != func2task_.end()) { it->second->weight += 1; - const tir::PrimFunc& alt_func = Downcast(it->first->Lookup("main")); + const tirx::PrimFunc& alt_func = Downcast(it->first->Lookup("main")); // When anchor-block based equality is used, tuning tasks "nn_conv2d_add_nn_relu" and // "nn_conv2d_add_add_nn_relu", for example, can be identified as equal. Thus, one of them // will be selected to tune by the code below. diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 8736d91af25c..13dc02fde4a1 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index 8316820b1f08..126c732cb4c0 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -28,10 +28,10 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include @@ -86,23 +86,23 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { PrimExpr RegListGet(int64_t slot) const { // use 128 bits to represent any - return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(), - {reg_anylist_handle_, ConstInt32(slot)}); + return tirx::Call(DataType::Handle(), tirx::builtin::anylist_getitem(), + {reg_anylist_handle_, ConstInt32(slot)}); } PrimExpr ConstListGet(int64_t slot) const { // use 128 bits to represent any - return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(), - {const_anylist_handle_, ConstInt32(slot)}); + return tirx::Call(DataType::Handle(), tirx::builtin::anylist_getitem(), + {const_anylist_handle_, ConstInt32(slot)}); } PrimExpr FuncListGet(int64_t slot) const { // use 128 bits to represent any - return tir::Call(DataType::Handle(), tir::builtin::anylist_getitem(), - {func_anylist_handle_, ConstInt32(slot)}); + return tirx::Call(DataType::Handle(), tirx::builtin::anylist_getitem(), + {func_anylist_handle_, ConstInt32(slot)}); } - void EmitStmt(tir::Stmt stmt) { + void EmitStmt(tirx::Stmt stmt) { TVM_FFI_ICHECK(!stmt_stack_.empty()); stmt_stack_.back().emplace_back(stmt); } @@ -114,20 +114,20 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { if (dst_anylist_slot >= 0) { all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)}; } - all_args.push_back(tir::StringImm(name)); + all_args.push_back(tirx::StringImm(name)); for (PrimExpr arg : args) { all_args.push_back(arg); } if (dst_anylist_slot >= 0) { - this->EmitStmt(tir::Evaluate( - tir::Call(DataType::Int(32), tir::builtin::anylist_setitem_call_packed(), all_args))); + this->EmitStmt(tirx::Evaluate( + tirx::Call(DataType::Int(32), tirx::builtin::anylist_setitem_call_packed(), all_args))); } else { - this->EmitStmt( - tir::Evaluate(tir::Call(DataType::Int(32), tir::builtin::tvm_call_packed(), all_args))); + this->EmitStmt(tirx::Evaluate( + tirx::Call(DataType::Int(32), tirx::builtin::tvm_call_packed(), all_args))); } } - void EmitCallCPacked(const tir::PrimFunc& prim_func, const ffi::Array& args, + void EmitCallCPacked(const tirx::PrimFunc& prim_func, const ffi::Array& args, int64_t dst_anylist_slot = -1) { ffi::Optional gsymbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); TVM_FFI_ICHECK(gsymbol.has_value()) << "All functions must have global symbol at this phase"; @@ -136,20 +136,20 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { if (dst_anylist_slot >= 0) { all_args = {reg_anylist_handle_, ConstInt32(dst_anylist_slot)}; } - all_args.push_back(tir::StringImm(gsymbol.value())); + all_args.push_back(tirx::StringImm(gsymbol.value())); for (PrimExpr arg : args) { all_args.push_back(arg); } if (dst_anylist_slot >= 0) { - this->EmitStmt(tir::Evaluate( - tir::Call(DataType::Int(32), tir::builtin::anylist_setitem_call_cpacked(), all_args))); + this->EmitStmt(tirx::Evaluate( + tirx::Call(DataType::Int(32), tirx::builtin::anylist_setitem_call_cpacked(), all_args))); } else { - this->EmitStmt( - tir::Evaluate(tir::Call(DataType::Int(32), tir::builtin::tvm_call_cpacked(), all_args))); + this->EmitStmt(tirx::Evaluate( + tirx::Call(DataType::Int(32), tirx::builtin::tvm_call_cpacked(), all_args))); } } - tir::PrimFunc Codegen(const Function& func) { + tirx::PrimFunc Codegen(const Function& func) { ffi::Optional gsymbol = func->GetAttr(tvm::attr::kGlobalSymbol); TVM_FFI_ICHECK(gsymbol.has_value()) << "there should be no local functions in Relax VM codegen phase. " @@ -158,10 +158,10 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { stmt_stack_ = {}; registers_num_ = 0; var_map_.clear(); - ctx_ptr_ = tir::Var("ctx_ptr", DataType::Handle()); - reg_anylist_handle_ = tir::Var("r", DataType::Handle()); - func_anylist_handle_ = tir::Var("f", DataType::Handle()); - const_anylist_handle_ = tir::Var("c", DataType::Handle()); + ctx_ptr_ = tirx::Var("ctx_ptr", DataType::Handle()); + reg_anylist_handle_ = tirx::Var("r", DataType::Handle()); + func_anylist_handle_ = tirx::Var("f", DataType::Handle()); + const_anylist_handle_ = tirx::Var("c", DataType::Handle()); ffi::Array param_names; for (Var param : func->params) { @@ -177,7 +177,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } size_t ret_reg = NewRegister(); - tir::Stmt body = WithNewScope([&]() { + tirx::Stmt body = WithNewScope([&]() { ffi::Optional ret = ExprFunctor::VisitExpr(func->body); if (ret.defined()) { this->EmitCallPacked("vm.builtin.copy", {ret.value()}, ret_reg); @@ -190,10 +190,10 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { builder_->EndFunction(gsymbol.value()); Type ret_type = VoidType(); - ffi::Array tir_params = {ctx_ptr_, reg_anylist_handle_, const_anylist_handle_, - func_anylist_handle_}; + ffi::Array tir_params = {ctx_ptr_, reg_anylist_handle_, const_anylist_handle_, + func_anylist_handle_}; ffi::String tir_func_name = system_lib_prefix_.value_or("") + "__vmtir__" + gsymbol.value(); - tir::PrimFunc tir_func(tir_params, body, ret_type, {}); + tirx::PrimFunc tir_func(tir_params, body, ret_type, {}); tir_func = WithAttr(tir_func, "global_symbol", tir_func_name); registers_num_ = 0; var_map_.clear(); @@ -228,8 +228,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { Call call = ffi::GetRef(call_node); if (call_node->op == null_value_op_) { - return tir::Call(DataType::Handle(), tir::builtin::reinterpret(), - {IntImm(DataType::Int(64), 0)}); + return tirx::Call(DataType::Handle(), tirx::builtin::reinterpret(), + {IntImm(DataType::Int(64), 0)}); } int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister(); if (call->op.as()) { @@ -262,18 +262,18 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { size_t merge_register = NewRegister(); PrimExpr cond_value = this->VisitExpr(op->cond).value(); - cond_value = tir::Call(DataType::Bool(), tir::builtin::tvm_call_packed(), - {tir::StringImm("vm.builtin.read_if_cond"), cond_value}); + cond_value = tirx::Call(DataType::Bool(), tirx::builtin::tvm_call_packed(), + {tirx::StringImm("vm.builtin.read_if_cond"), cond_value}); - tir::Stmt true_branch = WithNewScope([&]() { + tirx::Stmt true_branch = WithNewScope([&]() { PrimExpr true_value = this->VisitExpr(op->true_branch).value(); this->EmitCallPacked("vm.builtin.copy", {true_value}, merge_register); }); - tir::Stmt false_branch = WithNewScope([&]() { + tirx::Stmt false_branch = WithNewScope([&]() { PrimExpr false_value = this->VisitExpr(op->false_branch).value(); this->EmitCallPacked("vm.builtin.copy", {false_value}, merge_register); }); - this->EmitStmt(tir::IfThenElse(cond_value, true_branch, false_branch)); + this->EmitStmt(tirx::IfThenElse(cond_value, true_branch, false_branch)); return RegListGet(merge_register); } @@ -350,7 +350,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } else if (func.as()) { *kind = VMFuncInfo::FuncKind::kVMTIRFunc; return gvar->name_hint; - } else if (func.as()) { + } else if (func.as()) { *kind = VMFuncInfo::FuncKind::kPackedFunc; return gvar->name_hint; } else { @@ -368,15 +368,15 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } // Lookup PrimFunc in the same module // We can do direct PrimFunc call in such cases - ffi::Optional LookupPrimFunc(const ffi::String& name) { + ffi::Optional LookupPrimFunc(const ffi::String& name) { if (!ctx_mod_->ContainGlobalVar(name)) return std::nullopt; GlobalVar gvar = ctx_mod_->GetGlobalVar(name); auto it = ctx_mod_->functions.find(gvar); if (it != ctx_mod_->functions.end()) { BaseFunc func = (*it).second; - if (auto* prim_func = func.as()) { - return ffi::GetRef(prim_func); + if (auto* prim_func = func.as()) { + return ffi::GetRef(prim_func); } } return std::nullopt; @@ -418,7 +418,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { auto vdevice = GetGlobalVDevice(ctx_mod_, vdevice_index); if (vdevice.defined()) { - args.push_back(tir::StringImm(vdevice.value()->memory_scope)); + args.push_back(tirx::StringImm(vdevice.value()->memory_scope)); } this->EmitCallPacked("vm.builtin.alloc_tensor", args, dst_reg); @@ -429,12 +429,12 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { PrimExpr arg = this->VisitExpr(call_node->args[0]).value(); // Check the arg is a register. - const auto* tir_call = arg.as(); + const auto* tir_call = arg.as(); TVM_FFI_ICHECK(tir_call != nullptr); - TVM_FFI_ICHECK(tir_call->op == tir::builtin::anylist_getitem()); + TVM_FFI_ICHECK(tir_call->op == tirx::builtin::anylist_getitem()); TVM_FFI_ICHECK(tir_call->args.size() == 2); TVM_FFI_ICHECK(tir_call->args[0].same_as(reg_anylist_handle_)); - const auto* p_dst_reg = tir_call->args[1].as(); + const auto* p_dst_reg = tir_call->args[1].as(); TVM_FFI_ICHECK(p_dst_reg != nullptr); TVM_FFI_ICHECK(p_dst_reg->dtype == DataType::Int(32)); @@ -470,7 +470,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { if (symbol.has_value() && kind == VMFuncInfo::FuncKind::kPackedFunc) { // primfunc in the same module. // use cpacked to directly invoke without named based lookup - if (ffi::Optional prim_func = LookupPrimFunc(symbol.value())) { + if (ffi::Optional prim_func = LookupPrimFunc(symbol.value())) { this->EmitCallCPacked(prim_func.value(), args, dst_reg); } else { this->EmitCallPacked(symbol.value(), args, dst_reg); @@ -488,10 +488,10 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } template - tir::Stmt WithNewScope(const FLambda& callback) { + tirx::Stmt WithNewScope(const FLambda& callback) { stmt_stack_.push_back({}); callback(); - tir::Stmt stmt = tir::SeqStmt::Flatten(stmt_stack_.back()); + tirx::Stmt stmt = tirx::SeqStmt::Flatten(stmt_stack_.back()); stmt_stack_.pop_back(); return stmt; } @@ -506,20 +506,20 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { /*! \brief Internal ExecBuilder. */ relax::ExecBuilder builder_; /*! \brief List to ctx_ptr */ - tir::Var ctx_ptr_; + tirx::Var ctx_ptr_; /*! \brief List to store temp object registers */ - tir::Var reg_anylist_handle_; + tirx::Var reg_anylist_handle_; /*! \brief List to store closures */ - tir::Var func_anylist_handle_; + tirx::Var func_anylist_handle_; /*! \brief List to store constants */ - tir::Var const_anylist_handle_; + tirx::Var const_anylist_handle_; /*! * \brief Total number of virtual registers allocated. * \note The first two registers are reserved for special registers. */ int64_t registers_num_ = 0; /*! \brief Stack to build up statements */ - std::vector> stmt_stack_; + std::vector> stmt_stack_; /*! \brief Map from var to Expr. */ std::unordered_map> var_map_; /*! \brief the context module. */ diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index ac30fbed1e16..870a3c7e7cc4 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -29,7 +29,7 @@ #include #include #include -#include +#include namespace tvm { namespace relax { diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 4cca5d7b6c8b..f7173a804e2b 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -27,10 +27,10 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include namespace tvm { namespace relax { @@ -69,7 +69,7 @@ struct MatchShapeTodoItem { /*! \brief Slot map used for shape lowering. */ using PrimExprSlotMap = - std::unordered_map; + std::unordered_map; // Collector to collect PrimExprSlotMap class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor { @@ -304,8 +304,8 @@ class VMShapeLowerMutator void PopulateSlotInfo() { for (auto& kv : slot_map_) { auto* slot = kv.second; - if (!slot->expr.as()) { - ffi::Array dep_vars = tir::UndefinedVars(slot->expr); + if (!slot->expr.as()) { + ffi::Array dep_vars = tirx::UndefinedVars(slot->expr); for (auto var : dep_vars) { auto it = slot_map_.find(var); TVM_FFI_ICHECK(it != slot_map_.end()) @@ -463,7 +463,7 @@ class VMShapeLowerMutator // the value is not yet computed TVM_FFI_ICHECK(!require_value_computed) << "PrimExpr " << expr << " is not computed"; - if (expr.as()) { + if (expr.as()) { // It is a var we will populate it in this round. slot->value_computed = true; @@ -567,38 +567,38 @@ class VMShapeLowerMutator if (to_compute.size() == 0) return 0; TVM_FFI_ICHECK_GT(heap_size_->value, 0); // construct a PrimFunc that compute the shape. - tir::Var heap("heap", DataType::Handle()); + tirx::Var heap("heap", DataType::Handle()); ffi::Array buffer_shape{heap_size_}; - tir::Buffer buffer = tir::decl_buffer(buffer_shape, ShapeDType(), "H", "global"); - ffi::Map buffer_map; + tirx::Buffer buffer = tirx::decl_buffer(buffer_shape, ShapeDType(), "H", "global"); + ffi::Map buffer_map; buffer_map.Set(heap, buffer); - auto var_map = [&](const tir::Var& var) -> ffi::Optional { + auto var_map = [&](const tirx::Var& var) -> ffi::Optional { auto it = slot_map_.find(var); TVM_FFI_ICHECK(it != slot_map_.end()); - return tir::BufferLoad(buffer, {IntImm(ShapeDType(), it->second->index)}); + return tirx::BufferLoad(buffer, {IntImm(ShapeDType(), it->second->index)}); }; - ffi::Array seq; + ffi::Array seq; for (PrimExprSlot* slot : to_compute) { TVM_FFI_ICHECK(!slot->value_computed); slot->value_computed = true; - PrimExpr value = tir::Substitute(slot->expr, var_map); - seq.push_back(tir::BufferStore(buffer, value, {IntImm(ShapeDType(), slot->index)})); + PrimExpr value = tirx::Substitute(slot->expr, var_map); + seq.push_back(tirx::BufferStore(buffer, value, {IntImm(ShapeDType(), slot->index)})); } - tir::Stmt body = tir::SeqStmt::Flatten(seq); - ffi::Array params{heap}; + tirx::Stmt body = tirx::SeqStmt::Flatten(seq); + ffi::Array params{heap}; Type ret_type = VoidType(); // TODO(relax-team): Consider attach the target attribute to // the shape_func to indicate that this is a host function // This could require us to attach target to the relax function here. - tir::PrimFunc shape_func(params, body, ret_type, buffer_map); + tirx::PrimFunc shape_func(params, body, ret_type, buffer_map); if (!shape_func->attrs.GetAttr(tvm::attr::kTarget).has_value()) { // kTarget and kIsHostFunc are mutually exclusive shape_func = - WithAttr(std::move(shape_func), tvm::tir::attr::kIsHostFunc, true); + WithAttr(std::move(shape_func), tvm::tirx::attr::kIsHostFunc, true); } GlobalVar shape_func_var = builder_->AddFunction(shape_func, "shape_func"); builder_->Emit(Call(shape_func_var, {shape_heap_}), "_"); diff --git a/src/relax/distributed/axis_group_graph.cc b/src/relax/distributed/axis_group_graph.cc index b53f68e44d4b..47d768d32e1d 100644 --- a/src/relax/distributed/axis_group_graph.cc +++ b/src/relax/distributed/axis_group_graph.cc @@ -28,7 +28,7 @@ #include namespace tvm { -namespace tir { +namespace tirx { Var GetShardingVarFromIndex(PrimExpr index, ffi::Map var_range, arith::Analyzer* analyzer) { if (index.as()) { @@ -55,7 +55,7 @@ Var GetShardingVarFromIndex(PrimExpr index, ffi::Map var_range, } return ffi::GetRef(source_var); } -} // namespace tir +} // namespace tirx } // namespace tvm namespace tvm { @@ -347,10 +347,10 @@ inline int GetNumOutput(Call call) { } } -void BuildAxisGraphCallTIR(const Var& output_var, const Call& call, const tir::PrimFunc& func, +void BuildAxisGraphCallTIR(const Var& output_var, const Call& call, const tirx::PrimFunc& func, distributed::AxisGroupGraph* axis_group_graph) { - auto tir_var_axis_group_list = tir::BufferAxisGraphExtractor::GetTIRVarAxisGraph(func); - ffi::Map input_var_to_relax_expr; + auto tir_var_axis_group_list = tirx::BufferAxisGraphExtractor::GetTIRVarAxisGraph(func); + ffi::Map input_var_to_relax_expr; ffi::Array input_list = Downcast(call->args[1])->fields; input_list.push_back(output_var); for (int i = 0; i < static_cast(input_list.size()); i++) { diff --git a/src/relax/distributed/transform/legalize_redistribute.cc b/src/relax/distributed/transform/legalize_redistribute.cc index 88fdb14ffd0c..510927883752 100644 --- a/src/relax/distributed/transform/legalize_redistribute.cc +++ b/src/relax/distributed/transform/legalize_redistribute.cc @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include "../../../s_tir/schedule/transform.h" #include "../../op/ccl/ccl.h" diff --git a/src/relax/distributed/transform/lower_distir.cc b/src/relax/distributed/transform/lower_distir.cc index 49fc366a5360..ee94f143c84a 100644 --- a/src/relax/distributed/transform/lower_distir.cc +++ b/src/relax/distributed/transform/lower_distir.cc @@ -30,7 +30,7 @@ #include #include #include -#include +#include #include "../../../s_tir/schedule/transform.h" #include "../../op/ccl/ccl.h" diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc b/src/relax/distributed/transform/lower_global_view_to_local_view.cc index 66cd520cebb0..41747a283702 100644 --- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc +++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc @@ -27,12 +27,12 @@ #include #include #include -#include +#include #include "../../../s_tir/schedule/transform.h" #include "utils.h" namespace tvm { -namespace tir { +namespace tirx { using namespace tvm::relax::distributed; using s_tir::ReplaceBuffer; @@ -349,7 +349,7 @@ class DistributedBufferCompactor : StmtExprMutator { std::string add_allreduce_kind_; }; -} // namespace tir +} // namespace tirx } // namespace tvm namespace tvm { @@ -410,11 +410,11 @@ class LowerTIRToLocalView : public ExprMutator { sharding_specs.push_back(ShardingSpec(sinfo->device_mesh, sinfo->placement)); } GlobalVar gvar = Downcast(val->args[0]); - tir::PrimFunc prim_func = MatchPrimFunc(builder_->GetContextIRModule(), gvar).value(); - tir::PrimFunc new_prim_func; + tirx::PrimFunc prim_func = MatchPrimFunc(builder_->GetContextIRModule(), gvar).value(); + tirx::PrimFunc new_prim_func; std::string allreduce_kind; std::tie(new_prim_func, allreduce_kind) = - tir::DistributedBufferCompactor::DistBufferCompact(sharding_specs, prim_func); + tirx::DistributedBufferCompactor::DistBufferCompact(sharding_specs, prim_func); auto new_gvar = builder_->AddFunction(new_prim_func, gvar->name_hint); Call call = Downcast(this->VisitExpr(binding->value)); ObjectPtr new_call_node = ffi::make_object(*call.get()); diff --git a/src/relax/distributed/transform/propagate_sharding.cc b/src/relax/distributed/transform/propagate_sharding.cc index 1123b1db25b5..47748469f945 100644 --- a/src/relax/distributed/transform/propagate_sharding.cc +++ b/src/relax/distributed/transform/propagate_sharding.cc @@ -158,7 +158,7 @@ class AxisGroupGraphBuilder : public ExprVisitor { CollectAxisGraphReshape(binding, val, axis_group_graph_); static const Op& call_tir_op = Op::Get("relax.call_tir"); if (val->op.same_as(call_tir_op)) { - if (ffi::Optional func = MatchPrimFunc(mod_, val->args[0])) { + if (ffi::Optional func = MatchPrimFunc(mod_, val->args[0])) { BuildAxisGraphCallTIR(binding->var, ffi::GetRef(val), func.value(), axis_group_graph_); } @@ -439,7 +439,7 @@ class DistributedIRBuilder : public ExprMutator { Expr VisitExpr_(const CallNode* call) final { static const Op& call_tir_op = Op::Get("relax.call_tir"); FBuildAxisGraph f = [&](const Var& var, const Call& call, AxisGroupGraph* axis_group_graph) { - ffi::Optional prim_func = + ffi::Optional prim_func = MatchPrimFunc(this->builder_->GetContextIRModule(), call->args[0]); TVM_FFI_ICHECK(prim_func); return BuildAxisGraphCallTIR(var, call, prim_func.value(), axis_group_graph); diff --git a/src/relax/distributed/transform/utils.h b/src/relax/distributed/transform/utils.h index 963efc15f6a0..d0effbb318a4 100644 --- a/src/relax/distributed/transform/utils.h +++ b/src/relax/distributed/transform/utils.h @@ -33,12 +33,12 @@ namespace distributed { * \brief Pattern match op to a TIR function and look it up. * \return The TIR function, or nullopt if pattern match fails. */ -inline ffi::Optional MatchPrimFunc(const IRModule& mod_, const Expr& op) { +inline ffi::Optional MatchPrimFunc(const IRModule& mod_, const Expr& op) { const GlobalVar& global_var = Downcast(op); // NOTE: as check works for nullptr(returns null) ffi::Optional base_func = mod_->functions.Get(global_var); - if (auto* pfunc = base_func.as()) { - return ffi::GetRef(pfunc); + if (auto* pfunc = base_func.as()) { + return ffi::GetRef(pfunc); } return std::nullopt; } diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index c2af644fba26..78098ff3e94f 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -32,7 +32,7 @@ #include #include #include -#include +#include #include #include @@ -89,7 +89,7 @@ class BlockBuilderImpl : public BlockBuilderNode { StructInfo finfo; if (func->struct_info_.defined()) { finfo = GetStructInfo(func); - } else if (auto* prim_func = func.as()) { + } else if (auto* prim_func = func.as()) { // NOTE: use a slightly different struct info than checked type // in PrimFunc so handle can turn into Tensor. // TODO(relax-team): add fine-grained PrimFunc struct info signature generation. @@ -206,9 +206,9 @@ class BlockBuilderImpl : public BlockBuilderNode { // defined in parameter struct info annotations. The implementation // is correct (since we will simply erase all relax Vars in EraseToWellDefined), // but can be further improved. - ffi::Map var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); + ffi::Map var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); for (const auto& kv : var_map) { - const tir::Var& shape_var = kv.first; + const tirx::Var& shape_var = kv.first; const PrimExpr& shape_expr = kv.second; auto it = shape_var_map.find(shape_var); if (it == shape_var_map.end()) { @@ -343,7 +343,7 @@ class BlockBuilderImpl : public BlockBuilderNode { // // TODO(relax-team) tracks the var defined also through match-cast. /*! \brief set of defined symbolic vars, value as themself. */ - ffi::Map shape_var_map; + ffi::Map shape_var_map; }; /*! \brief A stack to store block frames. */ @@ -468,7 +468,7 @@ class BlockBuilderImpl : public BlockBuilderNode { // shape vars as defined when calling BeginScope(params) class StructInfoVarCollector : public StructInfoVisitor { public: - static ffi::Map Collect(const StructInfo& struct_info) { + static ffi::Map Collect(const StructInfo& struct_info) { StructInfoVarCollector collector; collector(struct_info); return collector.shape_var_map_; @@ -479,8 +479,8 @@ class BlockBuilderImpl : public BlockBuilderNode { if (const auto* shape_expr = op->shape.as()) { for (const PrimExpr& s : shape_expr->values) { // Only collect single var defined shape. Ignore something like `R.Tensor((m + 1, n + 1)) - if (const auto* var = s.as()) { - shape_var_map_.Set(ffi::GetRef(var), s); + if (const auto* var = s.as()) { + shape_var_map_.Set(ffi::GetRef(var), s); } } } @@ -489,8 +489,8 @@ class BlockBuilderImpl : public BlockBuilderNode { void VisitStructInfo_(const ShapeStructInfoNode* op) final { for (const PrimExpr& s : op->values.value_or(ffi::Array())) { // Only collect single var defined shape. Ignore something like `R.Shape((m + 1, n + 1)) - if (const auto* var = s.as()) { - shape_var_map_.Set(ffi::GetRef(var), s); + if (const auto* var = s.as()) { + shape_var_map_.Set(ffi::GetRef(var), s); } } } @@ -498,14 +498,14 @@ class BlockBuilderImpl : public BlockBuilderNode { void VisitStructInfo_(const PrimStructInfoNode* op) final { // Only collect single var defined shape. Ignore something like `R.Prim(value=m + 1)` if (op->value.defined()) { - if (auto var = op->value.as()) { + if (auto var = op->value.as()) { shape_var_map_.Set(var.value(), op->value.value()); } } } private: - ffi::Map shape_var_map_; + ffi::Map shape_var_map_; }; }; @@ -865,7 +865,7 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor ffi::Optional { + auto f_shape_var_map = [curr_scope](tirx::Var var) -> ffi::Optional { auto it = curr_scope->shape_var_map.find(var); if (it != curr_scope->shape_var_map.end()) return (*it).second; return std::nullopt; diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index b13fb84105b7..4a46a89cef74 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -211,7 +211,7 @@ static std::optional TryValidate( auto [necessary_condition, is_sufficient] = constraint->AsPrimExpr(query_match_state); necessary_condition = analyzer->Simplify(necessary_condition); - const auto* known = tir::as_const_int(necessary_condition); + const auto* known = tirx::as_const_int(necessary_condition); if (known && *known && is_sufficient) { // The condition passes, and the expression provided is both diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index a95b51745de0..868c8a1971ec 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -31,7 +31,7 @@ #include #include #include -#include +#include #include diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 3c0e57dc073d..5b61ff780e07 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -32,7 +32,7 @@ #include #include #include -#include +#include #include #include @@ -458,8 +458,8 @@ PrimExpr DFPatternMatcher::SimplifyCondition(PrimExpr condition) { } auto sort_key = [](PrimExpr expr) -> ffi::String { - if (const auto* equal = expr.as()) { - if (const auto* var = equal->a.as()) { + if (const auto* equal = expr.as()) { + if (const auto* var = equal->a.as()) { return var->name_hint; } } @@ -481,7 +481,7 @@ static bool ShapeEqual(Analyzer* analyzer, const ffi::Array& lhs, const ffi::Array& rhs) { if (lhs.size() != rhs.size()) return false; for (size_t i = 0; i < lhs.size(); ++i) - if (!tir::is_one(analyzer->Simplify(lhs[i] == rhs[i]))) return false; + if (!tirx::is_one(analyzer->Simplify(lhs[i] == rhs[i]))) return false; return true; } diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index 7fcc5e8b9b62..004a856fd602 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -24,7 +24,7 @@ #include #include -#include +#include namespace tvm { namespace relax { @@ -38,7 +38,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_FFI_STATIC_INIT_BLOCK() { RXPlaceholderOpNode::RegisterReflection(); } -te::Tensor TETensor(Expr value, ffi::Map tir_var_map, std::string name) { +te::Tensor TETensor(Expr value, ffi::Map tir_var_map, std::string name) { auto n = ffi::make_object(); n->name = name; n->value = value; @@ -69,7 +69,7 @@ te::Tensor TETensor(Expr value, ffi::Map tir_var_map, std::s "match_cast " << "to constrain the shape before passing into te_tensor"; n->shape = shape_expr->values.Map( - [&tir_var_map](const PrimExpr& e) { return tir::Substitute(e, tir_var_map); }); + [&tir_var_map](const PrimExpr& e) { return tirx::Substitute(e, tir_var_map); }); n->dtype = tensor_sinfo->dtype; return te::PlaceholderOp(n).output(0); } diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h index f09dcb7f8230..d7c998729a10 100644 --- a/src/relax/ir/emit_te.h +++ b/src/relax/ir/emit_te.h @@ -67,7 +67,7 @@ class RXPlaceholderOpNode : public te::PlaceholderOpNode { * shape of the input Expr. * \param name The name of the created tensor. */ -te::Tensor TETensor(Expr value, ffi::Map tir_var_map, std::string name); +te::Tensor TETensor(Expr value, ffi::Map tir_var_map, std::string name); } // namespace relax } // namespace tvm diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index b011327e8db1..2fbd573a5f7f 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -584,8 +584,8 @@ Function::Function(ffi::Array params, Expr body, ffi::Optional // used if they were defined by the function's parameters. auto f_shape_var_map = [&] { auto tir_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo))); - std::unordered_set lookup(tir_vars.begin(), tir_vars.end()); - return [lookup = std::move(lookup)](const tir::Var& var) -> ffi::Optional { + std::unordered_set lookup(tir_vars.begin(), tir_vars.end()); + return [lookup = std::move(lookup)](const tirx::Var& var) -> ffi::Optional { if (lookup.count(var)) { return var; } else { diff --git a/src/relax/ir/tir_pattern.cc b/src/relax/ir/tir_pattern.cc index d579aea632bc..4d00d17976dd 100644 --- a/src/relax/ir/tir_pattern.cc +++ b/src/relax/ir/tir_pattern.cc @@ -25,7 +25,7 @@ namespace relax { TVM_FFI_STATIC_INIT_BLOCK() { MatchResultNode::RegisterReflection(); } MatchResult::MatchResult(TIRPattern pattern, ffi::Array symbol_values, - ffi::Array matched_buffers) { + ffi::Array matched_buffers) { auto n = ffi::make_object(); n->pattern = std::move(pattern); n->symbol_values = std::move(symbol_values); diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc index d787f906d2ca..3fcdeeab834a 100644 --- a/src/relax/ir/transform.cc +++ b/src/relax/ir/transform.cc @@ -235,13 +235,13 @@ class DataflowBlockMutator : public ExprMutator { BindingBlock VisitBindingBlock_(const DataflowBlockNode* n) final { // collect Global Scope Vars and Symbolic Vars inside the DataflowBlock ffi::Map global_scope_vars; - ffi::Map symbolic_vars; + ffi::Map symbolic_vars; for (const Binding& binding : n->bindings) { Var var = binding->var; if (const auto* match_cast = binding.as()) { auto collected_vars = SymbolicVarCollector::Collect(match_cast->struct_info); - for (const tir::VarNode* var : collected_vars) { - symbolic_vars.Set(var->name_hint, ffi::GetRef(var)); + for (const tirx::VarNode* var : collected_vars) { + symbolic_vars.Set(var->name_hint, ffi::GetRef(var)); } } if (!var.as()) { @@ -258,9 +258,9 @@ class DataflowBlockMutator : public ExprMutator { Var var = binding->var; if (const auto* match_cast = binding.as()) { auto collected_vars = SymbolicVarCollector::Collect(match_cast->struct_info); - for (const tir::VarNode* var : collected_vars) { + for (const tirx::VarNode* var : collected_vars) { if (symbolic_vars.count(var->name_hint) > 0) { - tir::Var old_var = symbolic_vars[var->name_hint]; + tirx::Var old_var = symbolic_vars[var->name_hint]; TVM_FFI_ICHECK(var == old_var.get()) << "Error: DataflowBlock Pass should not rewrite any Symbolic Var."; symbolic_vars.erase(var->name_hint); @@ -282,7 +282,7 @@ class DataflowBlockMutator : public ExprMutator { private: class SymbolicVarCollector : public StructInfoVisitor { public: - static std::unordered_set Collect(const StructInfo& info) { + static std::unordered_set Collect(const StructInfo& info) { SymbolicVarCollector collector; collector.VisitStructInfo(info); return std::move(collector.symbolic_vars_); @@ -290,13 +290,13 @@ class DataflowBlockMutator : public ExprMutator { private: void VisitStructInfoExprField(const PrimExpr& expr) final { - if (const tir::VarNode* sym_var = expr.as()) { + if (const tirx::VarNode* sym_var = expr.as()) { symbolic_vars_.insert(sym_var); } } private: - std::unordered_set symbolic_vars_; + std::unordered_set symbolic_vars_; }; std::function pass_func_; diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc index bf384e863443..d9f617de14bf 100644 --- a/src/relax/op/nn/attention.cc +++ b/src/relax/op/nn/attention.cc @@ -125,7 +125,7 @@ StructInfo InferStructInfoAttention(const Call& call, const BlockBuilder& ctx) { } auto diag_equal_or_broadcast = [&](PrimExpr v1, PrimExpr v2, ffi::String m1, ffi::String m2, ffi::String dim) { - if (analyzer->CanProve(v1 != v2) && !tir::is_one(v2)) { + if (analyzer->CanProve(v1 != v2) && !tirx::is_one(v2)) { ctx->ReportFatal(Diagnostic::Error(call) << "The " << m1 << " " << dim << " and the " << m2 << " " << dim << " should be the same or broadcastable. However, the " << dim << " of " diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc index 648fc50c89c2..81c9fdf313d3 100644 --- a/src/relax/op/nn/convolution.cc +++ b/src/relax/op/nn/convolution.cc @@ -331,9 +331,9 @@ InferLayoutOutput InferLayoutConv2d( Layout desired_data_layout = (*it).second[0]; Layout desired_weight_layout = (*it).second[1]; Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0]; - tir::Layout input_layout(attrs->data_layout, DataType::Int(64)); - tir::Layout kernel_layout(attrs->kernel_layout, DataType::Int(64)); - tir::Layout out_layout(attrs->out_layout, DataType::Int(64)); + tirx::Layout input_layout(attrs->data_layout, DataType::Int(64)); + tirx::Layout kernel_layout(attrs->kernel_layout, DataType::Int(64)); + tirx::Layout out_layout(attrs->out_layout, DataType::Int(64)); if ((desired_data_layout.ndim() == input_layout.ndim()) && (desired_weight_layout.ndim() == kernel_layout.ndim()) && diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc index 3e963cfd145f..79fb3b4e5b7b 100644 --- a/src/relax/op/nn/pooling.cc +++ b/src/relax/op/nn/pooling.cc @@ -271,7 +271,7 @@ InferLayoutOutput InferLayoutPool2d( ObjectPtr new_attrs = ffi::make_object(*attrs); if (layout->layout.ndim() != layout->layout.ndim_primal()) { - tir::Layout in_layout(attrs->layout, DataType::Int(64)); + tirx::Layout in_layout(attrs->layout, DataType::Int(64)); auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout); auto data_si = GetStructInfo(call->args[0]); TensorStructInfo data_sinfo = data_si.as().value(); @@ -668,7 +668,7 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool2D( LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); ObjectPtr new_attrs = ffi::make_object(*attrs); if (layout->layout.ndim() != layout->layout.ndim_primal()) { - tir::Layout in_layout(attrs->layout, DataType::Int(64)); + tirx::Layout in_layout(attrs->layout, DataType::Int(64)); auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout); auto data_si = GetStructInfo(call->args[0]); TensorStructInfo data_sinfo = data_si.as().value(); diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index e78b75fe14ed..82fcf3750d4e 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -38,7 +38,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { } bool EqualConstInt(const PrimExpr& lhs, int64_t value) { - if (const int64_t* pvalue = tir::as_const_int(lhs)) { + if (const int64_t* pvalue = tirx::as_const_int(lhs)) { return pvalue[0] == value; } return false; @@ -46,12 +46,12 @@ bool EqualConstInt(const PrimExpr& lhs, int64_t value) { bool EqualCheck(const PrimExpr& lhs, const PrimExpr& rhs) { PrimExpr diff = lhs - rhs; - if (const int64_t* pdiff = tir::as_const_int(diff)) { + if (const int64_t* pdiff = tirx::as_const_int(diff)) { return pdiff[0] == 0; } tvm::arith::Analyzer ana; diff = ana.Simplify(diff); - if (const int64_t* pdiff = tir::as_const_int(diff)) { + if (const int64_t* pdiff = tirx::as_const_int(diff)) { return pdiff[0] == 0; } return false; diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index d084f0c0eb10..02be653dc1df 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -189,12 +189,12 @@ bool CanProveLayoutTransform(const Layout& input_layout, const Layout& desired_l ffi::Array shape) { bool can_prove = true; try { - tir::BijectiveLayout todesired(input_layout, desired_layout); + tirx::BijectiveLayout todesired(input_layout, desired_layout); ffi::Array desired_shape = todesired.ForwardShape(shape); ffi::Array back_shape = todesired.BackwardShape(desired_shape); arith::Analyzer analyzer; for (size_t i = 0; i < shape.size(); ++i) { - if (tir::is_const_int(shape[i])) { + if (tirx::is_const_int(shape[i])) { if (!analyzer.CanProveEqual(shape[i], back_shape[i])) { can_prove = false; break; diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 19e0398892ab..724b0b62ba9e 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -525,21 +525,21 @@ inline ffi::Array GetCompletePadding3D(ffi::Array padding) { /*! * \brief Check if the given tensor layout can be converted to the given target layout. - * If convertible, return the tensor layout and the bijective conversion in tir::Layout and - * tir::BijectiveLayout accordingly. + * If convertible, return the tensor layout and the bijective conversion in tirx::Layout and + * tirx::BijectiveLayout accordingly. * \param call The context Call to the operator. * \param ctx The error reporting context. * \param tensor_layout The tensor layout to be checked * \param tgt_layout The target layout to be matched * \param tensor_name The name of the input tensor - * \return The tensor layout and the bijective conversion in tir::Layout and tir::BijectiveLayout + * \return The tensor layout and the bijective conversion in tirx::Layout and tirx::BijectiveLayout * accordingly. */ -inline std::pair CheckTensorLayout( +inline std::pair CheckTensorLayout( const Call& call, const BlockBuilder& ctx, const ffi::String& tensor_layout, const ffi::String& tgt_layout, const ffi::String& tensor_name) { - tir::Layout _tensor_layout(tensor_layout, DataType::Int(64)); - tir::BijectiveLayout tensor2tgt(_tensor_layout, tir::Layout(tgt_layout, DataType::Int(64))); + tirx::Layout _tensor_layout(tensor_layout, DataType::Int(64)); + tirx::BijectiveLayout tensor2tgt(_tensor_layout, tirx::Layout(tgt_layout, DataType::Int(64))); if (!tensor2tgt.defined()) { ctx->ReportFatal(Diagnostic::Error(call) << call->op << " requires the given " << tensor_name << " layout to be convertible from " << tgt_layout @@ -561,7 +561,7 @@ inline std::pair CheckTensorLayout( inline ffi::Optional CheckNdimPerLayoutAndGetShape(const Call& call, const BlockBuilder& ctx, const TensorStructInfo& sinfo, - const tir::Layout& layout) { + const tirx::Layout& layout) { if (!sinfo->IsUnknownNdim() && sinfo->ndim != static_cast(layout.ndim())) { ctx->ReportFatal(Diagnostic::Error(call) << "In " << call->op << ", layout " << layout << " requires the input to be " diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc index b4654ef95c24..b227991ae5cb 100644 --- a/src/relax/op/tensor/inspect.cc +++ b/src/relax/op/tensor/inspect.cc @@ -25,9 +25,9 @@ #include "inspect.h" #include -#include -#include -#include +#include +#include +#include #include @@ -87,20 +87,20 @@ std::tuple GetTensorArgInfoWithIndex(const Cal DataType GetTensorDataType(const Call& call) { return GetTensorArgInfo(call)->dtype; } -tir::PrimFunc GetDLTensorField(tir::builtin::TVMStructFieldKind field, DataType field_dtype) { - tir::Var dlpack_handle("dlpack_handle", DataType::Handle()); +tirx::PrimFunc GetDLTensorField(tirx::builtin::TVMStructFieldKind field, DataType field_dtype) { + tirx::Var dlpack_handle("dlpack_handle", DataType::Handle()); - tir::Var value("value", field_dtype); + tirx::Var value("value", field_dtype); - tir::Stmt body = - tir::SeqStmt({tir::Bind(value, tir::Call(field_dtype, tir::builtin::tvm_struct_get(), - {dlpack_handle, IntImm(DataType::Int(32), 0), - IntImm(DataType::Int(32), field)})), - tir::Evaluate(tvm::ret(value))}); + tirx::Stmt body = + tirx::SeqStmt({tirx::Bind(value, tirx::Call(field_dtype, tirx::builtin::tvm_struct_get(), + {dlpack_handle, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), field)})), + tirx::Evaluate(tvm::ret(value))}); - DictAttrs attrs({{"tir.is_scheduled", true}, {"tir.is_host", true}}); + DictAttrs attrs({{"tirx.is_scheduled", true}, {"tirx.is_host", true}}); - tir::PrimFunc func(ffi::Array{dlpack_handle}, body, PrimType(field_dtype), {}, attrs); + tirx::PrimFunc func(ffi::Array{dlpack_handle}, body, PrimType(field_dtype), {}, attrs); FuncStructInfo sinfo({TensorStructInfo(DataType::Void(), kUnknownNDim)}, PrimStructInfo(field_dtype)); @@ -140,8 +140,8 @@ Expr LegalizeTensorDtypeCode(const BlockBuilder& bb, const Call& call) { auto field_dtype = Downcast(call->struct_info_)->dtype; Expr arg = call->args[0]; - tir::PrimFunc getter = - GetDLTensorField(tir::builtin::TVMStructFieldKind::kDLTensorTypeCode, field_dtype); + tirx::PrimFunc getter = + GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeCode, field_dtype); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_code"); return Call(gvar_getter, {arg}); @@ -178,8 +178,8 @@ Expr LegalizeTensorDtypeBits(const BlockBuilder& bb, const Call& call) { auto field_dtype = Downcast(call->struct_info_)->dtype; Expr arg = call->args[0]; - tir::PrimFunc getter = - GetDLTensorField(tir::builtin::TVMStructFieldKind::kDLTensorTypeBits, field_dtype); + tirx::PrimFunc getter = + GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeBits, field_dtype); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_bits"); return Call(gvar_getter, {arg}); @@ -216,8 +216,8 @@ Expr LegalizeTensorDtypeLanes(const BlockBuilder& bb, const Call& call) { auto field_dtype = Downcast(call->struct_info_)->dtype; Expr arg = call->args[0]; - tir::PrimFunc getter = - GetDLTensorField(tir::builtin::TVMStructFieldKind::kDLTensorTypeLanes, field_dtype); + tirx::PrimFunc getter = + GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorTypeLanes, field_dtype); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_dtype_lanes"); return Call(gvar_getter, {arg}); @@ -254,8 +254,8 @@ Expr LegalizeTensorNDim(const BlockBuilder& bb, const Call& call) { auto field_dtype = Downcast(call->struct_info_)->dtype; Expr arg = call->args[0]; - tir::PrimFunc getter = - GetDLTensorField(tir::builtin::TVMStructFieldKind::kDLTensorNDim, field_dtype); + tirx::PrimFunc getter = + GetDLTensorField(tirx::builtin::TVMStructFieldKind::kDLTensorNDim, field_dtype); GlobalVar gvar_getter = bb->AddFunction(getter, "_get_tensor_ndim"); return Call(gvar_getter, {arg}); @@ -295,37 +295,38 @@ StructInfo InferStructInfoTensorShape(const Call& call, const BlockBuilder&) { Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { auto field_dtype = Downcast(call->struct_info_)->dtype; - tir::PrimFunc getter = [&]() -> tir::PrimFunc { - tir::Var dlpack_handle("dlpack_handle", DataType::Handle()); - tir::Var axis("axis", DataType::Int(64)); + tirx::PrimFunc getter = [&]() -> tirx::PrimFunc { + tirx::Var dlpack_handle("dlpack_handle", DataType::Handle()); + tirx::Var axis("axis", DataType::Int(64)); - tir::Var ndim("ndim", DataType::Int(32)); + tirx::Var ndim("ndim", DataType::Int(32)); - tir::Buffer shape_buffer = tir::decl_buffer({ndim}, field_dtype, "shape"); + tirx::Buffer shape_buffer = tirx::decl_buffer({ndim}, field_dtype, "shape"); - tir::Var extent("extent", field_dtype); + tirx::Var extent("extent", field_dtype); - tir::Stmt body = tir::SeqStmt( - {tir::AssertStmt(0 <= axis, tir::StringImm("RuntimeError"), - {tir::StringImm("Specified axis may not be negative")}), - tir::Bind(ndim, tir::Call(ndim->dtype, tir::builtin::tvm_struct_get(), - {dlpack_handle, IntImm(DataType::Int(32), 0), - IntImm(DataType::Int(32), - tir::builtin::TVMStructFieldKind::kDLTensorNDim)})), - tir::AssertStmt( - axis < tvm::cast(axis->dtype, ndim), tir::StringImm("RuntimeError"), - {tir::StringImm("Specified axis may not be larger than the tensor's dimensionality")}), - tir::Bind(shape_buffer->data, - tir::Call(DataType::Handle(), tir::builtin::tvm_struct_get(), - {dlpack_handle, IntImm(DataType::Int(32), 0), - IntImm(DataType::Int(32), - tir::builtin::TVMStructFieldKind::kDLTensorShape)})), - tir::DeclBuffer(shape_buffer), tir::Bind(extent, tir::BufferLoad(shape_buffer, {axis})), - tir::Evaluate(tvm::ret(extent))}); + tirx::Stmt body = tirx::SeqStmt( + {tirx::AssertStmt(0 <= axis, tirx::StringImm("RuntimeError"), + {tirx::StringImm("Specified axis may not be negative")}), + tirx::Bind(ndim, tirx::Call(ndim->dtype, tirx::builtin::tvm_struct_get(), + {dlpack_handle, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), + tirx::builtin::TVMStructFieldKind::kDLTensorNDim)})), + tirx::AssertStmt( + axis < tvm::cast(axis->dtype, ndim), tirx::StringImm("RuntimeError"), + {tirx::StringImm( + "Specified axis may not be larger than the tensor's dimensionality")}), + tirx::Bind(shape_buffer->data, + tirx::Call(DataType::Handle(), tirx::builtin::tvm_struct_get(), + {dlpack_handle, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), + tirx::builtin::TVMStructFieldKind::kDLTensorShape)})), + tirx::DeclBuffer(shape_buffer), tirx::Bind(extent, tirx::BufferLoad(shape_buffer, {axis})), + tirx::Evaluate(tvm::ret(extent))}); - DictAttrs attrs({{"tir.is_scheduled", true}, {"tir.is_host", true}}); + DictAttrs attrs({{"tirx.is_scheduled", true}, {"tirx.is_host", true}}); - tir::PrimFunc func({dlpack_handle, axis}, body, PrimType(field_dtype), {}, attrs); + tirx::PrimFunc func({dlpack_handle, axis}, body, PrimType(field_dtype), {}, attrs); FuncStructInfo sinfo( {TensorStructInfo(DataType::Void(), kUnknownNDim), PrimStructInfo(axis->dtype)}, @@ -367,9 +368,9 @@ StructInfo InferStructInfoTensorStride(const Call& call, const BlockBuilder&) { // As of 2024-03-14, Relax does not have an explicit // representation for striding in `TensorStructInfo`. The // `FLegalize` function for most operators is implemented in terms - // of `topi`, and is then converted from TE to `tir::PrimFunc` - // using `tvm::tir::CreatePrimFunc`. The `te::Tensor` is - // converted to a `tir::Buffer` in `RewriteStageToBlock`, and uses + // of `topi`, and is then converted from TE to `tirx::PrimFunc` + // using `tvm::tirx::CreatePrimFunc`. The `te::Tensor` is + // converted to a `tirx::Buffer` in `RewriteStageToBlock`, and uses // the default empty list for the strides. The empty strides // represent a compact data array. // diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index fc6ec6b8aa04..a66c3fba3979 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -703,7 +703,7 @@ TVM_REGISTER_OP("relax.index_tensor") /* relax.layout_transform */ -Expr layout_transform(Expr x, tir::IndexMap index_map, ffi::Optional pad_value, +Expr layout_transform(Expr x, tirx::IndexMap index_map, ffi::Optional pad_value, ffi::Optional> axis_separators, ffi::Optional> input_axis_separators) { ObjectPtr attrs = ffi::make_object(); @@ -724,7 +724,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); - tir::IndexMap index_map = attrs->index_map; + tirx::IndexMap index_map = attrs->index_map; ffi::Optional optional_pad_value = attrs->pad_value; // Check pad_value has same dtype as input. @@ -1350,7 +1350,7 @@ InferLayoutOutput InferLayoutSqueeze( } else { axis.reserve(ndim); for (int i = 0; i < ndim; ++i) { - if (tir::is_one(shape->values[i])) { + if (tirx::is_one(shape->values[i])) { axis.push_back(Integer(i)); } } @@ -1974,7 +1974,7 @@ InferLayoutOutput InferLayoutTile( // Same dimension: reorder repeats according to layout transformation. // If len(repeats) < ndim, it's padded with 1s at the beginning. for (int i = 0; i < ndim; ++i) { - const tir::LayoutAxis& axis = existing_layout_obj[i]; + const tirx::LayoutAxis& axis = existing_layout_obj[i]; int pos_in_initial = initial_layout.IndexOf(axis); TVM_FFI_ICHECK_NE(pos_in_initial, -1) << "Axis not found in initial layout"; // If len(repeats) < ndim, repeats are right-aligned. @@ -1996,7 +1996,7 @@ InferLayoutOutput InferLayoutTile( } // Repeats for existing dimensions need to be permuted. for (int i = 0; i < ndim; ++i) { - const tir::LayoutAxis& axis = existing_layout_obj[i]; + const tirx::LayoutAxis& axis = existing_layout_obj[i]; int pos_in_initial = initial_layout.IndexOf(axis); TVM_FFI_ICHECK_NE(pos_in_initial, -1) << "Axis not found in initial layout"; new_repeats.push_back(attrs->repeats[pos_in_initial + num_new_dims]); diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 84d53addcc69..260d27f1ef1d 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -72,7 +72,7 @@ Expr flatten(Expr x); * \param input axis_separators Array of values for input buffer. * \return The transformed result. */ -Expr layout_transform(Expr x, tir::IndexMap index_map, ffi::Optional pad_value, +Expr layout_transform(Expr x, tirx::IndexMap index_map, ffi::Optional pad_value, ffi::Optional> axis_separators, ffi::Optional> input_axis_separators = std::nullopt); diff --git a/src/relax/op/tensor/sampling.cc b/src/relax/op/tensor/sampling.cc index ca5635baa74b..f54ce63e3788 100644 --- a/src/relax/op/tensor/sampling.cc +++ b/src/relax/op/tensor/sampling.cc @@ -122,8 +122,8 @@ StructInfo InferStructInfoMultinomialFromUniform(const Call& call, const BlockBu << n << "` and the given sample_indices tensor has batch size `" << sample_indices_shape->values[0] << "`"); } - if (!tir::is_one(uniform_sample_shape->values[1]) || - !tir::is_one(sample_indices_shape->values[1])) { + if (!tirx::is_one(uniform_sample_shape->values[1]) || + !tirx::is_one(sample_indices_shape->values[1])) { ctx->ReportFatal(Diagnostic::Error(call) << "Multinomial_from_uniform op requires the input uniform_sample and " "sample_indices to be 2D tensors with the second dimension being 1. " diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index fcbd30cd3c4a..84ad94c3887e 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -77,7 +77,7 @@ std::tuple)>> auto lower_bounds = func->GetAttr>("tir_var_lower_bound"); if (upper_bounds || lower_bounds) { - ffi::Map name_lookup; + ffi::Map name_lookup; for (const auto& tir_var : TIRVarsInStructInfo(GetStructInfo(func))) { name_lookup.Set(tir_var->name_hint, tir_var); symbolic_var_constraints = symbolic_var_constraints && (0 <= tir_var); diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index f324af5a425d..ff0aaa95d49a 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -32,19 +32,19 @@ #include #include #include -#include +#include #include #include "../../te/operation/create_primfunc.h" namespace tvm { namespace relax { -using namespace tir; +using namespace tirx; static constexpr const char* kOperatorName = "operator_name"; /*! \brief Construct ranges from shape dimensions */ static ffi::Array ConstructRangeFromShape(const ffi::Array& shape) { - return shape.Map([](const PrimExpr& dim) { return Range(tir::make_zero(dim.dtype()), dim); }); + return shape.Map([](const PrimExpr& dim) { return Range(tirx::make_zero(dim.dtype()), dim); }); } static ffi::Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { @@ -81,7 +81,7 @@ bool IsTransformBijective(const Expr& expr, const IndexMap& transform) { class AlterOpImplMutator : public ExprMutator { public: AlterOpImplMutator( - const IRModule& mod, const ffi::Map& op_impl_map, + const IRModule& mod, const ffi::Map& op_impl_map, const ffi::Map>& op_buffer_transforms_, const ffi::Map>>>& axis_separators_, const ffi::Map>>>& @@ -110,7 +110,7 @@ class AlterOpImplMutator : public ExprMutator { Expr VisitExpr_(const CallNode* op) final { auto call = Downcast(ExprMutator::VisitExpr_(op)); - // TODO(@tvm-team): When we differentiate the call for tir function and packed function, + // TODO(@tvm-team): When we differentiate the call for tirx function and packed function, // this logic should be changed accordingly. if (!call->op.same_as(call_tir_op_)) return call; @@ -119,8 +119,8 @@ class AlterOpImplMutator : public ExprMutator { // Get operator name from callee TVM_FFI_ICHECK(call->args[0]->IsInstance()); - const tir::PrimFunc& old_func = - Downcast(mod_->Lookup(Downcast(call->args[0]))); + const tirx::PrimFunc& old_func = + Downcast(mod_->Lookup(Downcast(call->args[0]))); ffi::Optional maybe_op_kind = old_func->attrs.GetAttr(kOperatorName); // If the callee does not have kOperatorName attribute or no replacement is requested for @@ -214,8 +214,8 @@ class AlterOpImplMutator : public ExprMutator { // Create dynamic shapes for input and output tensors ffi::Array dyn_padded_shape, dyn_old_shape; for (int i = 0; i < t_shape; i++) { - tir::Var var1("p" + std::to_string(i), old_shape[i].dtype()); - tir::Var var2("i" + std::to_string(i), old_shape[i].dtype()); + tirx::Var var1("p" + std::to_string(i), old_shape[i].dtype()); + tirx::Var var2("i" + std::to_string(i), old_shape[i].dtype()); dyn_padded_shape.push_back(var1); dyn_old_shape.push_back(var2); } @@ -225,7 +225,7 @@ class AlterOpImplMutator : public ExprMutator { // Output tensor of remove_pad op te::Tensor output_tensor = te::compute( dyn_old_shape, - [&placeholder_tensor](const ffi::Array& indices) { + [&placeholder_tensor](const ffi::Array& indices) { return placeholder_tensor(indices); }, "output", topi::kElementWise); @@ -257,7 +257,7 @@ class AlterOpImplMutator : public ExprMutator { auto [inverse_index_map, padding_predicate] = index_map.NonSurjectiveInverse(initial_ranges, &analyzer); - if (tir::is_zero(padding_predicate)) { + if (tirx::is_zero(padding_predicate)) { return TransformLayout(expr, inverse_index_map, axis_separator, input_axis_separator); } else { auto padded_expr = builder_->Normalize( @@ -433,7 +433,7 @@ class AlterOpImplMutator : public ExprMutator { namespace transform { Pass AlterOpImpl( - const ffi::Map& op_impl_map, + const ffi::Map& op_impl_map, const ffi::Map>& op_buffer_transforms_, const ffi::Map>>>& axis_separators_, const ffi::Map>>>& diff --git a/src/relax/transform/annotate_tir_op_pattern.cc b/src/relax/transform/annotate_tir_op_pattern.cc index f5b1061b6708..104427a0065e 100644 --- a/src/relax/transform/annotate_tir_op_pattern.cc +++ b/src/relax/transform/annotate_tir_op_pattern.cc @@ -25,12 +25,12 @@ #include #include #include -#include +#include namespace tvm { namespace relax { -tir::PrimFunc AnnotateOpPattern(tir::PrimFunc f) { +tirx::PrimFunc AnnotateOpPattern(tirx::PrimFunc f) { if (f->HasNonzeroAttr("op_pattern")) { return f; } else { @@ -42,10 +42,10 @@ tir::PrimFunc AnnotateOpPattern(tir::PrimFunc f) { namespace transform { Pass AnnotateTIROpPattern() { - auto pass_func = [=](tir::PrimFunc f, IRModule m, PassContext ctx) { + auto pass_func = [=](tirx::PrimFunc f, IRModule m, PassContext ctx) { return AnnotateOpPattern(std::move(f)); }; - return tir::transform::CreatePrimFuncPass(pass_func, 0, "AnnotateTIROpPattern", {}); + return tirx::transform::CreatePrimFuncPass(pass_func, 0, "AnnotateTIROpPattern", {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/transform/attach_attr_layout_free_buffers.cc b/src/relax/transform/attach_attr_layout_free_buffers.cc index e1e2e0ca26b8..52138f86c1c0 100644 --- a/src/relax/transform/attach_attr_layout_free_buffers.cc +++ b/src/relax/transform/attach_attr_layout_free_buffers.cc @@ -25,7 +25,7 @@ #include #include #include -#include +#include namespace tvm { namespace relax { @@ -80,16 +80,16 @@ class AttrAttacher : public ExprMutator { layout_free_buffers.push_back(i); } } - // Attach the layout free buffers to the tir::PrimFunc - tir::PrimFunc func = WithAttr(Downcast(mod_->Lookup(gv)), "layout_free_buffers", - layout_free_buffers); + // Attach the layout free buffers to the tirx::PrimFunc + tirx::PrimFunc func = WithAttr(Downcast(mod_->Lookup(gv)), + "layout_free_buffers", layout_free_buffers); // Renew defs func = s_tir::RenewDefs(func); - // Add the updated tir::PrimFunc in the IRModule - // Note the blockbuilder would automatically combine the same tir function + // Add the updated tirx::PrimFunc in the IRModule + // Note the blockbuilder would automatically combine the same tirx function // So we don't need to worry about the duplicate insertion GlobalVar new_gv = builder_->AddFunction(func, gv->name_hint); - // Create a new call node with the updated tir::PrimFunc + // Create a new call node with the updated tirx::PrimFunc auto n = ffi::make_object(*op); n->args = {new_gv, Tuple(call_tir_args)}; return Call(n); @@ -104,7 +104,7 @@ namespace transform { Pass AttachAttrLayoutFreeBuffers() { auto pass_func = [=](IRModule mod, PassContext pc) { return AttrAttacher::Transform(mod); }; auto pass = CreateModulePass(pass_func, 0, "_AttachAttrLayoutFreeBuffers", {}); - // Apply DeadCodeElimination to remove unused tir::PrimFunc + // Apply DeadCodeElimination to remove unused tirx::PrimFunc return tvm::transform::Sequential({pass, DeadCodeElimination()}, "AttachAttrLayoutFreeBuffers"); } diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc index 0079b504989a..0f94fa1fa0c5 100644 --- a/src/relax/transform/attach_global_symbol.cc +++ b/src/relax/transform/attach_global_symbol.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include namespace tvm { namespace relax { @@ -47,10 +47,10 @@ Pass AttachGlobalSymbol() { ffi::Optional new_name; BaseFunc new_func; - if (auto* prim_func = func.as()) { + if (auto* prim_func = func.as()) { new_name = c_prefix + gvar->name_hint; new_func = - WithAttr(ffi::GetRef(prim_func), tvm::attr::kGlobalSymbol, new_name); + WithAttr(ffi::GetRef(prim_func), tvm::attr::kGlobalSymbol, new_name); } else if (auto* relax_func = func.as()) { new_name = gvar->name_hint; new_func = WithAttr(ffi::GetRef(relax_func), tvm::attr::kGlobalSymbol, new_name); diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index 35d393a918bc..2ac49f2a3746 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include #include @@ -32,7 +32,7 @@ namespace tvm { namespace relax { void MatchSymbolicVar(const Expr& arg, const Expr& constant, - ffi::Map* symbolic_var_map, arith::Analyzer* analyzer_) { + ffi::Map* symbolic_var_map, arith::Analyzer* analyzer_) { auto opt_arg_sinfo = MatchStructInfo(arg); TVM_FFI_ICHECK(opt_arg_sinfo) << "The struct info of the bound parameter is expected to be TensorStructInfo, but got: " @@ -68,11 +68,11 @@ void MatchSymbolicVar(const Expr& arg, const Expr& constant, for (int i = 0; i < arg_sinfo->ndim; ++i) { const PrimExpr& const_dim = const_shape->values[i]; - TVM_FFI_ICHECK(tir::is_const_int(const_dim)); - if (const auto* shape_var = arg_shape->values[i].as()) { - auto it = symbolic_var_map->find(ffi::GetRef(shape_var)); + TVM_FFI_ICHECK(tirx::is_const_int(const_dim)); + if (const auto* shape_var = arg_shape->values[i].as()) { + auto it = symbolic_var_map->find(ffi::GetRef(shape_var)); if (it == symbolic_var_map->end()) { - symbolic_var_map->Set(ffi::GetRef(shape_var), const_dim); + symbolic_var_map->Set(ffi::GetRef(shape_var), const_dim); } else { TVM_FFI_ICHECK(analyzer_->CanProveEqual((*it).second, const_dim)) << "The shape of the bound parameter is expected to be " << (*it).second @@ -82,7 +82,7 @@ void MatchSymbolicVar(const Expr& arg, const Expr& constant, } } -std::tuple, ffi::Map> NormalizeBindings( +std::tuple, ffi::Map> NormalizeBindings( const Function& func, const ffi::Map& untyped_params) { TVM_FFI_ICHECK(func.defined()); TVM_FFI_ICHECK(untyped_params.defined()); @@ -144,7 +144,7 @@ std::tuple, ffi::Map> NormalizeBindings( } arith::Analyzer analyzer; - ffi::Map symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer); + ffi::Map symbolic_var_map = InferSymbolicVarMap(relax_var_remap, &analyzer); // for (const auto& [bind_param, bind_expr] : relax_var_remap) { // MatchSymbolicVar(bind_param, bind_expr, &symbolic_var_map, &analyzer); diff --git a/src/relax/transform/bind_symbolic_vars.cc b/src/relax/transform/bind_symbolic_vars.cc index b7d69186f4b5..92823a690d9d 100644 --- a/src/relax/transform/bind_symbolic_vars.cc +++ b/src/relax/transform/bind_symbolic_vars.cc @@ -32,24 +32,24 @@ namespace tvm { namespace relax { Function FunctionBindSymbolicVars( - Function func, ffi::Map, PrimExpr> obj_remap) { + Function func, ffi::Map, PrimExpr> obj_remap) { // Early bail-out if no updates need to be made. if (obj_remap.empty()) { return func; } - ffi::Array old_symbolic_vars = DefinedSymbolicVars(func); + ffi::Array old_symbolic_vars = DefinedSymbolicVars(func); // Map from string to the variable(s) with that name. - std::unordered_map> string_lookup; - std::unordered_set symbolic_var_set; + std::unordered_map> string_lookup; + std::unordered_set symbolic_var_set; for (const auto& var : old_symbolic_vars) { string_lookup[var->name_hint].push_back(var); symbolic_var_set.insert(var.get()); } // Replacement map to be used when rewriting the function. - ffi::Map var_remap; + ffi::Map var_remap; for (const auto& [key, replacement] : obj_remap) { if (auto opt = key.as()) { ffi::String string_key = opt.value(); @@ -66,7 +66,7 @@ Function FunctionBindSymbolicVars( TVM_FFI_ICHECK(!var_remap.count(var)) << "Remap of variable " << var << " was defined multiple times"; var_remap.Set(var, replacement); - } else if (auto opt = key.as()) { + } else if (auto opt = key.as()) { auto var = opt.value(); TVM_FFI_ICHECK(!var_remap.count(var)) @@ -77,7 +77,7 @@ Function FunctionBindSymbolicVars( var_remap.Set(var, replacement); } else { TVM_FFI_THROW(InternalError) - << "Expected symbolic variable to be a tir::Var or a string name, " + << "Expected symbolic variable to be a tirx::Var or a string name, " << "but " << key << " was of type " << key.GetTypeKey(); } } @@ -95,7 +95,7 @@ Function FunctionBindSymbolicVars( namespace { IRModule ModuleBindSymbolicVars( - IRModule mod, ffi::Map, PrimExpr> binding_map) { + IRModule mod, ffi::Map, PrimExpr> binding_map) { std::unordered_set used; IRModule updates; for (const auto& [gvar, base_func] : mod->functions) { @@ -103,24 +103,24 @@ IRModule ModuleBindSymbolicVars( auto func = opt.value(); // Collect bindings that are used by this function. - auto func_binding_map = [&]() -> ffi::Map, PrimExpr> { + auto func_binding_map = [&]() -> ffi::Map, PrimExpr> { std::unordered_set var_names; - std::unordered_set vars; + std::unordered_set vars; for (const auto& var : DefinedSymbolicVars(func)) { var_names.insert(var->name_hint); vars.insert(var.get()); } - ffi::Map, PrimExpr> out; + ffi::Map, PrimExpr> out; for (const auto& [key, replacement] : binding_map) { bool used_by_function = false; if (auto opt = key.as()) { used_by_function = var_names.count(opt.value()); - } else if (auto ptr = key.as()) { + } else if (auto ptr = key.as()) { used_by_function = vars.count(ptr); } else { TVM_FFI_THROW(InternalError) - << "Expected symbolic variable to be a tir::Var " + << "Expected symbolic variable to be a tirx::Var " << "or a string name, but " << key << " was of type " << key.GetTypeKey(); } if (used_by_function) { @@ -162,7 +162,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace transform { -Pass BindSymbolicVars(ffi::Map, PrimExpr> binding_map, +Pass BindSymbolicVars(ffi::Map, PrimExpr> binding_map, ffi::Optional func_name) { auto pass_func = [=](IRModule mod, PassContext context) -> IRModule { if (func_name) { diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index 48efd50d482c..513cffcfa58f 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include "utils.h" diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 98fd075f55d5..104c4e0597cb 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -30,7 +30,7 @@ #include #include #include -#include +#include namespace tvm { namespace relax { @@ -118,13 +118,14 @@ class SymbolicVarCanonicalizer : public ExprMutator { if (known_values_.empty()) { return expr; } - PrimExpr output = tir::Substitute(expr, [this](const tir::Var& var) -> ffi::Optional { - if (auto it = known_values_.find(var); it != known_values_.end()) { - return it->second.expr; - } else { - return std::nullopt; - } - }); + PrimExpr output = + tirx::Substitute(expr, [this](const tirx::Var& var) -> ffi::Optional { + if (auto it = known_values_.find(var); it != known_values_.end()) { + return it->second.expr; + } else { + return std::nullopt; + } + }); if (output.same_as(expr)) { return expr; } @@ -139,7 +140,7 @@ class SymbolicVarCanonicalizer : public ExprMutator { MatchCast source; }; - std::unordered_map known_values_; + std::unordered_map known_values_; }; struct CanonicalizationPlan { diff --git a/src/relax/transform/compute_prim_value.cc b/src/relax/transform/compute_prim_value.cc index 151205c4c1e4..c82cf60c3547 100644 --- a/src/relax/transform/compute_prim_value.cc +++ b/src/relax/transform/compute_prim_value.cc @@ -21,9 +21,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include namespace tvm { namespace relax { @@ -39,21 +39,21 @@ class PrimValueComputeInjector : public ExprMutator { Expr VisitExpr_(const PrimValueNode* op) override { auto node = Downcast(ExprMutator::VisitExpr_(op)); - if (node->value->IsInstance() || node->value->IsInstance()) { + if (node->value->IsInstance() || node->value->IsInstance()) { return node; } auto ret_dtype = node->value->dtype; - auto param_vars = tir::UndefinedVars(node->value); - tir::Stmt body = tir::Evaluate(tir::Call(ret_dtype, tir::builtin::ret(), {node->value})); + auto param_vars = tirx::UndefinedVars(node->value); + tirx::Stmt body = tirx::Evaluate(tirx::Call(ret_dtype, tirx::builtin::ret(), {node->value})); - tir::PrimFunc func(param_vars, body, PrimType(ret_dtype), {}, - DictAttrs({{tir::attr::kIsHostFunc, true}})); + tirx::PrimFunc func(param_vars, body, PrimType(ret_dtype), {}, + DictAttrs({{tirx::attr::kIsHostFunc, true}})); func = s_tir::RenewDefs(func); auto callee = builder_->AddFunction(func, "compute_symbolic_expr"); - return relax::Call(callee, param_vars.Map([](const tir::Var& tir_var) -> relax::Expr { + return relax::Call(callee, param_vars.Map([](const tirx::Var& tir_var) -> relax::Expr { return relax::PrimValue(tir_var); })); } diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index 62d029a3347b..7c26b1f6f7e8 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include "../op/tensor/manipulate.h" #include "infer_layout_utils.h" @@ -36,8 +36,8 @@ namespace tvm { namespace relax { -using tir::IndexMap; -using tir::Layout; +using tirx::IndexMap; +using tirx::Layout; using LayoutCb = tvm::relax::transform::LayoutCb; /*! @@ -94,14 +94,14 @@ class LayoutConvertMutator : public ExprMutator { } IndexMap LayoutIndexMap(int ndim, const Layout& src_layout, const Layout& desired_layout) { - tir::BijectiveLayout todesired(src_layout, desired_layout); + tirx::BijectiveLayout todesired(src_layout, desired_layout); ffi::Optional inverse_index_map; - ffi::Array initial_indices; + ffi::Array initial_indices; ffi::Array initial_indices_expr; initial_indices.reserve(ndim); for (int i = 0; i < ndim; ++i) { - auto var = tvm::tir::Var("i" + std::to_string(i), DataType::Int(32)); + auto var = tvm::tirx::Var("i" + std::to_string(i), DataType::Int(32)); initial_indices.push_back(var); initial_indices_expr.push_back(var); } diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index c672d0462be8..7d6b39fcbaab 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -31,7 +31,7 @@ #include #include #include -#include +#include #include "utils.h" @@ -693,79 +693,79 @@ FindInplaceOpportunities(const DataflowBlock& block, const ffi::Array& inpu } // Replace buffers in a PrimFunc according to the mapping. -tir::Stmt RemapBuffers(const tir::Stmt& stmt, - const ffi::Map& buffer_map) { - class BufferMapper : public tir::StmtExprMutator { +tirx::Stmt RemapBuffers(const tirx::Stmt& stmt, + const ffi::Map& buffer_map) { + class BufferMapper : public tirx::StmtExprMutator { public: - explicit BufferMapper(const ffi::Map& buffer_map) + explicit BufferMapper(const ffi::Map& buffer_map) : buffer_map_(buffer_map) {} - tir::Stmt Remap(const tir::Stmt& stmt) { return VisitStmt(stmt); } + tirx::Stmt Remap(const tirx::Stmt& stmt) { return VisitStmt(stmt); } - PrimExpr VisitExpr_(const tir::BufferLoadNode* op) final { - auto node = Downcast(tir::StmtExprMutator::VisitExpr_(op)); + PrimExpr VisitExpr_(const tirx::BufferLoadNode* op) final { + auto node = Downcast(tirx::StmtExprMutator::VisitExpr_(op)); auto* node_cow = node.CopyOnWrite(); node_cow->buffer = AttemptRemap(node->buffer); return node; } - tir::Stmt VisitStmt_(const tir::BufferStoreNode* op) final { - auto node = Downcast(tir::StmtExprMutator::VisitStmt_(op)); + tirx::Stmt VisitStmt_(const tirx::BufferStoreNode* op) final { + auto node = Downcast(tirx::StmtExprMutator::VisitStmt_(op)); auto* node_cow = node.CopyOnWrite(); node_cow->buffer = AttemptRemap(node->buffer); return node; } - tir::Stmt VisitStmt_(const tir::DeclBufferNode* op) final { - auto node = Downcast(tir::StmtExprMutator::VisitStmt_(op)); + tirx::Stmt VisitStmt_(const tirx::DeclBufferNode* op) final { + auto node = Downcast(tirx::StmtExprMutator::VisitStmt_(op)); auto* node_cow = node.CopyOnWrite(); node_cow->buffer = AttemptRemap(node->buffer); return node; } - tir::Stmt VisitStmt_(const tir::AllocBufferNode* op) final { - auto node = Downcast(tir::StmtExprMutator::VisitStmt_(op)); + tirx::Stmt VisitStmt_(const tirx::AllocBufferNode* op) final { + auto node = Downcast(tirx::StmtExprMutator::VisitStmt_(op)); auto* node_cow = node.CopyOnWrite(); node_cow->buffer = AttemptRemap(node->buffer); return node; } - tir::Stmt VisitStmt_(const tir::SBlockNode* op) final { - auto node = Downcast(tir::StmtExprMutator::VisitStmt_(op)); + tirx::Stmt VisitStmt_(const tirx::SBlockNode* op) final { + auto node = Downcast(tirx::StmtExprMutator::VisitStmt_(op)); auto* node_cow = node.CopyOnWrite(); // need the lambdas because class methods are not first-class (how ironic) node_cow->alloc_buffers = - node->alloc_buffers.Map([this](const tir::Buffer& b) { return AttemptRemap(b); }); + node->alloc_buffers.Map([this](const tirx::Buffer& b) { return AttemptRemap(b); }); node_cow->reads = - node->reads.Map([this](const tir::BufferRegion& br) { return VisitBufferRegion(br); }); + node->reads.Map([this](const tirx::BufferRegion& br) { return VisitBufferRegion(br); }); node_cow->writes = - node->writes.Map([this](const tir::BufferRegion& br) { return VisitBufferRegion(br); }); + node->writes.Map([this](const tirx::BufferRegion& br) { return VisitBufferRegion(br); }); node_cow->match_buffers = node->match_buffers.Map( - [this](const tir::MatchBufferRegion& mbr) { return VisitMatchBufferRegion(mbr); }); + [this](const tirx::MatchBufferRegion& mbr) { return VisitMatchBufferRegion(mbr); }); return node; } private: - tir::Buffer AttemptRemap(const tir::Buffer& buffer) { + tirx::Buffer AttemptRemap(const tirx::Buffer& buffer) { if (buffer_map_.count(buffer)) { return buffer_map_.at(buffer); } return buffer; } - tir::BufferRegion VisitBufferRegion(tir::BufferRegion region) { + tirx::BufferRegion VisitBufferRegion(tirx::BufferRegion region) { auto* region_cow = region.CopyOnWrite(); region_cow->buffer = AttemptRemap(region_cow->buffer); return region; } - tir::MatchBufferRegion VisitMatchBufferRegion(tir::MatchBufferRegion region) { + tirx::MatchBufferRegion VisitMatchBufferRegion(tirx::MatchBufferRegion region) { auto* region_cow = region.CopyOnWrite(); region_cow->buffer = AttemptRemap(region_cow->buffer); return region; } - const ffi::Map& buffer_map_; + const ffi::Map& buffer_map_; }; BufferMapper mapper(buffer_map); @@ -875,9 +875,9 @@ class ModuleInplaceTransformer : public ExprMutator { auto inline_legal_op_name = legal_op->name_hint + "_inplace"; auto mod = builder_->GetContextIRModule(); - auto old_primfunc = Downcast(mod->Lookup(legal_op)); + auto old_primfunc = Downcast(mod->Lookup(legal_op)); - tir::Stmt new_body = old_primfunc->body; + tirx::Stmt new_body = old_primfunc->body; size_t num_outs = inplace_indices.size(); size_t num_params = old_primfunc->params.size(); @@ -889,8 +889,8 @@ class ModuleInplaceTransformer : public ExprMutator { // 2. For each output var, replace its instances with the corresponding inplace index var // 3. Do the same for the *buffer vars* corresponding to the output vars // 4. Remove the output vars from the param list and buffer map - ffi::Map buffer_subst_map; - ffi::Map var_subst_map; + ffi::Map buffer_subst_map; + ffi::Map var_subst_map; for (size_t i = 0; i < num_outs; i++) { // we will substitute output i with the corresponding param indicated by inplace indices auto output_var = old_primfunc->params[num_params - num_outs + i]; @@ -907,7 +907,7 @@ class ModuleInplaceTransformer : public ExprMutator { // apply substitutions new_body = RemapBuffers(new_body, buffer_subst_map); new_body = - tir::Substitute(new_body, [&var_subst_map](const tir::Var& v) -> ffi::Optional { + tirx::Substitute(new_body, [&var_subst_map](const tirx::Var& v) -> ffi::Optional { if (var_subst_map.count(v)) { return var_subst_map.at(v); } @@ -922,11 +922,11 @@ class ModuleInplaceTransformer : public ExprMutator { // now get rid of the last num_outputs arguments // (couldn't do earlier or else it would have thrown off the indexing) - ffi::Array new_params(old_primfunc->params.begin(), - old_primfunc->params.begin() + (num_params - num_outs)); + ffi::Array new_params(old_primfunc->params.begin(), + old_primfunc->params.begin() + (num_params - num_outs)); - tir::PrimFunc new_primfunc(new_params, new_body, old_primfunc->ret_type, new_buffer_map, - old_primfunc->attrs, old_primfunc->span); + tirx::PrimFunc new_primfunc(new_params, new_body, old_primfunc->ret_type, new_buffer_map, + old_primfunc->attrs, old_primfunc->span); // note: this might be a good time to get rid of the old legalized function, but we don't do it // now because later ops might need the same one. Instead, we will clean up at the end diff --git a/src/relax/transform/decompose_ops.cc b/src/relax/transform/decompose_ops.cc index 1b6fc77c48d8..23c021a8ce10 100644 --- a/src/relax/transform/decompose_ops.cc +++ b/src/relax/transform/decompose_ops.cc @@ -155,7 +155,7 @@ Expr TensorToShape(const Call& call_node, const BlockBuilder& builder) { // ffi::Array), we define symbolic variables and returns them as a ShapeExpr. ffi::Array shape_var; for (int i = 0; i < sinfo->ndim; i++) { - shape_var.push_back(tir::Var("x", DataType::Int(64))); + shape_var.push_back(tirx::Var("x", DataType::Int(64))); } // bind symbolic variables to the shape tuple relax::Var var("y", ShapeStructInfo(shape_var)); diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 5194941c26c9..e68b9232e493 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -25,8 +25,8 @@ #include #include #include -#include -#include +#include +#include namespace tvm { namespace relax { @@ -87,12 +87,12 @@ class ConstantFolder : public ExprMutator { * \brief Pattern match op to a TIR function and look it up. * \return The TIR function, or nullopt if pattern match fails. */ - ffi::Optional MatchPrimFunc(const Expr& op) { + ffi::Optional MatchPrimFunc(const Expr& op) { const GlobalVar& global_var = Downcast(op); // NOTE: as check works for nullptr(returns null) ffi::Optional base_func = builder_->GetContextIRModule()->functions.Get(global_var); - if (auto* pfunc = base_func.as()) { - return ffi::GetRef(pfunc); + if (auto* pfunc = base_func.as()) { + return ffi::GetRef(pfunc); } return std::nullopt; } @@ -101,7 +101,7 @@ class ConstantFolder : public ExprMutator { * \brief Get a cached build version of func * \return The cached func, nullopt if func cannot be built. */ - ffi::Optional GetCachedBuild(tir::PrimFunc func) { + ffi::Optional GetCachedBuild(tirx::PrimFunc func) { // TODO(tvm-team): consider another way of bulk extract and build PrimFunc once // would be helpful for future cases where PrimFunc recursively call into each other Target eval_cpu_target{"llvm"}; @@ -117,7 +117,7 @@ class ConstantFolder : public ExprMutator { // already scheduled to only work on GPU, we will need to skip this in the const folder for // now // TODO(Hongyi): further check and narrow the scope of foldable function - const auto pf = tvm::ffi::Function::GetGlobalRequired("tir.build"); + const auto pf = tvm::ffi::Function::GetGlobalRequired("tirx.build"); func = WithAttr(func, tvm::attr::kGlobalSymbol, ffi::String("tir_function")); ffi::Module rt_module = pf(func, eval_cpu_target).cast(); build_func = rt_module->GetFunction("tir_function"); @@ -193,7 +193,7 @@ class ConstantFolder : public ExprMutator { // Try constant evaluate a call_tir with a single tensor output. // Returns std::nullopt on failure. - ffi::Optional ConstEvaluateCallTIR(tir::PrimFunc tir_func, + ffi::Optional ConstEvaluateCallTIR(tirx::PrimFunc tir_func, ffi::Array arr_args, ffi::Shape shape, DataType ret_type) { // obtain function from the cache. @@ -225,7 +225,7 @@ class ConstantFolder : public ExprMutator { // Try constant evaluate a call_tir with tuple outputs (multiple output tensors). // Returns std::nullopt on failure. - ffi::Optional ConstEvaluateCallTIRTuple(tir::PrimFunc tir_func, + ffi::Optional ConstEvaluateCallTIRTuple(tirx::PrimFunc tir_func, ffi::Array arr_args, const TupleStructInfoNode* tuple_sinfo) { ffi::Optional func = GetCachedBuild(tir_func); @@ -269,7 +269,7 @@ class ConstantFolder : public ExprMutator { ffi::Optional VisitCallTIR(Call call) { // call_tir needs to have at least two arguments TVM_FFI_ICHECK_GE(call->args.size(), 2); - ffi::Optional func = MatchPrimFunc(call->args[0]); + ffi::Optional func = MatchPrimFunc(call->args[0]); TVM_FFI_ICHECK(call->args[1].as()) << "call_tir.args[1] must be Tuple"; ffi::Optional> arr_args = MatchConstArrayArgs(call->args[1].as()->fields); @@ -412,7 +412,7 @@ class ConstantFolder : public ExprMutator { } // cache for function build, via structural equality - std::unordered_map, ffi::StructuralHash, + std::unordered_map, ffi::StructuralHash, ffi::StructuralEqual> func_build_cache_; }; diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 45888419b4be..106f8c5bb534 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -34,9 +34,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include @@ -204,7 +204,7 @@ class GraphCreator : public ExprVisitor { const auto* op = call->op.as(); if (op == call_tir_op_.get() || op == call_tir_inplace_op_.get()) { const GlobalVar& global_var = Downcast(call->args[0]); - tir::PrimFunc func = Downcast(mod_->Lookup(global_var)); + tirx::PrimFunc func = Downcast(mod_->Lookup(global_var)); // Override args for call_tir args = Downcast(call->args[1])->fields; @@ -562,7 +562,7 @@ class FunctionCreator : public ExprMutator { /*is_pure=*/true, // /*attrs=*/DictAttrs(group_attrs)); ffi::Array free_vars = - FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; }); + FreeSymbolicVars(function).Map([](const tirx::Var& var) -> PrimExpr { return var; }); if (!free_vars.empty()) { params_.push_back(Var("tir_vars", ShapeStructInfo(free_vars))); arguments_.push_back(ShapeExpr(free_vars)); @@ -648,10 +648,10 @@ class FunctionCreator : public ExprMutator { return std::all_of(tuple->fields.begin(), tuple->fields.end(), [this](const Expr& e) { return IsInlinableConstants(e); }); } else if (const auto* prim_value = expr.as()) { - return tvm::tir::UndefinedVars(prim_value->value).empty(); + return tvm::tirx::UndefinedVars(prim_value->value).empty(); } else if (const auto* shape_expr = expr.as()) { return std::all_of(shape_expr->values.begin(), shape_expr->values.end(), - [](const PrimExpr& e) { return tvm::tir::UndefinedVars(e).empty(); }); + [](const PrimExpr& e) { return tvm::tirx::UndefinedVars(e).empty(); }); } return false; } @@ -1362,7 +1362,7 @@ IRModule FuseOpsByPattern(const tvm::ffi::Array& patte } else { for (const auto& gv : mod->GetGlobalVars()) { const auto& base_func = mod->Lookup(gv); - if (base_func->IsInstance()) { + if (base_func->IsInstance()) { continue; } const FunctionNode* function = base_func.as(); diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 3f739cd243e4..2128510cdc95 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -23,15 +23,15 @@ #include #include #include -#include +#include #include #include -#include "../../tir/ir/functor_common.h" +#include "../../tirx/ir/functor_common.h" namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Match symbolic vars according to the given PrimExpr, and update the var_remap. @@ -39,7 +39,7 @@ namespace tir { */ class SymbolicMatcher : ExprFunctor { public: - explicit SymbolicMatcher(arith::Analyzer* analyzer, ffi::Map* var_remap) + explicit SymbolicMatcher(arith::Analyzer* analyzer, ffi::Map* var_remap) : analyzer_(analyzer), var_remap_(var_remap) {} void Match(const ffi::Array& params, const ffi::Array& args) { @@ -152,7 +152,7 @@ class SymbolicMatcher : ExprFunctor* var_remap_; + ffi::Map* var_remap_; PrimExpr must_prove_ = Bool(true); }; @@ -291,11 +291,11 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { private: /*! \brief Mapping from src buffer to tgt buffer. */ - ffi::Map buffer_remap_; - /*! \brief Mapping from src tir var to tgt var. */ - ffi::Map var_remap_; + ffi::Map buffer_remap_; + /*! \brief Mapping from src tirx var to tgt var. */ + ffi::Map var_remap_; - ffi::Array UnionAccessRegion(const ffi::Array& regions) const { + ffi::Array UnionAccessRegion(const ffi::Array& regions) const { // For now we only allow Buffer access the same elements. // e.g. `[A[vi, vj], A[vi, vj]]` is a legal pattern but need to union to `A[vi, vj]` // However, `A[vi, vj], A[vi, vj + 1]` is not allow for now. @@ -341,10 +341,10 @@ class FuseTIRBufferSubstitutor : private StmtExprMutator { }; /*! \brief A mutator which detect block name duplication and deduplicate the names. */ -class SBlockNameDeduplicator : public tir::StmtMutator { +class SBlockNameDeduplicator : public tirx::StmtMutator { private: Stmt VisitStmt_(const SBlockNode* op) final { - SBlock block = Downcast(tir::StmtMutator::VisitStmt_(op)); + SBlock block = Downcast(tirx::StmtMutator::VisitStmt_(op)); ffi::String name = GetUniqueName(block->name_hint); @@ -409,7 +409,7 @@ class SBlockNameDeduplicator : public tir::StmtMutator { std::unordered_map name_count_; }; -} // namespace tir +} // namespace tirx namespace relax { @@ -435,7 +435,7 @@ static ffi::Array GetInplaceOutputIndices(const ffi::Array& in class RelaxToTIRVarMapCollector : public ExprVisitor { public: explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {} - static ffi::Map Collect(const IRModule& mod, const Function& func) { + static ffi::Map Collect(const IRModule& mod, const Function& func) { RelaxToTIRVarMapCollector visitor(mod); visitor(func->body); return visitor.relax_to_tir_var_map_; @@ -459,7 +459,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool in_place) { GlobalVar gv = Downcast(call->args[0]); - tir::PrimFunc prim_func_ = Downcast(mod_->Lookup(gv)); + tirx::PrimFunc prim_func_ = Downcast(mod_->Lookup(gv)); const auto& buffer_map = prim_func_->buffer_map; const auto& tir_args = prim_func_->params; @@ -490,7 +490,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { // If the `expr` is already seen (present in the map), validate whether the mapped buffer is // structurally equal to the `new_buf` passed - auto ValidateBufferCompatibility = [this](tir::Buffer new_buf, Expr expr) { + auto ValidateBufferCompatibility = [this](tirx::Buffer new_buf, Expr expr) { if (auto it = relax_to_tir_var_map_.find(expr); it != relax_to_tir_var_map_.end()) { TVM_FFI_ICHECK(ffi::StructuralEqual()((*it).second, new_buf)) << "Inconsistent buffers " << (*it).second << " and " << new_buf @@ -519,7 +519,7 @@ class RelaxToTIRVarMapCollector : public ExprVisitor { private: /*! \brief The IRModule */ const IRModule& mod_; - ffi::Map relax_to_tir_var_map_; + ffi::Map relax_to_tir_var_map_; Var current_var_; }; @@ -531,8 +531,8 @@ class FusedTIRConstructor : public ExprVisitor { * \param gv The global var of relax subfunction to be fused into one PrimFunc * \return The fused TIR PrimFunc and the in-place indices (non-empty for an in-place call) */ - static std::pair> GetFusedTIR(const IRModule& mod, - const GlobalVar& gv) { + static std::pair> GetFusedTIR(const IRModule& mod, + const GlobalVar& gv) { FusedTIRConstructor visitor(mod, gv->name_hint); BaseFunc f = mod->Lookup(gv); TVM_FFI_ICHECK(f->IsInstance()) @@ -554,15 +554,15 @@ class FusedTIRConstructor : public ExprVisitor { void VisitExpr_(const FunctionNode* func) final { auto relax_to_tir_var_map = RelaxToTIRVarMapCollector::Collect(mod_, ffi::GetRef(func)); - std::vector> prim_func_params; + std::vector> prim_func_params; for (const Var& relax_param : func->params) { size_t size_before = prim_func_params.size(); CollectPrimFuncParams(relax_param, &prim_func_params, relax_to_tir_var_map.Get(relax_param)); - auto param_buffers = [&]() -> ffi::Array { - ffi::Array out; + auto param_buffers = [&]() -> ffi::Array { + ffi::Array out; for (size_t i = size_before; i < prim_func_params.size(); i++) { - if (auto buf = prim_func_params[i].as()) { + if (auto buf = prim_func_params[i].as()) { out.push_back(buf.value()); } } @@ -577,13 +577,13 @@ class FusedTIRConstructor : public ExprVisitor { // std::stable_sort is used instead of std::sort. std::stable_sort(prim_func_params.begin(), prim_func_params.end(), [](const auto& a, const auto& b) { - bool a_is_var = a.template as(); - bool b_is_var = b.template as(); + bool a_is_var = a.template as(); + bool b_is_var = b.template as(); return a_is_var < b_is_var; }); for (const auto& param : prim_func_params) { - if (auto opt = param.as()) { + if (auto opt = param.as()) { auto buffer = opt.value(); // Differentiate buffer name and param name by adding prefix // `p_` to the buffer name. Every symbol should be unique in @@ -591,7 +591,7 @@ class FusedTIRConstructor : public ExprVisitor { // printed, it's more readable when done explicitly. Since // Buffer is used more than param it gets the name with better // readability. - tir::Var param = tir::Var("p_" + buffer->name, PrimType(DataType::Handle())); + tirx::Var param = tirx::Var("p_" + buffer->name, PrimType(DataType::Handle())); func_info_.params.push_back(param); func_info_.buffer_map.Set(param, buffer); } @@ -606,11 +606,11 @@ class FusedTIRConstructor : public ExprVisitor { TVM_FFI_ICHECK(it != func_info_.expr2buffers.end()) << "Fail to detect output buffers for function body"; - const ffi::Array& buffers = (*it).second; + const ffi::Array& buffers = (*it).second; // map of input buffers to indices (helpful for detecting in-place inputs) - std::unordered_map buffer_to_idx; - std::unordered_map input_to_idx; + std::unordered_map buffer_to_idx; + std::unordered_map input_to_idx; for (size_t i = 0; i < func_info_.params.size(); i++) { input_to_idx[func_info_.params[i]] = i; } @@ -635,7 +635,8 @@ class FusedTIRConstructor : public ExprVisitor { continue; } - tir::Var param = tir::Var("p_output" + std::to_string(out_idx), PrimType(DataType::Handle())); + tirx::Var param = + tirx::Var("p_output" + std::to_string(out_idx), PrimType(DataType::Handle())); out_idx++; func_info_.buffer_map.Set(param, buffers[i]); func_info_.params.push_back(param); @@ -644,7 +645,7 @@ class FusedTIRConstructor : public ExprVisitor { // Step 4. Append symbolic vars for (const auto& param : prim_func_params) { - if (auto var = param.as()) { + if (auto var = param.as()) { func_info_.params.push_back(var.value()); } } @@ -680,17 +681,17 @@ class FusedTIRConstructor : public ExprVisitor { // Step 1. Get Global var and PrimFunc GlobalVar gv = Downcast(call->args[0]); - tir::PrimFunc prim_func_ = Downcast(mod_->Lookup(gv)); + tirx::PrimFunc prim_func_ = Downcast(mod_->Lookup(gv)); // Step 2. Renew all vars/buffer definitions and blocks to avoid duplication - tir::PrimFunc prim_func = s_tir::RenewDefs(prim_func_); + tirx::PrimFunc prim_func = s_tir::RenewDefs(prim_func_); // Step 3. Check functions are all schedulable funcs. i.e. the body of func is root block // TODO(Siyuan): support un-schedulable functions. - TVM_FFI_ICHECK(prim_func->body->IsInstance()) + TVM_FFI_ICHECK(prim_func->body->IsInstance()) << "Only schedulable functions (whose body is the root block) can be fused"; - const tir::SBlockRealize& root_realize = Downcast(prim_func->body); - const tir::SBlock& root_block = root_realize->block; + const tirx::SBlockRealize& root_realize = Downcast(prim_func->body); + const tirx::SBlock& root_block = root_realize->block; // Step 4. Add all the original alloc_buffers and body to the fused function. func_info_.alloc_buffers.insert(func_info_.alloc_buffers.end(), @@ -713,7 +714,7 @@ class FusedTIRConstructor : public ExprVisitor { size_t num_params = prim_func->params.size(); TVM_FFI_ICHECK_GE(num_params, args.size()); for (size_t i = 0; i < args.size(); ++i) { - const tir::Var& param = prim_func->params[num_params - args.size() + i]; + const tirx::Var& param = prim_func->params[num_params - args.size() + i]; func_info_.symbolic_var_matcher.Match(param, args[i]); } } else { @@ -745,7 +746,7 @@ class FusedTIRConstructor : public ExprVisitor { void VisitExpr_(const TupleNode* tuple) final { ExprVisitor::VisitExpr_(tuple); - ffi::Array buffers; + ffi::Array buffers; for (const Expr& expr : tuple->fields) { auto it = func_info_.expr2buffers.find(expr); if (it != func_info_.expr2buffers.end()) { @@ -800,16 +801,16 @@ class FusedTIRConstructor : public ExprVisitor { } /*! \brief Map old TIR func param buffer to new buffer, and then update `buffer_subst_map` */ - void MapArgsToBuffer(const ffi::Array args, const ffi::Array& buffers) { + void MapArgsToBuffer(const ffi::Array args, const ffi::Array& buffers) { size_t buffer_idx = 0; for (const Expr& arg : args) { if (const auto* v = arg.as()) { auto it = func_info_.expr2buffers.find(ffi::GetRef(v)); // Substitute the buffer with the already allocated one if it is an intermediate var if (it != func_info_.expr2buffers.end()) { - for (const tir::Buffer& target_buffer : (*it).second) { + for (const tirx::Buffer& target_buffer : (*it).second) { TVM_FFI_ICHECK_LT(buffer_idx, buffers.size()); - const tir::Buffer& buffer = buffers[buffer_idx]; + const tirx::Buffer& buffer = buffers[buffer_idx]; func_info_.symbolic_var_matcher.Match(buffer->shape, target_buffer->shape); func_info_.buffer_subst_map.Set(buffer, target_buffer); buffer_idx++; @@ -826,9 +827,9 @@ class FusedTIRConstructor : public ExprVisitor { * \param func The old TIR PrimFunc * \param output_size The number of output params. All output params are at the end of param list. */ - void MapInputBuffer(const tir::PrimFunc& func, const relax::Expr& args) { + void MapInputBuffer(const tirx::PrimFunc& func, const relax::Expr& args) { ffi::Array arg_list; - ffi::Array buffer_list; + ffi::Array buffer_list; if (const auto* arg_tuple = args.as()) { arg_list = arg_tuple->fields; } else { @@ -837,25 +838,25 @@ class FusedTIRConstructor : public ExprVisitor { TVM_FFI_ICHECK_GE(func->params.size(), arg_list.size()); for (size_t i = 0; i < arg_list.size(); ++i) { - const tir::Var& param = func->params[i]; - const tir::Buffer& buffer = func->buffer_map.at(param); + const tirx::Var& param = func->params[i]; + const tirx::Buffer& buffer = func->buffer_map.at(param); buffer_list.push_back(buffer); } MapArgsToBuffer(arg_list, buffer_list); } - static ffi::Array GetPrimFuncOutputParams(const tir::PrimFunc& func, - const ffi::Array& output_indices) { + static ffi::Array GetPrimFuncOutputParams(const tirx::PrimFunc& func, + const ffi::Array& output_indices) { size_t n = func->params.size(); int symbolic_var_index = -1; size_t output_size = output_indices.size(); TVM_FFI_ICHECK_GE(n, output_size); - ffi::Array ret; + ffi::Array ret; for (auto idx : output_indices) { int i = idx.IntValue(); - const tir::Var& param = func->params[static_cast(i)]; + const tirx::Var& param = func->params[static_cast(i)]; if (param->dtype.is_int() || param->dtype.is_uint()) { if (symbolic_var_index == -1) symbolic_var_index = i; } else if (param->dtype.is_handle()) { @@ -882,7 +883,7 @@ class FusedTIRConstructor : public ExprVisitor { * \param func The old TIR PrimFunc * \param output_shapes The shape of output params. */ - void AllocateIntermediateBuffer(const CallNode* call, const tir::PrimFunc& func, + void AllocateIntermediateBuffer(const CallNode* call, const tirx::PrimFunc& func, const ffi::Array>& output_shapes) { bool is_inplace = (call->op == Op::Get("relax.call_tir_inplace")); @@ -890,7 +891,7 @@ class FusedTIRConstructor : public ExprVisitor { int num_inputs = Downcast(call->args[1])->fields.size(); size_t output_size = output_shapes.size(); TVM_FFI_ICHECK_GE(n, output_size); - ffi::Array output_buffers; + ffi::Array output_buffers; ffi::Array output_idxs; if (is_inplace) { const auto* attrs = call->attrs.as(); @@ -902,11 +903,11 @@ class FusedTIRConstructor : public ExprVisitor { } } - ffi::Array output_params = GetPrimFuncOutputParams(func, output_idxs); + ffi::Array output_params = GetPrimFuncOutputParams(func, output_idxs); auto input_buffers = func_info_.expr2buffers.Get(call->args[1]); for (size_t i = 0; i < output_size; ++i) { - const tir::Var& param = output_params[i]; - const tir::Buffer& buffer = func->buffer_map.at(param); + const tirx::Var& param = output_params[i]; + const tirx::Buffer& buffer = func->buffer_map.at(param); // if this is an inplace output, do not do an intermediate allocation if (output_idxs[i].IntValue() < num_inputs) { @@ -932,10 +933,10 @@ class FusedTIRConstructor : public ExprVisitor { return unique_name; }; // Update buffer with new symbolic shape according to the sinfo - auto n = ffi::make_object(*buffer.get()); + auto n = ffi::make_object(*buffer.get()); n->shape = output_shapes[i]; n->name = unify_name_hints(); - tir::Buffer new_buffer(n); + tirx::Buffer new_buffer(n); func_info_.alloc_buffers.push_back(new_buffer); output_buffers.push_back(new_buffer); @@ -954,8 +955,8 @@ class FusedTIRConstructor : public ExprVisitor { * \param out The vector into which to collect the params/buffers */ static void CollectPrimFuncParams(const Var& relax_param, - std::vector>* out, - const ffi::Optional& tir_buffer_param) { + std::vector>* out, + const ffi::Optional& tir_buffer_param) { auto struct_info = GetStructInfo(relax_param); TVM_FFI_CHECK(!struct_info.as(), InternalError) @@ -965,30 +966,30 @@ class FusedTIRConstructor : public ExprVisitor { auto name_hint = relax_param->name_hint(); if (const auto* tensor = struct_info.as()) { - // Case 1. The relax param is a Tensor, we directly create a tir var and buffer + // Case 1. The relax param is a Tensor, we directly create a tirx var and buffer const auto* shape_expr = tensor->shape.as(); TVM_FFI_ICHECK(shape_expr) << "FuseTIR expects all Tensor parameters have a known shape."; DataType dtype = tensor->dtype; - tir::Buffer buffer; + tirx::Buffer buffer; if (tir_buffer_param.defined()) { - buffer = - tir::decl_buffer(shape_expr->values, dtype, name_hint, tir_buffer_param.value().scope(), - tir_buffer_param.value()->axis_separators); + buffer = tirx::decl_buffer(shape_expr->values, dtype, name_hint, + tir_buffer_param.value().scope(), + tir_buffer_param.value()->axis_separators); } else { - buffer = tir::decl_buffer(shape_expr->values, dtype, name_hint); + buffer = tirx::decl_buffer(shape_expr->values, dtype, name_hint); } out->push_back(std::move(buffer)); } else if (const auto* prim_value = struct_info.as()) { - // Case 2. The relax param is a scalar, we directly create a tir var - TVM_FFI_ICHECK(prim_value->value->IsInstance()); - out->push_back(Downcast(prim_value->value)); + // Case 2. The relax param is a scalar, we directly create a tirx var + TVM_FFI_ICHECK(prim_value->value->IsInstance()); + out->push_back(Downcast(prim_value->value)); } else if (const auto* shape_expr = struct_info.as()) { - // Case 3. The relax param is a tuple of scalars, each represented as a tir var + // Case 3. The relax param is a tuple of scalars, each represented as a tirx var for (const auto& var : shape_expr->values.value()) { - TVM_FFI_ICHECK(var->IsInstance()); - out->push_back(Downcast(var)); + TVM_FFI_ICHECK(var->IsInstance()); + out->push_back(Downcast(var)); } } else { TVM_FFI_THROW(TypeError) << "The param type of PrimFunc is expected to be " @@ -1001,25 +1002,26 @@ class FusedTIRConstructor : public ExprVisitor { * \brief Construct fused TIR func with collected FuseFuncInfo * \return The fused TIR */ - tir::PrimFunc ConstructFunc() { + tirx::PrimFunc ConstructFunc() { ffi::Map attr_map; - attr_map.Set(tir::attr::kNoAlias, true); - tir::FuseTIRBufferSubstitutor subst(func_info_.buffer_subst_map, func_info_.symbolic_var_remap); + attr_map.Set(tirx::attr::kNoAlias, true); + tirx::FuseTIRBufferSubstitutor subst(func_info_.buffer_subst_map, + func_info_.symbolic_var_remap); TVM_FFI_ICHECK(func_info_.global_name != "fused"); // Remove output buffers from func_info_.alloc_buffers - ffi::Array alloc_buffers; - for (const tir::Buffer& buf : func_info_.alloc_buffers) { + ffi::Array alloc_buffers; + for (const tirx::Buffer& buf : func_info_.alloc_buffers) { if (func_info_.output_buffers.count(buf.get()) == 0) { alloc_buffers.push_back(subst.SubstituteAllocatedBuffer(buf)); } } - tir::Stmt body = tir::SBlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies)); + tirx::Stmt body = tirx::SBlockNameDeduplicator()(tirx::SeqStmt::Flatten(func_info_.bodies)); body = subst.Substitute(body); - body = tir::SBlock({}, {}, {}, "root", std::move(body), std::nullopt, alloc_buffers); - body = tir::SBlockRealize({}, Bool(true), Downcast(body)); - tir::PrimFunc func(func_info_.params, body, VoidType(), func_info_.buffer_map, - DictAttrs(attr_map)); + body = tirx::SBlock({}, {}, {}, "root", std::move(body), std::nullopt, alloc_buffers); + body = tirx::SBlockRealize({}, Bool(true), Downcast(body)); + tirx::PrimFunc func(func_info_.params, body, VoidType(), func_info_.buffer_map, + DictAttrs(attr_map)); // Renew function defs to prevent using the same symbolic vars in different functions return s_tir::RenewDefs(func); } @@ -1050,22 +1052,22 @@ class FusedTIRConstructor : public ExprVisitor { * \brief The map from each dataflow var (intermediate var) to the corresponding buffers * allocated in the fused func */ - ffi::Map> expr2buffers; + ffi::Map> expr2buffers; /*! \brief The buffers to allocate in the fused func*/ - ffi::Array alloc_buffers; + ffi::Array alloc_buffers; /*! \brief The bodies of the original funcs, which is also the body of the fused func. */ - ffi::Array bodies; + ffi::Array bodies; /*! \brief The params of the fused function*/ - ffi::Array params; + ffi::Array params; /*! * \brief The map from buffer in original functions to corresponding buffer in the fused * function */ - ffi::Map buffer_subst_map; + ffi::Map buffer_subst_map; /*! \brief The `buffer_map` in the fused function*/ - ffi::Map buffer_map; + ffi::Map buffer_map; /*! \brief The output buffers in the function buffer_map*/ - std::unordered_set output_buffers; + std::unordered_set output_buffers; /*! \brief The name of the fused function */ std::string global_name = "fused"; @@ -1075,7 +1077,7 @@ class FusedTIRConstructor : public ExprVisitor { * `symbolic_var_matcher`, and must be before it in the struct * order. */ - ffi::Map symbolic_var_remap; + ffi::Map symbolic_var_remap; /*! \brief The map from symbolic var to its value in the fused function * @@ -1086,8 +1088,8 @@ class FusedTIRConstructor : public ExprVisitor { arith::Analyzer analyzer; /*! \brief The map from symbolic var to its corresponding var in the fused function */ - tir::SymbolicMatcher symbolic_var_matcher = - tir::SymbolicMatcher(&analyzer, &symbolic_var_remap); + tirx::SymbolicMatcher symbolic_var_matcher = + tirx::SymbolicMatcher(&analyzer, &symbolic_var_remap); }; /*! \brief The IRModule */ @@ -1096,8 +1098,8 @@ class FusedTIRConstructor : public ExprVisitor { ffi::String func_name_; /*! \brief The helper info to fuse TIR prim_func */ FuseFuncInfo func_info_; - /*! \brief The tir function after fusion*/ - tir::PrimFunc fused_tir_; + /*! \brief The tirx function after fusion*/ + tirx::PrimFunc fused_tir_; /*! \brief Indices of inputs that are used for in-place computation */ std::unordered_set inplace_indices_; }; @@ -1189,7 +1191,7 @@ class TIRFuseMutator : public ExprMutator { using ExprMutator::VisitExpr_; - // Get shape from call tir + // Get shape from call tirx static Expr GetCallTIRShape(StructInfo sinfo) { if (auto* tuple = sinfo.as()) { ffi::Array fields = tuple->fields.Map([&](StructInfo x) { return GetCallTIRShape(x); }); @@ -1251,17 +1253,17 @@ class TIRFuseMutator : public ExprMutator { TVM_FFI_ICHECK(shape->values.defined()) << "FuseTIR requires all shape input has struct_info value."; for (const PrimExpr& prim_value : shape->values.value()) { - TVM_FFI_ICHECK(prim_value->IsInstance()) - << "All shape inputs are expected to be single tir var."; + TVM_FFI_ICHECK(prim_value->IsInstance()) + << "All shape inputs are expected to be single tirx var."; tir_vars.push_back(prim_value); } } else if (const auto* prim_value = sinfo.as()) { TVM_FFI_ICHECK(prim_value->value.defined()) << "FuseTIR requires all R.Prim arguments to have a known value."; PrimExpr expr = prim_value->value.value(); - TVM_FFI_ICHECK(expr->IsInstance()) + TVM_FFI_ICHECK(expr->IsInstance()) << "FuseTIR currently requires all R.Prim " - "arguments to provide a single tir::Var."; + "arguments to provide a single tirx::Var."; tir_vars.push_back(expr); } else { diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index 90e0619f5dd5..4cb8bf389bd0 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -365,7 +365,7 @@ class BackwardBindingGenerator : private ExprVisitor { << "Differentiation of call_tir op without registering corresponding gradient " "function is not supported yet."; } else if (call_op == Op::Get("relax.call_tir_with_grad")) { - // tir gradient registering + // tirx gradient registering auto te_grad_name = call->attrs.as()->te_grad_name; const auto grad_func = tvm::ffi::Function::GetGlobalRequired(te_grad_func_prefix + te_grad_name); diff --git a/src/relax/transform/infer_layout_utils.cc b/src/relax/transform/infer_layout_utils.cc index b33beaa4e513..e559b43235e8 100644 --- a/src/relax/transform/infer_layout_utils.cc +++ b/src/relax/transform/infer_layout_utils.cc @@ -25,8 +25,8 @@ namespace tvm { namespace relax { -using tir::IterVar; -using tir::Layout; +using tirx::IterVar; +using tirx::Layout; std::string TransposeSubLayoutStrLike(const std::string ref_str, const std::string& src_str, const std::string& desired_str) { diff --git a/src/relax/transform/infer_layout_utils.h b/src/relax/transform/infer_layout_utils.h index e5524d3435ad..70a99f7d8ef1 100644 --- a/src/relax/transform/infer_layout_utils.h +++ b/src/relax/transform/infer_layout_utils.h @@ -49,7 +49,7 @@ namespace tvm { namespace relax { -using tir::Layout; +using tirx::Layout; /*! * \brief A layout decision node that holds the layout decision of the tensor. diff --git a/src/relax/transform/kill_after_last_use.cc b/src/relax/transform/kill_after_last_use.cc index 969319063c7b..83348c04628e 100644 --- a/src/relax/transform/kill_after_last_use.cc +++ b/src/relax/transform/kill_after_last_use.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index 528978da5406..314e38cb0bee 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -73,10 +73,10 @@ class LazyInputMutator : public ExprMutator { auto array_externally_visible_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(new_params.Map(GetStructInfo))); - std::unordered_set externally_visible_vars(array_externally_visible_vars.begin(), - array_externally_visible_vars.end()); + std::unordered_set externally_visible_vars(array_externally_visible_vars.begin(), + array_externally_visible_vars.end()); StructInfo new_ret_struct_info = EraseToWellDefined( - func->ret_struct_info, [&](const tir::Var& var) -> ffi::Optional { + func->ret_struct_info, [&](const tirx::Var& var) -> ffi::Optional { if (externally_visible_vars.count(var)) { return var; } else { diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index 723c2814038a..db1ced529e98 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include @@ -97,7 +97,7 @@ class LegalizeMutator : public ExprMutator { // Avoid accidental sharing of TIR variables in the legalized // PrimFuncs, when kernels for multiple devices are generated // from the same PrimFunc. - output = tir::transform::ConvertSSA()(output); + output = tirx::transform::ConvertSSA()(output); } return output; @@ -194,7 +194,7 @@ class LegalizeMutator : public ExprMutator { } auto base_func = builder_->GetContextIRModule()->Lookup(gvar.value()); - auto opt_prim_func = base_func.as(); + auto opt_prim_func = base_func.as(); if (!opt_prim_func) { // The call is to something other than a PrimFunc. It may be // another Relax function, in which case the legalization of its diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 0e9cc204ca06..d462680cd133 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -64,11 +64,11 @@ struct BaseCollectInfo { * model weights, and computed tensors that require neither model * weights nor runtime arguments (e.g. `R.zeros([16], "float16")`). */ - std::unordered_set, ObjectPtrHash, ObjectPtrEqual> + std::unordered_set, ObjectPtrHash, ObjectPtrEqual> requires_compile_time_param; /*! \brief Variables that are required at runtime */ - std::unordered_set, ObjectPtrHash, ObjectPtrEqual> + std::unordered_set, ObjectPtrHash, ObjectPtrEqual> required_at_runtime; protected: @@ -95,13 +95,13 @@ struct BaseCollectInfo { Function MakeCompileTimeFunctionHelper(const ffi::Array params, const ffi::Array& bindings, - const ffi::Array& output_symbolic_vars, + const ffi::Array& output_symbolic_vars, const ffi::Array& outputs) const { ffi::Array output_var_binding; ffi::Array output_exprs; if (output_symbolic_vars.size()) { output_exprs.push_back( - ShapeExpr(output_symbolic_vars.Map([](tir::Var var) -> PrimExpr { return var; }))); + ShapeExpr(output_symbolic_vars.Map([](tirx::Var var) -> PrimExpr { return var; }))); } for (const auto& var : outputs) { @@ -138,17 +138,17 @@ struct GlobalCollectInfo : public BaseCollectInfo { // The cross-function mapping between variables. ffi::Map var_remap; // The cross-function between between TIR variables. - ffi::Map tir_var_remap; - ffi::Array GetPropagatedSymbolicVariables() const { + ffi::Map tir_var_remap; + ffi::Array GetPropagatedSymbolicVariables() const { auto vars_from_original_params = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo))); - auto vars_from_transformed_params = [&]() -> std::unordered_set { + auto vars_from_transformed_params = [&]() -> std::unordered_set { auto tir_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(GetCompileTimeOutputs().Map(GetStructInfo))); return {tir_vars.begin(), tir_vars.end()}; }(); - ffi::Array output; + ffi::Array output; for (const auto& tir_var : vars_from_original_params) { if (required_at_runtime.count(tir_var) && !vars_from_transformed_params.count(tir_var)) { output.push_back(tir_var); @@ -181,23 +181,23 @@ struct LocalCollectInfo : public BaseCollectInfo { orig_func->params.begin() + num_runtime_params); } - ffi::Array GetPropagatedSymbolicVariables() const { + ffi::Array GetPropagatedSymbolicVariables() const { auto vars_from_any_param = DefinableTIRVarsInStructInfo(TupleStructInfo(orig_func->params.Map(GetStructInfo))); - auto vars_from_runtime_params = [&]() -> std::unordered_set { + auto vars_from_runtime_params = [&]() -> std::unordered_set { auto tir_var_vec = DefinableTIRVarsInStructInfo(TupleStructInfo(GetRuntimeInputs().Map(GetStructInfo))); return {tir_var_vec.begin(), tir_var_vec.end()}; }(); - auto vars_from_transformed_params = [&]() -> std::unordered_set { + auto vars_from_transformed_params = [&]() -> std::unordered_set { auto tir_var_vec = DefinableTIRVarsInStructInfo(TupleStructInfo(GetCompileTimeOutputs().Map(GetStructInfo))); return {tir_var_vec.begin(), tir_var_vec.end()}; }(); - ffi::Array output; + ffi::Array output; for (const auto& tir_var : vars_from_any_param) { if (required_at_runtime.count(tir_var) && !vars_from_runtime_params.count(tir_var) && !vars_from_transformed_params.count(tir_var)) { @@ -227,23 +227,23 @@ struct LocalCollectInfo : public BaseCollectInfo { // removed with CanonicalizeBindings. ffi::Array params = GetRuntimeInputs(); auto propagated_tir_vars = [&]() { - ffi::Array local_tir_vars = GetPropagatedSymbolicVariables(); + ffi::Array local_tir_vars = GetPropagatedSymbolicVariables(); if (!global_info) { return local_tir_vars; } // When global lifting is enabled, the compile-time outputs are the global outputs, but the // variables in the global outputs to the local variables. - ffi::Map reverse_map; + ffi::Map reverse_map; for (const auto& var : local_tir_vars) { if (auto it = global_info->tir_var_remap.find(var); it != global_info->tir_var_remap.end()) { - reverse_map.Set(Downcast((*it).second), var); + reverse_map.Set(Downcast((*it).second), var); } } - ffi::Array global_tir_vars = global_info->GetPropagatedSymbolicVariables(); - global_tir_vars = global_tir_vars.Map([&](const tir::Var& var) { + ffi::Array global_tir_vars = global_info->GetPropagatedSymbolicVariables(); + global_tir_vars = global_tir_vars.Map([&](const tirx::Var& var) { if (auto it = reverse_map.find(var); it != reverse_map.end()) { - return Downcast((*it).second); + return Downcast((*it).second); } else { // This is the case when the some of the outputs of the shared transform is not used in // this function. @@ -254,7 +254,7 @@ struct LocalCollectInfo : public BaseCollectInfo { }(); if (propagated_tir_vars.size()) { ShapeStructInfo shape_sinfo( - propagated_tir_vars.Map([](tir::Var var) -> PrimExpr { return var; })); + propagated_tir_vars.Map([](tirx::Var var) -> PrimExpr { return var; })); Var shape_expr("vars_from_compile_time_params", shape_sinfo); params.push_back(shape_expr); } @@ -380,7 +380,7 @@ class BaseLiftableBindingCollector : public ExprVisitor { return true; } - std::unordered_set, ObjectPtrHash, ObjectPtrEqual> liftable_vars_; + std::unordered_set, ObjectPtrHash, ObjectPtrEqual> liftable_vars_; bool is_in_dataflow_block_{false}; }; @@ -391,12 +391,12 @@ class LocalLiftableBindingCollector : public BaseLiftableBindingCollector { visitor(func); visitor.info_.orig_func = func; - auto set_union = [&](std::unordered_set, ObjectPtrHash, + auto set_union = [&](std::unordered_set, ObjectPtrHash, ObjectPtrEqual>& target_set, - const std::unordered_set, ObjectPtrHash, - ObjectPtrEqual>& source_set, + const std::unordered_set, + ObjectPtrHash, ObjectPtrEqual>& source_set, const ffi::Map& var_remap, - const ffi::Map& tir_var_remap) { + const ffi::Map& tir_var_remap) { // In-place update the set in global info by unioning with the local set, variable // mappings are applied. for (const auto& relax_or_tir_var : source_set) { @@ -407,11 +407,11 @@ class LocalLiftableBindingCollector : public BaseLiftableBindingCollector { target_set.insert(Downcast(relax_or_tir_var)); } } else { - if (auto it = tir_var_remap.find(Downcast(relax_or_tir_var)); + if (auto it = tir_var_remap.find(Downcast(relax_or_tir_var)); it != tir_var_remap.end()) { - target_set.insert(Downcast((*it).second)); + target_set.insert(Downcast((*it).second)); } else { - target_set.insert(Downcast(relax_or_tir_var)); + target_set.insert(Downcast(relax_or_tir_var)); } } } @@ -509,7 +509,7 @@ class LocalLiftableBindingCollector : public BaseLiftableBindingCollector { /*! \brief Visitor to find the correspondence between parameters in multiple functions. */ class ParamRemapper : private ExprFunctor { public: - static std::pair, ffi::Map> GetParamMapping( + static std::pair, ffi::Map> GetParamMapping( const ffi::Array& functions) { ParamRemapper mapper; if (functions.size()) { @@ -558,14 +558,14 @@ class ParamRemapper : private ExprFunctor { } ffi::Map var_remap_; - ffi::Map tir_var_remap_; + ffi::Map tir_var_remap_; }; class GlobalLiftableBindingCollector : public BaseLiftableBindingCollector { public: static GlobalCollectInfo Collect(const ffi::Array& functions, const ffi::Map& var_remap, - const ffi::Map& tir_var_remap) { + const ffi::Map& tir_var_remap) { GlobalLiftableBindingCollector collector(var_remap, tir_var_remap); TVM_FFI_ICHECK(functions.size()); for (const auto& func : functions) { @@ -613,7 +613,7 @@ class GlobalLiftableBindingCollector : public BaseLiftableBindingCollector { private: GlobalLiftableBindingCollector(const ffi::Map& var_remap, - const ffi::Map tir_var_remap) + const ffi::Map tir_var_remap) : var_remap_(var_remap), tir_var_remap_(tir_var_remap) {} void VisitBinding(const Binding& binding) override { TVM_FFI_ICHECK(!binding->IsInstance()) @@ -637,7 +637,7 @@ class GlobalLiftableBindingCollector : public BaseLiftableBindingCollector { // visits the bindings. ffi::Map var_remap_; // The cross-function between between TIR variables. - ffi::Map tir_var_remap_; + ffi::Map tir_var_remap_; std::vector unified_bindings_; // The mapping between the unified bindings and the original bindings in different functions. // The unified binding is the binding with all variables replaced by the unified variables as diff --git a/src/relax/transform/lower_alloc_tensor.cc b/src/relax/transform/lower_alloc_tensor.cc index d1eb58125a37..397abb60446a 100644 --- a/src/relax/transform/lower_alloc_tensor.cc +++ b/src/relax/transform/lower_alloc_tensor.cc @@ -71,7 +71,7 @@ class Mutator : public ExprMutator { }(); PrimExpr nbytes = [&]() -> PrimExpr { - PrimExpr nbytes = tir::make_const(DataType::Int(64), dtype->value.bytes()); + PrimExpr nbytes = tirx::make_const(DataType::Int(64), dtype->value.bytes()); for (const auto& dim : shape) { nbytes *= dim; } @@ -88,7 +88,7 @@ class Mutator : public ExprMutator { if (vdevice.defined()) { std::string dev_kind = vdevice.value()->target->kind->name; - PrimExpr dev_size = tir::make_const(DataType::Int(64), 1); + PrimExpr dev_size = tirx::make_const(DataType::Int(64), 1); if (vdevice.value()->memory_scope != "global") { auto device_size_handler = tvm::ffi::Function::GetGlobal(std::string("DeviceGetMemSize.") + dev_kind); diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index 86c5b83a8aab..21d2ff7ace64 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -58,7 +58,7 @@ #include #include #include -#include +#include #include "../../support/arena.h" #include "utils.h" diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index 06ac38a77dda..b0a2cba41ba9 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include "../../s_tir/meta_schedule/module_equality.h" #include "../../s_tir/meta_schedule/trace_apply.h" @@ -57,7 +57,7 @@ class MetaScheduleTuner { return mod; } - tir::PrimFunc TuneTIR(tir::PrimFunc f, transform::PassContext ctx) { + tirx::PrimFunc TuneTIR(tirx::PrimFunc f, transform::PassContext ctx) { static ffi::Function tune_tir_func = tvm::ffi::Function::GetGlobalRequired("tvm.s_tir.meta_schedule.tune_tir"); tune_tir_func(normalize_mod_func_(f), target_, work_dir_, max_trials_global_); @@ -101,8 +101,8 @@ Pass MetaScheduleApplyDatabase(ffi::Optional work_dir, bool enable_ for (const auto& iter : mod->functions) { GlobalVar gv = iter.first; BaseFunc base_func = iter.second; - if (const auto* prim_func_node = base_func.as()) { - tir::PrimFunc prim_func = ffi::GetRef(prim_func_node); + if (const auto* prim_func_node = base_func.as()) { + tirx::PrimFunc prim_func = ffi::GetRef(prim_func_node); IRModule tir_mod = (*normalize_mod_func_)(prim_func).cast(); if (ffi::Optional opt_record = @@ -126,15 +126,15 @@ Pass MetaScheduleApplyDatabase(ffi::Optional work_dir, bool enable_ IRModule new_mod = sch->mod(); TVM_FFI_ICHECK_EQ(new_mod->functions.size(), 1); BaseFunc new_base_func = (*new_mod->functions.begin()).second; - TVM_FFI_ICHECK(new_base_func->IsInstance()); - tir::PrimFunc tuned_prim_func = Downcast(new_base_func); + TVM_FFI_ICHECK(new_base_func->IsInstance()); + tirx::PrimFunc tuned_prim_func = Downcast(new_base_func); // maintain the original attributes - tir::PrimFunc new_prim_func = tir::PrimFunc(/*params=*/tuned_prim_func->params, - /*body=*/tuned_prim_func->body, - /*ret_type=*/tuned_prim_func->ret_type, - /*buffer_map=*/tuned_prim_func->buffer_map, - /*attrs=*/prim_func->attrs); - new_prim_func = WithAttr(std::move(new_prim_func), tir::attr::kIsScheduled, true); + tirx::PrimFunc new_prim_func = tirx::PrimFunc(/*params=*/tuned_prim_func->params, + /*body=*/tuned_prim_func->body, + /*ret_type=*/tuned_prim_func->ret_type, + /*buffer_map=*/tuned_prim_func->buffer_map, + /*attrs=*/prim_func->attrs); + new_prim_func = WithAttr(std::move(new_prim_func), tirx::attr::kIsScheduled, true); result.Set(gv, new_prim_func); continue; } else if (enable_warning) { @@ -168,16 +168,16 @@ Pass MetaScheduleTuneIRMod(ffi::Map params, ffi::S Pass MetaScheduleTuneTIR(ffi::String work_dir, Integer max_trials_global) { Target target = Target::Current(false); - ffi::TypedFunction pass_func = - [=](tir::PrimFunc f, IRModule mod, PassContext ctx) { + ffi::TypedFunction pass_func = + [=](tirx::PrimFunc f, IRModule mod, PassContext ctx) { return MetaScheduleTuner(target, work_dir, max_trials_global, max_trials_global, std::nullopt) .TuneTIR(f, ctx); }; - return tir::transform::CreatePrimFuncPass(/*pass function*/ pass_func, /*opt level*/ 0, - /*pass name*/ "MetaScheduleTuneTIR", - /*required*/ {}, - /*traceable*/ true); + return tirx::transform::CreatePrimFuncPass(/*pass function*/ pass_func, /*opt level*/ 0, + /*pass name*/ "MetaScheduleTuneTIR", + /*required*/ {}, + /*traceable*/ true); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc index b59cb61d2853..4f937fa3ff5e 100644 --- a/src/relax/transform/remove_unused_parameters.cc +++ b/src/relax/transform/remove_unused_parameters.cc @@ -85,7 +85,7 @@ std::optional AnalyzeCallee(Function func) { // symbolic variables. We still want to remove the relax variable // to reduce computational steps in the parent, but we need to // provide the symbolic variables the other steps. - auto defined_tir_params = [&]() -> PSet { + auto defined_tir_params = [&]() -> PSet { auto param_sinfo = TupleStructInfo(params.Map([](const auto& var) { return GetStructInfo(var); })); auto arr = DefinableTIRVarsInStructInfo(param_sinfo); @@ -93,7 +93,7 @@ std::optional AnalyzeCallee(Function func) { }(); // Use an array to define the order of the symbolic variables - ffi::Array free_tir_vars; + ffi::Array free_tir_vars; for (const auto& tir_var : FreeSymbolicVars(func->body)) { if (!defined_tir_params.count(tir_var)) { free_tir_vars.push_back(tir_var); diff --git a/src/relax/transform/replace_global_vars.cc b/src/relax/transform/replace_global_vars.cc index 48548de887cd..1216219288be 100644 --- a/src/relax/transform/replace_global_vars.cc +++ b/src/relax/transform/replace_global_vars.cc @@ -28,7 +28,7 @@ #include #include #include -#include +#include namespace tvm { namespace relax { diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index 4fa00d9612ad..9e0ecf44b8a7 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -53,9 +53,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include @@ -88,7 +88,7 @@ struct LiftedFunctionRewritePlan { std::unordered_map outputs; // The corresponding binding vars in the original function of the inputs of the lifted function std::vector inputs; - // The tir vars in the original function that are propagated to the lifted function + // The tirx vars in the original function that are propagated to the lifted function ffi::Optional propogated_tir_vars = std::nullopt; }; @@ -110,7 +110,7 @@ class FuncBuilder : public ExprMutator { * \brief Mark a TIR variable as the ShapeExpr input of the new function. * \param var The variable to mark as input */ - void MarkShapeExprInput(const tir::VarNode* var) { shape_expr_inputs_.push_back(var); } + void MarkShapeExprInput(const tirx::VarNode* var) { shape_expr_inputs_.push_back(var); } /*! * \brief Mark a variable as the output of the new function. The variable must be the LHS of an * existing binding in the new function. @@ -128,8 +128,8 @@ class FuncBuilder : public ExprMutator { if (shape_expr_inputs_.size()) { ffi::Array tir_vars; for (const auto* var : shape_expr_inputs_) { - auto new_var = ffi::GetRef(var).copy_with_suffix(""); - tir_var_remap_.Set(ffi::GetRef(var), new_var); + auto new_var = ffi::GetRef(var).copy_with_suffix(""); + tir_var_remap_.Set(ffi::GetRef(var), new_var); tir_vars.push_back(new_var); } shape_expr = Var("shape_expr", ShapeStructInfo(tir_vars)); @@ -165,13 +165,13 @@ class FuncBuilder : public ExprMutator { return func; } - PrimExpr VisitPrimExpr(const PrimExpr& expr) { return tir::Substitute(expr, tir_var_remap_); } + PrimExpr VisitPrimExpr(const PrimExpr& expr) { return tirx::Substitute(expr, tir_var_remap_); } support::OrderedSet inputs_; support::OrderedSet outputs_; - support::OrderedSet shape_expr_inputs_; + support::OrderedSet shape_expr_inputs_; std::vector bindings_; - ffi::Map tir_var_remap_; + ffi::Map tir_var_remap_; }; // Collect the storage objects that are used as the function output @@ -250,7 +250,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { func->attrs.GetAttr(attr::kNumInput).value_or(Integer(func->params.size())); auto capture_symbolic_var_name_hints = ExtractSymbolicVarHints(func); for (int i = 0; i < static_cast(func->params.size()); ++i) { - ffi::Array symbolic_vars = DefinableTIRVarsInStructInfo( + ffi::Array symbolic_vars = DefinableTIRVarsInStructInfo( Downcast(func->params[i]->struct_info_.value())); if (i < num_inputs.IntValue()) { for (const auto& symbolic_var : symbolic_vars) { @@ -366,13 +366,13 @@ class CUDAGraphRewritePlanner : public ExprVisitor { const auto* call_gv = call->op.as(); bool call_prim_func = - call_gv ? mod_->Lookup(ffi::GetRef(call_gv))->IsInstance() + call_gv ? mod_->Lookup(ffi::GetRef(call_gv))->IsInstance() : false; // Check whether the call can be lifted to the capture function. It requires all the arguments // to be static and the call to be a kernel launch or a pure operation (e.g. memory view). std::vector args; - std::vector tir_vars; + std::vector tir_vars; bool is_all_static = [&]() { if (!IsStatic(call->args, &args, &tir_vars)) { return false; @@ -419,7 +419,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { } void MarkAsFuncInput(const std::vector& vars, - const std::vector& tir_vars = {}) { + const std::vector& tir_vars = {}) { if (current_block_scope_.capture_builder == nullptr) { return; } @@ -429,7 +429,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { current_block_scope_.capture_builder->MarkInput(var); } } - for (const tir::VarNode* tir_var : tir_vars) { + for (const tirx::VarNode* tir_var : tir_vars) { current_block_scope_.capture_builder->MarkShapeExprInput(tir_var); } } @@ -459,7 +459,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple) final { std::vector args; - std::vector tir_vars; + std::vector tir_vars; if (IsStatic(tuple->fields, &args, &tir_vars)) { AddStaticBinding(binding, false); MarkAsFuncInput(args, tir_vars); @@ -483,10 +483,10 @@ class CUDAGraphRewritePlanner : public ExprVisitor { bool IsStatic(const PrimExpr& expr, [[maybe_unused]] std::vector* vars_collector = nullptr, - std::vector* tir_vars_collector = nullptr) { + std::vector* tir_vars_collector = nullptr) { bool is_static = true; - tir::PostOrderVisit(expr, [&](const ObjectRef& e) { - if (auto var = e.as()) { + tirx::PostOrderVisit(expr, [&](const ObjectRef& e) { + if (auto var = e.as()) { if (!capture_symbolic_vars_.count(var)) { is_static = false; return; @@ -500,7 +500,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { } bool IsStatic(const Expr& expr, std::vector* vars_collector = nullptr, - std::vector* tir_vars_collector = nullptr) { + std::vector* tir_vars_collector = nullptr) { if (expr->IsInstance() || expr->IsInstance() || expr->IsInstance() || expr->IsInstance()) { return true; @@ -528,7 +528,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { template bool IsStatic(const ffi::Array& exprs, std::vector* vars_collector = nullptr, - std::vector* tir_vars_collector = nullptr) { + std::vector* tir_vars_collector = nullptr) { bool result = true; for (const auto& expr : exprs) { // If vars_collector is provided, we will collect all the vars in the exprs and we should @@ -542,7 +542,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { } bool IsStatic(const StructInfo& sinfo, std::vector* vars_collector = nullptr, - std::vector* tir_vars_collector = nullptr) { + std::vector* tir_vars_collector = nullptr) { if (const auto* tensor_sinfo = sinfo.as()) { if (auto shape = tensor_sinfo->GetShape()) { return IsStatic(shape.value(), vars_collector, tir_vars_collector); @@ -618,7 +618,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { std::unordered_set static_vars_; // Symbolic variables that are allowed to be captured. This can come from symbolic shapes of // weights or hints in the function annotations. - std::unordered_set capture_symbolic_vars_; + std::unordered_set capture_symbolic_vars_; // Binding to the FuncBuilder if the binding is lifted. This is used to update the inputs/outputs // of the lifted function when its binding is used outside. std::unordered_map binding_to_region_; @@ -805,10 +805,11 @@ class CUDAGraphRewriter : public ExprMutator { const auto& shape_expr = plan->func->params.back(); auto symbolic_params = Downcast(shape_expr->struct_info_.value())->values.value(); - ffi::Map tir_var_remap; + ffi::Map tir_var_remap; TVM_FFI_ICHECK_EQ(symbolic_params.size(), propogated_tir_vars->values.size()); for (int i = 0; i < static_cast(symbolic_params.size()); ++i) { - tir_var_remap.Set(Downcast(symbolic_params[i]), propogated_tir_vars->values[i]); + tir_var_remap.Set(Downcast(symbolic_params[i]), + propogated_tir_vars->values[i]); } call_sinfo = Bind(call_sinfo, tir_var_remap); } diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc index 7cd0e46cb328..abc802d5b482 100644 --- a/src/relax/transform/rewrite_dataflow_reshape.cc +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -25,8 +25,8 @@ #include #include #include -#include -#include +#include +#include #include @@ -35,13 +35,13 @@ namespace tvm { namespace relax { -std::vector GetUsedTensorArgIndices(const tir::PrimFunc& fn, size_t num_args) { +std::vector GetUsedTensorArgIndices(const tirx::PrimFunc& fn, size_t num_args) { std::vector indices; for (size_t i = 0; i < num_args; ++i) { if (auto buffer = fn->buffer_map.Get(fn->params[i])) { auto buffer_var = buffer.value()->data; - if (tir::UsesVar(fn->body, - [=](const tir::VarNode* var) { return var == buffer_var.get(); })) { + if (tirx::UsesVar(fn->body, + [=](const tirx::VarNode* var) { return var == buffer_var.get(); })) { indices.push_back(i); } } @@ -85,7 +85,7 @@ class DataflowReshapeRewriter : public ExprMutator { // relax.reshape op, which will be lowered to calls of the ExternFunc // vm.builtin.reshape in the VMBuiltinLower pass. - auto prim_fn = Downcast(mod_->Lookup(Downcast(call->args[0]))); + auto prim_fn = Downcast(mod_->Lookup(Downcast(call->args[0]))); auto arg_tuple = Downcast(call->args[1])->fields; auto used_tensor_arg_indices = GetUsedTensorArgIndices(prim_fn, arg_tuple.size()); @@ -109,9 +109,9 @@ class DataflowReshapeRewriter : public ExprMutator { bool IsCallingTIRReshape(const CallNode* call, Expr inp) { const GlobalVar& global_var = Downcast(call->args[0]); - const auto* func = mod_->functions.Get(global_var).value().as(); + const auto* func = mod_->functions.Get(global_var).value().as(); TVM_FFI_ICHECK_NOTNULL(func); - if (!HasReshapePattern(ffi::GetRef(func))) { + if (!HasReshapePattern(ffi::GetRef(func))) { return false; } diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index d02e62fc4d2a..14b64c34d6e2 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -174,7 +174,7 @@ class CodeGenRunner : ExprMutator { std::unordered_map> target_functions; for (const auto& entry : mod->functions) { - if (entry.second->IsInstance()) { + if (entry.second->IsInstance()) { continue; } PostOrderVisit(entry.second, [&target_functions](Expr e) { diff --git a/src/relax/transform/specialize_primfunc_based_on_callsite.cc b/src/relax/transform/specialize_primfunc_based_on_callsite.cc index a0dcdaf4a13b..d39adefbada7 100644 --- a/src/relax/transform/specialize_primfunc_based_on_callsite.cc +++ b/src/relax/transform/specialize_primfunc_based_on_callsite.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include @@ -37,7 +37,7 @@ namespace tvm { namespace relax { -using tvm::tir::Buffer; +using tvm::tirx::Buffer; static ffi::Array GetShapeFromTensorStructInfo(const TensorStructInfo& tensor_sinfo) { auto shape = tensor_sinfo->GetShape(); @@ -78,9 +78,9 @@ class SpecializeTIRCallArgs : ExprMutator { private: Expr SpecializeTirPrimFunc(Call call) { auto gv = Downcast(call->args[0]); - auto pfunc = Downcast(mod_->Lookup(gv)); + auto pfunc = Downcast(mod_->Lookup(gv)); auto args = Downcast(call->args[1])->fields; - ffi::Map> param_map; + ffi::Map> param_map; for (size_t i = 0; i < args.size(); ++i) { auto sinfo = GetStructInfo(args[i]); @@ -99,8 +99,8 @@ class SpecializeTIRCallArgs : ExprMutator { name = std::string({static_cast('A' + i)}); } - const Buffer& buffer = tir::decl_buffer(GetShapeFromTensorStructInfo(tensor_sinfo), - tensor_sinfo->dtype, name, scope); + const Buffer& buffer = tirx::decl_buffer(GetShapeFromTensorStructInfo(tensor_sinfo), + tensor_sinfo->dtype, name, scope); param_map.Set(pfunc->params[i], buffer); } ffi::String scope = "global"; @@ -111,7 +111,7 @@ class SpecializeTIRCallArgs : ExprMutator { scope = sinfo->vdevice.value()->memory_scope; } const Buffer& buffer = - tir::decl_buffer(GetShapeFromTensorStructInfo(sinfo), sinfo->dtype, "ret_val", scope); + tirx::decl_buffer(GetShapeFromTensorStructInfo(sinfo), sinfo->dtype, "ret_val", scope); param_map.Set(pfunc->params[pfunc->params.size() - 1], buffer); } else { TVM_FFI_ICHECK(out_sinfo->IsInstance()) @@ -132,8 +132,8 @@ class SpecializeTIRCallArgs : ExprMutator { scope = sinfo->vdevice.value()->memory_scope; } - const Buffer& buffer = tir::decl_buffer(GetShapeFromTensorStructInfo(sinfo), sinfo->dtype, - "ret_val_" + std::to_string(index), scope); + const Buffer& buffer = tirx::decl_buffer(GetShapeFromTensorStructInfo(sinfo), sinfo->dtype, + "ret_val_" + std::to_string(index), scope); param_map.Set(pfunc->params[args.size() + index], buffer); index++; } diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index f850f681aaec..ff0e961d8f8d 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -28,9 +28,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include "../../s_tir/schedule/ir_comparator.h" @@ -41,7 +41,7 @@ static const constexpr char* kCSource = "c_source"; static const constexpr char* kCSourceFmt = "c_source_fmt"; static const constexpr char* kCSourceFmtCuda = "cu"; -namespace tir { +namespace tirx { using relax::FCodegen; using relax::MatchResult; @@ -53,7 +53,7 @@ using s_tir::TensorizeComparator; class ForMatcher : public TensorizeComparator { public: using SymbolMap = std::unordered_map; - explicit ForMatcher(const tir::PrimFunc& pattern, const ffi::Array& pattern_vars) + explicit ForMatcher(const tirx::PrimFunc& pattern, const ffi::Array& pattern_vars) : TensorizeComparator(IRModule({{GlobalVar(""), pattern}}), false), pattern_(pattern) { for (const auto& pattern_var : pattern_vars) { this->pattern_vars_.insert(pattern_var); @@ -101,7 +101,7 @@ class ForMatcher : public TensorizeComparator { // special case for pattern vars const auto* lhs_ptr = lhs.as(); if (lhs_ptr == nullptr) { - if (lhs->IsInstance() || lhs->IsInstance()) { + if (lhs->IsInstance() || lhs->IsInstance()) { ffi::Optional value = QueryEvaluatedSymbols(ffi::GetRef(op)); if (value.defined()) { if (!analyzer_.CanProveEqual(lhs, value.value())) return false; @@ -180,7 +180,7 @@ class ForMatcher : public TensorizeComparator { return TensorizeComparator::VisitExpr(lhs, rhs); } - bool VisitExpr_(const tir::AddNode* add, const PrimExpr& other) final { + bool VisitExpr_(const tirx::AddNode* add, const PrimExpr& other) final { const auto* rhs = other.as(); if (rhs == nullptr) return false; { @@ -206,7 +206,7 @@ class ForMatcher : public TensorizeComparator { return false; } - bool VisitExpr_(const tir::MulNode* mul, const PrimExpr& other) final { + bool VisitExpr_(const tirx::MulNode* mul, const PrimExpr& other) final { const auto* rhs = other.as(); if (rhs == nullptr) return false; { @@ -232,7 +232,7 @@ class ForMatcher : public TensorizeComparator { return false; } - bool VisitExpr_(const tir::CallNode* call, const PrimExpr& other) final { + bool VisitExpr_(const tirx::CallNode* call, const PrimExpr& other) final { const auto* rhs = other.as(); if (rhs == nullptr) return false; const auto* lhs_op = call->op.as(); @@ -246,7 +246,7 @@ class ForMatcher : public TensorizeComparator { return true; } - bool VisitStmt_(const tir::ForNode* op, const Stmt& other) final { + bool VisitStmt_(const tirx::ForNode* op, const Stmt& other) final { const auto* rhs = other.as(); loop_stack_lhs_.push_back(ffi::GetRef(op)); loop_stack_rhs_.push_back(ffi::GetRef(rhs)); @@ -269,7 +269,7 @@ class ForMatcher : public TensorizeComparator { return VisitStmt(op->body, rhs->body); } - bool VisitStmt_(const tir::SBlockNode* op, const Stmt& other) final { + bool VisitStmt_(const tirx::SBlockNode* op, const Stmt& other) final { const auto* rhs = other.as(); // Check block equality. // All iter vars and buffer regions including the order should match. @@ -369,7 +369,7 @@ class ForMatcher : public TensorizeComparator { arith::Analyzer analyzer_; std::vector loop_stack_lhs_, loop_stack_rhs_; - tir::PrimFunc pattern_; + tirx::PrimFunc pattern_; std::unordered_set pattern_vars_; }; @@ -389,7 +389,7 @@ class TIRPatternMatcher { // Find an op that matches this block bool BlockPatternMatch(const For& top) { for (const TIRPattern& pattern : patterns_) { - tir::PrimFunc pattern_func = pattern; + tirx::PrimFunc pattern_func = pattern; ffi::Array pattern_symbolic_vars; int buffer_count = pattern_func->buffer_map.size(); for (int i = buffer_count; i < static_cast(pattern_func->params.size()); i++) { @@ -655,7 +655,7 @@ std::pair> SplitFunctions( PrimFunc func2 = PrimFunc(new_params2, body2, func->ret_type, new_buffer_map2, func->attrs); return {func1, func2}; } -} // namespace tir +} // namespace tirx namespace relax { void StringReplace(std::string* subject, const std::string& search, const std::string& replace) { @@ -665,11 +665,11 @@ void StringReplace(std::string* subject, const std::string& search, const std::s } } -tvm::BaseFunc CodegenWithLibrary(const tir::PrimFuncNode* pf, ffi::String global_symbol) { - using namespace tvm::tir; +tvm::BaseFunc CodegenWithLibrary(const tirx::PrimFuncNode* pf, ffi::String global_symbol) { + using namespace tvm::tirx; ffi::Optional library_code = pf->attrs.GetAttr(kLibraryKernel); if (!library_code.has_value()) { - return ffi::GetRef(pf); + return ffi::GetRef(pf); } std::string source = library_code.value(); StringReplace(&source, "{global_symbol}", global_symbol); @@ -719,15 +719,15 @@ class SplitMutator : public ExprMutator { if (gv_ptr == nullptr) return call; GlobalVar gv = ffi::GetRef(gv_ptr); // retrieve the function from the module and split it - tir::PrimFunc func = Downcast(mod_->Lookup(gv)); + tirx::PrimFunc func = Downcast(mod_->Lookup(gv)); std::vector> arg_partition; // split the function into two functions, one for the library kernel and one for the rest. - std::pair> split_funcs = - tir::SplitFunctions(func, &arg_partition, patterns_, fcodegen_); + std::pair> split_funcs = + tirx::SplitFunctions(func, &arg_partition, patterns_, fcodegen_); if (!split_funcs.second.defined()) { // no need to split, the function itself a library kernel tvm::BaseFunc lib_func = CodegenWithLibrary(split_funcs.first.get(), gv->name_hint); - if (lib_func->IsInstance()) return ffi::GetRef(op); + if (lib_func->IsInstance()) return ffi::GetRef(op); // Update the function in the module with the library kernel TVM_FFI_ICHECK(lib_func->IsInstance()); builder_->UpdateFunction(gv, lib_func); @@ -737,8 +737,8 @@ class SplitMutator : public ExprMutator { new_call->args = {lib_func, call->args[1]}; return Call(new_call); } - tir::PrimFunc func1 = s_tir::RenewDefs(split_funcs.first); - tir::PrimFunc func2 = s_tir::RenewDefs(split_funcs.second.value()); + tirx::PrimFunc func1 = s_tir::RenewDefs(split_funcs.first); + tirx::PrimFunc func2 = s_tir::RenewDefs(split_funcs.second.value()); TVM_FFI_ICHECK(arg_partition.size() == 2); // emit the first call to the library kernel ffi::Array args1; @@ -747,10 +747,10 @@ class SplitMutator : public ExprMutator { } // replace the function in the module with the library kernel tvm::BaseFunc lib_func = CodegenWithLibrary(func1.get(), gv->name_hint); - if (lib_func->IsInstance()) return ffi::GetRef(op); + if (lib_func->IsInstance()) return ffi::GetRef(op); TVM_FFI_ICHECK(lib_func->IsInstance()); builder_->UpdateFunction(gv, lib_func); - tir::Buffer intermediate_buffer = func1->buffer_map.at(func1->params.back()); + tirx::Buffer intermediate_buffer = func1->buffer_map.at(func1->params.back()); DataType dtype = intermediate_buffer->dtype; Call call1(call_dps_packed_, {lib_func, Tuple(args1)}, call->attrs, {TensorStructInfo(ShapeExpr(intermediate_buffer->shape), dtype)}); diff --git a/src/relax/transform/split_layout_rewrite_preproc.cc b/src/relax/transform/split_layout_rewrite_preproc.cc index 535cf361497b..da25f5f10d64 100644 --- a/src/relax/transform/split_layout_rewrite_preproc.cc +++ b/src/relax/transform/split_layout_rewrite_preproc.cc @@ -27,13 +27,13 @@ #include #include #include -#include +#include #include #include namespace tvm { -namespace tir { +namespace tirx { class SplitPrimFuncLayoutRewrite : public StmtMutator { public: explicit SplitPrimFuncLayoutRewrite(const PrimFunc& func) : original_func_(func) {} @@ -236,7 +236,7 @@ class SplitPrimFuncLayoutRewrite : public StmtMutator { /*! \brief The original primfunc*/ PrimFunc original_func_; }; -} // namespace tir +} // namespace tirx namespace relax { class SplitLayoutRewritePreproc : public ExprMutator { @@ -246,9 +246,9 @@ class SplitLayoutRewritePreproc : public ExprMutator { // Step 1: Split the primfunc into preproc and compute for (auto [gv, func] : mod->functions) { - if (func->IsInstance()) { - tir::SplitPrimFuncLayoutRewrite tir_rewriter(Downcast(func)); - auto [preproc_func, compute_func] = tir_rewriter.Transform(Downcast(func)); + if (func->IsInstance()) { + tirx::SplitPrimFuncLayoutRewrite tir_rewriter(Downcast(func)); + auto [preproc_func, compute_func] = tir_rewriter.Transform(Downcast(func)); if (preproc_func.defined()) { mutator.split_funcs_.emplace(gv.get(), std::make_tuple(preproc_func.value(), compute_func)); @@ -274,7 +274,7 @@ class SplitLayoutRewritePreproc : public ExprMutator { static const Op& call_tir_op = Op::Get("relax.call_tir"); Call call = Downcast(ExprMutator::VisitExpr_(op)); - // Step 1: Skip call to other than `tir.call_tir` + // Step 1: Skip call to other than `tirx.call_tir` if (!call->op.same_as(call_tir_op)) { return call; } @@ -302,9 +302,9 @@ class SplitLayoutRewritePreproc : public ExprMutator { ffi::Array preproc_sinfo_list; for (const auto& info : rewrite_infos) { preproc_args.push_back(call_tir_args[info.buffer_index]); - tir::Buffer rewritten_buffer = info.post_rewrite_buffer; + tirx::Buffer rewritten_buffer = info.post_rewrite_buffer; for (const auto& shape_expr : rewritten_buffer->shape) { - TVM_FFI_ICHECK(shape_expr.as()) + TVM_FFI_ICHECK(shape_expr.as()) << "Currently does not support rewrite buffer with " "dynamic shape."; } @@ -332,9 +332,9 @@ class SplitLayoutRewritePreproc : public ExprMutator { } private: - std::unordered_map> split_funcs_; + std::unordered_map> split_funcs_; std::unordered_map> + std::vector> rewrite_infos_; }; diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index b4008ec2cdad..30d116f12566 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -71,7 +71,7 @@ #include #include #include -#include +#include #include #include @@ -116,7 +116,7 @@ class StorageTokenNode : public Object { /*! \brief Get the constant number of bytes that this token requires, or -1 if the number of bytes * is symbolic */ int64_t const_bytes() const { - const int64_t* const_val = tir::as_const_int(bytes); + const int64_t* const_val = tirx::as_const_int(bytes); if (const_val) { return *const_val; } else { @@ -138,7 +138,7 @@ class StorageToken : public ObjectRef { ffi::Optional vdevice = std::nullopt) { // Compute the tensor size from the shape. int64_t const_coeff = dtype.bytes() * dtype.lanes(); - PrimExpr size = tir::make_const(DataType::Int(64), 1); + PrimExpr size = tirx::make_const(DataType::Int(64), 1); bool size_computed = false; if (vdevice.defined()) { @@ -172,7 +172,7 @@ class StorageToken : public ObjectRef { } } - size = tir::make_const(DataType::Int(64), const_coeff) * size; + size = tirx::make_const(DataType::Int(64), const_coeff) * size; ObjectPtr n = ffi::make_object(); n->bytes = size; @@ -258,7 +258,7 @@ class TokenAllocatorMixed { TVM_FFI_ICHECK_GE(available_size, 0); TVM_FFI_ICHECK_GE(size, available_size); // Enlarge the token size. - available_token->bytes = tir::make_const(DataType::Int(64), size); + available_token->bytes = tirx::make_const(DataType::Int(64), size); available_token->ref_counter = prototype->ref_counter; pool.erase(mid); return available_token; @@ -408,7 +408,7 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { * \param dom_map The domain map of the TIR variables. */ void SetTIRVarRangeConstraints(Function func, arith::Analyzer* ana, - ffi::Map* dom_map) { + ffi::Map* dom_map) { // Use the attribute-annotated TIR var bounds as the TIR var values for // memory planning. // NOTE: we only apply the annotated bounds to the TIR variables that @@ -435,8 +435,8 @@ void SetTIRVarRangeConstraints(Function func, arith::Analyzer* ana, for (const ffi::String& var_name : non_negative_var_attr_raw) { non_negative_var_attr.insert(var_name); } - ffi::Array var_in_signature = TIRVarsInStructInfo(GetStructInfo(func)); - for (const tir::Var& tir_var : var_in_signature) { + ffi::Array var_in_signature = TIRVarsInStructInfo(GetStructInfo(func)); + for (const tirx::Var& tir_var : var_in_signature) { auto it_upper = var_upper_bound_attr.find(tir_var->name_hint); auto it_lower = var_lower_bound_attr.find(tir_var->name_hint); @@ -468,7 +468,7 @@ void SetTIRVarRangeConstraints(Function func, arith::Analyzer* ana, * cannot be determined, we keep the dimension unchanged. */ ffi::Array GetUpperBoundShape(ffi::Array shape, arith::Analyzer* ana, - const ffi::Map& dom_map) { + const ffi::Map& dom_map) { // Use the upper bounds of TIR vars as their values. ffi::Array upper_bounded_shape; upper_bounded_shape.reserve(shape.size()); @@ -616,7 +616,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { if (func_it == ctx_mod_->functions.end()) { return false; } - return (*func_it).second->IsInstance(); + return (*func_it).second->IsInstance(); } /*! @@ -725,7 +725,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { /*! \brief The arithmetic analyzer. */ arith::Analyzer* analyzer_; /*! \brief The domain map of dynamic TIR variables for analysis. */ - ffi::Map dom_map_; + ffi::Map dom_map_; /*! \brief The mapping from each token to the binding block where it is created. */ std::unordered_map token2block_; /*! \brief The mapping from each token to the Exprs that are using this token. */ @@ -994,7 +994,7 @@ class StorageAllocationRewriter : public ExprMutator { /*! \brief The arithmetic analyzer. */ arith::Analyzer ana_; /*! \brief The domain map of dynamic TIR variables for analysis. */ - ffi::Map dom_map_; + ffi::Map dom_map_; /*! \brief A boolean indicating whether to plan dynamic-shape function output tensors. */ bool plan_dynamic_output_; /*! @@ -1047,8 +1047,9 @@ PrimExpr GetTextureMemorySizeFromVDevice(ffi::Array pshape, DataType d struct Shape { const ffi::Array& shape; int64_t operator[](size_t i) const { - TVM_FFI_ICHECK(tir::as_const_int(shape[i])) << "Dymamic shapes not suported over texture now"; - return *tir::as_const_int(shape[i]); + TVM_FFI_ICHECK(tirx::as_const_int(shape[i])) + << "Dymamic shapes not suported over texture now"; + return *tirx::as_const_int(shape[i]); } int size() { return this->shape.size(); } }; @@ -1056,7 +1057,7 @@ PrimExpr GetTextureMemorySizeFromVDevice(ffi::Array pshape, DataType d size_t size = runtime::GetTextureMemorySize(shape, dtype.bytes() * 8, dtype.lanes(), vdevice->memory_scope, image_row_align); - return tir::make_const(DataType::Int(64), size); + return tirx::make_const(DataType::Int(64), size); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/relax/transform/to_non_dataflow.cc b/src/relax/transform/to_non_dataflow.cc index b9345744320c..f9fd6c12232a 100644 --- a/src/relax/transform/to_non_dataflow.cc +++ b/src/relax/transform/to_non_dataflow.cc @@ -26,7 +26,7 @@ #include #include #include -#include +#include namespace tvm { namespace relax { diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 5d216f3f8425..2eef6c92cdef 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include @@ -223,7 +223,7 @@ class VarReplacer : public ExprMutator { * \details This mutator is used to prevent the same symbolic var from being used in different * functions, which is malformed. */ -class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator { +class SymbolicVarRenewMutator : public ExprMutator, tirx::ExprMutator { public: static Function Renew(const Function& function) { SymbolicVarRenewMutator mutator; @@ -234,21 +234,21 @@ class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator { protected: using relax::ExprMutator::VisitExpr; using relax::ExprMutator::VisitExpr_; - using tir::ExprMutator::VisitExpr_; + using tirx::ExprMutator::VisitExpr_; - PrimExpr VisitPrimExpr(const PrimExpr& expr) final { return tir::ExprMutator::VisitExpr(expr); } + PrimExpr VisitPrimExpr(const PrimExpr& expr) final { return tirx::ExprMutator::VisitExpr(expr); } // TODO(Siyuan): enhance the method to the following steps: - // 1. Visit and replace all tir::Vars at the definition point + // 1. Visit and replace all tirx::Vars at the definition point // 2. Revisit the function again and update the use side. - PrimExpr VisitExpr_(const tir::VarNode* op) final { - auto it = var_map_.find(ffi::GetRef(op)); + PrimExpr VisitExpr_(const tirx::VarNode* op) final { + auto it = var_map_.find(ffi::GetRef(op)); if (it != var_map_.end()) { return (*it).second; } else { - auto n = ffi::make_object(*op); - tir::Var v(n); - var_map_.Set(ffi::GetRef(op), v); + auto n = ffi::make_object(*op); + tirx::Var v(n); + var_map_.Set(ffi::GetRef(op), v); return v; } } @@ -275,11 +275,11 @@ class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator { } } - ffi::Map var_map_; + ffi::Map var_map_; }; /*! - * \brief Copy a function while renewing the relax Vars and the tir Vars. + * \brief Copy a function while renewing the relax Vars and the tirx Vars. * \details All variables that are bound inside the original function would be copied to satisfy * the restriction in the well-formed check: Variables in Relax must be bound exactly once. */ diff --git a/src/relax/utils.cc b/src/relax/utils.cc index ae4c953fd007..14367a67b583 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -23,7 +23,7 @@ #include #include #include -#include +#include namespace tvm { namespace relax { @@ -32,7 +32,7 @@ namespace relax { class ExprBinder : public ExprMutator { public: explicit ExprBinder(const tvm::ffi::Map& args_map, - const tvm::ffi::Map& symbolic_var_map) + const tvm::ffi::Map& symbolic_var_map) : args_map_(args_map), symbolic_var_map_(symbolic_var_map) {} private: @@ -77,7 +77,7 @@ class ExprBinder : public ExprMutator { } PrimExpr VisitPrimExpr(const PrimExpr& expr) final { - auto new_expr = tir::Substitute(expr, symbolic_var_map_); + auto new_expr = tirx::Substitute(expr, symbolic_var_map_); if (!expr.same_as(new_expr)) { arith::Analyzer analyzer; new_expr = analyzer.Simplify(new_expr); @@ -87,7 +87,7 @@ class ExprBinder : public ExprMutator { private: const tvm::ffi::Map& args_map_; - const tvm::ffi::Map& symbolic_var_map_; + const tvm::ffi::Map& symbolic_var_map_; }; /*! @@ -98,22 +98,22 @@ class ExprBinder : public ExprMutator { * \return The result expr after bind params */ Expr Bind(const Expr& expr, const tvm::ffi::Map& binds, - const tvm::ffi::Map& symbolic_var_map) { + const tvm::ffi::Map& symbolic_var_map) { return ExprBinder(binds, symbolic_var_map).VisitExpr(expr); } StructInfo Bind(const StructInfo& sinfo, - const tvm::ffi::Map& symbolic_var_map) { + const tvm::ffi::Map& symbolic_var_map) { return ExprBinder({}, symbolic_var_map).VisitExprDepStructInfoField(sinfo); } -tvm::ffi::Map InferSymbolicVarMap( +tvm::ffi::Map InferSymbolicVarMap( const tvm::ffi::Map& relax_var_remap, arith::Analyzer* analyzer) { - tvm::ffi::Map tir_var_remap; + tvm::ffi::Map tir_var_remap; auto bind_from_prim_expr = [&tir_var_remap](const PrimExpr& var_shape, const PrimExpr& expr_shape) { - if (auto var = var_shape.as()) { + if (auto var = var_shape.as()) { tir_var_remap.Set(var.value(), expr_shape); } }; diff --git a/src/runtime/metadata.h b/src/runtime/metadata.h index 88ca4f609d7b..adc578848516 100644 --- a/src/runtime/metadata.h +++ b/src/runtime/metadata.h @@ -49,11 +49,11 @@ inline ffi::String get_name_mangled(const ffi::String& module_name, const ffi::S namespace launch_param { /*! \brief A tag to specify whether or not dynamic shared memory is used */ -constexpr const char* kUseDynamicSharedMemoryTag = "tir.use_dyn_shared_memory"; +constexpr const char* kUseDynamicSharedMemoryTag = "tirx.use_dyn_shared_memory"; /*! \brief A tag to specify whether or not use programatic dependent launch */ -constexpr const char* kUseProgramaticDependentLaunch = "tir.use_programtic_dependent_launch"; +constexpr const char* kUseProgramaticDependentLaunch = "tirx.use_programtic_dependent_launch"; /*! \brief A tag to specify whether or not use cooperative launch */ -constexpr const char* kUseCooperativeLaunch = "tir.use_cooperative_launch"; +constexpr const char* kUseCooperativeLaunch = "tirx.use_cooperative_launch"; } // namespace launch_param diff --git a/src/runtime/vm/attn_backend.cc b/src/runtime/vm/attn_backend.cc index 8bb57103d305..e2a2c5232550 100644 --- a/src/runtime/vm/attn_backend.cc +++ b/src/runtime/vm/attn_backend.cc @@ -31,7 +31,7 @@ std::unique_ptr ConvertPagedPrefillFunc(ffi::Array a return nullptr; } ffi::String backend_name = args[0].cast(); - if (backend_name == "tir") { + if (backend_name == "tirx") { TVM_FFI_ICHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); @@ -53,7 +53,7 @@ std::unique_ptr ConvertRaggedPrefillFunc(ffi::Array return nullptr; } ffi::String backend_name = args[0].cast(); - if (backend_name == "tir") { + if (backend_name == "tirx") { TVM_FFI_ICHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); @@ -82,7 +82,7 @@ std::unique_ptr ConvertPagedDecodeFunc(ffi::Array arg return nullptr; } ffi::String backend_name = args[0].cast(); - if (backend_name == "tir") { + if (backend_name == "tirx") { TVM_FFI_ICHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); @@ -104,7 +104,7 @@ std::unique_ptr ConvertPagedPrefillTreeMaskFunc(ffi::A return nullptr; } ffi::String backend_name = args[0].cast(); - if (backend_name == "tir") { + if (backend_name == "tirx") { TVM_FFI_ICHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); @@ -119,7 +119,7 @@ std::unique_ptr ConvertRaggedPrefillTreeMaskFunc( return nullptr; } ffi::String backend_name = args[0].cast(); - if (backend_name == "tir") { + if (backend_name == "tirx") { TVM_FFI_ICHECK_EQ(args.size(), 2); ffi::Function attn_func = args[1].cast(); return std::make_unique(std::move(attn_func), attn_kind); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 38991d7714d7..16a2c5ea95eb 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -627,7 +627,7 @@ ffi::Optional VirtualMachineImpl::GetClosureInternal(const ffi::Strin << "Cannot support closure with function kind " << static_cast(finfo.kind); ffi::Optional tir_func = GetFuncFromImports("__vmtir__" + finfo.name); TVM_FFI_ICHECK(tir_func.has_value()) - << "Cannot find underlying compiled tir function of VMTIRFunc " << finfo.name; + << "Cannot find underlying compiled tirx function of VMTIRFunc " << finfo.name; auto impl = ffi::Function([this, finfo, tir_func](ffi::PackedArgs args, ffi::Any* rv) { // Per convention, ctx ptr is a VirtualMachine* VirtualMachine* ctx_ptr = static_cast(args[0].cast()); diff --git a/src/s_tir/analysis/calculate_allocated_memory.cc b/src/s_tir/analysis/calculate_allocated_memory.cc index 03a74ac0e207..57e6d690f419 100644 --- a/src/s_tir/analysis/calculate_allocated_memory.cc +++ b/src/s_tir/analysis/calculate_allocated_memory.cc @@ -18,7 +18,7 @@ */ /*! - * \file tir/analysis/calculate_allocated_memory.cc + * \file tirx/analysis/calculate_allocated_memory.cc * \brief Calculate allocated memory per memory scope required by PrimFuncs. */ #include @@ -26,10 +26,10 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include @@ -37,7 +37,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; std::string GetStorageScope(const Var& var) { auto* ptr = var->type_annotation.as(); @@ -112,7 +112,7 @@ tvm::ffi::Map > CalculateAlloca const IRModule& mod) { tvm::ffi::Map > results; for (const auto& kv : mod->functions) { - if (auto prim_func = kv.second.as()) { + if (auto prim_func = kv.second.as()) { ffi::String func_name = kv.first->name_hint; auto alloc_buffer_result = AllocBufferCalculator()(prim_func.value()); results.Set(func_name, alloc_buffer_result); @@ -166,7 +166,7 @@ int64_t GetVTCMCapacity(Target target, const tvm::transform::PassContext& pass_c auto value = target->GetAttr("vtcm-capacity").value()->value; if (value > 0) return value; } - return pass_ctx->GetConfig("tir.vtcm_capacity", Integer(0)).value()->value; + return pass_ctx->GetConfig("tirx.vtcm_capacity", Integer(0)).value()->value; } ffi::Array GetVTCMCompactionPasses() { @@ -178,10 +178,10 @@ ffi::Array GetVTCMCompactionPasses() { pass_list.push_back(s_tir::transform::LowerMatchBuffer()); pass_list.push_back(s_tir::transform::InjectSoftwarePipeline()); pass_list.push_back(s_tir::transform::LowerOpaqueBlock()); - pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::VectorizeLoop(true)); - pass_list.push_back(tir::transform::StorageRewrite()); + pass_list.push_back(tirx::transform::FlattenBuffer()); + pass_list.push_back(tirx::transform::Simplify()); + pass_list.push_back(tirx::transform::VectorizeLoop(true)); + pass_list.push_back(tirx::transform::StorageRewrite()); return pass_list; } diff --git a/src/s_tir/analysis/estimate_flops.cc b/src/s_tir/analysis/estimate_flops.cc index 8a1812fb4089..508523f91e0d 100644 --- a/src/s_tir/analysis/estimate_flops.cc +++ b/src/s_tir/analysis/estimate_flops.cc @@ -17,14 +17,14 @@ * under the License. */ #include -#include -#include +#include +#include #include "tvm/arith/analyzer.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; int32_t DataType2Int(const tvm::DataType& dtype) { static_assert(sizeof(DLDataType) == sizeof(int32_t), "Incorrect size of DLDataType"); diff --git a/src/s_tir/analysis/find_anchor_sblock.cc b/src/s_tir/analysis/find_anchor_sblock.cc index ad04aaf284f3..9b3fc58c5d03 100644 --- a/src/s_tir/analysis/find_anchor_sblock.cc +++ b/src/s_tir/analysis/find_anchor_sblock.cc @@ -24,11 +24,11 @@ #include #include -#include -#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { Stmt GetEnclosingLoop(const SBlockNode* block, Stmt func_body) { struct GetRootSeqStmt : public StmtVisitor { @@ -118,5 +118,5 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/s_tir/analysis/identify_memcpy.cc b/src/s_tir/analysis/identify_memcpy.cc index c7edf89c06c8..b23e7cee0221 100644 --- a/src/s_tir/analysis/identify_memcpy.cc +++ b/src/s_tir/analysis/identify_memcpy.cc @@ -18,7 +18,7 @@ */ /*! - * \file tir/analysis/identify_memcpy.cc + * \file tirx/analysis/identify_memcpy.cc * \brief Check if a loop nest is equivalent to memcpy */ @@ -26,9 +26,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include @@ -39,7 +39,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; std::variant IdentifyMemCpyImpl(const For& loop, arith::Analyzer* analyzer) { diff --git a/src/s_tir/analysis/is_pure_function.cc b/src/s_tir/analysis/is_pure_function.cc index 325a71e48f55..db970000cae1 100644 --- a/src/s_tir/analysis/is_pure_function.cc +++ b/src/s_tir/analysis/is_pure_function.cc @@ -23,14 +23,14 @@ */ #include #include -#include -#include +#include +#include -#include "../../tir/ir/tir_visitor_with_path.h" +#include "../../tirx/ir/tir_visitor_with_path.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; using AccessPath = ffi::reflection::AccessPath; diff --git a/src/s_tir/analysis/oob_checker.cc b/src/s_tir/analysis/oob_checker.cc index 18cb2418e497..1e82198fc4ef 100644 --- a/src/s_tir/analysis/oob_checker.cc +++ b/src/s_tir/analysis/oob_checker.cc @@ -22,14 +22,14 @@ */ #include -#include +#include #include "../../arith/ir_visitor_with_analyzer.h" #include "../schedule/error.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; namespace transform { struct OOBLocation { Buffer buf; @@ -112,7 +112,7 @@ class OOBCheckerVisitor final : public arith::IRVisitorWithAnalyzer { }; tvm::transform::Pass OOBChecker() { - auto pass_func = [=](tir::PrimFunc func, IRModule mod, tvm::transform::PassContext ctx) { + auto pass_func = [=](tirx::PrimFunc func, IRModule mod, tvm::transform::PassContext ctx) { OOBCheckerVisitor checker; checker(func->body); if (checker.errors.size() > 0) { @@ -123,7 +123,7 @@ tvm::transform::Pass OOBChecker() { } return func; }; - return tir::transform::CreatePrimFuncPass(pass_func, 0, "s_tir.analysis.OOBChecker", {}); + return tirx::transform::CreatePrimFuncPass(pass_func, 0, "s_tir.analysis.OOBChecker", {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/s_tir/analysis/sblock_access_region_detector.cc b/src/s_tir/analysis/sblock_access_region_detector.cc index aaba79827eb8..6a50bd1f2c5a 100644 --- a/src/s_tir/analysis/sblock_access_region_detector.cc +++ b/src/s_tir/analysis/sblock_access_region_detector.cc @@ -24,15 +24,15 @@ #include #include -#include -#include +#include +#include #include #include -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Detect which regions of tensors in this block are read or written to. Regions are sorted @@ -419,5 +419,5 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("s_tir.analysis.GetSBlockReadWriteRegion", GetSBlockReadWriteRegion); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/s_tir/analysis/sblock_buffer_access_lca_detector.cc b/src/s_tir/analysis/sblock_buffer_access_lca_detector.cc index 9726f58a738b..06a8b9ac49ea 100644 --- a/src/s_tir/analysis/sblock_buffer_access_lca_detector.cc +++ b/src/s_tir/analysis/sblock_buffer_access_lca_detector.cc @@ -23,14 +23,14 @@ */ #include -#include -#include +#include +#include #include "../../runtime/thread_storage_scope.h" #include "../../support/arena.h" namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Detect the lowest common ancestor(LCA) position of Buffer access. @@ -345,5 +345,5 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("s_tir.analysis.detect_buffer_access_lca", DetectBufferAccessLCA); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/s_tir/analysis/verify_gpu_code.cc b/src/s_tir/analysis/verify_gpu_code.cc index 64b56d7178fc..bd7b7c92ba7c 100644 --- a/src/s_tir/analysis/verify_gpu_code.cc +++ b/src/s_tir/analysis/verify_gpu_code.cc @@ -27,17 +27,17 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include "../../runtime/thread_storage_scope.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class GPUCodeVerifier : public StmtExprVisitor { public: @@ -96,7 +96,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { + if (op->attr_key == tirx::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { if (nest_level_ == 0) { // enter a new kernel, reset statistics Reset_(); diff --git a/src/s_tir/backend/adreno/inject_texture_alloc.cc b/src/s_tir/backend/adreno/inject_texture_alloc.cc index 374ff8007283..195168a35b31 100644 --- a/src/s_tir/backend/adreno/inject_texture_alloc.cc +++ b/src/s_tir/backend/adreno/inject_texture_alloc.cc @@ -23,18 +23,18 @@ #include #include -#include -#include +#include +#include #include "../../../arith/ir_mutator_with_analyzer.h" #include "../../../runtime/texture.h" -#include "../../../tir/transform/ir_utils.h" +#include "../../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { namespace backend { namespace adreno { -using namespace tvm::tir; +using namespace tvm::tirx; using runtime::ApplyTexture2DFlattening; using runtime::DefaultTextureLayoutSeparator; using runtime::IsTextureStorage; @@ -102,8 +102,8 @@ Pass InjectTextureAlloc() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { return TextureAllocInjector::Inject(std::move(f)); }; - return tir::transform::CreatePrimFuncPass(pass_func, 0, "s_tir.backend.adreno.InjectTextureAlloc", - {}); + return tirx::transform::CreatePrimFuncPass(pass_func, 0, + "s_tir.backend.adreno.InjectTextureAlloc", {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/s_tir/backend/adreno/texture_flatten.cc b/src/s_tir/backend/adreno/texture_flatten.cc index 26786d303b37..2d0322d91097 100644 --- a/src/s_tir/backend/adreno/texture_flatten.cc +++ b/src/s_tir/backend/adreno/texture_flatten.cc @@ -26,9 +26,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include @@ -40,7 +40,7 @@ namespace tvm { namespace s_tir { namespace backend { namespace adreno { -using namespace tvm::tir; +using namespace tvm::tirx; using arith::IRVisitorWithAnalyzer; using runtime::ApplyTexture2DFlattening; using runtime::DefaultTextureLayoutSeparator; @@ -146,8 +146,8 @@ class TextureFlattener : public TextureLoweringBase { PrimExpr row_offset = SimplifyOffset(row_dims, row_indices); PrimExpr col_offset = SimplifyOffset(col_dims, col_indices); PrimExpr depth_offset = SimplifyOffset(depth_dims, depth_indices); - PrimExpr channel_size = IntImm(DataType::Int(32, 1), - *tir::as_const_int(buffer->shape.back()) * buffer->dtype.bits()); + PrimExpr channel_size = IntImm( + DataType::Int(32, 1), *tirx::as_const_int(buffer->shape.back()) * buffer->dtype.bits()); args.push_back(row_offset); args.push_back(col_offset); args.push_back(depth_offset); @@ -173,8 +173,8 @@ Pass TextureFlatten() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { return TextureFlattenHandler(std::move(f)); }; - return tir::transform::CreatePrimFuncPass(pass_func, 0, "s_tir.backend.adreno.TextureFlatten", - {}); + return tirx::transform::CreatePrimFuncPass(pass_func, 0, "s_tir.backend.adreno.TextureFlatten", + {}); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/s_tir/data_layout.cc b/src/s_tir/data_layout.cc index fb64ee00ff24..bee4c2e31fce 100644 --- a/src/s_tir/data_layout.cc +++ b/src/s_tir/data_layout.cc @@ -29,19 +29,19 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include namespace tvm { -namespace tir { -using tir::IterVar; -using tir::IterVarNode; -using tir::Var; +namespace tirx { +using tirx::IterVar; +using tirx::IterVarNode; +using tirx::Var; TVM_FFI_STATIC_INIT_BLOCK() { LayoutNode::RegisterReflection(); @@ -133,7 +133,7 @@ Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*) TVM_FFI_ICHECK_EQ(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor << " before dimension " << c; IterVar axis(Range(IntImm(dtype, 0), Var(std::string(1, c), dtype)), - Var(std::string(1, c), dtype), tir::kDataPar); + Var(std::string(1, c), dtype), tirx::kDataPar); if (!in_packing) { node->axes.push_back(axis); } else { @@ -145,7 +145,7 @@ Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*) std::stringstream name; name << factor << c; IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)), Var(name.str(), dtype), - tir::kDataPar); + tirx::kDataPar); if (!in_packing) { node->axes.push_back(axis); } else { @@ -176,7 +176,7 @@ Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*) } std::string grouped_name = ss.str(); IterVar grouped_axis(Range(IntImm(dtype, 0), IntImm(dtype, extent)), Var(grouped_name, dtype), - tir::kDataPar); + tirx::kDataPar); node->axes.push_back(grouped_axis); in_packing = false; @@ -240,13 +240,13 @@ ffi::Array Layout::UnpackIterVar(IterVar packed_iter) { } else if (ch >= 'a' && ch <= 'z') { TVM_FFI_ICHECK(factor != 0) << "Invalid Factor Size"; result.push_back(IterVar(Range(IntImm(dtype, 0), IntImm(dtype, factor)), - Var(std::string(1, ch), dtype), tir::kDataPar)); + Var(std::string(1, ch), dtype), tirx::kDataPar)); final_factor *= factor; factor = 0; } else if (ch >= 'A' && ch <= 'Z') { TVM_FFI_ICHECK(factor == 0) << "Can't have non-zero factors for primal axis"; result.push_back(IterVar(Range(IntImm(dtype, 0), Var(std::string(1, ch), dtype)), - Var(std::string(1, ch), dtype), tir::kDataPar)); + Var(std::string(1, ch), dtype), tirx::kDataPar)); } } @@ -266,7 +266,7 @@ IterVar Layout::PackIterVar(ffi::Array iter_vars) { } return IterVar(Range(IntImm(dtype, 0), IntImm(dtype, extent)), Var(name.str(), dtype), - tir::kDataPar); + tirx::kDataPar); } int32_t Layout::FactorOf(const LayoutAxis& axis) const { @@ -444,12 +444,12 @@ inline ffi::Array TransformIndex(const ffi::Array& src_index const ffi::Array& transform_rule) { arith::Analyzer ana; ffi::Array result; - std::unordered_map bind_map; + std::unordered_map bind_map; for (size_t i = 0; i < src_index.size(); ++i) { bind_map[src_axis[i]->var.get()] = src_index[i]; } for (PrimExpr rule : transform_rule) { - result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); + result.push_back(ana.Simplify(tirx::Substitute(rule, bind_map))); } return result; } @@ -482,7 +482,7 @@ inline ffi::Array TransformShape(const ffi::Array& src_shape // for major-axis, bind the corresponding size // for minor-axis, simply bind it as 0, so that we can reuse forward/backward_rule, // e.g., (C * 16 + c) / 32 - std::unordered_map bind_map; + std::unordered_map bind_map; for (size_t i = 0; i < src_shape.size(); ++i) { PrimExpr orig_shape = src_shape[i]; IterVar orig_axis = src_axis[i]; @@ -516,7 +516,7 @@ inline ffi::Array TransformShape(const ffi::Array& src_shape if (layout.size() != 1 || !LayoutAxis::Get(layout[0]).IsPrimal()) { result.push_back(axis->dom->extent); } else { - result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); + result.push_back(ana.Simplify(tirx::Substitute(rule, bind_map))); } } @@ -602,5 +602,5 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_method("s_tir.BijectiveLayoutForwardShape", &BijectiveLayout::ForwardShape) .def_method("s_tir.BijectiveLayoutBackwardShape", &BijectiveLayout::BackwardShape); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/s_tir/meta_schedule/arg_info.cc b/src/s_tir/meta_schedule/arg_info.cc index bfc9c1ee527d..f873f0a8a377 100644 --- a/src/s_tir/meta_schedule/arg_info.cc +++ b/src/s_tir/meta_schedule/arg_info.cc @@ -27,22 +27,22 @@ namespace meta_schedule { /*! * \brief Find the entry function of the given IRModule, i.e, functions marked by - * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc. + * `tirx::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc. * \param mod The IRModule to find the entry function. * \return The entry function. */ -inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { - // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc` +inline tirx::PrimFunc FindEntryFunc(const IRModule& mod) { + // Priority 1: PrimFunc marked as `tirx::attr::kIsEntryFunc` int num_prim_func = 0; - const tir::PrimFuncNode* main_func = nullptr; - const tir::PrimFuncNode* last_func = nullptr; + const tirx::PrimFuncNode* main_func = nullptr; + const tirx::PrimFuncNode* last_func = nullptr; for (const auto& kv : mod->functions) { GlobalVar gv = kv.first; BaseFunc base_func = kv.second; - if (const auto* func = base_func.as()) { + if (const auto* func = base_func.as()) { last_func = func; - if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - return ffi::GetRef(func); + if (func->HasNonzeroAttr(tirx::attr::kIsEntryFunc)) { + return ffi::GetRef(func); } if (gv->name_hint == "main") { main_func = func; @@ -52,7 +52,7 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { } // Priority 2: PrimFunc whose name is `main` if (main_func != nullptr) { - return ffi::GetRef(main_func); + return ffi::GetRef(main_func); } // Priority 3: The only PrimFunc in the IRModule if (num_prim_func == 0) { @@ -60,10 +60,10 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { } if (num_prim_func > 1) { TVM_FFI_THROW(ValueError) << "Multiple PrimFuncs exist in the IRModule, but none of them are " - "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`" + "annotated with `kIsEntryFunc`, i.e. `tirx.is_entry_func`" << mod; } - return ffi::GetRef(last_func); + return ffi::GetRef(last_func); } /******** ArgInfo ********/ @@ -88,13 +88,13 @@ ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) { throw; } -ffi::Array ArgInfo::FromPrimFunc(const tir::PrimFunc& func) { +ffi::Array ArgInfo::FromPrimFunc(const tirx::PrimFunc& func) { using support::AsVector; ffi::Array result; result.reserve(func->params.size()); - for (const tir::Var& arg : func->params) { - if (ffi::Optional _buffer = func->buffer_map.Get(arg)) { - tir::Buffer buffer = _buffer.value(); + for (const tirx::Var& arg : func->params) { + if (ffi::Optional _buffer = func->buffer_map.Get(arg)) { + tirx::Buffer buffer = _buffer.value(); result.push_back(TensorInfo(/*dtype=*/buffer->dtype, /*shape=*/AsVector(buffer->shape))); } else { diff --git a/src/s_tir/meta_schedule/database/database_utils.cc b/src/s_tir/meta_schedule/database/database_utils.cc index 6ef1f254ae9f..2d31bf799b29 100644 --- a/src/s_tir/meta_schedule/database/database_utils.cc +++ b/src/s_tir/meta_schedule/database/database_utils.cc @@ -81,7 +81,7 @@ void JSONDumps(Any json_obj, std::ostringstream& os) { JSONDumps(kv.second, os); } os << "}"; - } else if (json_obj.as()) { + } else if (json_obj.as()) { JSONDumps(ffi::String(ffi::json::Stringify( ffi::ToJSONGraph(json_obj, ffi::json::Object{{"tvm_version", TVM_VERSION}}), /*indent=*/2)), diff --git a/src/s_tir/meta_schedule/extracted_task.cc b/src/s_tir/meta_schedule/extracted_task.cc index d1eaa4a9d69e..55b2a02e18de 100644 --- a/src/s_tir/meta_schedule/extracted_task.cc +++ b/src/s_tir/meta_schedule/extracted_task.cc @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include "../../te/operation/create_primfunc.h" #include "./utils.h" diff --git a/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc b/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc index 038d65217f50..82f1c2983184 100644 --- a/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/s_tir/meta_schedule/feature_extractor/per_store_feature.cc @@ -18,7 +18,7 @@ */ #include #include -#include +#include #include #include @@ -32,7 +32,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; using support::NDIntSet; @@ -87,7 +87,7 @@ std::vector GetBufferShape(const Buffer& buffer, arith::Analyzer* analy */ int64_t GetPragmaAutoUnroll(const ForNode* loop) { if (ffi::Optional auto_unroll = - GetAnn(loop, tir::attr::pragma_auto_unroll_max_step)) { + GetAnn(loop, tirx::attr::pragma_auto_unroll_max_step)) { return auto_unroll.value()->value; } return -1; @@ -302,7 +302,7 @@ Pass SimplifyForFeatureExtraction() { n->body = Simplifier::Run(std::move(n->body)); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.SimplifyForFeatureExtraction", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.SimplifyForFeatureExtraction", {}); } /*! @@ -318,11 +318,11 @@ tvm::transform::Sequential PassListForPerStoreFeature() { s_tir::transform::PlanAndUpdateBufferAllocationLocation(), s_tir::transform::ConvertBlocksToOpaque(), s_tir::transform::CompactBufferAllocation(), - tir::transform::Simplify(), + tirx::transform::Simplify(), s_tir::transform::LowerAutoCopy(), s_tir::transform::UnifyThreadBinding(), s_tir::transform::LowerMatchBuffer(), - tir::transform::Simplify(), + tirx::transform::Simplify(), }); } diff --git a/src/s_tir/meta_schedule/module_equality.cc b/src/s_tir/meta_schedule/module_equality.cc index fff1a88c3386..c7e2e8b55621 100644 --- a/src/s_tir/meta_schedule/module_equality.cc +++ b/src/s_tir/meta_schedule/module_equality.cc @@ -21,7 +21,7 @@ #include #include #include -#include +#include #include @@ -53,20 +53,20 @@ class ModuleEqualityIgnoreTensor : public ModuleEquality { // on the extracted anchor blocks. class ModuleEqualityAnchorBlock : public ModuleEquality { size_t Hash(IRModule mod) const { - auto anchor_block = tir::FindAnchorBlock(mod); + auto anchor_block = tirx::FindAnchorBlock(mod); if (anchor_block) { - return ffi::StructuralHash::Hash(ffi::GetRef(anchor_block), + return ffi::StructuralHash::Hash(ffi::GetRef(anchor_block), /*map_free_vars=*/false, /*skip_tensor_content=*/true); } return ModuleEqualityIgnoreTensor().Hash(mod); } bool Equal(IRModule lhs, IRModule rhs) const { - auto anchor_block_lhs = tir::FindAnchorBlock(lhs); - auto anchor_block_rhs = tir::FindAnchorBlock(rhs); + auto anchor_block_lhs = tirx::FindAnchorBlock(lhs); + auto anchor_block_rhs = tirx::FindAnchorBlock(rhs); if (anchor_block_lhs && anchor_block_rhs) { - return tvm::ffi::StructuralEqual::Equal(ffi::GetRef(anchor_block_lhs), - ffi::GetRef(anchor_block_rhs), + return tvm::ffi::StructuralEqual::Equal(ffi::GetRef(anchor_block_lhs), + ffi::GetRef(anchor_block_rhs), /*map_free_vars=*/false, /*skip_tensor_content=*/true); } diff --git a/src/s_tir/meta_schedule/module_equality.h b/src/s_tir/meta_schedule/module_equality.h index f3e887271f3f..f3c7a24b6247 100644 --- a/src/s_tir/meta_schedule/module_equality.h +++ b/src/s_tir/meta_schedule/module_equality.h @@ -47,7 +47,7 @@ class ModuleEquality { * - "anchor-block": Apply equality testing and hashing on the anchor block extracted from a * given module. The "ignore-tensor" varint is used for the extracted blocks * or in case no anchor block is found. - * For the definition of the anchor block, see tvm/tir/analysis.h. + * For the definition of the anchor block, see tvm/tirx/analysis.h. * \return An owning pointer to the created instance */ static std::unique_ptr Create(const std::string& mod_eq_name); diff --git a/src/s_tir/meta_schedule/mutator/mutate_compute_location.cc b/src/s_tir/meta_schedule/mutator/mutate_compute_location.cc index 86f3c5c1c1cc..507f49a0cef7 100644 --- a/src/s_tir/meta_schedule/mutator/mutate_compute_location.cc +++ b/src/s_tir/meta_schedule/mutator/mutate_compute_location.cc @@ -95,7 +95,7 @@ std::vector MutateComputeLocationNode::Fin if (inst->kind.same_as(inst_sample_compute_location)) { // Step 1. Extract the instruction input and the old decision. TVM_FFI_ICHECK_EQ(inputs.size(), 1); - tir::StmtSRef block_sref = sch->GetSRef(Downcast(inputs[0])); + tirx::StmtSRef block_sref = sch->GetSRef(Downcast(inputs[0])); int old_decision = Downcast(decision)->value; // Step 2. Collect all the compute_at locations. diff --git a/src/s_tir/meta_schedule/mutator/mutate_parallel.cc b/src/s_tir/meta_schedule/mutator/mutate_parallel.cc index 3365b7f64f3e..fee7ee8e961d 100644 --- a/src/s_tir/meta_schedule/mutator/mutate_parallel.cc +++ b/src/s_tir/meta_schedule/mutator/mutate_parallel.cc @@ -27,7 +27,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief Check if the instruction is annotation with `meta_schedule_parallel` diff --git a/src/s_tir/meta_schedule/mutator/mutate_unroll.cc b/src/s_tir/meta_schedule/mutator/mutate_unroll.cc index 13d80f322fe6..c7895312b6a9 100644 --- a/src/s_tir/meta_schedule/mutator/mutate_unroll.cc +++ b/src/s_tir/meta_schedule/mutator/mutate_unroll.cc @@ -23,7 +23,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief Check if an instruction is annotate with diff --git a/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc b/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc index 2222289d5f1f..3e740d747c9b 100644 --- a/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc +++ b/src/s_tir/meta_schedule/postproc/disallow_async_strided_mem_copy.cc @@ -24,7 +24,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! \brief Check if an IRModule has any async strided mem copies. */ struct AsyncStridedMemCopyFinder : private StmtExprVisitor { @@ -136,11 +136,11 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { for (const auto& kv : mod->functions) { const GlobalVar& g_var = kv.first; const BaseFunc& base_func = kv.second; - if (const auto* prim_func = base_func.as()) { + if (const auto* prim_func = base_func.as()) { IRModule lowered{ffi::UnsafeInit()}; try { auto pass_list = ffi::Array(); - pass_list.push_back(tir::transform::BindTarget(this->target)); + pass_list.push_back(tirx::transform::BindTarget(this->target)); pass_list.push_back(s_tir::transform::LowerInitBlock()); pass_list.push_back(s_tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(s_tir::transform::ConvertBlocksToOpaque()); @@ -148,16 +148,16 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode { pass_list.push_back(s_tir::transform::LowerMatchBuffer()); pass_list.push_back(s_tir::transform::InjectSoftwarePipeline()); pass_list.push_back(s_tir::transform::LowerOpaqueBlock()); - pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::BF16ComputeLegalize()); - pass_list.push_back(tir::transform::NarrowDataType(32)); - pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tirx::transform::FlattenBuffer()); + pass_list.push_back(tirx::transform::BF16ComputeLegalize()); + pass_list.push_back(tirx::transform::NarrowDataType(32)); + pass_list.push_back(tirx::transform::Simplify()); pass_list.push_back(s_tir::transform::InjectVirtualThread()); pass_list.push_back(s_tir::transform::InjectDoubleBuffer()); - pass_list.push_back(tir::transform::VectorizeLoop(true)); - pass_list.push_back(tir::transform::StorageRewrite()); - tir::PrimFunc f = WithAttr(ffi::GetRef(prim_func), "global_symbol", - ffi::String(g_var->name_hint)); + pass_list.push_back(tirx::transform::VectorizeLoop(true)); + pass_list.push_back(tirx::transform::StorageRewrite()); + tirx::PrimFunc f = WithAttr(ffi::GetRef(prim_func), "global_symbol", + ffi::String(g_var->name_hint)); IRModule mod = IRModule(ffi::Map({{GlobalVar(g_var->name_hint), f}})); lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); diff --git a/src/s_tir/meta_schedule/postproc/disallow_dynamic_loop.cc b/src/s_tir/meta_schedule/postproc/disallow_dynamic_loop.cc index 10a71b7f6602..48efc72ee1a7 100644 --- a/src/s_tir/meta_schedule/postproc/disallow_dynamic_loop.cc +++ b/src/s_tir/meta_schedule/postproc/disallow_dynamic_loop.cc @@ -22,7 +22,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! \brief Check if an IRModule has any dynamic loop. */ struct DynamicExtentFinder : private StmtVisitor { diff --git a/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc index 2b2242d4f061..9a52f6fc1980 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -23,7 +23,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief Parse instruction: sch.bind(..., axis) @@ -89,20 +89,20 @@ bool ParseWarpExecutionAnn(const Schedule& sch, const Instruction& inst) { size_t GetMaxUsedDtypeBytes(SBlock block) { size_t max_bytes = 1; - static auto q_multiply_shift_per_axis = Op::Get("tir.q_multiply_shift_per_axis"); - static auto q_multiply_shift = Op::Get("tir.q_multiply_shift"); + static auto q_multiply_shift_per_axis = Op::Get("tirx.q_multiply_shift_per_axis"); + static auto q_multiply_shift = Op::Get("tirx.q_multiply_shift"); - tir::PostOrderVisit(block->body, [&](const ObjectRef& obj) { - if (const auto* store = obj.as()) { + tirx::PostOrderVisit(block->body, [&](const ObjectRef& obj) { + if (const auto* store = obj.as()) { max_bytes = std::max(max_bytes, static_cast(store->value->dtype.bytes())); - } else if (const auto* load = obj.as()) { + } else if (const auto* load = obj.as()) { max_bytes = std::max(max_bytes, static_cast(load->dtype.bytes())); - } else if (const auto* call = obj.as()) { + } else if (const auto* call = obj.as()) { if (call->op.same_as(q_multiply_shift_per_axis) || call->op.same_as(q_multiply_shift)) { // q_multiply_shift uses 64 bit multiply max_bytes = std::max(max_bytes, 8); } - } else if (const auto* cast = obj.as()) { + } else if (const auto* cast = obj.as()) { max_bytes = std::max(max_bytes, cast->dtype.bytes()); } }); diff --git a/src/s_tir/meta_schedule/postproc/rewrite_layout.cc b/src/s_tir/meta_schedule/postproc/rewrite_layout.cc index 6f2fe7741cd2..701a80a085f9 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_layout.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_layout.cc @@ -26,7 +26,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief Collect the block and index where the buffer is read. diff --git a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index d095d9861242..ebaa58660e3a 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -23,7 +23,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief Check whether the loop has any annotation @@ -402,8 +402,8 @@ void RewriteUnroll(const Schedule& sch, int unroll_explicit, int max_step, const return; } - sch->Annotate(loop, tir::attr::pragma_auto_unroll_max_step, IntImm(DataType::Int(32), max_step)); - sch->Annotate(loop, tir::attr::pragma_unroll_explicit, + sch->Annotate(loop, tirx::attr::pragma_auto_unroll_max_step, IntImm(DataType::Int(32), max_step)); + sch->Annotate(loop, tirx::attr::pragma_unroll_explicit, IntImm(DataType::Int(32), unroll_explicit)); } diff --git a/src/s_tir/meta_schedule/postproc/rewrite_reduction_block.cc b/src/s_tir/meta_schedule/postproc/rewrite_reduction_block.cc index 67e15602c22e..5996fb7bf2d2 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_reduction_block.cc @@ -23,7 +23,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! \brief The visitor that finds all the reduction block to be decomposed */ struct ReductionBlockFinder : private StmtVisitor { @@ -72,7 +72,7 @@ struct ReductionBlockFinder : private StmtVisitor { for (int i = 0; i < n; ++i) { IterVar iter_var = block->iter_vars[i]; PrimExpr binding = realize->iter_values[i]; - if (iter_var->iter_type == tir::kCommReduce) { + if (iter_var->iter_type == tirx::kCommReduce) { if (UsesVar(binding, f_find)) { return false; } @@ -135,11 +135,11 @@ class RewriteReductionBlockNode : public PostprocNode { bool RewriteReductionBlockNode::Apply(const s_tir::Schedule& sch) { for (;;) { - std::vector> results = + std::vector> results = s_tir::ReductionBlockFinder::Find(sch->state()); int rewritten = 0; for (const auto& kv : results) { - const tir::StmtSRef& block_sref = kv.first; + const tirx::StmtSRef& block_sref = kv.first; const ffi::String& global_var_name = kv.second; int decompose_point = s_tir::FindDecomposePoint(block_sref); if (decompose_point == -1) { diff --git a/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc b/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc index 0090b3a95b4a..78ee945cdfd6 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc @@ -32,13 +32,13 @@ using s_tir::LoopRV; using s_tir::SBlockRV; void CollectTensorizationJobs( - const s_tir::Schedule& sch, const ffi::String& func_name, const tir::PrimFuncNode* func, + const s_tir::Schedule& sch, const ffi::String& func_name, const tirx::PrimFuncNode* func, bool vectorize_init_loop, std::vector>>* jobs) { - tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) { - if (const auto* block = obj.as()) { - tir::StmtSRef block_sref = sch->GetSRef(block); - std::string block_name = block_sref->StmtAs()->name_hint; + tirx::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) { + if (const auto* block = obj.as()) { + tirx::StmtSRef block_sref = sch->GetSRef(block); + std::string block_name = block_sref->StmtAs()->name_hint; if (ffi::Optional intrin_name = s_tir::GetAnn(block_sref, s_tir::attr::meta_schedule_auto_tensorize)) { if (intrin_name.value() != "") { @@ -91,7 +91,7 @@ bool RewriteTensorizeNode::Apply(const s_tir::Schedule& sch) { for (const auto& kv : sch->mod()->functions) { GlobalVar g_var = kv.first; BaseFunc base_func = kv.second; - if (const tir::PrimFuncNode* prim_func = base_func.as()) { + if (const tirx::PrimFuncNode* prim_func = base_func.as()) { CollectTensorizationJobs(sch, g_var->name_hint, prim_func, vectorize_init_loop, &jobs); } } diff --git a/src/s_tir/meta_schedule/postproc/rewrite_unbound_block.cc b/src/s_tir/meta_schedule/postproc/rewrite_unbound_block.cc index cf29cb503d98..59bcb0d05066 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_unbound_block.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_unbound_block.cc @@ -23,7 +23,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! \brief Find all the blocks that are not bound */ class UnboundBlockFinder : private StmtVisitor { @@ -129,10 +129,10 @@ bool RewriteUnboundBlockNode::Apply(const s_tir::Schedule& sch) { auto get_factor = [t = this->max_threads_per_block_](int max_extent) -> ExprRV { return Integer(std::min(t, max_extent)); }; - std::vector> unbound_blocks = + std::vector> unbound_blocks = s_tir::UnboundBlockFinder::Find(sch->state()); for (const auto& kv : unbound_blocks) { - tir::StmtSRef block_sref = kv.first; + tirx::StmtSRef block_sref = kv.first; ffi::String global_var_name = kv.second; SBlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name); BindBlockThreadIdx(sch, block_rv, max_threadblocks_, max_threads_per_block_, get_factor); diff --git a/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc b/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc index f05c28fca084..bed996f7f0db 100644 --- a/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/s_tir/meta_schedule/postproc/verify_gpu_code.cc @@ -19,13 +19,13 @@ #include #include #include -#include +#include #include "../utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class ThreadExtentChecker : private StmtVisitor { public: @@ -136,7 +136,7 @@ class VerifyGPUCodeNode : public PostprocNode { bool Verify(const IRModule& mod) const { for (const auto& kv : mod->functions) { - if (auto prim_func = kv.second.as()) { + if (auto prim_func = kv.second.as()) { if (!s_tir::VerifyGPUCode(prim_func.value(), this->target_constraints_)) { return false; } @@ -150,7 +150,7 @@ class VerifyGPUCodeNode : public PostprocNode { for (const auto& kv : mod->functions) { const GlobalVar& g_var = kv.first; const BaseFunc& base_func = kv.second; - if (const auto* prim_func = base_func.as()) { + if (const auto* prim_func = base_func.as()) { if (!s_tir::ThreadExtentChecker::Check(prim_func->body, thread_warp_size_)) { return false; } @@ -165,31 +165,31 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(s_tir::transform::LiftThreadBinding()); pass_list.push_back(s_tir::transform::ManifestSharedMemoryLocalStage()); pass_list.push_back(s_tir::transform::CompactBufferAllocation()); - pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tirx::transform::Simplify()); pass_list.push_back(s_tir::transform::LowerAutoCopy()); pass_list.push_back(s_tir::transform::UnifyThreadBinding()); pass_list.push_back(s_tir::transform::LowerMatchBuffer()); pass_list.push_back(s_tir::transform::InjectSoftwarePipeline()); pass_list.push_back(s_tir::transform::LowerOpaqueBlock()); - pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::BF16ComputeLegalize()); - pass_list.push_back(tir::transform::NarrowDataType(32)); - pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tirx::transform::FlattenBuffer()); + pass_list.push_back(tirx::transform::BF16ComputeLegalize()); + pass_list.push_back(tirx::transform::NarrowDataType(32)); + pass_list.push_back(tirx::transform::Simplify()); // Phase 2 - pass_list.push_back(tir::transform::VectorizeLoop(true)); + pass_list.push_back(tirx::transform::VectorizeLoop(true)); pass_list.push_back(s_tir::transform::InjectVirtualThread()); pass_list.push_back(s_tir::transform::InjectDoubleBuffer()); - pass_list.push_back(tir::transform::StorageRewrite()); + pass_list.push_back(tirx::transform::StorageRewrite()); pass_list.push_back(s_tir::transform::MergeSharedMemoryAllocations()); - pass_list.push_back(tir::transform::LowerIntrin()); + pass_list.push_back(tirx::transform::LowerIntrin()); // Convert Function to IRModule tvm::transform::PassContext pass_ctx = tvm::transform::PassContext::Current(); - tir::PrimFunc f = WithAttr(ffi::GetRef(prim_func), "global_symbol", - ffi::String(g_var->name_hint)); + tirx::PrimFunc f = WithAttr(ffi::GetRef(prim_func), "global_symbol", + ffi::String(g_var->name_hint)); f = WithAttr(f, tvm::attr::kTarget, this->target_); // Required for LowerIntrin - bool noalias = pass_ctx->GetConfig("tir.noalias", true).value(); + bool noalias = pass_ctx->GetConfig("tirx.noalias", true).value(); if (noalias) { - f = WithAttr(std::move(f), "tir.noalias", true); + f = WithAttr(std::move(f), "tirx.noalias", true); } IRModule mod = IRModule(ffi::Map({{GlobalVar(g_var->name_hint), f}})); diff --git a/src/s_tir/meta_schedule/schedule/cpu/winograd.cc b/src/s_tir/meta_schedule/schedule/cpu/winograd.cc index 6c2839877e94..9bb8f83529e5 100644 --- a/src/s_tir/meta_schedule/schedule/cpu/winograd.cc +++ b/src/s_tir/meta_schedule/schedule/cpu/winograd.cc @@ -25,7 +25,7 @@ namespace tvm { namespace s_tir { namespace meta_schedule { -using namespace tvm::tir; +using namespace tvm::tirx; using s_tir::ExprRV; using s_tir::LoopRV; using s_tir::SBlockRV; @@ -34,7 +34,7 @@ using s_tir::Schedule; static ffi::Array ScheduleDataPack(s_tir::Schedule sch, s_tir::SBlockRV block, std::vector tiled, std::vector unrolled) { - using namespace tvm::tir; + using namespace tvm::tirx; TVM_FFI_ICHECK_EQ(tiled.size(), 2); TVM_FFI_ICHECK_EQ(unrolled.size(), 4); ffi::Array factors{ffi::UnsafeInit()}; diff --git a/src/s_tir/meta_schedule/schedule/cuda/thread_bind.cc b/src/s_tir/meta_schedule/schedule/cuda/thread_bind.cc index 5c661cc8ad28..a8ddcdf92880 100644 --- a/src/s_tir/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/s_tir/meta_schedule/schedule/cuda/thread_bind.cc @@ -18,7 +18,7 @@ */ #include #include -#include +#include #include #include @@ -30,7 +30,7 @@ namespace tvm { namespace s_tir { namespace meta_schedule { -using namespace tvm::tir; +using namespace tvm::tirx; using s_tir::ExprRV; using s_tir::GetLoopIterType; using s_tir::GetLoops; diff --git a/src/s_tir/meta_schedule/schedule/cuda/winograd.cc b/src/s_tir/meta_schedule/schedule/cuda/winograd.cc index 8ac211b338ed..5a75e000d6d9 100644 --- a/src/s_tir/meta_schedule/schedule/cuda/winograd.cc +++ b/src/s_tir/meta_schedule/schedule/cuda/winograd.cc @@ -28,7 +28,7 @@ namespace tvm { namespace s_tir { namespace meta_schedule { -using namespace tvm::tir; +using namespace tvm::tirx; using s_tir::ExprRV; using s_tir::LoopRV; using s_tir::SBlockRV; @@ -38,7 +38,7 @@ static ffi::Array ScheduleDataPack(s_tir::Schedule sch, s_tir::SB std::vector tiled, std::vector unrolled) { // This method is used for NHWC layout only. Will likely be refactored into a more schedule - using namespace tvm::tir; + using namespace tvm::tirx; TVM_FFI_ICHECK_EQ(tiled.size(), 2); TVM_FFI_ICHECK_EQ(unrolled.size(), 4); ffi::Array factors{ffi::UnsafeInit()}; diff --git a/src/s_tir/meta_schedule/schedule/generic/winograd.cc b/src/s_tir/meta_schedule/schedule/generic/winograd.cc index 4a5e25dbac11..27e509d4b6b7 100644 --- a/src/s_tir/meta_schedule/schedule/generic/winograd.cc +++ b/src/s_tir/meta_schedule/schedule/generic/winograd.cc @@ -22,7 +22,7 @@ namespace tvm { namespace s_tir { namespace meta_schedule { -using namespace tvm::tir; +using namespace tvm::tirx; using s_tir::ExprRV; using s_tir::LoopRV; using s_tir::SBlockRV; diff --git a/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc b/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc index d618661ea072..e66e6400c87c 100644 --- a/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/s_tir/meta_schedule/schedule_rule/add_rfactor.cc @@ -82,7 +82,7 @@ ScheduleRule ScheduleRule::AddRFactor(int max_jobs_per_core, ffi::Array AddRFactorNode::Apply(const s_tir::Schedule& sch, const s_tir::SBlockRV& block_rv) { - tir::StmtSRef block_sref = sch->GetSRef(block_rv); + tirx::StmtSRef block_sref = sch->GetSRef(block_rv); if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_parallel_extent_, max_parallel_basic_)) { return {sch}; diff --git a/src/s_tir/meta_schedule/schedule_rule/auto_inline.cc b/src/s_tir/meta_schedule/schedule_rule/auto_inline.cc index 3096da373fba..5606301eb5b7 100644 --- a/src/s_tir/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/s_tir/meta_schedule/schedule_rule/auto_inline.cc @@ -48,8 +48,8 @@ enum class InlineType : int32_t { kInlineIntoProducer = 2, }; -bool IsInSpatialPrimFunc(const s_tir::Schedule& sch, const tir::StmtSRef& block_sref) { - using namespace tvm::tir; +bool IsInSpatialPrimFunc(const s_tir::Schedule& sch, const tirx::StmtSRef& block_sref) { + using namespace tvm::tirx; const StmtSRefNode* sref = block_sref.get(); for (; sref->parent != nullptr; sref = sref->parent) { } @@ -117,7 +117,7 @@ class AutoInlineNode : public ScheduleRuleNode { inline InlineType AutoInlineNode::CheckInline(const s_tir::Schedule& sch, const s_tir::SBlockRV& block_rv) { - using namespace tvm::tir; + using namespace tvm::tirx; StmtSRef block_sref = sch->GetSRef(block_rv); bool is_pure_sptial = IsInSpatialPrimFunc(sch, block_sref); ScheduleState state = sch->state(); @@ -129,7 +129,7 @@ inline InlineType AutoInlineNode::CheckInline(const s_tir::Schedule& sch, } // Cond 2. For a block that generates a constant tensor, ignore all other conditions if (inline_const_tensor && block->reads.empty()) { - ffi::Array consumer_srefs = GetConsumers(state, block_sref); + ffi::Array consumer_srefs = GetConsumers(state, block_sref); if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) { return InlineType::kInlineIntoConsumer; } @@ -164,16 +164,16 @@ inline InlineType AutoInlineNode::CheckInline(const s_tir::Schedule& sch, if (ann.value() == "disable") return InlineType::kNoInline; } // Last cond: Check inline into the consumers or the spatial producer - tir::StmtSRef scope_block = s_tir::GetScopeRoot(sch->state(), block_sref, - /*require_stage_pipeline=*/false); + tirx::StmtSRef scope_block = s_tir::GetScopeRoot(sch->state(), block_sref, + /*require_stage_pipeline=*/false); if (into_consumer) { - ffi::Array consumer_srefs = GetConsumers(state, block_sref); + ffi::Array consumer_srefs = GetConsumers(state, block_sref); if (!consumer_srefs.empty() && CanComputeInline(state, block_sref)) { return InlineType::kInlineIntoConsumer; } } if (into_producer) { - ffi::Array producer_srefs = GetProducers(state, block_sref); + ffi::Array producer_srefs = GetProducers(state, block_sref); if (producer_srefs.size() == 1 && s_tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) && CanReverseComputeInline(state, block_sref) && diff --git a/src/s_tir/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/s_tir/meta_schedule/schedule_rule/cross_thread_reduction.cc index 12e83f6c078b..873ebe606b4f 100644 --- a/src/s_tir/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/s_tir/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -56,7 +56,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { if (max_threads_per_block == -1 || warp_size == -1) { return {sch}; } - const tir::StmtSRef& block_sref = sch->GetSRef(block_rv); + const tirx::StmtSRef& block_sref = sch->GetSRef(block_rv); if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_threads_per_block, warp_size)) { return {sch}; @@ -136,7 +136,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { bool InThreadScope(const s_tir::Schedule& sch, const s_tir::SBlockRV& block) { const ffi::Array& axes = sch->GetLoops(block); for (const s_tir::LoopRV& loop_rv : axes) { - const tir::For& loop = sch->Get(loop_rv); + const tirx::For& loop = sch->Get(loop_rv); runtime::ThreadScope thread_scope = s_tir::GetThreadScope(loop.get()); if (s_tir::IsThreadIdx(thread_scope)) { return true; @@ -222,9 +222,9 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // - if there are multiple consumers, they must not share a common loop, and the case is not // fusible; // - If the lowest common ancestor is a loop, the target block is also the first consumer. - const tir::StmtSRef& lca_sref = + const tirx::StmtSRef& lca_sref = s_tir::GetSRefLowestCommonAncestor(s_tir::SBlockRVs2StmtSRefs(sch, consumers)); - if (consumers.size() > 1 && lca_sref->StmtAs() != nullptr) { + if (consumers.size() > 1 && lca_sref->StmtAs() != nullptr) { return std::make_tuple(false, s_tir::LoopRV{ffi::UnsafeInit()}, s_tir::SBlockRV{ffi::UnsafeInit()}, s_tir::LoopRV{ffi::UnsafeInit()}); } @@ -255,12 +255,12 @@ class CrossThreadReductionNode : public ScheduleRuleNode { */ int GetComputePosition(const s_tir::Schedule& sch, const ffi::Array& block_loops, const ffi::Array& tgt_block_loops, - const tir::StmtSRef& lca_sref) { + const tirx::StmtSRef& lca_sref) { int n_block_loop = static_cast(block_loops.size()); int n_tgt_block_loop = static_cast(tgt_block_loops.size()); for (int i = 0; i < n_block_loop && i < n_tgt_block_loop; ++i) { - if (s_tir::GetLoopIterType(sch->GetSRef(block_loops[i])) != tir::IterVarType::kDataPar) { + if (s_tir::GetLoopIterType(sch->GetSRef(block_loops[i])) != tirx::IterVarType::kDataPar) { return i - 1; } else if (sch->GetSRef(tgt_block_loops[i]).same_as(lca_sref)) { // If the lowest common ancestor is a loop, the compute location of the input block should diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc index 613fcfb43cc8..a281610ecc9b 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -30,7 +30,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; std::vector GetReadBufferNDims(const StmtSRef& block_sref) { const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); @@ -58,7 +58,7 @@ using s_tir::IsWriteCache; using s_tir::LoopRV; using s_tir::SBlockRV; using s_tir::Schedule; -using tir::IterVarType; +using tirx::IterVarType; TVM_FFI_STATIC_INIT_BLOCK() { MultiLevelTilingNode::RegisterReflection(); } @@ -335,14 +335,14 @@ std::vector MultiLevelTilingNode::AddAsyncPipeline(State state) const { // @see src/meta_schedule/schedule_rule/schedule_rule.cc // check the reduce loop contains exactly 3 for loops // therefore it matches the notation array size in the following code - tir::StmtSRef r_loop_sref = state->sch->GetSRef(state->tiles[r_indices_[0]].back()); - const tir::ForNode* r_for_loop = TVM_SREF_TO_FOR(r_loop_sref); - ffi::Array seq = Downcast(r_for_loop->body)->seq; + tirx::StmtSRef r_loop_sref = state->sch->GetSRef(state->tiles[r_indices_[0]].back()); + const tirx::ForNode* r_for_loop = TVM_SREF_TO_FOR(r_loop_sref); + ffi::Array seq = Downcast(r_for_loop->body)->seq; if (seq.size() != 3) { return {state}; } for (auto& stmt : seq) { - if (!stmt.as()) { + if (!stmt.as()) { return {state}; } } @@ -366,7 +366,7 @@ std::vector MultiLevelTilingNode::AddAsyncPipeline(State state) const { void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, const s_tir::SBlockRV& block) const { // Filter out invalid vector lanes according to the data type. - const tir::SBlockNode* block_node = (*sch)->GetSRef(block)->StmtAs(); + const tirx::SBlockNode* block_node = (*sch)->GetSRef(block)->StmtAs(); TVM_FFI_ICHECK_EQ(block_node->writes.size(), 1); const runtime::DataType dtype = block_node->writes[0]->buffer->dtype; std::function f_filter = nullptr; diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.h b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.h index adfd3ebdb9ac..0b52cf2dfbbb 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.h @@ -31,7 +31,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief Get the buffer dimensions for all the read buffers of a block, but marks the reduction * buffers' dimensions as -1 diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 37b6d14f2a85..b514eca63608 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include @@ -36,7 +36,7 @@ using s_tir::GetSBlockVarTypes; using s_tir::LoopRV; using s_tir::SBlockRV; using s_tir::Schedule; -using tir::IterVarType; +using tirx::IterVarType; struct TensorCoreIntrinGroup { ffi::String init_intrin; @@ -64,7 +64,7 @@ TensorCoreIntrinGroup TensorCoreIntrinGroup::FromConfig( TVM_FFI_CHECK(config.count(key_name), ValueError) << key_name << " is not set."; *intrin_name = config.at(key_name); // Check the existence of the intrin - tir::TensorIntrin::Get(*intrin_name); + tirx::TensorIntrin::Get(*intrin_name); }; TensorCoreIntrinGroup intrin_group; f_initialize_intrin("init", &intrin_group.init_intrin); @@ -225,7 +225,7 @@ ffi::Array MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch, ffi::Optional mapping_info = s_tir::GetAutoTensorizeMappingInfo( sch->state(), sch->GetSRef(block_rv), - tir::TensorIntrin::Get(intrin_groups[i].compute_intrin).value()->desc); + tirx::TensorIntrin::Get(intrin_groups[i].compute_intrin).value()->desc); if (mapping_info.defined()) { intrin_group_to_mapping_info.emplace(i, mapping_info.value()); } @@ -444,12 +444,12 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa // Get the shape of the wmma accumulator auto [frag_shape_m, frag_shape_n] = [&]() { - tir::SBlock intrin_block = - Downcast( - tir::TensorIntrin::Get(state->intrin_group.init_intrin).value()->desc->body) + tirx::SBlock intrin_block = + Downcast( + tirx::TensorIntrin::Get(state->intrin_group.init_intrin).value()->desc->body) ->block; - tir::For loop_m = Downcast(intrin_block->body); - tir::For loop_n = Downcast(loop_m->body); + tirx::For loop_m = Downcast(intrin_block->body); + tirx::For loop_n = Downcast(loop_m->body); return std::make_tuple(loop_m->extent, loop_n->extent); }(); @@ -494,31 +494,31 @@ std::vector MultiLevelTilingTensorCoreNode::TransformIntermediateOutputLa int num_higher_dims = buffer_ndim - 2; auto index_map = - tir::IndexMap::FromFunc(buffer_ndim, - // frag_shape_m and frag_shape_n are structural bindings that cannot - // not be automatically captured until c++20 - [&, frag_shape_m = frag_shape_m, - frag_shape_n = frag_shape_n](const ffi::Array& indices) { - ffi::Array result; - result.reserve(indices.size() + 4); - for (int i = 0; i < num_higher_dims; ++i) { - result.push_back(indices[i]); - } - const auto& m = indices[num_higher_dims]; - const auto& n = indices[num_higher_dims + 1]; - auto accum_m = floormod(m, frag_shape_m); - auto accum_n = floormod(n, frag_shape_n); - auto outer_m = floordiv(m, frag_shape_m); - auto outer_n = floordiv(n, frag_shape_n); - - result.push_back(floordiv(outer_m, warp_num_frag_m)); - result.push_back(floordiv(outer_n, warp_num_frag_n)); - result.push_back(floormod(outer_m, warp_num_frag_m)); - result.push_back(floormod(outer_n, warp_num_frag_n)); - result.push_back(accum_m); - result.push_back(accum_n); - return result; - }); + tirx::IndexMap::FromFunc(buffer_ndim, + // frag_shape_m and frag_shape_n are structural bindings that cannot + // not be automatically captured until c++20 + [&, frag_shape_m = frag_shape_m, + frag_shape_n = frag_shape_n](const ffi::Array& indices) { + ffi::Array result; + result.reserve(indices.size() + 4); + for (int i = 0; i < num_higher_dims; ++i) { + result.push_back(indices[i]); + } + const auto& m = indices[num_higher_dims]; + const auto& n = indices[num_higher_dims + 1]; + auto accum_m = floormod(m, frag_shape_m); + auto accum_n = floormod(n, frag_shape_n); + auto outer_m = floordiv(m, frag_shape_m); + auto outer_n = floordiv(n, frag_shape_n); + + result.push_back(floordiv(outer_m, warp_num_frag_m)); + result.push_back(floordiv(outer_n, warp_num_frag_n)); + result.push_back(floormod(outer_m, warp_num_frag_m)); + result.push_back(floormod(outer_n, warp_num_frag_n)); + result.push_back(accum_m); + result.push_back(accum_n); + return result; + }); sch->TransformLayout(state->block_rv, 0, s_tir::BufferIndexType::kWrite, index_map, /*pad_value=*/std::nullopt, /*assume_injective_transform=*/true); @@ -621,9 +621,9 @@ std::vector MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore( const s_tir::SBlockRV cache_read = state->read_reuse.at(i); // Inline the reindex / padding block sch->ComputeInline(sch->GetProducers(cache_read)[0]); - const tir::SBlockNode* cache_read_block = sch->GetSRef(cache_read)->StmtAs(); - tir::Buffer cache_read_buffer = - s_tir::GetNthAccessBuffer(sch->state(), ffi::GetRef(cache_read_block), 0, + const tirx::SBlockNode* cache_read_block = sch->GetSRef(cache_read)->StmtAs(); + tirx::Buffer cache_read_buffer = + s_tir::GetNthAccessBuffer(sch->state(), ffi::GetRef(cache_read_block), 0, s_tir::BufferIndexType::kWrite); const DataType& dtype = cache_read_buffer->dtype; if (dtype.is_float16()) { @@ -771,12 +771,12 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( TensorCoreStateNode* state, const ffi::String& intrin_name) const { SBlockRV block_rv = state->block_rv; const s_tir::AutoTensorizeMappingInfo& mapping_info = state->mapping_info; - tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv); + tirx::StmtSRef block_sref = state->sch->GetSRef(state->block_rv); // Add reindex stages - const tir::SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + const tirx::SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); // Hold the reference of the block before reindex - const tir::SBlock block_before_reindex = ffi::GetRef(block); + const tirx::SBlock block_before_reindex = ffi::GetRef(block); if (block->reads.size() != 2 || block->writes.size() != 1) { // only matmul-like computation is allowed return std::nullopt; @@ -792,12 +792,12 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( // The index map defines the mapping for the computation block. We need to extract the sub index // map to transform the load and store block. TVM_FFI_ICHECK_EQ(mapping_info->mappings.size(), 1U); // assume only one mapping is present - const tir::IndexMap& index_map = mapping_info->mappings[0]; + const tirx::IndexMap& index_map = mapping_info->mappings[0]; // Find the correspondence between block iters and the iters in the index map. - std::unordered_map lhs_to_index_map_src; - std::unordered_map rhs_to_index_map_tgt; - std::unordered_set unmapped_index_map_src; + std::unordered_map lhs_to_index_map_src; + std::unordered_map rhs_to_index_map_tgt; + std::unordered_set unmapped_index_map_src; TVM_FFI_ICHECK_EQ(mapping_info->lhs_iters.size(), index_map->initial_indices.size()); for (int i = 0; i < static_cast(mapping_info->lhs_iters.size()); ++i) { lhs_to_index_map_src[mapping_info->lhs_iters[i]->var] = index_map->initial_indices[i]; @@ -811,43 +811,44 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( static_cast(mapping_info->rhs_iters.size()); TVM_FFI_ICHECK_GE(offset, 0); for (int i = 0; i < offset; ++i) { - const tir::VarNode* var_ptr = index_map->final_indices[i].as(); + const tirx::VarNode* var_ptr = index_map->final_indices[i].as(); TVM_FFI_ICHECK(var_ptr != nullptr); - unmapped_index_map_src.insert(ffi::GetRef(var_ptr)); + unmapped_index_map_src.insert(ffi::GetRef(var_ptr)); } for (int i = offset; i < static_cast(index_map->final_indices.size()); ++i) { rhs_to_index_map_tgt[mapping_info->rhs_iters[i - offset]->var] = index_map->final_indices[i]; } - auto f_get_sub_index_map = [&](const tir::Buffer& lhs_buffer, const tir::Region& lhs_region) { - std::vector sub_index_map_src; + auto f_get_sub_index_map = [&](const tirx::Buffer& lhs_buffer, const tirx::Region& lhs_region) { + std::vector sub_index_map_src; std::vector sub_index_map_tgt; - const tir::Buffer& rhs_buffer = mapping_info->lhs_buffer_map[lhs_buffer]; + const tirx::Buffer& rhs_buffer = mapping_info->lhs_buffer_map[lhs_buffer]; for (const Range& range : lhs_region) { - TVM_FFI_ICHECK(tir::is_one(range->extent)); - const tir::VarNode* var_ptr = range->min.as(); + TVM_FFI_ICHECK(tirx::is_one(range->extent)); + const tirx::VarNode* var_ptr = range->min.as(); TVM_FFI_ICHECK(var_ptr != nullptr); - const tir::Var& lhs_representer = lhs_to_index_map_src[ffi::GetRef(var_ptr)]; + const tirx::Var& lhs_representer = lhs_to_index_map_src[ffi::GetRef(var_ptr)]; sub_index_map_src.push_back(lhs_representer); if (unmapped_index_map_src.count(lhs_representer)) { sub_index_map_tgt.push_back(lhs_representer); } } for (size_t i = 0; i < mapping_info->rhs_buffer_indices[rhs_buffer].size(); ++i) { - const tir::VarNode* var = mapping_info->rhs_buffer_indices[rhs_buffer][i].as(); + const tirx::VarNode* var = + mapping_info->rhs_buffer_indices[rhs_buffer][i].as(); TVM_FFI_ICHECK(var != nullptr); - sub_index_map_tgt.push_back(rhs_to_index_map_tgt[ffi::GetRef(var)]); + sub_index_map_tgt.push_back(rhs_to_index_map_tgt[ffi::GetRef(var)]); } - return tir::IndexMap(sub_index_map_src, sub_index_map_tgt); + return tirx::IndexMap(sub_index_map_src, sub_index_map_tgt); }; - std::unordered_set visited_buffers; + std::unordered_set visited_buffers; - ffi::Map buffer_sub_index_map; // cache of the sub index map - // associated with each buffer + ffi::Map buffer_sub_index_map; // cache of the sub index map + // associated with each buffer auto f_transform_buffer_layout = [&](s_tir::BufferIndexType index_type, int buffer_index) { - const tir::Buffer& lhs_buffer = s_tir::GetNthAccessBuffer( + const tirx::Buffer& lhs_buffer = s_tir::GetNthAccessBuffer( state->sch->state(), block_before_reindex, buffer_index, index_type); if (visited_buffers.count(lhs_buffer)) { return; @@ -855,8 +856,8 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( visited_buffers.insert(lhs_buffer); // Refresh block pointer (block sref is not invalidated) block = TVM_SREF_TO_SBLOCK(block_sref); - const tir::BufferRegion& reindexed_buffer_region = s_tir::GetNthAccessBufferRegion( - state->sch->state(), ffi::GetRef(block), buffer_index, index_type); + const tirx::BufferRegion& reindexed_buffer_region = s_tir::GetNthAccessBufferRegion( + state->sch->state(), ffi::GetRef(block), buffer_index, index_type); auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); buffer_sub_index_map.Set(lhs_buffer, sub_index_map); state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, @@ -873,7 +874,7 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( // Transform the layout of current block and reindex blocks auto f_transform_reindex_block_layout = [&](const SBlockRV& block_rv, s_tir::BufferIndexType buffer_type) { - tir::Buffer buffer = + tirx::Buffer buffer = s_tir::GetNthAccessBuffer(state->sch->state(), state->sch->Get(block_rv), 0, buffer_type); const auto& sub_index_map = buffer_sub_index_map.at(buffer); state->sch->TransformBlockLayout(block_rv, sub_index_map); diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc index 5acc528e7d40..5d67c19a4d56 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -65,10 +65,10 @@ class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode { std::pair, ffi::Array> MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, SBlockRV block_rv, LoopRV loop_rv, int n_tiles) const { - const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv)); - const tir::StmtSRef block_sref = sch->GetSRef(block_rv); - const tir::SBlockNode* block_node = block_sref->StmtAs(); - const tir::SBlockRealize block_realize = s_tir::GetSBlockRealize(sch->state(), block_sref); + const tirx::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv)); + const tirx::StmtSRef block_sref = sch->GetSRef(block_rv); + const tirx::SBlockNode* block_node = block_sref->StmtAs(); + const tirx::SBlockRealize block_realize = s_tir::GetSBlockRealize(sch->state(), block_sref); TVM_FFI_ICHECK(block_node && block_node->writes.size() == 1); const auto out_dtype = block_node->writes[0]->buffer->dtype; diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc index a790a8fa0bed..5a33056d01ff 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -52,7 +52,7 @@ class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { protected: ffi::Array Apply(const s_tir::Schedule& sch, const s_tir::SBlockRV& block_rv) final { - auto desc_func = tir::TensorIntrin::Get(intrin_name).value()->desc; + auto desc_func = tirx::TensorIntrin::Get(intrin_name).value()->desc; if (!CheckAutoTensorizeApplicable(sch, block_rv, desc_func)) { TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized."; return {sch}; @@ -106,7 +106,7 @@ ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin( ffi::Optional> vector_load_lens, ffi::Optional> reuse_read, ffi::Optional> reuse_write) { - TVM_FFI_ICHECK(tir::TensorIntrin::Get(intrin_name).defined()) + TVM_FFI_ICHECK(tirx::TensorIntrin::Get(intrin_name).defined()) << "Provided tensor intrinsic " << intrin_name << " is not registered."; auto node = MultiLevelTilingInitCommon( structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); diff --git a/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 6d7fd3063f59..8a21ddee9ed1 100644 --- a/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/s_tir/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -23,7 +23,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; bool IsRootBlock(const Schedule& sch, const SBlockRV& block_rv) { StmtSRef block_sref = sch->GetSRef(block_rv); diff --git a/src/s_tir/meta_schedule/schedule_rule/random_compute_location.cc b/src/s_tir/meta_schedule/schedule_rule/random_compute_location.cc index f00c1c87b4a4..9aacbc259a8e 100644 --- a/src/s_tir/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/s_tir/meta_schedule/schedule_rule/random_compute_location.cc @@ -70,7 +70,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { private: bool CheckConditions(const s_tir::Schedule sch, const s_tir::SBlockRV& block_rv) const { - tir::StmtSRef block_sref = sch->GetSRef(block_rv); + tirx::StmtSRef block_sref = sch->GetSRef(block_rv); TVM_SREF_TO_SBLOCK(block_sref); // Cond 1. The block is not the root block. @@ -85,7 +85,7 @@ class RandomComputeLocationNode : public ScheduleRuleNode { } // Cond 3 & 4. The block has at least one outer loop, and the outermost loop has only one child // block. - ffi::Array loop_srefs = s_tir::GetLoops(block_sref); + ffi::Array loop_srefs = s_tir::GetLoops(block_sref); if (loop_srefs.empty()) { return false; } diff --git a/src/s_tir/meta_schedule/schedule_rule/schedule_rule.cc b/src/s_tir/meta_schedule/schedule_rule/schedule_rule.cc index f237498726c4..aa280b2e5b91 100644 --- a/src/s_tir/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/s_tir/meta_schedule/schedule_rule/schedule_rule.cc @@ -66,7 +66,7 @@ ffi::Array ScheduleRule::DefaultLLVM() { /*disallow_if_then_else=*/true, /*require_injective=*/true, /*require_ordered=*/true, - /*disallow_op=*/ffi::Array{"tir.exp"}), + /*disallow_op=*/ffi::Array{"tirx.exp"}), ScheduleRule::AddRFactor( /*max_jobs_per_core=*/16, /*max_innermost_factor=*/Integer(64)), @@ -102,7 +102,7 @@ ffi::Array ScheduleRule::DefaultX86(const ffi::String& type) { /*disallow_if_then_else=*/true, /*require_injective=*/true, /*require_ordered=*/true, - /*disallow_op=*/ffi::Array{"tir.exp"}), + /*disallow_op=*/ffi::Array{"tirx.exp"}), ScheduleRule::AddRFactor( /*max_jobs_per_core=*/16, /*max_innermost_factor=*/Integer(64)), @@ -288,7 +288,7 @@ ffi::Array ScheduleRule::DefaultHexagon() { /*disallow_if_then_else=*/true, /*require_injective=*/true, /*require_ordered=*/true, - /*disallow_op=*/ffi::Array{"tir.exp"}), + /*disallow_op=*/ffi::Array{"tirx.exp"}), ScheduleRule::MultiLevelTilingWideVector( /*structure=*/"SRSRS", /*vector_length_in_bits=*/1024, @@ -317,17 +317,17 @@ ffi::Array ScheduleRule::DefaultRISCV(const int vlen) { /*disallow_if_then_else=*/true, /*require_injective=*/true, /*require_ordered=*/true, - /*disallow_op=*/ffi::Array{"tir.exp"})); + /*disallow_op=*/ffi::Array{"tirx.exp"})); rules.push_back(ScheduleRule::AddRFactor( /*max_jobs_per_core=*/16, /*max_innermost_factor=*/Integer(64))); auto current_target = tvm::Target::Current(); const auto reg_rvv_intrinsics = - tvm::ffi::Function::GetGlobalRequired("tir.tensor_intrin.register_rvv_isa_intrinsics"); + tvm::ffi::Function::GetGlobalRequired("tirx.tensor_intrin.register_rvv_isa_intrinsics"); const auto rvv_kernels_inventory = reg_rvv_intrinsics(current_target, /* inventory_only */ true) .cast>(); for (const auto& intrin : rvv_kernels_inventory) { - if (!tir::TensorIntrin::Get(intrin.first, /*allow_missing*/ true)) { + if (!tirx::TensorIntrin::Get(intrin.first, /*allow_missing*/ true)) { // on demand intrinsic register reg_rvv_intrinsics(current_target, /* inventory_only */ false); } @@ -427,7 +427,7 @@ ffi::Array ScheduleRule::DefaultARM(const ffi::String& type) { /*disallow_if_then_else=*/true, /*require_injective=*/true, /*require_ordered=*/true, - /*disallow_op=*/ffi::Array{"tir.exp"}), + /*disallow_op=*/ffi::Array{"tirx.exp"}), ScheduleRule::AddRFactor( /*max_jobs_per_core=*/8, /*max_innermost_factor=*/Integer(32)), diff --git a/src/s_tir/meta_schedule/trace_apply.cc b/src/s_tir/meta_schedule/trace_apply.cc index ff373bdff468..ecaa141e2b24 100644 --- a/src/s_tir/meta_schedule/trace_apply.cc +++ b/src/s_tir/meta_schedule/trace_apply.cc @@ -19,8 +19,8 @@ #include "trace_apply.h" #include -#include -#include +#include +#include #include #include @@ -36,7 +36,7 @@ namespace tvm { namespace s_tir { namespace meta_schedule { -using namespace tir; +using namespace tirx; using s_tir::GetSBlockNames; using s_tir::Instruction; using s_tir::InstructionKind; diff --git a/src/s_tir/meta_schedule/utils.h b/src/s_tir/meta_schedule/utils.h index f2938c9068e9..847adc2591da 100644 --- a/src/s_tir/meta_schedule/utils.h +++ b/src/s_tir/meta_schedule/utils.h @@ -43,7 +43,7 @@ #include #include #include -#include +#include #include #include @@ -298,9 +298,9 @@ inline std::string Concat(const ffi::Array& strs, const std::string * \param global_var_name The global variable name * \return The SBlockRV */ -inline s_tir::SBlockRV GetRVFromSRef(const s_tir::Schedule& sch, const tir::StmtSRef& block_sref, +inline s_tir::SBlockRV GetRVFromSRef(const s_tir::Schedule& sch, const tirx::StmtSRef& block_sref, const ffi::String& global_var_name) { - const tir::SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); + const tirx::SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); return sch->GetSBlock(block->name_hint, global_var_name); } @@ -586,7 +586,7 @@ inline double Sum(const ffi::Array& arr) { } /*! \brief Collecting all the blocks */ -class SBlockCollector : public tir::StmtVisitor { +class SBlockCollector : public tirx::StmtVisitor { public: static ffi::Array Collect(const s_tir::Schedule& sch, const ffi::Function f_block_filter = nullptr) { // @@ -597,7 +597,7 @@ class SBlockCollector : public tir::StmtVisitor { /*! \brief Entry point */ ffi::Array Run() { std::vector results; - auto f_collect = [this, &results](tir::PrimFunc func, ffi::String func_name) { + auto f_collect = [this, &results](tirx::PrimFunc func, ffi::String func_name) { func_name_ = func_name; block_names_.clear(); blocks_to_collect_.clear(); @@ -609,14 +609,14 @@ class SBlockCollector : public tir::StmtVisitor { if (sch_->func_working_on().defined()) { GlobalVar gv = sch_->func_working_on().value(); - tir::PrimFunc func = Downcast(sch_->mod()->functions[gv]); + tirx::PrimFunc func = Downcast(sch_->mod()->functions[gv]); f_collect(func, gv->name_hint); } else { for (const auto& [gv, base_func] : sch_->mod()->functions) { // `gv->name_hint` is the name of the function // `base_func` can be PrimFunc or relax::Function - if (const auto* func = base_func.as()) { - f_collect(ffi::GetRef(func), gv->name_hint); + if (const auto* func = base_func.as()) { + f_collect(ffi::GetRef(func), gv->name_hint); } } } @@ -626,8 +626,8 @@ class SBlockCollector : public tir::StmtVisitor { explicit SBlockCollector(const s_tir::Schedule& sch, const ffi::Function f_block_filter = nullptr) : sch_(sch), f_block_filter_(f_block_filter) {} /*! \brief Override the Stmt visiting behaviour */ - void VisitStmt_(const tir::SBlockNode* block) override { - tir::StmtVisitor::VisitStmt_(block); + void VisitStmt_(const tirx::SBlockNode* block) override { + tirx::StmtVisitor::VisitStmt_(block); TVM_FFI_ICHECK(block_names_.count(block->name_hint) == 0) << "Duplicated block name " << block->name_hint << " in function " << func_name_ << " not supported!"; @@ -637,7 +637,7 @@ class SBlockCollector : public tir::StmtVisitor { // Otherwise collect all blocks. Bool collect_block = Bool(true); if (f_block_filter_ != nullptr) { - collect_block = f_block_filter_(ffi::GetRef(block)).cast(); + collect_block = f_block_filter_(ffi::GetRef(block)).cast(); } if (collect_block) { blocks_to_collect_.push_back(block->name_hint); diff --git a/src/s_tir/sblock_dependence_info.cc b/src/s_tir/sblock_dependence_info.cc index bbfa691cee2f..2d75d24137fe 100644 --- a/src/s_tir/sblock_dependence_info.cc +++ b/src/s_tir/sblock_dependence_info.cc @@ -22,7 +22,7 @@ #include namespace tvm { -namespace tir { +namespace tirx { TVM_FFI_STATIC_INIT_BLOCK() { SBlockDependenceInfoNode::RegisterReflection(); } @@ -103,5 +103,5 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/s_tir/sblock_scope.cc b/src/s_tir/sblock_scope.cc index dc8d12d0ac56..87c636480c88 100644 --- a/src/s_tir/sblock_scope.cc +++ b/src/s_tir/sblock_scope.cc @@ -21,7 +21,7 @@ #include namespace tvm { -namespace tir { +namespace tirx { TVM_FFI_STATIC_INIT_BLOCK() { StmtSRefNode::RegisterReflection(); @@ -210,5 +210,5 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_method("s_tir.SBlockScopeGetDepsByDst", &SBlockScopeNode::GetDepsByDst); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/s_tir/schedule/analysis.h b/src/s_tir/schedule/analysis.h index e059a912489d..609e1c84dee4 100644 --- a/src/s_tir/schedule/analysis.h +++ b/src/s_tir/schedule/analysis.h @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include #include @@ -35,7 +35,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /******** Verification ********/ /*! @@ -616,7 +616,7 @@ bool HasOp(const Stmt& stmt, const ffi::Array& ops); * \brief Checks if the given AST statement contains if-then-else, including * 1) IfThenElse statement * 2) Select expression - * 3) The operator `tir.if_then_else` + * 3) The operator `tirx.if_then_else` * 4) non-constant-true SBlock predicates * \param stmt The AST statement to be checked * \return A boolean indicating whether the statement contains the if-then-else pattern @@ -689,7 +689,7 @@ bool IsSpatialPrimFunc(const PrimFunc& func); * \return A boolean indicating whether the operation is beneficial. */ bool NeedsRFactorOrCrossThreadReduction(const s_tir::ScheduleState& self, // - const tir::StmtSRef& block_sref, // + const tirx::StmtSRef& block_sref, // int64_t max_parallel_extent, // int64_t max_parallel_basic); @@ -740,9 +740,9 @@ PrimExpr SimplifyNonTrivialExpr(const PrimExpr& expr, arith::Analyzer* analyzer) class TensorizeInfoNode : public Object { public: /*! \brief Maps loops in a target block to the ones in an intrinsic description */ - ffi::Map loop_map; + ffi::Map loop_map; /*! \brief Maps loops in an intrinsic description to its index, outer to inner */ - ffi::Map desc_loop_indexer; + ffi::Map desc_loop_indexer; /*! \brief Optional padded extents of the block iters when padding is needed to match the * intrinsic description */ @@ -775,8 +775,8 @@ class TensorizeInfo : public ObjectRef { * \return TensorizeInfo structure if a valid mapping is found, std::nullopt otherwise */ ffi::Optional GetTensorizeLoopMapping(const s_tir::ScheduleState& self, - const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func, + const tirx::StmtSRef& block_sref, + const tirx::PrimFunc& desc_func, bool allow_padding); /*!\brief Necessary information used to perform transformations for tensorization */ @@ -843,7 +843,7 @@ ffi::Optional GetAutoTensorizeMappingInfo(const Schedu * \return true if basic conditions are met. */ bool CheckAutoTensorizeApplicable(const s_tir::Schedule& sch, const s_tir::SBlockRV& block_rv, - const tir::PrimFunc& desc_func); + const tirx::PrimFunc& desc_func); } // namespace s_tir } // namespace tvm diff --git a/src/s_tir/schedule/analysis/analysis.cc b/src/s_tir/schedule/analysis/analysis.cc index e210881b2bb6..9819bbb24f45 100644 --- a/src/s_tir/schedule/analysis/analysis.cc +++ b/src/s_tir/schedule/analysis/analysis.cc @@ -24,7 +24,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; TVM_FFI_STATIC_INIT_BLOCK() { TensorizeInfoNode::RegisterReflection(); @@ -1355,7 +1355,7 @@ bool HasIfThenElse(const Stmt& stmt) { has_branch = true; } else if (const auto* call = obj.as()) { // Case 3: Call the `if_then_else` operator - static const Op& op_if_then_else = Op::Get("tir.if_then_else"); + static const Op& op_if_then_else = Op::Get("tirx.if_then_else"); if (call->op.same_as(op_if_then_else)) { has_branch = true; } @@ -1593,24 +1593,24 @@ bool IsSpatialPrimFunc(const PrimFunc& func) { } std::pair GetCumulativeSpaceAndReductionLength(const s_tir::ScheduleState& self, - const tir::StmtSRef& block_sref) { - ffi::Array loops = GetLoops(block_sref); + const tirx::StmtSRef& block_sref) { + ffi::Array loops = GetLoops(block_sref); int64_t cum_space_len = 1, cum_reduce_len = 1; /* * Return (-1, -1) if * 1. there is some loop with type other than kDataPar and kCommReduce; * 2. there is some loop which is dynamic. */ - for (const tir::StmtSRef& loop_sref : loops) { - tir::IterVarType type = GetLoopIterType(loop_sref); - if (type == tir::kDataPar) { + for (const tirx::StmtSRef& loop_sref : loops) { + tirx::IterVarType type = GetLoopIterType(loop_sref); + if (type == tirx::kDataPar) { const int64_t* extent = GetLoopIntExtent(loop_sref); if (extent && *extent != -1) { cum_space_len *= *extent; } else { return std::make_pair(-1, -1); } - } else if (type == tir::kCommReduce) { + } else if (type == tirx::kCommReduce) { const int64_t* extent = GetLoopIntExtent(loop_sref); if (extent && *extent != -1) { cum_reduce_len *= *extent; @@ -1625,11 +1625,11 @@ std::pair GetCumulativeSpaceAndReductionLength(const s_tir::Sc } bool NeedsRFactorOrCrossThreadReduction(const s_tir::ScheduleState& self, // - const tir::StmtSRef& block_sref, // + const tirx::StmtSRef& block_sref, // int64_t max_parallel_extent, // int64_t max_parallel_basic) { const SBlockNode* block = TVM_SREF_TO_SBLOCK(block_sref); - ffi::Array loops = GetLoops(block_sref); + ffi::Array loops = GetLoops(block_sref); // Cond 1. The block must have at lease one write buffer if (block->writes.size() == 0) { @@ -1646,9 +1646,9 @@ bool NeedsRFactorOrCrossThreadReduction(const s_tir::ScheduleState& self, // } // Cond 3. Every the loop axis must be either spatial axis or reduction axis. - for (const tir::StmtSRef& loop_sref : loops) { - const tir::IterVarType& type = GetLoopIterType(loop_sref); - if (type != tir::kDataPar && type != tir::kCommReduce) { + for (const tirx::StmtSRef& loop_sref : loops) { + const tirx::IterVarType& type = GetLoopIterType(loop_sref); + if (type != tirx::kDataPar && type != tirx::kCommReduce) { return false; } } @@ -1658,7 +1658,7 @@ bool NeedsRFactorOrCrossThreadReduction(const s_tir::ScheduleState& self, // bool has_reduction_loop = false; for (size_t i = 0; i < loops.size(); ++i) { // Cond 4. - if (GetLoopIterType(loops[i]) == tir::kCommReduce) { + if (GetLoopIterType(loops[i]) == tirx::kCommReduce) { has_reduction_loop = true; } @@ -1670,7 +1670,7 @@ bool NeedsRFactorOrCrossThreadReduction(const s_tir::ScheduleState& self, // return false; } } else { - const auto* block_realize = loop_i->body.as(); + const auto* block_realize = loop_i->body.as(); if (!block_realize || block_realize->block.get() != block) { return false; } @@ -1712,9 +1712,9 @@ struct TensorIntrinDescInfo { */ const SBlockRealizeNode* desc_block = nullptr; /*! \brief The loops of the description function, in the order from outer loops to inner ones. */ - std::vector desc_loops; + std::vector desc_loops; /*! \brief The loop variables. */ - std::unordered_set desc_loop_vars; + std::unordered_set desc_loop_vars; }; /*! @@ -1745,7 +1745,7 @@ TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer, } return true; }; - tir::PostOrderVisit(desc_scope_realize->block->body, f_visit); + tirx::PostOrderVisit(desc_scope_realize->block->body, f_visit); std::reverse(info.desc_loops.begin(), info.desc_loops.end()); TVM_FFI_ICHECK(info.desc_block); } @@ -1753,22 +1753,22 @@ TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer, } ffi::Optional GetTensorizeLoopMapping(const s_tir::ScheduleState& self, - const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func, + const tirx::StmtSRef& block_sref, + const tirx::PrimFunc& desc_func, bool allow_padding) { arith::Analyzer analyzer; - const tir::SBlockRealize& block = GetSBlockRealize(self, block_sref); + const tirx::SBlockRealize& block = GetSBlockRealize(self, block_sref); // Step 1. Analyze desc_func, extract its block, loops and loop vars TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer, desc_func); // Step 2. Collect loops from block_sref - const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); + const tirx::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); TVM_SREF_TO_SBLOCK(scope_sref); - std::vector block_loops; - std::unordered_set block_loop_vars; + std::vector block_loops; + std::unordered_set block_loop_vars; { - for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) { - const auto* loop = loop_sref->StmtAs(); - if (loop == nullptr || loop->body->IsInstance()) { + for (const tirx::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) { + const auto* loop = loop_sref->StmtAs(); + if (loop == nullptr || loop->body->IsInstance()) { break; } block_loops.push_back(loop); @@ -1820,7 +1820,7 @@ ffi::Optional GetTensorizeLoopMapping(const s_tir::ScheduleState& for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) { // Step 3.1. Find the corresponding loop of the i_desc-th block var of desc const PrimExpr& desc_bind = desc_block->iter_values[i_desc]; - const tir::ForNode* desc_loop = nullptr; + const tirx::ForNode* desc_loop = nullptr; IterVarType iter_type_desc = iter_types_desc[i_desc]; for (int i = 0, n = desc_loops.size(); i < n; ++i) { // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars @@ -1854,8 +1854,8 @@ ffi::Optional GetTensorizeLoopMapping(const s_tir::ScheduleState& // Step 3.3. Find the corresponding loop of the target block for (int i = 0, n = block_loops.size(); i < n; ++i) { // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars - const tir::ForNode* block_loop = block_loops[i]; - const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; + const tirx::ForNode* block_loop = block_loops[i]; + const tirx::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; // Skip i-th loop if it has already been mapped if (ret->loop_map.find(block_loop_sref) != ret->loop_map.end()) continue; @@ -1886,13 +1886,13 @@ ffi::Optional GetTensorizeLoopMapping(const s_tir::ScheduleState& } } - ret->loop_map.Set(block_loop_sref, ffi::GetRef(desc_loop)); + ret->loop_map.Set(block_loop_sref, ffi::GetRef(desc_loop)); break; } } for (int i = 0, n = desc_loops.size(); i < n; ++i) { - ret->desc_loop_indexer.Set(ffi::GetRef(desc_loops[i]), Integer(i)); + ret->desc_loop_indexer.Set(ffi::GetRef(desc_loops[i]), Integer(i)); } if (!block_index_to_padding.empty()) { if (!allow_padding) { @@ -2106,8 +2106,8 @@ class AutoTensorizeMappingProposer { std::unordered_map lhs_feasible_vars_; }; -bool CheckAutoTensorizeApplicable(const ScheduleState& state, const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func, +bool CheckAutoTensorizeApplicable(const ScheduleState& state, const tirx::StmtSRef& block_sref, + const tirx::PrimFunc& desc_func, AutoTensorizeComparator* extractor) { // Step 1. Analyze desc_func, extract its block, loops and loop vars // Step 2. Check if `desc_block` matches `block` @@ -2120,14 +2120,14 @@ bool CheckAutoTensorizeApplicable(const ScheduleState& state, const tir::StmtSRe } bool CheckAutoTensorizeApplicable(const s_tir::Schedule& sch, const s_tir::SBlockRV& block_rv, - const tir::PrimFunc& desc_func) { + const tirx::PrimFunc& desc_func) { AutoTensorizeComparator extractor(sch->state()->mod); return CheckAutoTensorizeApplicable(sch->state(), sch->GetSRef(block_rv), desc_func, &extractor); } ffi::Optional GetAutoTensorizeMappingInfo( - const s_tir::ScheduleState& self, const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func) { + const s_tir::ScheduleState& self, const tirx::StmtSRef& block_sref, + const tirx::PrimFunc& desc_func) { AutoTensorizeComparator extractor(self->mod); if (!CheckAutoTensorizeApplicable(self, block_sref, desc_func, &extractor)) { return std::nullopt; diff --git a/src/s_tir/schedule/analysis/layout.cc b/src/s_tir/schedule/analysis/layout.cc index 2b353399377f..8f2885529bff 100644 --- a/src/s_tir/schedule/analysis/layout.cc +++ b/src/s_tir/schedule/analysis/layout.cc @@ -22,7 +22,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief Calculate the strides of the buffer @@ -100,7 +100,7 @@ class SplitExprCollector { private: void Visit(const arith::IterSplitExpr& expr) { - if (const auto* var = expr->source->source.as()) { + if (const auto* var = expr->source->source.as()) { const int64_t* lower_factor = as_const_int(expr->lower_factor); const int64_t* extent = as_const_int(expr->extent); if (lower_factor == nullptr || extent == nullptr) { diff --git a/src/s_tir/schedule/analysis/reducer.cc b/src/s_tir/schedule/analysis/reducer.cc index 6ba93769c6ce..9cd2795e3d18 100644 --- a/src/s_tir/schedule/analysis/reducer.cc +++ b/src/s_tir/schedule/analysis/reducer.cc @@ -20,7 +20,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /******** Pattern Matcher ********/ @@ -606,7 +606,7 @@ class NoMatchedReducerError : public ScheduleError { ffi::String DetailRenderTemplate() const final { std::ostringstream os; os << "No matched reducer for identity " << identities_ << " and combiner " << combiners_ - << "In this case rfactor cannot be applied. You can check tvm::tir::ReducerRegistry for " + << "In this case rfactor cannot be applied. You can check tvm::tirx::ReducerRegistry for " "default reducers or registering new reducers."; return os.str(); } diff --git a/src/s_tir/schedule/analysis/verify.cc b/src/s_tir/schedule/analysis/verify.cc index b31f624c3bca..9ecb9711fa93 100644 --- a/src/s_tir/schedule/analysis/verify.cc +++ b/src/s_tir/schedule/analysis/verify.cc @@ -20,7 +20,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class SRefTreeVerifier : public StmtVisitor { public: diff --git a/src/s_tir/schedule/concrete_schedule.cc b/src/s_tir/schedule/concrete_schedule.cc index 9d5068c61b62..51189b725427 100644 --- a/src/s_tir/schedule/concrete_schedule.cc +++ b/src/s_tir/schedule/concrete_schedule.cc @@ -22,7 +22,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, @@ -912,7 +912,7 @@ SBlockRV ConcreteScheduleNode::Blockize(const ffi::Array& blocks, void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const ffi::String& intrin, bool preserve_unit_iters) { TVM_TIR_SCHEDULE_BEGIN(); - s_tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin).value(), + s_tir::Tensorize(state_, this->GetSRef(loop_rv), tirx::TensorIntrin::Get(intrin).value(), preserve_unit_iters); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); @@ -921,7 +921,7 @@ void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const ffi::String& i void ConcreteScheduleNode::Tensorize(const SBlockRV& block_rv, const ffi::String& intrin, bool preserve_unit_iters) { TVM_TIR_SCHEDULE_BEGIN(); - s_tir::Tensorize(state_, this->GetSRef(block_rv), tir::TensorIntrin::Get(intrin).value(), + s_tir::Tensorize(state_, this->GetSRef(block_rv), tirx::TensorIntrin::Get(intrin).value(), preserve_unit_iters); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_); diff --git a/src/s_tir/schedule/concrete_schedule.h b/src/s_tir/schedule/concrete_schedule.h index ba058637dc97..84b934ff7263 100644 --- a/src/s_tir/schedule/concrete_schedule.h +++ b/src/s_tir/schedule/concrete_schedule.h @@ -27,7 +27,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class ConcreteScheduleNode : public ScheduleNode { friend class Schedule; diff --git a/src/s_tir/schedule/error.cc b/src/s_tir/schedule/error.cc index 1a8ceff888b9..8e984469cd56 100644 --- a/src/s_tir/schedule/error.cc +++ b/src/s_tir/schedule/error.cc @@ -20,7 +20,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; ffi::String ScheduleError::RenderReport(const ffi::String& primitive) const { IRModule mod = this->mod(); diff --git a/src/s_tir/schedule/error.h b/src/s_tir/schedule/error.h index 8cd7b891af1a..965ba143f1b6 100644 --- a/src/s_tir/schedule/error.h +++ b/src/s_tir/schedule/error.h @@ -25,7 +25,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! \brief Error that happens during TensorIR scheduling */ class ScheduleError : public tvm::runtime::Error { diff --git a/src/s_tir/schedule/instruction.cc b/src/s_tir/schedule/instruction.cc index cb9f357de31b..29fe3c2d88dd 100644 --- a/src/s_tir/schedule/instruction.cc +++ b/src/s_tir/schedule/instruction.cc @@ -22,7 +22,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; TVM_FFI_STATIC_INIT_BLOCK() { InstructionKindNode::RegisterReflection(); diff --git a/src/s_tir/schedule/instruction_traits.h b/src/s_tir/schedule/instruction_traits.h index 47968f4bc34e..067af3772126 100644 --- a/src/s_tir/schedule/instruction_traits.h +++ b/src/s_tir/schedule/instruction_traits.h @@ -30,7 +30,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief Register an InstructionKind using a trait class diff --git a/src/s_tir/schedule/ir_comparator.cc b/src/s_tir/schedule/ir_comparator.cc index fe3193d58412..385cf8fa891a 100644 --- a/src/s_tir/schedule/ir_comparator.cc +++ b/src/s_tir/schedule/ir_comparator.cc @@ -23,7 +23,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /******** Tensorize Comparator ********/ diff --git a/src/s_tir/schedule/ir_comparator.h b/src/s_tir/schedule/ir_comparator.h index 9ec96f05b8aa..0fb2aae64d8d 100644 --- a/src/s_tir/schedule/ir_comparator.h +++ b/src/s_tir/schedule/ir_comparator.h @@ -28,7 +28,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; using ExprComparator = ExprFunctor; using StmtComparator = StmtFunctor; diff --git a/src/s_tir/schedule/primitive.h b/src/s_tir/schedule/primitive.h index 094d1f405fe8..63213bb4abdd 100644 --- a/src/s_tir/schedule/primitive.h +++ b/src/s_tir/schedule/primitive.h @@ -26,7 +26,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /******** Schedule: Sampling ********/ /*! @@ -99,7 +99,7 @@ TVM_DLL std::vector SamplePerfectTile( */ TVM_DLL std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // - const tir::StmtSRef& loop_sref, int32_t n_split, int32_t max_innermost_factor, + const tirx::StmtSRef& loop_sref, int32_t n_split, int32_t max_innermost_factor, ffi::Optional>* decision); /*! * \brief Sample the factors to a partitioned tile for a specific loop @@ -137,7 +137,7 @@ TVM_DLL std::vector SamplePartitionedTile( */ TVM_DLL std::vector SamplePartitionedTile( support::LinearCongruentialEngine::TRandState* rand_state, // - const tir::StmtSRef& loop_sref, int32_t n_split, int32_t partition_pos, + const tirx::StmtSRef& loop_sref, int32_t n_split, int32_t partition_pos, int32_t innerpart_factor, ffi::Optional>* decision); /*! * \brief Sample a compute-at location of the given block @@ -147,9 +147,9 @@ TVM_DLL std::vector SamplePartitionedTile( * \param decision The sampling decision * \return The sampled loop where the input block is to be computed at */ -TVM_DLL tir::StmtSRef SampleComputeLocation( +TVM_DLL tirx::StmtSRef SampleComputeLocation( s_tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* rand_state, - const tir::StmtSRef& block_sref, ffi::Optional* decision); + const tirx::StmtSRef& block_sref, ffi::Optional* decision); /******** Schedule: Get blocks & loops ********/ /*! diff --git a/src/s_tir/schedule/primitive/annotate.cc b/src/s_tir/schedule/primitive/annotate.cc index 0a8d985e2632..7aab39e41126 100644 --- a/src/s_tir/schedule/primitive/annotate.cc +++ b/src/s_tir/schedule/primitive/annotate.cc @@ -20,7 +20,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; void Annotate(ScheduleState self, const StmtSRef& sref, const ffi::String& ann_key, const Any& ann_val) { diff --git a/src/s_tir/schedule/primitive/annotate_buffer_access.cc b/src/s_tir/schedule/primitive/annotate_buffer_access.cc index 9962cc47fc30..50d752518bfb 100644 --- a/src/s_tir/schedule/primitive/annotate_buffer_access.cc +++ b/src/s_tir/schedule/primitive/annotate_buffer_access.cc @@ -22,7 +22,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class AnnotateRegionRewriter : public StmtExprMutator { public: diff --git a/src/s_tir/schedule/primitive/block_annotate.cc b/src/s_tir/schedule/primitive/block_annotate.cc index f6e9fa8bab37..1b169dbe2a0a 100644 --- a/src/s_tir/schedule/primitive/block_annotate.cc +++ b/src/s_tir/schedule/primitive/block_annotate.cc @@ -18,14 +18,14 @@ */ #include #include -#include +#include -#include "../../../tir/transform/ir_utils.h" +#include "../../../tirx/transform/ir_utils.h" #include "../utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class StorageAlignAxisOutOfRangeError : public ScheduleError { public: diff --git a/src/s_tir/schedule/primitive/blockize_tensorize.cc b/src/s_tir/schedule/primitive/blockize_tensorize.cc index 95074147a02a..2357bc354778 100644 --- a/src/s_tir/schedule/primitive/blockize_tensorize.cc +++ b/src/s_tir/schedule/primitive/blockize_tensorize.cc @@ -19,18 +19,18 @@ #include -#include "../../../tir/ir/data_type_rewriter.h" -#include "../../../tir/transform/simplify.h" +#include "../../../tirx/ir/data_type_rewriter.h" +#include "../../../tirx/transform/simplify.h" #include "../ir_comparator.h" #include "../utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; template bool UsesVar(const T& x, const Var& var) { - return tir::UsesVar(x, [tgt = var.get()](const VarNode* v) { return v == tgt; }); + return tirx::UsesVar(x, [tgt = var.get()](const VarNode* v) { return v == tgt; }); } Range RangeFromExtent(const PrimExpr& extent) { @@ -104,7 +104,7 @@ ffi::Array> TrivialSubspaceDivision( var_set.insert(var.get()); } return [var_set = std::move(var_set)](const PrimExpr& expr) -> bool { - return tir::UsesVar(expr, [&var_set](const VarNode* var) { + return tirx::UsesVar(expr, [&var_set](const VarNode* var) { return var_set.count(var); // }); }; diff --git a/src/s_tir/schedule/primitive/cache_index.cc b/src/s_tir/schedule/primitive/cache_index.cc index a876dab015ad..1467a91acd68 100644 --- a/src/s_tir/schedule/primitive/cache_index.cc +++ b/src/s_tir/schedule/primitive/cache_index.cc @@ -18,13 +18,13 @@ */ #include -#include "../../../tir/transform/replace_selected_expr.h" +#include "../../../tirx/transform/replace_selected_expr.h" #include "../utils.h" #include "cache_index_helpers.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /******** Helper Functions/Classes ********/ diff --git a/src/s_tir/schedule/primitive/cache_index_helpers.cc b/src/s_tir/schedule/primitive/cache_index_helpers.cc index 752e05856802..c37ed260bac9 100644 --- a/src/s_tir/schedule/primitive/cache_index_helpers.cc +++ b/src/s_tir/schedule/primitive/cache_index_helpers.cc @@ -26,11 +26,11 @@ #include "cache_index_helpers.h" #include // For the arith::Analyzer::Simplify() method simplifying terms -#include // For the ExprDeepEqual analysis -#include -#include -#include -#include +#include // For the ExprDeepEqual analysis +#include +#include +#include +#include #include // For std::find_if #include // For the hashtable datatype @@ -38,7 +38,7 @@ #include namespace tvm { -namespace tir { +namespace tirx { // cache_ is a static variable of the class ComputationsDoneBy, and C++ requires to define here // such static attribute, otherwise it causes a linking error. @@ -488,5 +488,5 @@ void InsertVectorToSortedSemanticComputations(std::vector #include -#include // For the ExprDeepEqual analysis -#include -#include -#include -#include // For the class StmtExprVisitor +#include // For the ExprDeepEqual analysis +#include +#include +#include +#include // For the class StmtExprVisitor #include #include // For pairs datatype @@ -41,7 +41,7 @@ #include "../../../support/ordered_map.h" namespace tvm { -namespace tir { +namespace tirx { /*! * \brief A computation table is a hashtable which associates to each expression being computed @@ -163,7 +163,7 @@ void InsertVectorToSortedSemanticComputations(std::vector& vec_to_add, bool identify_equiv_terms, size_t increase_count = 1); -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_S_TIR_SCHEDULE_PRIMITIVE_CACHE_INDEX_HELPERS_H_ diff --git a/src/s_tir/schedule/primitive/cache_read_write.cc b/src/s_tir/schedule/primitive/cache_read_write.cc index b2886f158f14..baa80b38a9c4 100644 --- a/src/s_tir/schedule/primitive/cache_read_write.cc +++ b/src/s_tir/schedule/primitive/cache_read_write.cc @@ -19,13 +19,13 @@ #include -#include "../../../tir/analysis/var_use_def_analysis.h" -#include "../../../tir/transform/ir_utils.h" +#include "../../../tirx/analysis/var_use_def_analysis.h" +#include "../../../tirx/transform/ir_utils.h" #include "../utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /******** Error Classes ********/ diff --git a/src/s_tir/schedule/primitive/compute_at.cc b/src/s_tir/schedule/primitive/compute_at.cc index 5c476213ddb2..f099796bad4c 100644 --- a/src/s_tir/schedule/primitive/compute_at.cc +++ b/src/s_tir/schedule/primitive/compute_at.cc @@ -20,7 +20,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; using support::NDIntSet; diff --git a/src/s_tir/schedule/primitive/compute_inline.cc b/src/s_tir/schedule/primitive/compute_inline.cc index 17f804514db4..d2b28b1a1f5a 100644 --- a/src/s_tir/schedule/primitive/compute_inline.cc +++ b/src/s_tir/schedule/primitive/compute_inline.cc @@ -22,7 +22,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; static const char kErrBodyInline[] = R"(The body of the inlined block should be in form of 'A[f(i, j, k, ...)] = g(i, j, k, ...)', @@ -679,7 +679,7 @@ class ReverseComputeInliner : public BaseInliner { } const BufferStoreNode* producer_store = nullptr; - if (const auto* producer_if = producer_block_->body.as()) { + if (const auto* producer_if = producer_block_->body.as()) { if (producer_if->else_case.defined()) { return false; } @@ -1287,7 +1287,7 @@ SBlock ReductionEpilogueFuser::CreateFusedReductionBlock( }; // Identity element for reduction (assumed to be 0 for addition-based reductions) - PrimExpr identity_elem = tir::make_zero(epilogue_output_buffer_->dtype); + PrimExpr identity_elem = tirx::make_zero(epilogue_output_buffer_->dtype); // Substitute reduction buffer load with identity element InitSubstituter init_subst(inlined_buffer_, identity_elem); diff --git a/src/s_tir/schedule/primitive/decompose_padding.cc b/src/s_tir/schedule/primitive/decompose_padding.cc index e1dbb32f4c60..7d57cde1373b 100644 --- a/src/s_tir/schedule/primitive/decompose_padding.cc +++ b/src/s_tir/schedule/primitive/decompose_padding.cc @@ -18,12 +18,12 @@ */ #include -#include "../../../tir/transform/ir_utils.h" +#include "../../../tirx/transform/ir_utils.h" #include "../utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! \brief Information used to create new padding block */ struct PaddingSBlockInfo { @@ -97,7 +97,7 @@ class PaddingInfoAnalyzer { return false; } const CallNode* if_then_else = store->value.as(); - if (!if_then_else || !if_then_else->op.same_as(tir::builtin::if_then_else())) { + if (!if_then_else || !if_then_else->op.same_as(tirx::builtin::if_then_else())) { SetError("Value of BufferStore expect to be constrained by a padding predicate"); return false; } diff --git a/src/s_tir/schedule/primitive/for_kind.cc b/src/s_tir/schedule/primitive/for_kind.cc index fe9ae79893f9..1ea03efa8dca 100644 --- a/src/s_tir/schedule/primitive/for_kind.cc +++ b/src/s_tir/schedule/primitive/for_kind.cc @@ -20,7 +20,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class WrongBlockIterTypeError : public ScheduleError { public: @@ -124,7 +124,7 @@ void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind PreOrderVisit(loop, [&](const ObjectRef& node) { if (const auto* realize = node.as()) { // If this block doesn't have corresponding StmtSRef in the schedule state, it must be a block - // inside `tir.init()`. We don't check the condition for such blocks. + // inside `tirx.init()`. We don't check the condition for such blocks. if (!self->stmt2ref.count(realize->block.get())) { return false; } diff --git a/src/s_tir/schedule/primitive/get_block_loop.cc b/src/s_tir/schedule/primitive/get_block_loop.cc index b2427851f2b5..d64cf9c689c9 100644 --- a/src/s_tir/schedule/primitive/get_block_loop.cc +++ b/src/s_tir/schedule/primitive/get_block_loop.cc @@ -21,7 +21,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; ffi::Array GetSBlocks(const ScheduleState& self, const ffi::String& name, const GlobalVar& gv) { diff --git a/src/s_tir/schedule/primitive/hide_buffer_access.cc b/src/s_tir/schedule/primitive/hide_buffer_access.cc index 08482525f9df..10db5c256ace 100644 --- a/src/s_tir/schedule/primitive/hide_buffer_access.cc +++ b/src/s_tir/schedule/primitive/hide_buffer_access.cc @@ -18,12 +18,12 @@ */ #include -#include "../../../tir/transform/ir_utils.h" +#include "../../../tirx/transform/ir_utils.h" #include "../utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /******** Error Classes ********/ diff --git a/src/s_tir/schedule/primitive/layout_transformation.cc b/src/s_tir/schedule/primitive/layout_transformation.cc index 3fd210e91409..b505e952cc04 100644 --- a/src/s_tir/schedule/primitive/layout_transformation.cc +++ b/src/s_tir/schedule/primitive/layout_transformation.cc @@ -29,7 +29,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! \brief Planning stage prior to rewriting in TransformLayoutRewriter * @@ -57,12 +57,12 @@ using namespace tvm::tir; * of those write stages writes to all pre-transformation indices * following a row-major traversal. These write stage is rewritten to * be row-major traversals of the post-transformation indices, with a - * `tir::if_then_else` call to write either the specified `pad_value` + * `tirx::if_then_else` call to write either the specified `pad_value` * into padding or the computed value into non-padding. * * 4. EpiloguePlan. The transformation introduces padding, has at * least one write stage for the transformed buffer, but no write - * stage can be rewritten to use `tir::if_then_else`. The + * stage can be rewritten to use `tirx::if_then_else`. The * transformation still requires the `pad_value` to be written into * the padding, so a new block is inserted after the last write stage * to explicitly fill the padding. @@ -117,7 +117,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { // contribute, but the first and last must. std::vector dependent_loopnest; - // Whether the padding could be represented as a tir::if_then_else + // Whether the padding could be represented as a tirx::if_then_else // node. This requires that the surrounding loop iterators // iterate over all pre-transformation buffer axes, that there are // no data dependencies between loop iterations, and that @@ -705,7 +705,7 @@ class TransformLayoutPlanner : private StmtExprVisitor { * \brief Collect blocks that are part of root block to be passed to ScheduleState::Replace for SRef * reuse */ -class ReuseBlocksCollector : public tir::StmtVisitor { +class ReuseBlocksCollector : public tirx::StmtVisitor { public: static ffi::Map Collect(SBlock result, ffi::Map new_block_to_old) { @@ -723,7 +723,7 @@ class ReuseBlocksCollector : public tir::StmtVisitor { : new_block_to_old_(new_block_to_old) {} /*! \brief Override the Stmt visiting behaviour */ - void VisitStmt_(const tir::SBlockNode* block) override { + void VisitStmt_(const tirx::SBlockNode* block) override { SBlock block_ref = ffi::GetRef(block); auto it = new_block_to_old_.find(block_ref); if (it != new_block_to_old_.end()) { @@ -1136,7 +1136,7 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const ffi::Arrayfinal_indices.Map([&](PrimExpr index) { if (auto* ptr = index.as()) { TVM_FFI_ICHECK(index_dtype.has_value()); - return tir::make_const(*index_dtype, ptr->value); + return tirx::make_const(*index_dtype, ptr->value); } else { return SubstituteWithDataTypeLegalization(index, [&](const Var& var) { return var_map.Get(var); }); diff --git a/src/s_tir/schedule/primitive/loop_transformation.cc b/src/s_tir/schedule/primitive/loop_transformation.cc index 702a32852378..ae35d1c91c93 100644 --- a/src/s_tir/schedule/primitive/loop_transformation.cc +++ b/src/s_tir/schedule/primitive/loop_transformation.cc @@ -20,7 +20,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! \brief Append a new predicate to the each child of type BlockRealize (not recursively) */ class BlockPredicateAppender : public StmtMutator { @@ -600,7 +600,7 @@ class BlockMutator : public StmtExprMutator { } // Update all instances of old iter_vars in the block with new iter_vars - auto block_stmt = tir::Substitute(new_block, var_map); + auto block_stmt = tirx::Substitute(new_block, var_map); return block_stmt; } @@ -623,7 +623,7 @@ class BlockMutator : public StmtExprMutator { if (!op->loop_var.same_as(new_var)) { // If the partioned loop contains nested for loop, then create new iteration variable instance - res.CopyOnWrite()->body = tir::Substitute(res->body, {{op->loop_var, new_var}}); + res.CopyOnWrite()->body = tirx::Substitute(res->body, {{op->loop_var, new_var}}); res.CopyOnWrite()->loop_var = new_var; } return res; @@ -672,7 +672,7 @@ ffi::Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref for (int i = 0; i < n; i++) { extent_value = analyzer.Simplify(factors[i]); Var new_loop_var = loop->loop_var.copy_with_suffix(std::to_string(i)).copy_with_dtype(dtype); - Stmt loop_body = tir::Substitute(loop->body, {{loop->loop_var, new_loop_var}}); + Stmt loop_body = tirx::Substitute(loop->body, {{loop->loop_var, new_loop_var}}); // Create new block with new reference to each variable/stmt/expr in the existing block loop_body = BlockMutator(new_loop_var, min_value, extent_value)(std::move(loop_body)); @@ -691,7 +691,7 @@ ffi::Array LoopPartition(ScheduleState self, const StmtSRef& loop_sref // Create common block with all the partitioned blocks as its children blocks SBlockRealize common({}, make_const(DataType::Bool(), 1), - SBlock({}, {}, {}, block_name + "_common", tir::SeqStmt(block_partitions))); + SBlock({}, {}, {}, block_name + "_common", tirx::SeqStmt(block_partitions))); // Replace existing loop with the newly created common block self->Replace(loop_sref, common, {}); diff --git a/src/s_tir/schedule/primitive/pad_einsum.cc b/src/s_tir/schedule/primitive/pad_einsum.cc index 5ddeabf2e5e2..c8d19fee31ec 100644 --- a/src/s_tir/schedule/primitive/pad_einsum.cc +++ b/src/s_tir/schedule/primitive/pad_einsum.cc @@ -17,13 +17,13 @@ * under the License. */ -#include +#include #include "../utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief Check if buffer indices are all Vars and expr @@ -50,7 +50,7 @@ ffi::Optional> CheckTrivialBufferAccess(const BufferRegion& buff ffi::Array indices; indices.reserve(buffer_region->region.size()); for (const Range& range : buffer_region->region) { - if (!tir::is_one(range->extent)) { + if (!tirx::is_one(range->extent)) { return std::nullopt; } if (range->min->IsInstance()) { diff --git a/src/s_tir/schedule/primitive/read_write_at.cc b/src/s_tir/schedule/primitive/read_write_at.cc index f560de517bcd..78524351d4d3 100644 --- a/src/s_tir/schedule/primitive/read_write_at.cc +++ b/src/s_tir/schedule/primitive/read_write_at.cc @@ -25,7 +25,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; using support::NDIntSet; diff --git a/src/s_tir/schedule/primitive/reduction.cc b/src/s_tir/schedule/primitive/reduction.cc index 8a7056f4896b..c09a392aed1e 100644 --- a/src/s_tir/schedule/primitive/reduction.cc +++ b/src/s_tir/schedule/primitive/reduction.cc @@ -22,7 +22,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief A helper class to create a new scope that contains decomposed init body @@ -746,7 +746,7 @@ class BaseBlockCreator { } ffi::Array stmts; for (int i = 0; i < n_buffers_; ++i) { - stmts.push_back(tir::Bind(let_vars[i], stored_values[i])); + stmts.push_back(tirx::Bind(let_vars[i], stored_values[i])); } for (const auto& store : buf_stores) { stmts.push_back(store); diff --git a/src/s_tir/schedule/primitive/reorder_block_iter_var.cc b/src/s_tir/schedule/primitive/reorder_block_iter_var.cc index 5b0def9a7248..1e82baf60b88 100644 --- a/src/s_tir/schedule/primitive/reorder_block_iter_var.cc +++ b/src/s_tir/schedule/primitive/reorder_block_iter_var.cc @@ -22,7 +22,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief The reorder index is not a valid permutation of diff --git a/src/s_tir/schedule/primitive/rolling_buffer.cc b/src/s_tir/schedule/primitive/rolling_buffer.cc index ccf85f894b21..76f67d979ed6 100644 --- a/src/s_tir/schedule/primitive/rolling_buffer.cc +++ b/src/s_tir/schedule/primitive/rolling_buffer.cc @@ -23,7 +23,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; namespace { @@ -217,7 +217,7 @@ class RollingBufferInfoCollector { // to be the rolling axis ffi::Optional roll_iter_var; int roll_axis = 0; - for (const tir::StmtSRef& loop_sref : loop_srefs) { + for (const tirx::StmtSRef& loop_sref : loop_srefs) { auto loop_var = loop_sref->StmtAs()->loop_var; auto it{std::find_if( diff --git a/src/s_tir/schedule/primitive/sampling.cc b/src/s_tir/schedule/primitive/sampling.cc index 94f2784e13f9..273f0d844192 100644 --- a/src/s_tir/schedule/primitive/sampling.cc +++ b/src/s_tir/schedule/primitive/sampling.cc @@ -23,7 +23,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; struct PrimeTable { /*! \brief The table contains prime numbers in [2, kMaxPrime) */ @@ -309,7 +309,7 @@ std::vector SamplePerfectTile(support::LinearCongruentialEngine::TRandS std::vector SamplePerfectTile( support::LinearCongruentialEngine::TRandState* rand_state, // - const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, + const tirx::StmtSRef& loop_sref, int32_t n_splits, int32_t max_innermost_factor, ffi::Optional>* decision) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); const int64_t* extent = GetLoopIntExtent(loop); @@ -370,7 +370,7 @@ TVM_DLL std::vector SamplePartitionedTile( std::vector SamplePartitionedTile( support::LinearCongruentialEngine::TRandState* rand_state, // - const tir::StmtSRef& loop_sref, int32_t n_splits, int32_t partition_pos, + const tirx::StmtSRef& loop_sref, int32_t n_splits, int32_t partition_pos, int32_t innerpart_factor, ffi::Optional>* decision) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); const int64_t* extent = GetLoopIntExtent(loop); @@ -418,9 +418,9 @@ std::vector SamplePartitionedTile( return result; } -tir::StmtSRef SampleComputeLocation(s_tir::ScheduleState self, - support::LinearCongruentialEngine::TRandState* rand_state, - const StmtSRef& block_sref, ffi::Optional* decision) { +tirx::StmtSRef SampleComputeLocation(s_tir::ScheduleState self, + support::LinearCongruentialEngine::TRandState* rand_state, + const StmtSRef& block_sref, ffi::Optional* decision) { // Step 1. Collect all possible compute-at locations. auto [location_srefs, location_indices] = CollectComputeLocation(self, block_sref); TVM_FFI_ICHECK_EQ(location_srefs.size(), location_indices.size()); diff --git a/src/s_tir/schedule/schedule.cc b/src/s_tir/schedule/schedule.cc index 4ca8d4d15ed3..a14ff2aa8aa9 100644 --- a/src/s_tir/schedule/schedule.cc +++ b/src/s_tir/schedule/schedule.cc @@ -21,7 +21,7 @@ #include "./utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; diff --git a/src/s_tir/schedule/state.cc b/src/s_tir/schedule/state.cc index 20c319d88453..83f0cae533ff 100644 --- a/src/s_tir/schedule/state.cc +++ b/src/s_tir/schedule/state.cc @@ -22,7 +22,7 @@ #include "./utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; TVM_FFI_STATIC_INIT_BLOCK() { ScheduleStateNode::RegisterReflection(); } @@ -817,7 +817,7 @@ class ChildReplacer : private StmtMutator { int seq_index_; }; -void ScheduleStateNode::Replace(const tir::StmtSRef& _src_sref, const Stmt& tgt_stmt, +void ScheduleStateNode::Replace(const tirx::StmtSRef& _src_sref, const Stmt& tgt_stmt, const ffi::Map& _block_sref_reuse) { if (this->debug_mask != 0) { const StmtNode* src_stmt = _src_sref->stmt; diff --git a/src/s_tir/schedule/trace.cc b/src/s_tir/schedule/trace.cc index 6114a275a542..a63fb15f64a8 100644 --- a/src/s_tir/schedule/trace.cc +++ b/src/s_tir/schedule/trace.cc @@ -22,7 +22,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; TVM_FFI_STATIC_INIT_BLOCK() { TraceNode::RegisterReflection(); } @@ -67,7 +67,7 @@ ffi::Array TranslateInputRVs(const ffi::Array& inputs, } const Object* dst = it->second; TVM_FFI_CHECK(dst->IsInstance(), TypeError) - << "Expect 'tir.Var', but gets: " << dst->GetTypeKey(); + << "Expect 'tirx.Var', but gets: " << dst->GetTypeKey(); return ffi::GetRef(static_cast(dst)); }; diff --git a/src/s_tir/schedule/traced_schedule.cc b/src/s_tir/schedule/traced_schedule.cc index 68541fc26ddc..e43df0835ca3 100644 --- a/src/s_tir/schedule/traced_schedule.cc +++ b/src/s_tir/schedule/traced_schedule.cc @@ -20,7 +20,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, int debug_mask, ScheduleErrorRenderLevel error_render_level, diff --git a/src/s_tir/schedule/traced_schedule.h b/src/s_tir/schedule/traced_schedule.h index fd0027ac8d91..038e808c1cf7 100644 --- a/src/s_tir/schedule/traced_schedule.h +++ b/src/s_tir/schedule/traced_schedule.h @@ -23,7 +23,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class TracedScheduleNode : public ConcreteScheduleNode { friend class Schedule; diff --git a/src/s_tir/schedule/transform.cc b/src/s_tir/schedule/transform.cc index cfc3ef4d8831..a401eba5e784 100644 --- a/src/s_tir/schedule/transform.cc +++ b/src/s_tir/schedule/transform.cc @@ -19,12 +19,12 @@ #include -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" #include "./utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /******** Annotation ********/ @@ -310,7 +310,7 @@ ffi::Optional TileWithTensorIntrin(const s_tir::Schedule& sch, const ffi::String& intrin_name, bool allow_padding) { ffi::Optional opt_tensorize_info = GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block_rv), - tir::TensorIntrin::Get(intrin_name).value()->desc, allow_padding); + tirx::TensorIntrin::Get(intrin_name).value()->desc, allow_padding); if (!opt_tensorize_info) return std::nullopt; const TensorizeInfoNode* info = opt_tensorize_info.value().get(); if (info->block_iter_paddings.defined()) { @@ -372,8 +372,8 @@ ffi::Optional TileWithTensorIntrin(const s_tir::Schedule& sch, sch->ComputeInline(consumer); } } - // Construct a mapping from tir loops back to LoopRVs - ffi::Map loop2rv; + // Construct a mapping from tirx loops back to LoopRVs + ffi::Map loop2rv; { ffi::Array loop_rvs = sch->GetLoops(block_rv); for (const LoopRV& loop_rv : loop_rvs) { @@ -382,14 +382,14 @@ ffi::Optional TileWithTensorIntrin(const s_tir::Schedule& sch, } // Split the loops arith::Analyzer analyzer; - std::unordered_set inner_loops; + std::unordered_set inner_loops; std::vector reorder_suffix; reorder_suffix.resize(info->loop_map.size()); for (const auto& kv : info->loop_map) { // Extract mapping (block_loop => desc_loop) - const tir::StmtSRef& block_loop_sref = kv.first; - const tir::ForNode* block_loop = block_loop_sref->StmtAs(); - const tir::ForNode* desc_loop = kv.second.get(); + const tirx::StmtSRef& block_loop_sref = kv.first; + const tirx::ForNode* block_loop = block_loop_sref->StmtAs(); + const tirx::ForNode* desc_loop = kv.second.get(); TVM_FFI_ICHECK(block_loop != nullptr && desc_loop != nullptr); // Extract the loop extent PrimExpr block_extent = analyzer.Simplify(block_loop->extent); @@ -408,7 +408,7 @@ ffi::Optional TileWithTensorIntrin(const s_tir::Schedule& sch, TVM_FFI_ICHECK_EQ(split.size(), 2); inner_loops.insert(sch->GetSRef(split[1]).operator->()); // The inner split will be reordered to the loop domain that is tensorized - int desc_loop_index = info->desc_loop_indexer.at(ffi::GetRef(desc_loop)).IntValue(); + int desc_loop_index = info->desc_loop_indexer.at(ffi::GetRef(desc_loop)).IntValue(); reorder_suffix[desc_loop_index] = split[1]; } // Reorder the loops @@ -533,7 +533,7 @@ ffi::Optional NormalizePrimFunc(Schedule sch) { } } if (index_map_outputs.empty() || !has_spatial_iter) { - index_map_outputs.insert(index_map_outputs.begin(), tir::make_const(DataType::Int(64), 0)); + index_map_outputs.insert(index_map_outputs.begin(), tirx::make_const(DataType::Int(64), 0)); } try { sch->TransformBlockLayout(block, IndexMap(index_map_inputs, index_map_outputs)); diff --git a/src/s_tir/schedule/transform.h b/src/s_tir/schedule/transform.h index 6451d69354b4..cabaf8404bc1 100644 --- a/src/s_tir/schedule/transform.h +++ b/src/s_tir/schedule/transform.h @@ -21,17 +21,17 @@ #include #include -#include +#include #include #include #include "../../arith/ir_mutator_with_analyzer.h" -#include "../../tir/ir/functor_common.h" +#include "../../tirx/ir/functor_common.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /******** Annotation ********/ diff --git a/src/s_tir/schedule/utils.h b/src/s_tir/schedule/utils.h index b1c3903eb065..4026694df169 100644 --- a/src/s_tir/schedule/utils.h +++ b/src/s_tir/schedule/utils.h @@ -29,10 +29,10 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include @@ -52,7 +52,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief Convert an array of loop StmtSRefs to an array of loops * \param loop_srefs The loop StmtSRefs to be converted @@ -325,7 +325,7 @@ inline void ReorderAndFuseReductionLoops(const s_tir::Schedule& sch, s_tir::LoopRV* fused_reduce_loop, size_t* num_spatial_loops) { ffi::Array loops = sch->GetLoops(block_rv); - ffi::Array loop_srefs; + ffi::Array loop_srefs; for (const s_tir::LoopRV& loop_rv : loops) { loop_srefs.push_back(sch->GetSRef(loop_rv)); } @@ -334,7 +334,7 @@ inline void ReorderAndFuseReductionLoops(const s_tir::Schedule& sch, // Step 1. Add spatial loops. *num_spatial_loops = 0; for (size_t i = 0; i < loops.size(); ++i) { - if (GetLoopIterType(loop_srefs[i]) == tir::kDataPar) { + if (GetLoopIterType(loop_srefs[i]) == tirx::kDataPar) { new_order.push_back(loops[i]); (*num_spatial_loops)++; } @@ -342,7 +342,7 @@ inline void ReorderAndFuseReductionLoops(const s_tir::Schedule& sch, // Step 2. Add reduction loops. ffi::Array reduction_loops; for (size_t i = 0; i < loops.size(); ++i) { - if (GetLoopIterType(loop_srefs[i]) == tir::kCommReduce) { + if (GetLoopIterType(loop_srefs[i]) == tirx::kCommReduce) { new_order.push_back(loops[i]); reduction_loops.push_back(loops[i]); } @@ -385,15 +385,15 @@ inline ffi::String BufferIndexType2Str(BufferIndexType buffer_index_type) { /*! \brief Returns the names of the blocks in the provided module. */ inline std::unordered_set GetSBlockNames(const IRModule& mod) { - struct BlockNameCollector : public tir::StmtVisitor { - void VisitStmt_(const tir::SBlockNode* block) override { + struct BlockNameCollector : public tirx::StmtVisitor { + void VisitStmt_(const tirx::SBlockNode* block) override { block_names.insert(block->name_hint); StmtVisitor::VisitStmt(block->body); } std::unordered_set block_names; }; - if (auto prim_func = tir::FindEntryFunc(mod, nullptr)) { + if (auto prim_func = tirx::FindEntryFunc(mod, nullptr)) { BlockNameCollector collector; collector(prim_func->body); return collector.block_names; diff --git a/src/s_tir/transform/annotate_irregular_loop.cc b/src/s_tir/transform/annotate_irregular_loop.cc index 711d87c3af4f..8c92de64d991 100644 --- a/src/s_tir/transform/annotate_irregular_loop.cc +++ b/src/s_tir/transform/annotate_irregular_loop.cc @@ -22,13 +22,13 @@ #include #include #include -#include -#include -#include +#include +#include +#include namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class IrregularLoopAnnotator : public StmtMutator { public: @@ -45,7 +45,7 @@ class IrregularLoopAnnotator : public StmtMutator { TVM_FFI_ICHECK(op->kind == ForKind::kSerial) << "Loop kind " << op->kind << " is invalid for irregular loop " << op->loop_var; for (const char* key : - {tir::attr::pragma_auto_unroll_max_step, tir::attr::pragma_unroll_explicit, + {tirx::attr::pragma_auto_unroll_max_step, tirx::attr::pragma_unroll_explicit, s_tir::attr::pragma_loop_partition_hint, s_tir::attr::software_pipeline_stage}) { TVM_FFI_ICHECK(!res->annotations.count(key)) << "Annotation `" << key << "` is invalid for irregular loop " << op->loop_var; diff --git a/src/s_tir/transform/bound_checker.cc b/src/s_tir/transform/bound_checker.cc index 241492d2537f..4d3a9b1471c8 100644 --- a/src/s_tir/transform/bound_checker.cc +++ b/src/s_tir/transform/bound_checker.cc @@ -26,10 +26,10 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include @@ -39,7 +39,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; // TODO(Lunderberg): Move this pass to be before // FlattenBuffer. That will simplify this pass, @@ -49,7 +49,7 @@ class BoundCollector : public StmtVisitor { BoundCollector() {} void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::buffer_bound) { + if (op->attr_key == tirx::attr::buffer_bound) { const VarNode* key = op->node.as(); const CallNode* container = op->value.as(); if (key && container) { diff --git a/src/s_tir/transform/canonicalize_loop.cc b/src/s_tir/transform/canonicalize_loop.cc index 99ee0f614b14..3b5d0c589492 100644 --- a/src/s_tir/transform/canonicalize_loop.cc +++ b/src/s_tir/transform/canonicalize_loop.cc @@ -25,16 +25,16 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class LoopCanonicalizer : public StmtExprMutator { public: diff --git a/src/s_tir/transform/compact_buffer_region.cc b/src/s_tir/transform/compact_buffer_region.cc index b2e8db67bbe5..3d73fa1912b2 100644 --- a/src/s_tir/transform/compact_buffer_region.cc +++ b/src/s_tir/transform/compact_buffer_region.cc @@ -27,8 +27,8 @@ #include #include #include -#include -#include +#include +#include #include #include @@ -36,12 +36,12 @@ #include "../../support/arena.h" #include "../../support/nd_int_set.h" #include "../../support/utils.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" #include "../schedule/utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; using support::NDIntSet; @@ -312,7 +312,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { + if (op->attr_key == tirx::attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread) { IterVar iter = Downcast(op->node); ancestor_iters_.push_back(iter); Range dom = iter->dom; diff --git a/src/s_tir/transform/convert_blocks_to_opaque.cc b/src/s_tir/transform/convert_blocks_to_opaque.cc index 53b47735b89d..fe4faf063ab7 100644 --- a/src/s_tir/transform/convert_blocks_to_opaque.cc +++ b/src/s_tir/transform/convert_blocks_to_opaque.cc @@ -24,13 +24,13 @@ #include #include -#include +#include -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief Substitute expr via BlockRealize value bindings and convert each block into opaque diff --git a/src/s_tir/transform/decorate_device_scope.cc b/src/s_tir/transform/decorate_device_scope.cc index 37b6b4441690..f6330cb1b0c9 100644 --- a/src/s_tir/transform/decorate_device_scope.cc +++ b/src/s_tir/transform/decorate_device_scope.cc @@ -23,15 +23,15 @@ #include #include #include -#include -#include +#include +#include namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; Stmt DecorateDeviceScopeImpl(Stmt&& stmt) { - Stmt body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::device_scope, 0, stmt); + Stmt body = AttrStmt(make_zero(DataType::Int(32)), tirx::attr::device_scope, 0, stmt); return body; } diff --git a/src/s_tir/transform/default_gpu_schedule.cc b/src/s_tir/transform/default_gpu_schedule.cc index 216182e0f434..796e87aafcb2 100644 --- a/src/s_tir/transform/default_gpu_schedule.cc +++ b/src/s_tir/transform/default_gpu_schedule.cc @@ -23,7 +23,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; namespace transform { /*! * \brief A helper function to do default thread binding for a block. @@ -42,16 +42,16 @@ void ThreadBind(s_tir::Schedule sch, const s_tir::SBlockRV& block, int64_t max_t return; } } - ffi::Array iters = sch->Get(block)->iter_vars; + ffi::Array iters = sch->Get(block)->iter_vars; - // when there is no loops, tir will add a dummy iter var for the block + // when there is no loops, tirx will add a dummy iter var for the block // so loops.size() == 0 && iters.size() == 1 TVM_FFI_ICHECK(loops.size() == iters.size() || (loops.size() == 0 && iters.size() == 1)); ffi::Array data_parallel_loops; // only fuse data parallel loops for (size_t i = 0; i < loops.size(); ++i) { - if (iters[i]->iter_type == tir::IterVarType::kDataPar) { + if (iters[i]->iter_type == tirx::IterVarType::kDataPar) { data_parallel_loops.push_back(loops[i]); } } @@ -64,8 +64,8 @@ void ThreadBind(s_tir::Schedule sch, const s_tir::SBlockRV& block, int64_t max_t // fuse all data parallel loops s_tir::LoopRV fused = sch->Fuse(data_parallel_loops, /*preserve_unit_iters=*/false); int64_t product = std::numeric_limits::max(); - if (sch->Get(fused)->extent->IsInstance()) { - product = sch->Get(fused)->extent.as()->value; + if (sch->Get(fused)->extent->IsInstance()) { + product = sch->Get(fused)->extent.as()->value; } // schedule the fused loop if (product > max_thread_per_block * max_threadblocks) { @@ -87,9 +87,9 @@ IRModule MarkScheduled(const IRModule& mod) { ffi::Map result; for (const auto& [gv, base_func] : mod->functions) { - if (const auto* prim_func_node = base_func.as()) { - tir::PrimFunc prim_func = ffi::GetRef(prim_func_node); - tir::PrimFunc new_prim_func = WithAttr(std::move(prim_func), tir::attr::kIsScheduled, true); + if (const auto* prim_func_node = base_func.as()) { + tirx::PrimFunc prim_func = ffi::GetRef(prim_func_node); + tirx::PrimFunc new_prim_func = WithAttr(std::move(prim_func), tirx::attr::kIsScheduled, true); result.Set(gv, new_prim_func); } else { result.Set(gv, base_func); @@ -127,8 +127,8 @@ Pass DefaultGPUSchedule() { s_tir::Schedule sch = s_tir::Schedule::Traced(m, /*seed=*/-1, /*debug_mask=*/0, s_tir::ScheduleErrorRenderLevel::kDetail); for (const auto& [gv, func] : m->functions) { - if (func->IsInstance() && - !func->HasNonzeroAttr(tir::attr::kIsScheduled) && IsScheduledOnGPU(func)) { + if (func->IsInstance() && + !func->HasNonzeroAttr(tirx::attr::kIsScheduled) && IsScheduledOnGPU(func)) { // get the target from context. tvm::Target target = tvm::Target::Current(); // get the target from kTarget attribute diff --git a/src/s_tir/transform/hoist_expression.cc b/src/s_tir/transform/hoist_expression.cc index 6403f3a7801a..e13a2afed703 100644 --- a/src/s_tir/transform/hoist_expression.cc +++ b/src/s_tir/transform/hoist_expression.cc @@ -24,9 +24,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include @@ -36,11 +36,11 @@ #include "../../arith/interval_set.h" #include "../../arith/ir_mutator_with_analyzer.h" #include "../../runtime/thread_storage_scope.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; enum class HoistedConditionals : int { kNone = 0, @@ -577,8 +577,8 @@ Pass HoistExpression() { return tvm::transform::Sequential( { insertion_pass, - tir::transform::Simplify(), - tir::transform::RemoveNoOp(), + tirx::transform::Simplify(), + tirx::transform::RemoveNoOp(), }, "s_tir.HoistExpression"); } @@ -592,7 +592,7 @@ static Pass HoistIfThenElseImpl() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); auto cfg = ctx->GetConfig("s_tir.HoistIfThenElse"); - auto flag = f->GetAttr("tir.HoistIfThenElseExprWithBlock"); + auto flag = f->GetAttr("tirx.HoistIfThenElseExprWithBlock"); if (flag && flag.value().IntValue() == 1) { HoistExpressionConfig config(static_cast(HoistedConditionals::kUsingBlockVar) | static_cast(HoistedConditionals::kIfElseExpr), @@ -615,8 +615,8 @@ static Pass HoistIfThenElseImpl() { return tvm::transform::Sequential( { insertion_pass, - tir::transform::Simplify(), - tir::transform::RemoveNoOp(), + tirx::transform::Simplify(), + tirx::transform::RemoveNoOp(), }, "s_tir.HoistIfThenElse"); } @@ -633,8 +633,8 @@ static Pass HoistIfThenElseBasicImpl() { return tvm::transform::Sequential( { insertion_pass, - tir::transform::Simplify(), - tir::transform::RemoveNoOp(), + tirx::transform::Simplify(), + tirx::transform::RemoveNoOp(), }, "s_tir.HoistIfThenElseBasic"); } diff --git a/src/s_tir/transform/inject_double_buffer.cc b/src/s_tir/transform/inject_double_buffer.cc index 7a0835bcad34..786869598a1f 100644 --- a/src/s_tir/transform/inject_double_buffer.cc +++ b/src/s_tir/transform/inject_double_buffer.cc @@ -25,14 +25,14 @@ #include #include #include -#include -#include +#include +#include -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; struct InjectDoubleBufferConfigNode : public AttrsNodeReflAdapter { int split_loop; diff --git a/src/s_tir/transform/inject_permuted_layout.cc b/src/s_tir/transform/inject_permuted_layout.cc index ee0479bb914c..b5be6b540b34 100644 --- a/src/s_tir/transform/inject_permuted_layout.cc +++ b/src/s_tir/transform/inject_permuted_layout.cc @@ -24,18 +24,18 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include "../../arith/ir_mutator_with_analyzer.h" #include "../../runtime/thread_storage_scope.h" #include "../../support/utils.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; using namespace arith; using namespace runtime; diff --git a/src/s_tir/transform/inject_ptx_async_copy.cc b/src/s_tir/transform/inject_ptx_async_copy.cc index c0632fdd23ad..3c84a021b551 100644 --- a/src/s_tir/transform/inject_ptx_async_copy.cc +++ b/src/s_tir/transform/inject_ptx_async_copy.cc @@ -23,20 +23,20 @@ */ #include #include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include -#include "../../tir/ir/buffer_common.h" +#include "../../tirx/ir/buffer_common.h" #include "storage_access.h" #include "tvm/s_tir/stmt.h" -#include "tvm/tir/stmt.h" +#include "tvm/tirx/stmt.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class PTXAsyncCopyInjector : public StmtMutator { public: @@ -89,7 +89,7 @@ class PTXAsyncCopyInjector : public StmtMutator { if (predicated) { args.push_back(predicate_value); } - return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), args)); + return Evaluate(Call(store->buffer->dtype, tvm::tirx::builtin::ptx_cp_async(), args)); } // Predicated load don't support vectorized indexing. @@ -112,12 +112,12 @@ class PTXAsyncCopyInjector : public StmtMutator { auto* add = store->indices[0].as(); if (!add->a->IsInstance()) return PrimExpr(); if (!add->b->IsInstance()) return PrimExpr(); - return tir::Add(add->a.as()->base, add->b.as()->value); + return tirx::Add(add->a.as()->base, add->b.as()->value); } return PrimExpr(); }(); if (src_offset.defined() && dst_offset.defined()) { - return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), + return Evaluate(Call(store->buffer->dtype, tvm::tirx::builtin::ptx_cp_async(), {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)})); } @@ -140,14 +140,14 @@ class PTXAsyncCopyInjector : public StmtMutator { auto* add = store->indices[0].as(); if (!add->a->IsInstance()) return PrimExpr(); if (!add->b->IsInstance()) return PrimExpr(); - return tir::Add(add->a.as()->base, add->b.as()->value); + return tirx::Add(add->a.as()->base, add->b.as()->value); } return PrimExpr(); }(); if (src_offset.defined() && dst_offset.defined()) { return Evaluate( - Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), + Call(store->buffer->dtype, tvm::tirx::builtin::ptx_cp_async(), {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes), predicate_value})); } @@ -162,7 +162,7 @@ class PTXAsyncCopyInjector : public StmtMutator { if (auto* load = store->value.as()) { return InjectPTX(load, store); } else if (auto* call = store->value.as()) { - // tir.if_then_else is a call to tir::builtin::if_then_else() + // tirx.if_then_else is a call to tirx::builtin::if_then_else() if (call->op.same_as(builtin::if_then_else()) && call->args.size() == 3) { if (auto* load = call->args[1].as()) { // Only default value of 0 is supported since 0 is the default value used by cp.async diff --git a/src/s_tir/transform/inject_ptx_ldg32.cc b/src/s_tir/transform/inject_ptx_ldg32.cc index 4763699a432e..f02b253b29e8 100644 --- a/src/s_tir/transform/inject_ptx_ldg32.cc +++ b/src/s_tir/transform/inject_ptx_ldg32.cc @@ -22,17 +22,17 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include "../../arith/const_fold.h" #include "../../arith/pattern_match.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class PTXRewriter : public StmtMutator { public: @@ -65,7 +65,7 @@ class PTXRewriter : public StmtMutator { const CallNode* call = load_value.as(); if (call != nullptr) { const OpNode* op = call->op.as(); - if (op != nullptr && op->name == "tir.if_then_else") { + if (op != nullptr && op->name == "tirx.if_then_else") { const PrimExpr& predicate = call->args[0]; const PrimExpr& lhs = call->args[1]; const PrimExpr& rhs = call->args[2]; @@ -97,7 +97,7 @@ class PTXRewriter : public StmtMutator { new_predicate = BufferLoad(predicate_buffer, {IntImm(DataType::Int(32), 0)}); new_indice = BufferLoad(addr_buffer, {IntImm(DataType::Int(32), 1)}); BufferStore value_store(store->buffer, imm_value, {new_indice}); - Evaluate ptx_load(Call(store->buffer->dtype, tvm::tir::builtin::ptx_ldg32(), + Evaluate ptx_load(Call(store->buffer->dtype, tvm::tirx::builtin::ptx_ldg32(), {store->buffer->data, new_predicate, new_lhs, new_indice})); ffi::Array tmp_seq = {addr_store, local_addr_store, predicate_store, value_store, ptx_load}; diff --git a/src/s_tir/transform/inject_software_pipeline.cc b/src/s_tir/transform/inject_software_pipeline.cc index ac618018814d..340d8fd5f804 100644 --- a/src/s_tir/transform/inject_software_pipeline.cc +++ b/src/s_tir/transform/inject_software_pipeline.cc @@ -26,18 +26,18 @@ #include #include #include -#include +#include #include #include #include "../../support/utils.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" #include "../schedule/utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; namespace software_pipeline { diff --git a/src/s_tir/transform/inject_virtual_thread.cc b/src/s_tir/transform/inject_virtual_thread.cc index 6a74f36e0810..9f965733b119 100644 --- a/src/s_tir/transform/inject_virtual_thread.cc +++ b/src/s_tir/transform/inject_virtual_thread.cc @@ -24,18 +24,18 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include "../../arith/ir_mutator_with_analyzer.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; // If expression is touched by var. class ExprTouched final : public StmtExprVisitor { diff --git a/src/s_tir/transform/lift_thread_binding.cc b/src/s_tir/transform/lift_thread_binding.cc index e0f9987c172b..3aceecdb2867 100644 --- a/src/s_tir/transform/lift_thread_binding.cc +++ b/src/s_tir/transform/lift_thread_binding.cc @@ -24,14 +24,14 @@ #include #include -#include +#include #include "../../runtime/thread_storage_scope.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; std::pair>>, ObjectPtrHash, ObjectPtrEqual>, diff --git a/src/s_tir/transform/loop_partition.cc b/src/s_tir/transform/loop_partition.cc index 2afa17331934..718ced207a04 100644 --- a/src/s_tir/transform/loop_partition.cc +++ b/src/s_tir/transform/loop_partition.cc @@ -26,10 +26,10 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include @@ -37,11 +37,11 @@ #include "../../arith/interval_set.h" #include "../../runtime/thread_storage_scope.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; struct LoopPartitionConfigNode : public AttrsNodeReflAdapter { bool partition_const_loop; @@ -133,7 +133,7 @@ class CandidateSelector final : public StmtExprVisitor { } void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::thread_extent) { + if (op->attr_key == tirx::attr::thread_extent) { const IterVarNode* iv = op->node.as(); TVM_FFI_ICHECK(iv); Var var = iv->var; @@ -255,7 +255,7 @@ class PartitionFinder : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) final { // handle thread_axis - if (op->attr_key == tir::attr::thread_extent) { + if (op->attr_key == tirx::attr::thread_extent) { const IterVarNode* thread_axis = op->node.as(); TVM_FFI_ICHECK(thread_axis); const VarNode* var = thread_axis->var.get(); @@ -383,7 +383,7 @@ class ThreadPartitionInserter : public StmtMutator { : ps_(ps), cond_(cond), innermost_thread_scope_(false) {} Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::thread_extent) { + if (op->attr_key == tirx::attr::thread_extent) { innermost_thread_scope_ = true; Stmt stmt = StmtMutator::VisitStmt_(op); // add branch code inside the innermost thread scope @@ -438,7 +438,7 @@ class LoopPartitioner : public StmtMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key != tir::attr::thread_extent) { + if (op->attr_key != tirx::attr::thread_extent) { return StmtMutator::VisitStmt_(op); } diff --git a/src/s_tir/transform/lower_async_dma.cc b/src/s_tir/transform/lower_async_dma.cc index fb6e9260eef1..f38b48988061 100644 --- a/src/s_tir/transform/lower_async_dma.cc +++ b/src/s_tir/transform/lower_async_dma.cc @@ -28,19 +28,19 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include #include "../../arith/ir_mutator_with_analyzer.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { public: @@ -89,7 +89,7 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { // attr [0] "async_wait_inflight_count" = 0; // // To this: - // @tir.dma_wait( + // @tirx.dma_wait( // 0, /* queue id */ // 0, /* in flight count */ // dtype=int32 @@ -129,10 +129,10 @@ class AsyncDMALowerer : public arith::IRMutatorWithAnalyzer { // } // // To this: - // @tir.dma_copy( + // @tirx.dma_copy( // 0, /* queue id */ - // @tir.address_of(A_global[0], dtype=handle), - // @tir.address_of(A[0], dtype=handle), + // @tirx.address_of(A_global[0], dtype=handle), + // @tirx.address_of(A[0], dtype=handle), // 128, /* size */ // dtype=int32 // ) @@ -172,7 +172,7 @@ Pass LowerAsyncDMA() { auto fptr = f.CopyOnWrite(); arith::Analyzer analyzer; bool dma_bypass_cache = - ctx->GetConfig("tir.experimental_dma_bypass_cache", Bool(false)).value(); + ctx->GetConfig("tirx.experimental_dma_bypass_cache", Bool(false)).value(); fptr->body = AsyncDMALowerer(dma_bypass_cache, &analyzer)(std::move(fptr->body)); return f; }; diff --git a/src/s_tir/transform/lower_cross_thread_reduction.cc b/src/s_tir/transform/lower_cross_thread_reduction.cc index 03df34b61791..47a445bdbb3d 100644 --- a/src/s_tir/transform/lower_cross_thread_reduction.cc +++ b/src/s_tir/transform/lower_cross_thread_reduction.cc @@ -24,18 +24,18 @@ #include #include #include -#include -#include +#include +#include #include "../../runtime/thread_storage_scope.h" #include "../../support/utils.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" #include "../schedule/analysis.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; using runtime::ThreadScope; using support::StartsWith; @@ -104,7 +104,7 @@ bool IsDominantBlock(const SBlock& scope_block, const SBlock& block) { * \param analyzer The analyzer * \return A boolean indicating whether the input block is a reduction block. * \note A similar check has been implemented in "src/s_tir/schedule/analysis.h", but that check is - * based on `tir.Schedule`. Here we have no schedule information, and thus we must implement the + * based on `tirx.Schedule`. Here we have no schedule information, and thus we must implement the * check again. */ bool IsReductionBlock(const SBlockRealize& realize, const ffi::Map& loop_range_map, @@ -423,7 +423,7 @@ Stmt TransformReductionBlock(const SBlockRealizeNode* realize, /*value=*/make_zero(DataType::Handle()), /*body=*/ Evaluate(Call(/*dtype=*/DataType::Handle(), - /*op=*/tir::builtin::tvm_thread_allreduce(), + /*op=*/tirx::builtin::tvm_thread_allreduce(), /*args=*/std::move(parameters))))))); } // Stmt 4: write cross-thread reduction result to the original buffer diff --git a/src/s_tir/transform/lower_init_block.cc b/src/s_tir/transform/lower_init_block.cc index 0efded756b1b..4ce75a1435bc 100644 --- a/src/s_tir/transform/lower_init_block.cc +++ b/src/s_tir/transform/lower_init_block.cc @@ -23,14 +23,14 @@ */ #include #include -#include -#include +#include +#include -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class InitBlockLower : public StmtMutator { private: diff --git a/src/s_tir/transform/lower_match_buffer.cc b/src/s_tir/transform/lower_match_buffer.cc index 692f02895dc2..a33699308baf 100644 --- a/src/s_tir/transform/lower_match_buffer.cc +++ b/src/s_tir/transform/lower_match_buffer.cc @@ -25,16 +25,16 @@ #include #include #include -#include -#include -#include +#include +#include +#include -#include "../../tir/ir/functor_common.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/ir/functor_common.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class MatchBufferLower : public StmtExprMutator { public: explicit MatchBufferLower(const PrimFunc& func) { diff --git a/src/s_tir/transform/lower_opaque_block.cc b/src/s_tir/transform/lower_opaque_block.cc index 59c8b3e36c8d..8c3016808a68 100644 --- a/src/s_tir/transform/lower_opaque_block.cc +++ b/src/s_tir/transform/lower_opaque_block.cc @@ -24,13 +24,13 @@ #include #include #include -#include +#include -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief Remove SBlock to ensure that the TIR can not be scheduled again. @@ -144,7 +144,7 @@ class OpaqueBlockLower : public StmtExprMutator { ffi::String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" || thread_tag == "vthread.y" || thread_tag == "vthread.z") ? s_tir::attr::virtual_thread - : tir::attr::thread_extent; + : tirx::attr::thread_extent; return AttrStmt(/*node=*/std::move(iter_var), /*attr_key=*/std::move(attr_key), /*value=*/std::move(extent), @@ -181,7 +181,7 @@ class OpaqueBlockLower : public StmtExprMutator { pragma_attrs->clear(); for (const auto& kv : annotations) { const ffi::String& key = kv.first; - if (tir::attr::IsPragmaKey(key)) { + if (tirx::attr::IsPragmaKey(key)) { pragma_attrs->emplace_back(key, ConvertAttrValue(key, kv.second)); } else if (!is_block) { // the loop annotation is preserved diff --git a/src/s_tir/transform/lower_thread_allreduce.cc b/src/s_tir/transform/lower_thread_allreduce.cc index c1a4b4e83739..f1e5a3cfafe0 100644 --- a/src/s_tir/transform/lower_thread_allreduce.cc +++ b/src/s_tir/transform/lower_thread_allreduce.cc @@ -27,19 +27,19 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include "../../runtime/thread_storage_scope.h" -#include "../../tir/transform/ir_utils.h" -#include "../../tir/transform/update_pointer_storage_scope.h" +#include "../../tirx/transform/ir_utils.h" +#include "../../tirx/transform/update_pointer_storage_scope.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class ThreadAllreduceBuilder final : public StmtExprMutator { public: @@ -49,7 +49,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { max_num_threads_(target->GetAttr("max_num_threads", -1).value().IntValue()) {} Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::thread_extent) { + if (op->attr_key == tirx::attr::thread_extent) { thread_extents_.push_back(op); Stmt ret = StmtExprMutator::VisitStmt_(op); thread_extents_.pop_back(); @@ -102,7 +102,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { cow->buffer = replacement; if (replacement.scope() == "shared") { auto annotations = cow->annotations; - annotations.Set(tir::attr::kVolatile, Bool(true)); + annotations.Set(tirx::attr::kVolatile, Bool(true)); cow->annotations = annotations; } return node; @@ -864,7 +864,7 @@ class DeferredRemapper : public StmtExprMutator { cow->buffer = replacement; if (replacement.scope() == "shared") { auto annotations = cow->annotations; - annotations.Set(tir::attr::kVolatile, Bool(true)); + annotations.Set(tirx::attr::kVolatile, Bool(true)); cow->annotations = annotations; } } diff --git a/src/s_tir/transform/lower_vtcm_alloc.cc b/src/s_tir/transform/lower_vtcm_alloc.cc index 6b7e104ccb27..9aaaf7a1f655 100644 --- a/src/s_tir/transform/lower_vtcm_alloc.cc +++ b/src/s_tir/transform/lower_vtcm_alloc.cc @@ -19,14 +19,14 @@ #include #include -#include -#include +#include +#include #include "../../arith/ir_visitor_with_analyzer.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; inline bool IsVtcmStorage(std::string scope) { return scope.find("global.vtcm") != std::string::npos; diff --git a/src/s_tir/transform/manifest_shared_memory_local_stage.cc b/src/s_tir/transform/manifest_shared_memory_local_stage.cc index 5222d53f5371..f481b02b7ee4 100644 --- a/src/s_tir/transform/manifest_shared_memory_local_stage.cc +++ b/src/s_tir/transform/manifest_shared_memory_local_stage.cc @@ -30,19 +30,19 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include "../../runtime/thread_storage_scope.h" #include "../schedule/transform.h" -#include "tvm/tir/stmt.h" +#include "tvm/tirx/stmt.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! \brief Rewriter for the block storing to the target buffer. Create an intermediate cache stage * to store the result. Rewrite the original block to load from the intermediate buffer. diff --git a/src/s_tir/transform/memhammer_coalesce.cc b/src/s_tir/transform/memhammer_coalesce.cc index 44a925fda77b..ce57a21e1d28 100644 --- a/src/s_tir/transform/memhammer_coalesce.cc +++ b/src/s_tir/transform/memhammer_coalesce.cc @@ -21,7 +21,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief Fuse consecutive loops diff --git a/src/s_tir/transform/memhammer_intermediate_stage.cc b/src/s_tir/transform/memhammer_intermediate_stage.cc index 78f6170c56f7..7f3e6016a3c2 100644 --- a/src/s_tir/transform/memhammer_intermediate_stage.cc +++ b/src/s_tir/transform/memhammer_intermediate_stage.cc @@ -20,7 +20,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; Stmt CopyLoopChain(const std::vector loops, const Stmt& inner_body, int ith = -1, Stmt* ith_loop = nullptr) { @@ -279,7 +279,7 @@ std::pair InsertCacheStage(Stmt stmt, bool is_write_cache, ffi::S arith::Analyzer analyzer; const BufferLoadNode* target_buffer_load = nullptr; if (is_write_cache) { - tir::PreOrderVisit(stmt, [&](const ObjectRef& obj) { + tirx::PreOrderVisit(stmt, [&](const ObjectRef& obj) { if (const auto* buffer_load = obj.as()) { if (buffer_load->buffer.scope() == "wmma.accumulator" || buffer_load->buffer.scope() == "m16n8k8.matrixC") { diff --git a/src/s_tir/transform/memhammer_lower_auto_copy.cc b/src/s_tir/transform/memhammer_lower_auto_copy.cc index 59b535739967..e31606f0b9cb 100644 --- a/src/s_tir/transform/memhammer_lower_auto_copy.cc +++ b/src/s_tir/transform/memhammer_lower_auto_copy.cc @@ -23,22 +23,22 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include #include "../../runtime/thread_storage_scope.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" #include "../schedule/utils.h" #include "./memhammer_rewrite_rule.h" -#include "tvm/tir/stmt.h" +#include "tvm/tirx/stmt.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; using support::NDIntSet; diff --git a/src/s_tir/transform/memhammer_rewrite_rule.h b/src/s_tir/transform/memhammer_rewrite_rule.h index 90662dc17538..7cbdcc9c53dc 100644 --- a/src/s_tir/transform/memhammer_rewrite_rule.h +++ b/src/s_tir/transform/memhammer_rewrite_rule.h @@ -23,9 +23,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include @@ -33,7 +33,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! \brief The set containing all possible constraints of a data copy */ struct ConstraintSet { diff --git a/src/s_tir/transform/memhammer_tensorcore_rewrite.cc b/src/s_tir/transform/memhammer_tensorcore_rewrite.cc index 09776f8a0624..c4dab611de8d 100644 --- a/src/s_tir/transform/memhammer_tensorcore_rewrite.cc +++ b/src/s_tir/transform/memhammer_tensorcore_rewrite.cc @@ -21,7 +21,7 @@ namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief Tile the 2 innermost loops to extent=16. This helps further tensor core rewrite. diff --git a/src/s_tir/transform/merge_shared_memory_allocations.cc b/src/s_tir/transform/merge_shared_memory_allocations.cc index 7343f6366172..465f30bae135 100644 --- a/src/s_tir/transform/merge_shared_memory_allocations.cc +++ b/src/s_tir/transform/merge_shared_memory_allocations.cc @@ -27,9 +27,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include @@ -38,11 +38,11 @@ #include "../../runtime/thread_storage_scope.h" #include "../../support/arena.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; using runtime::StorageRank; using runtime::StorageScope; @@ -231,11 +231,11 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) final { // Only record the outer most thread extent. - if (op->attr_key == tir::attr::thread_extent && !in_thread_env_) { + if (op->attr_key == tirx::attr::thread_extent && !in_thread_env_) { in_thread_env_ = true; VisitNewScope(op); in_thread_env_ = false; - } else if (op->attr_key == tir::attr::extern_scope) { + } else if (op->attr_key == tirx::attr::extern_scope) { VisitNewScope(op); } else if (op->attr_key == s_tir::attr::virtual_thread) { VisitNewScope(op); @@ -296,7 +296,7 @@ class SharedMemoryRewriter : public StmtExprMutator { private: Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::thread_extent && !allocated_) { + if (op->attr_key == tirx::attr::thread_extent && !allocated_) { // Allocate one dynamic shared memory allocation at the beginning of thread scope int max_layer_num = 0; std::vector all_entry; @@ -342,7 +342,7 @@ class SharedMemoryRewriter : public StmtExprMutator { Stmt visited_body = StmtExprMutator::VisitStmt(op->body); ffi::Map annotations; if (has_volatile_alloc_) { - annotations.Set(tir::attr::kVolatile, Bool(true)); + annotations.Set(tirx::attr::kVolatile, Bool(true)); } Stmt alloc_stmt = AllocBuffer(merged_buf, annotations); Stmt new_body = SeqStmt::Flatten(alloc_stmt, visited_body); @@ -353,7 +353,7 @@ class SharedMemoryRewriter : public StmtExprMutator { Stmt VisitStmt_(const AllocBufferNode* op) final { if (IsAppropriateSharedMemory(op->buffer->data)) { - if (op->annotations.count(tir::attr::kVolatile)) { + if (op->annotations.count(tirx::attr::kVolatile)) { has_volatile_alloc_ = true; } return Evaluate(0); @@ -723,7 +723,7 @@ namespace transform { Pass MergeSharedMemoryAllocations() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - bool merge_static_smem = ctx->GetConfig("tir.merge_static_smem", Bool(false)).value(); + bool merge_static_smem = ctx->GetConfig("tirx.merge_static_smem", Bool(false)).value(); auto* n = f.CopyOnWrite(); n->body = s_tir::MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem); return f; diff --git a/src/s_tir/transform/plan_update_buffer_allocation_location.cc b/src/s_tir/transform/plan_update_buffer_allocation_location.cc index 3b66230a2cad..08f0095d42de 100644 --- a/src/s_tir/transform/plan_update_buffer_allocation_location.cc +++ b/src/s_tir/transform/plan_update_buffer_allocation_location.cc @@ -24,15 +24,15 @@ #include #include -#include -#include -#include +#include +#include +#include -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class CollectManagedAllocations : public StmtExprVisitor { public: diff --git a/src/s_tir/transform/profile_instrumentation.cc b/src/s_tir/transform/profile_instrumentation.cc index af7e2ebd4637..dcf9d1cf3b68 100644 --- a/src/s_tir/transform/profile_instrumentation.cc +++ b/src/s_tir/transform/profile_instrumentation.cc @@ -26,14 +26,14 @@ #include #include -#include -#include -#include -#include +#include +#include +#include +#include namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; namespace lwp { TVM_REGISTER_PASS_CONFIG_OPTION("s_tir.lwp_disable_func_prof", Bool); diff --git a/src/s_tir/transform/remove_store_undef.cc b/src/s_tir/transform/remove_store_undef.cc index 1d383bec12b9..b6e33c7cfb63 100644 --- a/src/s_tir/transform/remove_store_undef.cc +++ b/src/s_tir/transform/remove_store_undef.cc @@ -19,20 +19,20 @@ /*! * \file remove_store_undef.cc - * \brief Remove stores of tir::builtin::undef + * \brief Remove stores of tirx::builtin::undef */ #include #include #include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; struct UndefInfo { std::unordered_set undef_stores; @@ -202,7 +202,7 @@ Pass ValidateAllUndefRemoved() { Pass RemoveStoreUndef() { return tvm::transform::Sequential( - {RemoveStoreUndefInternal(), tir::transform::RemoveNoOp(), ValidateAllUndefRemoved()}, + {RemoveStoreUndefInternal(), tirx::transform::RemoveNoOp(), ValidateAllUndefRemoved()}, "s_tir.RemoveStoreUndef"); } diff --git a/src/s_tir/transform/remove_weight_layout_rewrite_block.cc b/src/s_tir/transform/remove_weight_layout_rewrite_block.cc index bfeee5d850e0..bbcc2b0fe242 100644 --- a/src/s_tir/transform/remove_weight_layout_rewrite_block.cc +++ b/src/s_tir/transform/remove_weight_layout_rewrite_block.cc @@ -25,15 +25,15 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class RemoveLayoutRewriteBlock : public StmtMutator { public: @@ -126,7 +126,7 @@ class WeightLayoutRewriteBlockRemover : public StmtMutator { PrimFuncNode* n = f_.CopyOnWrite(); - ffi::Map buffer_map; + ffi::Map buffer_map; for (const auto& [param, buffer] : f_->buffer_map) { auto it = buf_map.find(buffer); if (it != buf_map.end()) { diff --git a/src/s_tir/transform/renew_defs.cc b/src/s_tir/transform/renew_defs.cc index f34fb3cd856d..87aa20ae902c 100644 --- a/src/s_tir/transform/renew_defs.cc +++ b/src/s_tir/transform/renew_defs.cc @@ -24,13 +24,13 @@ #include #include -#include +#include -#include "../../tir/ir/functor_common.h" +#include "../../tirx/ir/functor_common.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; #define STMT_REGENERATE_VAR_DEF(NODE, FIELD) \ Stmt VisitStmt_(const NODE* op) final { \ @@ -66,7 +66,7 @@ class RenewDefMutator : public StmtExprMutator { } // Redefine buffers in order // TODO(Siyuan Feng): checking var is used after define - ffi::Map buffer_map; + ffi::Map buffer_map; for (const auto& param : func->params) { if (param->dtype.is_handle()) { const Buffer& buffer = func->buffer_map.at(param); diff --git a/src/s_tir/transform/renormalize_split_pattern.cc b/src/s_tir/transform/renormalize_split_pattern.cc index 365dc344b3e3..ae3d048b8892 100644 --- a/src/s_tir/transform/renormalize_split_pattern.cc +++ b/src/s_tir/transform/renormalize_split_pattern.cc @@ -24,17 +24,17 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include "../../arith/ir_mutator_with_analyzer.h" #include "../../arith/pattern_match.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; using namespace arith; @@ -148,7 +148,7 @@ class SplitPatternReNormalizer : public IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const LTNode* op) { PrimExpr a = VisitExpr(op->a); PrimExpr b = VisitExpr(op->b); - PrimExpr ret = tir::LT(a, b); + PrimExpr ret = tirx::LT(a, b); // Pattern var to match any expression PVar x; // Pattern var match IntImm diff --git a/src/s_tir/transform/rewrite_unsafe_select.cc b/src/s_tir/transform/rewrite_unsafe_select.cc index 267d9b4d008f..f43d3da820af 100644 --- a/src/s_tir/transform/rewrite_unsafe_select.cc +++ b/src/s_tir/transform/rewrite_unsafe_select.cc @@ -24,14 +24,14 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; // For now, rewrite unsafe select expression to if_then_else // TODO(tqchen) pattern matching to support masked load diff --git a/src/s_tir/transform/storage_access.cc b/src/s_tir/transform/storage_access.cc index ad0c9a5bd84d..928672391e74 100644 --- a/src/s_tir/transform/storage_access.cc +++ b/src/s_tir/transform/storage_access.cc @@ -23,16 +23,16 @@ #include "storage_access.h" #include -#include +#include #include #include -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; void StorageAccessVisitor::VisitExpr_(const BufferLoadNode* op) { Var buf = op->buffer->data; @@ -126,7 +126,7 @@ void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { scope_.back().emplace_back(std::move(s)); } double_buffer_write_ = nullptr; - } else if (op->attr_key == tir::attr::thread_extent) { + } else if (op->attr_key == tirx::attr::thread_extent) { IterVar iv = Downcast(op->node); env_threads_.push_back(iv); if (!in_device_env_) { diff --git a/src/s_tir/transform/storage_access.h b/src/s_tir/transform/storage_access.h index 7635996acb7a..c9125f86881e 100644 --- a/src/s_tir/transform/storage_access.h +++ b/src/s_tir/transform/storage_access.h @@ -26,8 +26,8 @@ #include #include -#include -#include +#include +#include #include #include @@ -36,7 +36,7 @@ namespace tvm { namespace s_tir { -using namespace tir; +using namespace tirx; using runtime::StorageRank; using runtime::StorageScope; diff --git a/src/s_tir/transform/tensorcore_infer_fragment.cc b/src/s_tir/transform/tensorcore_infer_fragment.cc index bb3ef30d7120..89f003cf456e 100644 --- a/src/s_tir/transform/tensorcore_infer_fragment.cc +++ b/src/s_tir/transform/tensorcore_infer_fragment.cc @@ -25,19 +25,19 @@ #include #include #include -#include -#include +#include +#include #include #include #include "../../runtime/thread_storage_scope.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" #include "storage_access.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; // Get fragment information from tensor intrinsics class FragmentGetter : public StmtExprVisitor { @@ -117,13 +117,13 @@ class FragmentGetter : public StmtExprVisitor { } // namespace s_tir -namespace tir { +namespace tirx { std::unordered_map GetTensorCoreFragmentInfo(const Stmt& stmt) { s_tir::FragmentGetter getter; getter(stmt); return std::move(getter.fragments); } -} // namespace tir +} // namespace tirx namespace s_tir { diff --git a/src/s_tir/transform/thread_storage_sync.cc b/src/s_tir/transform/thread_storage_sync.cc index 834fdd79fb3c..f854de249a92 100644 --- a/src/s_tir/transform/thread_storage_sync.cc +++ b/src/s_tir/transform/thread_storage_sync.cc @@ -24,21 +24,21 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include #include "../../runtime/thread_storage_scope.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" #include "storage_access.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; class ThreadSyncPlanner : public StorageAccessVisitor { public: @@ -233,7 +233,7 @@ class ThreadSyncPlanner : public StorageAccessVisitor { PrimExpr curr_index = curr_intset.PointValue(); has_same_index = ExprDeepEqual()(prev_index, curr_index); if (thread_index_var != nullptr) { - auto f_uses_thread_index = [=](const tvm::tir::VarNode* parameter) { + auto f_uses_thread_index = [=](const tvm::tirx::VarNode* parameter) { return parameter == thread_index_var; }; depends_on_thread_index = depends_on_thread_index && @@ -349,7 +349,7 @@ class ThreadSyncInserter : public StmtExprMutator { return StmtExprMutator::VisitStmt_(op); } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::thread_extent) { + if (op->attr_key == tirx::attr::thread_extent) { bool temp = true; std::swap(temp, in_thread_env_); thread_extents_.push_back(op); @@ -373,7 +373,7 @@ class ThreadSyncInserter : public StmtExprMutator { if (volatile_vars_.count(op->buffer->data.get())) { auto* cow = node.CopyOnWrite(); auto annotations = cow->annotations; - annotations.Set(tir::attr::kVolatile, Bool(true)); + annotations.Set(tirx::attr::kVolatile, Bool(true)); cow->annotations = annotations; } return node; diff --git a/src/s_tir/transform/transform_mma_buffer_layout.cc b/src/s_tir/transform/transform_mma_buffer_layout.cc index 1437a89f93dd..832731ae0d6a 100644 --- a/src/s_tir/transform/transform_mma_buffer_layout.cc +++ b/src/s_tir/transform/transform_mma_buffer_layout.cc @@ -20,15 +20,15 @@ #include #include #include -#include -#include -#include +#include +#include +#include -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; /*! * \brief Rewriter for all m16n8k8.matrix[A/B/C] buffer. This pass mainly do two things: @@ -37,7 +37,7 @@ using namespace tvm::tir; * 2. Rewrite access of m16n8k8.matrixC so it can access the correct part of the matrix. * The reason why access of m16n8k8.matrix[A/B] buffer doesn't need this kind of rewrite is * that their access is through opaque access inside ldmatrix and mma_sync. Please refer to - * get_index_[A/B] in python/tvm/tir/tensor_intrin/cuda.py. + * get_index_[A/B] in python/tvm/tirx/tensor_intrin/cuda.py. * We cannot use this kind of opaque access in matrixC too since the ptx stmatrix is only * supported for sm90 or higher. Therefore, writeback of matrixC is limited to the * transparent way. @@ -131,7 +131,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { if (buffer_map_.count(store->buffer)) { auto* n = store.CopyOnWrite(); if (store->buffer.scope() == "m16n8k8.matrixC") { - const auto index_map_func = tvm::ffi::Function::GetGlobal("tir.index_map_m16n8k8.matrixC"); + const auto index_map_func = tvm::ffi::Function::GetGlobal("tirx.index_map_m16n8k8.matrixC"); TVM_FFI_ICHECK(index_map_func.has_value()); auto index_map = IndexMap::FromFunc(2, *index_map_func); auto new_indices = index_map->MapIndices(store->indices, &analyzer); @@ -150,7 +150,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { if (buffer_map_.count(load->buffer)) { auto* n = load.CopyOnWrite(); if (load->buffer.scope() == "m16n8k8.matrixC") { - const auto index_map_func = tvm::ffi::Function::GetGlobal("tir.index_map_m16n8k8.matrixC"); + const auto index_map_func = tvm::ffi::Function::GetGlobal("tirx.index_map_m16n8k8.matrixC"); TVM_FFI_ICHECK(index_map_func.has_value()); auto index_map = IndexMap::FromFunc(2, *index_map_func); auto new_indices = index_map->MapIndices(load->indices, &analyzer); diff --git a/src/s_tir/transform/unify_thread_binding.cc b/src/s_tir/transform/unify_thread_binding.cc index c33380175fa6..42b84d190da4 100644 --- a/src/s_tir/transform/unify_thread_binding.cc +++ b/src/s_tir/transform/unify_thread_binding.cc @@ -25,15 +25,15 @@ #include #include #include -#include -#include +#include +#include #include "../../support/utils.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; using support::StartsWith; @@ -49,7 +49,7 @@ class ThreadBindingUnifier : public StmtExprMutator { private: Stmt VisitStmt_(const AttrStmtNode* op) final { // If this AttrStmt is not thread binding attribute, return as usual. - if (op->attr_key != tir::attr::thread_extent && op->attr_key != s_tir::attr::virtual_thread) { + if (op->attr_key != tirx::attr::thread_extent && op->attr_key != s_tir::attr::virtual_thread) { return StmtMutator::VisitStmt_(op); } IterVar old_iter_var = Downcast(op->node); diff --git a/src/s_tir/transform/using_assume_to_reduce_branches.cc b/src/s_tir/transform/using_assume_to_reduce_branches.cc index e506d1985431..d5921de12209 100644 --- a/src/s_tir/transform/using_assume_to_reduce_branches.cc +++ b/src/s_tir/transform/using_assume_to_reduce_branches.cc @@ -39,10 +39,10 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include @@ -51,7 +51,7 @@ #include "tvm/ir/expr.h" namespace tvm { namespace s_tir { -using namespace tvm::tir; +using namespace tvm::tirx; using namespace arith; @@ -122,20 +122,20 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { using Parent::VisitStmt_; // This struct stores all the relevant data related to asssume statement - struct assume_struct { // Consider the example : T.assume(i < 14 or A[i] == 0) - PrimExpr buffer_context; // The context of the assume statement (the bound on the axis) - PrimExpr buffer_predicate; // The condition inside assume statement (i < 14) excluding - // bufferload expression (A[i] == 0) - tir::BufferLoad buffer_load; // Storing the buffer load Eg: A[i] in A[i] == 0 - PrimExpr buffer_value; // Storing the value for the buffer Eg : 0 in A[i] == 0 + struct assume_struct { // Consider the example : T.assume(i < 14 or A[i] == 0) + PrimExpr buffer_context; // The context of the assume statement (the bound on the axis) + PrimExpr buffer_predicate; // The condition inside assume statement (i < 14) excluding + // bufferload expression (A[i] == 0) + tirx::BufferLoad buffer_load; // Storing the buffer load Eg: A[i] in A[i] == 0 + PrimExpr buffer_value; // Storing the value for the buffer Eg : 0 in A[i] == 0 ffi::Array buffer_indices; // Storing the indices of the buffer Eg : i }; // List of conditions in a scope std::vector conditions_; // Storing all the buffer assumptions data in map - std::map map_buffer_assumption; - tir::Buffer current_bufferstorenode_name; + std::map map_buffer_assumption; + tirx::Buffer current_bufferstorenode_name; struct InternalConstraintContext { /* This stuct appends the constraint passed to it in the conditions list. @@ -144,10 +144,10 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { : self(self), analyzer_context(self->analyzer_, constraint) { old_num_constraints = self->conditions_.size(); - auto side_effect = tir::SideEffect(constraint); - if (side_effect <= tir::CallEffectKind::kPure) { + auto side_effect = tirx::SideEffect(constraint); + if (side_effect <= tirx::CallEffectKind::kPure) { self->conditions_.push_back(constraint); - } else if (side_effect <= tir::CallEffectKind::kReadState) { + } else if (side_effect <= tirx::CallEffectKind::kReadState) { assume = constraint; } @@ -285,8 +285,8 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { std::vector buffer_exprs; for (const auto& expr : arith::ExtractComponents(assumption)) { - auto side_effect = tir::SideEffect(expr); - if (side_effect <= tir::CallEffectKind::kPure) { + auto side_effect = tirx::SideEffect(expr); + if (side_effect <= tirx::CallEffectKind::kPure) { // Pulling out portions of the assumption that do not depend // on a buffer value allows the following two forms to be // treated identically. @@ -294,7 +294,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { // Option 1: if i < 3: T.assume(buf[i] == value) // Option 2: T.assume(i>=3 or buf[i] == value) additional_predicate = additional_predicate && logical_not(expr); - } else if (side_effect == tir::CallEffectKind::kReadState) { + } else if (side_effect == tirx::CallEffectKind::kReadState) { buffer_exprs.push_back(expr); } else { TVM_FFI_THROW(InternalError) @@ -307,7 +307,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { TVM_FFI_ICHECK_EQ(buffer_exprs.size(), 1) << "T.assume must contain only a single buffer expression"; - auto* as_equal_node = buffer_exprs[0].as(); + auto* as_equal_node = buffer_exprs[0].as(); TVM_FFI_ICHECK(as_equal_node) << "T.assume buffer constraint must be of the form 'buffer[indices] == " "value', but received " @@ -321,12 +321,12 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { // Parse the statement and store the desired values // Ex: A[i]==0, load = A[i], value = 0 - tir::BufferLoad load; + tirx::BufferLoad load; PrimExpr value; - if (auto opt = as_equal_node->a.as()) { + if (auto opt = as_equal_node->a.as()) { load = opt.value(); value = as_equal_node->b; - } else if (auto opt = as_equal_node->b.as()) { + } else if (auto opt = as_equal_node->b.as()) { load = opt.value(); value = as_equal_node->a; } else { @@ -346,7 +346,7 @@ class ParseAssumeAndOvercompute : public IRMutatorWithAnalyzer { } map_buffer_assumption[buf_data.buffer_load->buffer] = buf_data; - auto has_side_effect = tir::SideEffect(value) > tir::CallEffectKind::kPure; + auto has_side_effect = tirx::SideEffect(value) > tirx::CallEffectKind::kPure; TVM_FFI_ICHECK(!has_side_effect) << "Buffer value in constraint must be pure expression, but was " << value; if (has_side_effect) { diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index cee620d3ddec..5bcf0dca7740 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -21,8 +21,8 @@ #include #include #include -#include -#include +#include +#include #include "./utils.h" @@ -41,7 +41,7 @@ IRModuleFrame IRModule() { inline relax::StructInfo GetGlobalVarStructInfo(const BaseFunc& func) { if (func->struct_info_.defined()) { return tvm::relax::GetStructInfo(func); - } else if (const auto* prim_func = func.as()) { + } else if (const auto* prim_func = func.as()) { return tvm::relax::FuncStructInfo::OpaqueFunc( tvm::relax::StructInfoFromType(prim_func->ret_type)); } else { @@ -55,7 +55,7 @@ GlobalVar DeclFunction(const ffi::String& func_name, const BaseFunc& func_signat << "function " << func_name << " already exists"; auto gvar_type = [&]() -> Type { - if (auto prim_func = func_signature.as()) { + if (auto prim_func = func_signature.as()) { ffi::Array arg_types = prim_func->params.Map([](const auto& var) { return GetType(var); }); return FuncType(arg_types, prim_func->ret_type); diff --git a/src/script/ir_builder/relax/distributed.cc b/src/script/ir_builder/relax/distributed.cc index a505d317787f..be2bf1a22aa5 100644 --- a/src/script/ir_builder/relax/distributed.cc +++ b/src/script/ir_builder/relax/distributed.cc @@ -22,7 +22,7 @@ #include #include #include -#include +#include #include "./utils.h" diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 35b44feae399..b5a434b7bb92 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include "./utils.h" diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tirx/frame.cc similarity index 81% rename from src/script/ir_builder/tir/frame.cc rename to src/script/ir_builder/tirx/frame.cc index 71dd98d481ca..659c23bf3b25 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tirx/frame.cc @@ -17,16 +17,16 @@ * under the License. */ #include -#include -#include +#include +#include -#include "../../../tir/ir/script/script_complete.h" +#include "../../../tirx/ir/script/script_complete.h" #include "./utils.h" namespace tvm { namespace script { namespace ir_builder { -namespace tir { +namespace tirx { TVM_FFI_STATIC_INIT_BLOCK() { TIRFrameNode::RegisterReflection(); @@ -51,13 +51,13 @@ void PrimFuncFrameNode::ExitWithScope() { attrs.Set(tvm::attr::kGlobalSymbol, name.value()); } - tvm::tir::PrimFunc func( + tvm::tirx::PrimFunc func( /*params=*/args, /*body=*/AsStmt(stmts), /*ret_type=*/ret_type.value_or(TupleType::Empty()), /*buffer_map=*/buffer_map, /*attrs=*/DictAttrs(attrs)); - func = tvm::tir::ScriptComplete(func, root_alloc_buffers); + func = tvm::tirx::ScriptComplete(func, root_alloc_buffers); IRBuilder builder = IRBuilder::Current(); if (builder->frames.empty()) { TVM_FFI_CHECK(!builder->result.defined(), ValueError) << "Builder.result has already been set"; @@ -82,17 +82,17 @@ void PrimFuncFrameNode::ExitWithScope() { void SBlockFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); - ffi::Array tir_alloc_buffers; - for (const tvm::tir::Buffer& buffer : alloc_buffers) { + ffi::Array tir_alloc_buffers; + for (const tvm::tirx::Buffer& buffer : alloc_buffers) { tir_alloc_buffers.push_back(buffer); } ffi::Map attrs = annotations.value_or({}); if (int detect_access = (!reads.defined()) | (!writes.defined() << 1)) { - attrs.Set("tir.script_parsing_detect_access", tvm::IntImm(DataType::Int(64), detect_access)); + attrs.Set("tirx.script_parsing_detect_access", tvm::IntImm(DataType::Int(64), detect_access)); } - tvm::tir::SBlock block(iter_vars, reads.value_or(ffi::Array()), - writes.value_or(ffi::Array()), name, AsStmt(stmts), - init, tir_alloc_buffers, match_buffers, attrs); + tvm::tirx::SBlock block(iter_vars, reads.value_or(ffi::Array()), + writes.value_or(ffi::Array()), name, + AsStmt(stmts), init, tir_alloc_buffers, match_buffers, attrs); if (no_realize) { TVM_FFI_CHECK(iter_values.empty(), ValueError) << "Block bindings are not allowed when `no_realize=True`"; @@ -100,7 +100,7 @@ void SBlockFrameNode::ExitWithScope() { << "`T.where` is not allowed when `no_realize=True`"; AddToParent(block); } else { - AddToParent(tvm::tir::SBlockRealize(iter_values, predicate.value_or(Bool(true)), block)); + AddToParent(tvm::tirx::SBlockRealize(iter_values, predicate.value_or(Bool(true)), block)); } } @@ -126,30 +126,30 @@ void ForFrameNode::ExitWithScope() { void AssertFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); if (stmts.empty()) { - AddToParent(tvm::tir::AssertStmt(condition, error_kind, message_parts)); + AddToParent(tvm::tirx::AssertStmt(condition, error_kind, message_parts)); } else { - ffi::Array seq; - seq.push_back(tvm::tir::AssertStmt(condition, error_kind, message_parts)); + ffi::Array seq; + seq.push_back(tvm::tirx::AssertStmt(condition, error_kind, message_parts)); for (const auto& stmt : stmts) { seq.push_back(stmt); } - AddToParent(tvm::tir::SeqStmt(seq)); + AddToParent(tvm::tirx::SeqStmt(seq)); } } void LaunchThreadFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); - AddToParent(tvm::tir::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts))); + AddToParent(tvm::tirx::AttrStmt(iter_var, attr_key, extent, AsStmt(stmts))); } void AttrFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); - AddToParent(tvm::tir::AttrStmt(node, attr_key, value, AsStmt(stmts))); + AddToParent(tvm::tirx::AttrStmt(node, attr_key, value, AsStmt(stmts))); } void WhileFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); - AddToParent(tvm::tir::While(condition, AsStmt(stmts))); + AddToParent(tvm::tirx::While(condition, AsStmt(stmts))); } void IfFrameNode::ExitWithScope() { @@ -161,9 +161,9 @@ void IfFrameNode::ExitWithScope() { if (!then_stmts.defined()) { TVM_FFI_THROW(InternalError) << "IfThenElse frame should have at least one then branch"; } - AddToParent(tvm::tir::IfThenElse( + AddToParent(tvm::tirx::IfThenElse( condition, AsStmt(then_stmts.value()), - else_stmts.defined() ? AsStmt(else_stmts.value()) : tvm::tir::Stmt(nullptr))); + else_stmts.defined() ? AsStmt(else_stmts.value()) : tvm::tirx::Stmt(nullptr))); } void ThenFrameNode::EnterWithScope() { @@ -197,7 +197,7 @@ void ElseFrameNode::ExitWithScope() { FindIfFrame("T.else_")->else_stmts = stmts; } -} // namespace tir +} // namespace tirx } // namespace ir_builder } // namespace script } // namespace tvm diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tirx/ir.cc similarity index 76% rename from src/script/ir_builder/tir/ir.cc rename to src/script/ir_builder/tirx/ir.cc index f3522daa754c..6695f6840ef6 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tirx/ir.cc @@ -19,16 +19,16 @@ #include #include #include -#include +#include #include "./utils.h" namespace tvm { namespace script { namespace ir_builder { -namespace tir { +namespace tirx { -using tvm::tir::IterVar; +using tvm::tirx::IterVar; Buffer BufferDecl(ffi::Array shape, DataType dtype, ffi::String buffer_name, ffi::Optional data, ffi::Optional> strides, @@ -44,17 +44,17 @@ Buffer BufferDecl(ffi::Array shape, DataType dtype, ffi::String buffer if (storage_dtype == DataType::Bool()) { storage_dtype = DataType::Int(8); } - buffer_data = tvm::tir::Var(buffer_name, PointerType(PrimType(storage_dtype), storage_scope)); + buffer_data = tvm::tirx::Var(buffer_name, PointerType(PrimType(storage_dtype), storage_scope)); } else { buffer_data = data.value(); } if (!elem_offset.defined() && offset_factor) { DataType shape_dtype = shape.empty() ? DataType::Int(32) : shape[0]->dtype; - elem_offset = tvm::tir::Var("elem_offset", shape_dtype); + elem_offset = tvm::tirx::Var("elem_offset", shape_dtype); } return Buffer(buffer_data, dtype, shape, strides.value_or(ffi::Array()), elem_offset.value_or(PrimExpr()), buffer_name, align, offset_factor, - (buffer_type == "auto" ? tvm::tir::kAutoBroadcast : tvm::tir::kDefault), + (buffer_type == "auto" ? tvm::tirx::kAutoBroadcast : tvm::tirx::kDefault), axis_separators.value_or(ffi::Array())); } @@ -97,7 +97,7 @@ void FuncName(ffi::String name) { } void FuncAttrs(ffi::Map new_attrs) { - using namespace tvm::tir; + using namespace tvm::tirx; PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr"); for (const auto& [key, value] : new_attrs) { if (key == tvm::attr::kGlobalSymbol && frame->is_private) { @@ -133,7 +133,7 @@ Buffer MatchBuffer(ObjectRef param, ffi::Array shape, DataType dtype, ffi::String buffer_type_str, ffi::Optional> axis_separators) { Buffer buffer = BufferDecl(shape, dtype, "", data, strides, elem_offset, storage_scope, align, offset_factor, buffer_type_str, axis_separators); - if (const auto* var = param.as()) { + if (const auto* var = param.as()) { PrimFuncFrame frame = FindPrimFuncFrame("T.match_buffer"); Var v = ffi::GetRef(var); for (auto const& arg : frame->args) { @@ -143,14 +143,14 @@ Buffer MatchBuffer(ObjectRef param, ffi::Array shape, DataType dtype, } } TVM_FFI_THROW(ValueError) << "Can not bind non-input param to buffer."; - } else if (const auto* buffer_load = param.as()) { + } else if (const auto* buffer_load = param.as()) { SBlockFrame frame = FindSBlockFrame("T.match_buffer"); - frame->match_buffers.push_back(tvm::tir::MatchBufferRegion( - buffer, BufferRegionFromLoad(ffi::GetRef(buffer_load)))); - } else if (const auto* buffer_region = param.as()) { + frame->match_buffers.push_back(tvm::tirx::MatchBufferRegion( + buffer, BufferRegionFromLoad(ffi::GetRef(buffer_load)))); + } else if (const auto* buffer_region = param.as()) { SBlockFrame frame = FindSBlockFrame("T.match_buffer"); frame->match_buffers.push_back( - tvm::tir::MatchBufferRegion(buffer, ffi::GetRef(buffer_region))); + tvm::tirx::MatchBufferRegion(buffer, ffi::GetRef(buffer_region))); } else { TVM_FFI_THROW(ValueError) << "Unexpected type for TIR MatchBuffer."; } @@ -185,7 +185,7 @@ void Where(PrimExpr predicate) { } void Reads(ffi::Array buffer_slices) { - using namespace tvm::tir; + using namespace tvm::tirx; SBlockFrame frame = FindSBlockFrame("T.reads"); if (frame->reads.defined()) { TVM_FFI_THROW(ValueError) << "Duplicate read region declaration, previous one is " @@ -205,7 +205,7 @@ void Reads(ffi::Array buffer_slices) { } void Writes(ffi::Array buffer_slices) { - using namespace tvm::tir; + using namespace tvm::tirx; SBlockFrame frame = FindSBlockFrame("T.writes"); if (frame->writes.defined()) { TVM_FFI_THROW(ValueError) << "Duplicate write region declaration, previous one is " @@ -307,14 +307,14 @@ IterVar PushBlockVar(IterVar iter_var, PrimExpr binding) { binding) \ ->var; \ } -TVM_TIR_IR_BUILDER_AXIS(Spatial, tvm::tir::IterVarType::kDataPar, "Spatial"); -TVM_TIR_IR_BUILDER_AXIS(Reduce, tvm::tir::IterVarType::kCommReduce, "Reduction"); -TVM_TIR_IR_BUILDER_AXIS(Scan, tvm::tir::IterVarType::kOrdered, "Scan"); -TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tir::IterVarType::kOpaque, "Opaque"); +TVM_TIR_IR_BUILDER_AXIS(Spatial, tvm::tirx::IterVarType::kDataPar, "Spatial"); +TVM_TIR_IR_BUILDER_AXIS(Reduce, tvm::tirx::IterVarType::kCommReduce, "Reduction"); +TVM_TIR_IR_BUILDER_AXIS(Scan, tvm::tirx::IterVarType::kOrdered, "Scan"); +TVM_TIR_IR_BUILDER_AXIS(Opaque, tvm::tirx::IterVarType::kOpaque, "Opaque"); #undef TVM_TIR_IR_BUILDER_AXIS ffi::Array Remap(ffi::String kinds, ffi::Array bindings, DataType dtype) { - using namespace tvm::tir; + using namespace tvm::tirx; ffi::Array results; TVM_FFI_ICHECK_EQ(kinds.size(), bindings.size()); int n = bindings.size(); @@ -366,39 +366,39 @@ ffi::Array Remap(ffi::String kinds, ffi::Array bindings, DataType } // namespace axis -#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \ - ForFrame Method(PrimExpr start, PrimExpr stop, \ - ffi::Optional> annotations, \ - ffi::Optional step) { \ - PrimExpr min = start; \ - PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ - ObjectPtr n = ffi::make_object(); \ - int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ - n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))}; \ - n->doms = {Range::FromMinExtent(min, extent)}; \ - n->steps = {step}; \ - n->f_make_for_loop = [annotations](ffi::Array vars, ffi::Array doms, \ - ffi::Array> steps, \ - tvm::tir::Stmt body) { \ - TVM_FFI_ICHECK_EQ(vars.size(), 1); \ - TVM_FFI_ICHECK_EQ(doms.size(), 1); \ - TVM_FFI_ICHECK_EQ(steps.size(), 1); \ - return tvm::tir::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, std::nullopt, \ - annotations.value_or(ffi::Map()), steps[0]); \ - }; \ - return ForFrame(n); \ +#define TVM_TIR_IR_BUILDER_FOR_FRAME(Method, Kind) \ + ForFrame Method(PrimExpr start, PrimExpr stop, \ + ffi::Optional> annotations, \ + ffi::Optional step) { \ + PrimExpr min = start; \ + PrimExpr extent = arith::Analyzer().Simplify(stop - start); \ + ObjectPtr n = ffi::make_object(); \ + int bits = std::max(min.dtype().bits(), extent.dtype().bits()); \ + n->vars = {Var("v", DataType(min.dtype().code(), bits, 1))}; \ + n->doms = {Range::FromMinExtent(min, extent)}; \ + n->steps = {step}; \ + n->f_make_for_loop = [annotations](ffi::Array vars, ffi::Array doms, \ + ffi::Array> steps, \ + tvm::tirx::Stmt body) { \ + TVM_FFI_ICHECK_EQ(vars.size(), 1); \ + TVM_FFI_ICHECK_EQ(doms.size(), 1); \ + TVM_FFI_ICHECK_EQ(steps.size(), 1); \ + return tvm::tirx::For(vars[0], doms[0]->min, doms[0]->extent, Kind, body, std::nullopt, \ + annotations.value_or(ffi::Map()), steps[0]); \ + }; \ + return ForFrame(n); \ } -TVM_TIR_IR_BUILDER_FOR_FRAME(Serial, tvm::tir::ForKind::kSerial); -TVM_TIR_IR_BUILDER_FOR_FRAME(Parallel, tvm::tir::ForKind::kParallel); -TVM_TIR_IR_BUILDER_FOR_FRAME(Vectorized, tvm::tir::ForKind::kVectorized); -TVM_TIR_IR_BUILDER_FOR_FRAME(Unroll, tvm::tir::ForKind::kUnrolled); +TVM_TIR_IR_BUILDER_FOR_FRAME(Serial, tvm::tirx::ForKind::kSerial); +TVM_TIR_IR_BUILDER_FOR_FRAME(Parallel, tvm::tirx::ForKind::kParallel); +TVM_TIR_IR_BUILDER_FOR_FRAME(Vectorized, tvm::tirx::ForKind::kVectorized); +TVM_TIR_IR_BUILDER_FOR_FRAME(Unroll, tvm::tirx::ForKind::kUnrolled); #undef TVM_TIR_IR_BUILDER_FOR_FRAME ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, ffi::String thread, ffi::Optional> annotations) { - using namespace tvm::tir; + using namespace tvm::tirx; PrimExpr min = start; PrimExpr extent = arith::Analyzer().Simplify(stop - start); ObjectPtr n = ffi::make_object(); @@ -421,7 +421,7 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, ffi::String thread, } ForFrame Grid(ffi::Array extents) { - using namespace tvm::tir; + using namespace tvm::tirx; ObjectPtr n = ffi::make_object(); n->vars.reserve(extents.size()); n->doms.reserve(extents.size()); @@ -451,10 +451,10 @@ AssertFrame Assert(PrimExpr condition, ffi::String error_kind, ffi::Array message_parts) { ObjectPtr n = ffi::make_object(); n->condition = condition; - n->error_kind = tvm::tir::StringImm(error_kind); - ffi::Array parts; + n->error_kind = tvm::tirx::StringImm(error_kind); + ffi::Array parts; for (const auto& p : message_parts) { - parts.push_back(tvm::tir::StringImm(p)); + parts.push_back(tvm::tirx::StringImm(p)); } n->message_parts = parts; return AssertFrame(n); @@ -470,7 +470,7 @@ Var Bind(PrimExpr value, ffi::Optional type_annotation, ffi::Optional return Var("v", value.dtype()); } }(); - AddToParent(tvm::tir::Bind(bind_var, value)); + AddToParent(tvm::tirx::Bind(bind_var, value)); return bind_var; } @@ -489,8 +489,8 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { } ObjectPtr n = ffi::make_object(); if (!iter_var->dom.defined()) { - const_cast(iter_var.get())->dom = - Range(tvm::tir::make_zero(extent.dtype()), extent); + const_cast(iter_var.get())->dom = + Range(tvm::tirx::make_zero(extent.dtype()), extent); } else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) { TVM_FFI_THROW(ValueError) << "Inconsistent extents of environment thread. " << iter_var->dom->extent << " vs " << extent; @@ -542,7 +542,8 @@ ElseFrame Else() { } Var EnvThread(ffi::String thread_tag, DataType dtype) { - IterVar iter_var(Range{nullptr}, Var("", dtype), tvm::tir::IterVarType::kThreadIndex, thread_tag); + IterVar iter_var(Range{nullptr}, Var("", dtype), tvm::tirx::IterVarType::kThreadIndex, + thread_tag); Var var = iter_var->var; if (ffi::Optional opt_frame = IRBuilder::Current()->FindFrame()) { opt_frame.value()->env_threads.Set(var, iter_var); @@ -615,7 +616,7 @@ void BufferStore(Buffer buffer, PrimExpr value, ffi::Array indices, } value = tvm::cast(lhs_dtype, value); } - AddToParent(tvm::tir::BufferStore(buffer, value, indices, predicate)); + AddToParent(tvm::tirx::BufferStore(buffer, value, indices, predicate)); } Buffer DeclBuffer(ffi::Array shape, DataType dtype, ffi::String buffer_name, @@ -627,10 +628,10 @@ Buffer DeclBuffer(ffi::Array shape, DataType dtype, ffi::String buffer align, offset_factor, buffer_type, axis_separators); if (data.defined()) { // Alias an existing buffer: emit DeclBuffer statement - AddToParent(tvm::tir::DeclBuffer(buffer)); + AddToParent(tvm::tirx::DeclBuffer(buffer)); } else { // No backing data pointer: emit AllocBuffer statement - AddToParent(tvm::tir::AllocBuffer(buffer)); + AddToParent(tvm::tirx::AllocBuffer(buffer)); } return buffer; } @@ -640,52 +641,53 @@ Buffer AllocBuffer(ffi::Array shape, DataType dtype, ffi::String stora Buffer buffer = BufferDecl(shape, dtype, "", std::nullopt, std::nullopt, std::nullopt, storage_scope, 0, 0, "", std::nullopt); AddToParent( - tvm::tir::AllocBuffer(buffer, annotations.value_or(ffi::Map()))); + tvm::tirx::AllocBuffer(buffer, annotations.value_or(ffi::Map()))); return buffer; } -void Evaluate(PrimExpr value) { AddToParent(tvm::tir::Evaluate(value)); } +void Evaluate(PrimExpr value) { AddToParent(tvm::tirx::Evaluate(value)); } PrimExpr Ptr(runtime::DataType dtype, ffi::String storage_scope = "global", bool is_size_var = false) { PointerType type_annotation(PrimType(dtype), storage_scope); - return is_size_var ? tvm::tir::SizeVar("", type_annotation) : tvm::tir::Var("", type_annotation); + return is_size_var ? tvm::tirx::SizeVar("", type_annotation) + : tvm::tirx::Var("", type_annotation); } using tvm::script::ir_builder::details::Namer; TVM_STATIC_IR_FUNCTOR(Namer, vtable) - .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { - tvm::tir::BufferNode* buffer = - const_cast(node.as()); + .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { + tvm::tirx::BufferNode* buffer = + const_cast(node.as()); buffer->name = name; Namer::Name(buffer->data, name); int n = buffer->strides.size(); for (int i = 0; i < n; ++i) { PrimExpr e = buffer->strides[i]; - if (auto v = e.as()) { + if (auto v = e.as()) { Namer::Name(v.value(), name + "_s" + std::to_string(i)); } } }); TVM_STATIC_IR_FUNCTOR(Namer, vtable) - .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { - using namespace tvm::tir; + .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { + using namespace tvm::tirx; SizeVarNode* var = const_cast(node.as()); var->name_hint = name; }); TVM_STATIC_IR_FUNCTOR(Namer, vtable) - .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { - using namespace tvm::tir; + .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { + using namespace tvm::tirx; VarNode* var = const_cast(node.as()); var->name_hint = name; }); TVM_STATIC_IR_FUNCTOR(Namer, vtable) - .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { - using namespace tvm::tir; + .set_dispatch([](const ObjectRef& node, ffi::String name) -> void { + using namespace tvm::tirx; IterVarNode* var = const_cast(node.as()); Namer::Name(var->var, name); }); @@ -693,11 +695,11 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("script.ir_builder.tir.Buffer", BufferDecl) - .def("script.ir_builder.tir.PrimFunc", PrimFunc) - .def("script.ir_builder.tir.Arg", + .def("script.ir_builder.tirx.Buffer", BufferDecl) + .def("script.ir_builder.tirx.PrimFunc", PrimFunc) + .def("script.ir_builder.tirx.Arg", [](ffi::String name, ObjectRef obj) -> ObjectRef { - using namespace tvm::tir; + using namespace tvm::tirx; if (auto var = obj.as()) { return Arg(name, var.value()); } @@ -707,40 +709,40 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_THROW(ValueError) << "Unexpected type for TIR Arg: " << obj->GetTypeKey(); throw; }) - .def("script.ir_builder.tir.FuncName", FuncName) - .def("script.ir_builder.tir.FuncAttrs", FuncAttrs) - .def("script.ir_builder.tir.FuncRet", FuncRet) - .def("script.ir_builder.tir.MatchBuffer", MatchBuffer) - .def("script.ir_builder.tir.Block", Block) - .def("script.ir_builder.tir.Init", Init) - .def("script.ir_builder.tir.Where", Where) - .def("script.ir_builder.tir.Reads", Reads) - .def("script.ir_builder.tir.Writes", Writes) - .def("script.ir_builder.tir.BlockAttrs", BlockAttrs) - .def("script.ir_builder.tir.SBlockAllocBuffer", SBlockAllocBuffer) - .def("script.ir_builder.tir.AxisSpatial", axis::Spatial) - .def("script.ir_builder.tir.AxisReduce", axis::Reduce) - .def("script.ir_builder.tir.AxisScan", axis::Scan) - .def("script.ir_builder.tir.AxisOpaque", axis::Opaque) - .def("script.ir_builder.tir.AxisRemap", axis::Remap) - .def("script.ir_builder.tir.Serial", Serial) - .def("script.ir_builder.tir.Parallel", Parallel) - .def("script.ir_builder.tir.Vectorized", Vectorized) - .def("script.ir_builder.tir.Unroll", Unroll) - .def("script.ir_builder.tir.ThreadBinding", ThreadBinding) - .def("script.ir_builder.tir.Grid", Grid) - .def("script.ir_builder.tir.Assert", Assert) - .def("script.ir_builder.tir.Bind", Bind) - .def("script.ir_builder.tir.Attr", Attr) - .def("script.ir_builder.tir.While", While) - .def("script.ir_builder.tir.If", If) - .def("script.ir_builder.tir.Then", Then) - .def("script.ir_builder.tir.Else", Else) - .def("script.ir_builder.tir.DeclBuffer", DeclBuffer) - .def("script.ir_builder.tir.AllocBuffer", AllocBuffer) - .def("script.ir_builder.tir.LaunchThread", - [](ffi::Variant thread_tag_or_var, PrimExpr extent) { - if (auto var = thread_tag_or_var.as()) { + .def("script.ir_builder.tirx.FuncName", FuncName) + .def("script.ir_builder.tirx.FuncAttrs", FuncAttrs) + .def("script.ir_builder.tirx.FuncRet", FuncRet) + .def("script.ir_builder.tirx.MatchBuffer", MatchBuffer) + .def("script.ir_builder.tirx.Block", Block) + .def("script.ir_builder.tirx.Init", Init) + .def("script.ir_builder.tirx.Where", Where) + .def("script.ir_builder.tirx.Reads", Reads) + .def("script.ir_builder.tirx.Writes", Writes) + .def("script.ir_builder.tirx.BlockAttrs", BlockAttrs) + .def("script.ir_builder.tirx.SBlockAllocBuffer", SBlockAllocBuffer) + .def("script.ir_builder.tirx.AxisSpatial", axis::Spatial) + .def("script.ir_builder.tirx.AxisReduce", axis::Reduce) + .def("script.ir_builder.tirx.AxisScan", axis::Scan) + .def("script.ir_builder.tirx.AxisOpaque", axis::Opaque) + .def("script.ir_builder.tirx.AxisRemap", axis::Remap) + .def("script.ir_builder.tirx.Serial", Serial) + .def("script.ir_builder.tirx.Parallel", Parallel) + .def("script.ir_builder.tirx.Vectorized", Vectorized) + .def("script.ir_builder.tirx.Unroll", Unroll) + .def("script.ir_builder.tirx.ThreadBinding", ThreadBinding) + .def("script.ir_builder.tirx.Grid", Grid) + .def("script.ir_builder.tirx.Assert", Assert) + .def("script.ir_builder.tirx.Bind", Bind) + .def("script.ir_builder.tirx.Attr", Attr) + .def("script.ir_builder.tirx.While", While) + .def("script.ir_builder.tirx.If", If) + .def("script.ir_builder.tirx.Then", Then) + .def("script.ir_builder.tirx.Else", Else) + .def("script.ir_builder.tirx.DeclBuffer", DeclBuffer) + .def("script.ir_builder.tirx.AllocBuffer", AllocBuffer) + .def("script.ir_builder.tirx.LaunchThread", + [](ffi::Variant thread_tag_or_var, PrimExpr extent) { + if (auto var = thread_tag_or_var.as()) { return LaunchThread(var.value(), extent); } else if (auto str = thread_tag_or_var.as()) { return LaunchThread(str.value(), extent); @@ -750,10 +752,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { throw; } }) - .def("script.ir_builder.tir.EnvThread", EnvThread) - .def("script.ir_builder.tir.BufferStore", BufferStore) - .def("script.ir_builder.tir.Evaluate", Evaluate) - .def("script.ir_builder.tir.Ptr", Ptr); + .def("script.ir_builder.tirx.EnvThread", EnvThread) + .def("script.ir_builder.tirx.BufferStore", BufferStore) + .def("script.ir_builder.tirx.Evaluate", Evaluate) + .def("script.ir_builder.tirx.Ptr", Ptr); } #define TVM_TMP_STR(x) #x @@ -780,109 +782,109 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("script.ir_builder.tir.BFloat16", BFloat16) - .TVM_FFI_REFL_DEF_GLOBAL_SIZE("script.ir_builder.tir.Float", Float) - .TVM_FFI_REFL_DEF_GLOBAL_SIZE("script.ir_builder.tir.UInt", UInt) - .TVM_FFI_REFL_DEF_GLOBAL_SIZE("script.ir_builder.tir.Int", Int) - .TVM_FFI_REFL_DEF_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float) - .TVM_FFI_REFL_DEF_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt) - .TVM_FFI_REFL_DEF_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int) - .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16); + .def("script.ir_builder.tirx.BFloat16", BFloat16) + .TVM_FFI_REFL_DEF_GLOBAL_SIZE("script.ir_builder.tirx.Float", Float) + .TVM_FFI_REFL_DEF_GLOBAL_SIZE("script.ir_builder.tirx.UInt", UInt) + .TVM_FFI_REFL_DEF_GLOBAL_SIZE("script.ir_builder.tirx.Int", Int) + .TVM_FFI_REFL_DEF_GLOBAL_SIZES_LANES("script.ir_builder.tirx.Float", Float) + .TVM_FFI_REFL_DEF_GLOBAL_SIZES_LANES("script.ir_builder.tirx.UInt", UInt) + .TVM_FFI_REFL_DEF_GLOBAL_SIZES_LANES("script.ir_builder.tirx.Int", Int) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tirx.BFloat16", BFloat16); } // Float8 variants TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("script.ir_builder.tir.Float8E3M4", Float8E3M4) - .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E3M4", Float8E3M4); + .def("script.ir_builder.tirx.Float8E3M4", Float8E3M4) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tirx.Float8E3M4", Float8E3M4); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("script.ir_builder.tir.Float8E4M3", Float8E4M3) - .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3", Float8E4M3); + .def("script.ir_builder.tirx.Float8E4M3", Float8E4M3) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tirx.Float8E4M3", Float8E4M3); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("script.ir_builder.tir.Float8E4M3B11FNUZ", Float8E4M3B11FNUZ) - .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3B11FNUZ", Float8E4M3B11FNUZ); + .def("script.ir_builder.tirx.Float8E4M3B11FNUZ", Float8E4M3B11FNUZ) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tirx.Float8E4M3B11FNUZ", Float8E4M3B11FNUZ); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN) - .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN); + .def("script.ir_builder.tirx.Float8E4M3FN", Float8E4M3FN) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tirx.Float8E4M3FN", Float8E4M3FN); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("script.ir_builder.tir.Float8E4M3FNUZ", Float8E4M3FNUZ) - .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FNUZ", Float8E4M3FNUZ); + .def("script.ir_builder.tirx.Float8E4M3FNUZ", Float8E4M3FNUZ) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tirx.Float8E4M3FNUZ", Float8E4M3FNUZ); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("script.ir_builder.tir.Float8E5M2", Float8E5M2) - .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2); + .def("script.ir_builder.tirx.Float8E5M2", Float8E5M2) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tirx.Float8E5M2", Float8E5M2); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("script.ir_builder.tir.Float8E5M2FNUZ", Float8E5M2FNUZ) - .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2FNUZ", Float8E5M2FNUZ); + .def("script.ir_builder.tirx.Float8E5M2FNUZ", Float8E5M2FNUZ) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tirx.Float8E5M2FNUZ", Float8E5M2FNUZ); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("script.ir_builder.tir.Float8E8M0FNU", Float8E8M0FNU) - .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float8E8M0FNU", Float8E8M0FNU); + .def("script.ir_builder.tirx.Float8E8M0FNU", Float8E8M0FNU) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tirx.Float8E8M0FNU", Float8E8M0FNU); } // Float6 variants TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("script.ir_builder.tir.Float6E2M3FN", Float6E2M3FN) - .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float6E2M3FN", Float6E2M3FN); + .def("script.ir_builder.tirx.Float6E2M3FN", Float6E2M3FN) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tirx.Float6E2M3FN", Float6E2M3FN); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("script.ir_builder.tir.Float6E3M2FN", Float6E3M2FN) - .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float6E3M2FN", Float6E3M2FN); + .def("script.ir_builder.tirx.Float6E3M2FN", Float6E3M2FN) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tirx.Float6E3M2FN", Float6E3M2FN); } // Float4 variant TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN) - .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tir.Float4E2M1FN", Float4E2M1FN); + .def("script.ir_builder.tirx.Float4E2M1FN", Float4E2M1FN) + .TVM_FFI_REFL_DEF_GLOBAL_LANES("script.ir_builder.tirx.Float4E2M1FN", Float4E2M1FN); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("script.ir_builder.tir.Boolean", Boolean) - .def("script.ir_builder.tir.Handle", Handle) - .def("script.ir_builder.tir.TensormapHandle", TensormapHandle) - .def("script.ir_builder.tir.Void", Void) - .def("script.ir_builder.tir.min", + .def("script.ir_builder.tirx.Boolean", Boolean) + .def("script.ir_builder.tirx.Handle", Handle) + .def("script.ir_builder.tirx.TensormapHandle", TensormapHandle) + .def("script.ir_builder.tirx.Void", Void) + .def("script.ir_builder.tirx.min", [](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::min(a, b); }) - .def("script.ir_builder.tir.max", + .def("script.ir_builder.tirx.max", [](PrimExpr a, PrimExpr b) -> PrimExpr { return tvm::max(a, b); }); } -} // namespace tir +} // namespace tirx } // namespace ir_builder } // namespace script } // namespace tvm diff --git a/src/script/ir_builder/tir/utils.h b/src/script/ir_builder/tirx/utils.h similarity index 89% rename from src/script/ir_builder/tir/utils.h rename to src/script/ir_builder/tirx/utils.h index a53356b494ed..f9aecbc28857 100644 --- a/src/script/ir_builder/tir/utils.h +++ b/src/script/ir_builder/tirx/utils.h @@ -19,21 +19,21 @@ #ifndef TVM_SCRIPT_IR_BUILDER_TIR_UTILS_H_ #define TVM_SCRIPT_IR_BUILDER_TIR_UTILS_H_ -#include -#include -#include -#include +#include +#include +#include +#include namespace tvm { namespace script { namespace ir_builder { -namespace tir { +namespace tirx { /*! - * \brief Add tir Stmt to the top frame in IRBuilder frame stack. + * \brief Add tirx Stmt to the top frame in IRBuilder frame stack. * \param stmt The Stmt. */ -inline void AddToParent(tvm::tir::Stmt stmt) { +inline void AddToParent(tvm::tirx::Stmt stmt) { IRBuilder builder = IRBuilder::Current(); if (builder->frames.empty()) { TVM_FFI_CHECK(!builder->result.defined(), ValueError) << "Builder.result has already been set"; @@ -46,12 +46,12 @@ inline void AddToParent(tvm::tir::Stmt stmt) { } /*! - * \brief Convert array of tir Stmt to single Stmt. + * \brief Convert array of tirx Stmt to single Stmt. * \param stmt The array of Stmt. * \return The SeqStmt. */ -inline tvm::tir::Stmt AsStmt(const ffi::Array& stmt) { - return tvm::tir::SeqStmt::Flatten(stmt); +inline tvm::tirx::Stmt AsStmt(const ffi::Array& stmt) { + return tvm::tirx::SeqStmt::Flatten(stmt); } /*! @@ -124,15 +124,15 @@ inline IfFrame FindIfFrame(const ffi::String& method) { * \param buffer_load The BufferLoad. * \return The converted BufferRegion. */ -inline tvm::tir::BufferRegion BufferRegionFromLoad(tvm::tir::BufferLoad buffer_load) { +inline tvm::tirx::BufferRegion BufferRegionFromLoad(tvm::tirx::BufferLoad buffer_load) { ffi::Array ranges; for (const PrimExpr& index : buffer_load->indices) { ranges.push_back(Range::FromMinExtent(index, IntImm(index->dtype, 1))); } - return tvm::tir::BufferRegion(buffer_load->buffer, ranges); + return tvm::tirx::BufferRegion(buffer_load->buffer, ranges); } -} // namespace tir +} // namespace tirx } // namespace ir_builder } // namespace script } // namespace tvm diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index ee1dcde1035f..8398b1f0a484 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -35,7 +35,7 @@ struct SortableFunction { : priority(0), gv(obj.first), func(obj.second) { if (gv->name_hint == "main") { priority = 1000; - } else if (obj.second->GetTypeKey() == "tir.PrimFunc") { + } else if (obj.second->GetTypeKey() == "tirx.PrimFunc") { priority = 1; } else if (obj.second->GetTypeKey() == "relax.expr.ExternFunc") { priority = 2; diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index f2baac0f5375..9abe464d24fd 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -144,7 +144,7 @@ ffi::Optional PrintCallTIRDPSPacked(const relax::Call& n, const AccessP if (n->op.same_as(call_dps_packed_op)) { return Relax(d, "call_dps_packed")->Call(args, kwargs_keys, kwargs_values); } - // Step 4. Print n->args[2], the tir variables + // Step 4. Print n->args[2], the tirx variables if (n->args.size() == 3) { kwargs_keys.push_back("tir_vars"); kwargs_values.push_back(d->AsDoc(n->args[2], n_p->Attr("args")->ArrayItem(2))); diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index 978c4a8243da..24ae192c73b3 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -41,7 +41,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { RelaxFrameNode::RegisterReflection(); } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](relax::Function n, AccessPath n_p, IRDocsifier d) -> Doc { - std::unordered_set func_vars; + std::unordered_set func_vars; With f(d); IdDoc func_name(""); diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc index e597df64501d..0b0cd2939769 100644 --- a/src/script/printer/relax/struct_info.cc +++ b/src/script/printer/relax/struct_info.cc @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -#include +#include #include "./utils.h" @@ -45,8 +45,8 @@ ExprDoc PrintShapeVar(const PrimExpr& e, const AccessPath& e_p, const IRDocsifie // Step 2. Figure out if the PrimExpr contains at least a func var bool func_var_mode = false; if (f != nullptr) { - tir::PostOrderVisit(e, [f, &func_var_mode](const ObjectRef& obj) -> void { - if (const auto* var = obj.as()) { + tirx::PostOrderVisit(e, [f, &func_var_mode](const ObjectRef& obj) -> void { + if (const auto* var = obj.as()) { if (f->func_vars->count(var)) { func_var_mode = true; } diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index ae35f018cbfe..d08476873c7c 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -18,7 +18,7 @@ */ #include -#include "../tir/utils.h" +#include "../tirx/utils.h" #include "./utils.h" namespace tvm { @@ -41,7 +41,7 @@ RelaxFrameNode* GetRelaxFrame(IRDocsifier d) { return f; } -Doc PrintTIRVar(tir::Var n, AccessPath n_p, IRDocsifier d) { +Doc PrintTIRVar(tirx::Var n, AccessPath n_p, IRDocsifier d) { TVM_FFI_CHECK(n->dtype.is_scalar(), TypeError) << "Relax only uses scalar TIR variables," << "but received TIR variable " << n << " with dtype " << n->dtype; @@ -69,8 +69,8 @@ Doc PrintTIRVar(tir::Var n, AccessPath n_p, IRDocsifier d) { TVM_FFI_UNREACHABLE(); } -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", PrintTIRVar); -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", PrintTIRVar); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", PrintTIRVar); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch("relax", PrintTIRVar); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index 558abaef3350..d6aa98fda704 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -42,7 +42,7 @@ class RelaxFrameNode : public FrameNode { public: bool is_func = false; bool module_alias_printed = false; - std::unordered_set* func_vars = nullptr; + std::unordered_set* func_vars = nullptr; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tirx/block.cc similarity index 81% rename from src/script/printer/tir/block.cc rename to src/script/printer/tirx/block.cc index 38b89f1c704f..00f8fc0df02a 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tirx/block.cc @@ -22,22 +22,22 @@ namespace tvm { namespace script { namespace printer { -Doc PrintBlock(IRDocsifier d, tir::SBlock block, AccessPath block_p, // - ffi::Optional opt_realize, +Doc PrintBlock(IRDocsifier d, tirx::SBlock block, AccessPath block_p, // + ffi::Optional opt_realize, ffi::Optional opt_realize_p) { With frame(d, block); TVM_FFI_ICHECK_EQ(opt_realize.defined(), opt_realize_p.defined()); - const tir::SBlockRealizeNode* realize = + const tirx::SBlockRealizeNode* realize = opt_realize.defined() ? opt_realize.value().get() : nullptr; AccessPath realize_p = *opt_realize_p; // Step 1. Handle block var and block bindings // Step 1.1. Obtain all loop var defined along path - std::unordered_map loop_vars; + std::unordered_map loop_vars; for (Frame f : d->frames) { if (const auto* tir_f = f.as()) { - if (auto for_loop = tir_f->tir.as()) { - for (ffi::Optional loop = for_loop; loop; - loop = loop.value()->body.as()) { + if (auto for_loop = tir_f->tirx.as()) { + for (ffi::Optional loop = for_loop; loop; + loop = loop.value()->body.as()) { loop_vars.insert(std::make_pair(loop.value()->loop_var.get(), loop.value())); } } @@ -47,14 +47,14 @@ Doc PrintBlock(IRDocsifier d, tir::SBlock block, AccessPath block_p, // std::vector remap_vars_indices; auto add_remapped_iter_var = [&](int i) -> bool { if (realize && d->cfg->syntax_sugar) { - tir::ExprDeepEqual expr_equal; - tir::IterVar iter_var = block->iter_vars[i]; + tirx::ExprDeepEqual expr_equal; + tirx::IterVar iter_var = block->iter_vars[i]; PrimExpr value = realize->iter_values[i]; - if (iter_var->iter_type == tir::IterVarType::kDataPar || - iter_var->iter_type == tir::IterVarType::kCommReduce) { - if (const auto* var = value.as()) { + if (iter_var->iter_type == tirx::IterVarType::kDataPar || + iter_var->iter_type == tirx::IterVarType::kCommReduce) { + if (const auto* var = value.as()) { if (loop_vars.count(var)) { - tir::For for_loop = loop_vars.at(var); + tirx::For for_loop = loop_vars.at(var); if (expr_equal(for_loop->min, iter_var->dom->min) && expr_equal(for_loop->extent, iter_var->dom->extent)) { remap_vars_indices.push_back(i); @@ -68,23 +68,23 @@ Doc PrintBlock(IRDocsifier d, tir::SBlock block, AccessPath block_p, // }; auto print_single_iter_var = [&](int i) { - tir::IterVar iter_var = block->iter_vars[i]; + tirx::IterVar iter_var = block->iter_vars[i]; AccessPath iter_var_p = block_p->Attr("iter_var")->ArrayItem(i); ExprDoc rhs = TIR(d, "axis"); - if (iter_var->iter_type == tir::IterVarType::kDataPar) { + if (iter_var->iter_type == tirx::IterVarType::kDataPar) { rhs = rhs->Attr("spatial"); - } else if (iter_var->iter_type == tir::IterVarType::kCommReduce) { + } else if (iter_var->iter_type == tirx::IterVarType::kCommReduce) { rhs = rhs->Attr("reduce"); - } else if (iter_var->iter_type == tir::IterVarType::kOrdered) { + } else if (iter_var->iter_type == tirx::IterVarType::kOrdered) { rhs = rhs->Attr("scan"); - } else if (iter_var->iter_type == tir::IterVarType::kOpaque) { + } else if (iter_var->iter_type == tirx::IterVarType::kOpaque) { rhs = rhs->Attr("opaque"); } else { TVM_FFI_THROW(ValueError) << "Unknown IterVarType in block signature: " - << tir::IterVarType2String(iter_var->iter_type); + << tirx::IterVarType2String(iter_var->iter_type); } ExprDoc dom{ffi::UnsafeInit()}; - if (tir::is_zero(iter_var->dom->min)) { + if (tirx::is_zero(iter_var->dom->min)) { ExprDoc extent = d->AsDoc(iter_var->dom->extent, // iter_var_p->Attr("dom")->Attr("extent")); dom = extent; @@ -122,13 +122,13 @@ Doc PrintBlock(IRDocsifier d, tir::SBlock block, AccessPath block_p, // std::string binding_type = ""; ffi::Array binding_paths; for (int i : remap_vars_indices) { - tir::IterVar iter_var = block->iter_vars[i]; + tirx::IterVar iter_var = block->iter_vars[i]; AccessPath iter_var_p = block_p->Attr("iter_vars")->ArrayItem(i); lhs.push_back(DefineVar(iter_var->var, *frame, d)); loop_var_doc.push_back(d->AsDoc(realize->iter_values[i], realize_p->Attr("iter_values")->ArrayItem(i))); binding_paths.push_back(iter_var_p->Attr("iter_type")); - binding_type += iter_var->iter_type == tir::IterVarType::kDataPar ? "S" : "R"; + binding_type += iter_var->iter_type == tirx::IterVarType::kDataPar ? "S" : "R"; } ExprDoc rhs = TIR(d, "axis")->Attr("remap"); ExprDoc binding_str = LiteralDoc::Str(binding_type, std::nullopt); @@ -152,7 +152,7 @@ Doc PrintBlock(IRDocsifier d, tir::SBlock block, AccessPath block_p, // // Step 2. Handle block predicate if (realize) { TVM_FFI_ICHECK(realize->predicate.defined() && realize->predicate->dtype.is_bool()); - if (!tir::is_one(realize->predicate)) { + if (!tirx::is_one(realize->predicate)) { (*frame)->stmts.push_back(ExprStmtDoc( TIR(d, "where") ->Call({d->AsDoc(realize->predicate, realize_p->Attr("predicate"))}))); @@ -179,7 +179,7 @@ Doc PrintBlock(IRDocsifier d, tir::SBlock block, AccessPath block_p, // } // Step 5. Handle `alloc_buffer` for (int i = 0, n = block->alloc_buffers.size(); i < n; ++i) { - tir::Buffer buffer = block->alloc_buffers[i]; + tirx::Buffer buffer = block->alloc_buffers[i]; AccessPath buffer_p = block_p->Attr("alloc_buffers")->ArrayItem(i); IdDoc lhs = DefineBuffer(buffer, *frame, d); ExprDoc rhs = BufferDecl(buffer, "sblock_alloc_buffer", {}, buffer_p, *frame, d, @@ -188,14 +188,14 @@ Doc PrintBlock(IRDocsifier d, tir::SBlock block, AccessPath block_p, // } // Step 6. Handle `match_buffer` for (int i = 0, n = block->match_buffers.size(); i < n; ++i) { - tir::MatchBufferRegion buffer_region = block->match_buffers[i]; + tirx::MatchBufferRegion buffer_region = block->match_buffers[i]; AccessPath buffer_region_p = block_p->Attr("match_buffers")->ArrayItem(i); StmtDoc doc = d->AsDoc(buffer_region, buffer_region_p); (*frame)->stmts.push_back(doc); } // Step 7. Handle init block if (block->init.defined()) { - tir::Stmt init = block->init.value(); + tirx::Stmt init = block->init.value(); With init_frame(d, init); AsDocBody(init, block_p->Attr("init"), init_frame->get(), d); (*frame)->stmts.push_back( @@ -217,8 +217,8 @@ Doc PrintBlock(IRDocsifier d, tir::SBlock block, AccessPath block_p, // } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( - "", [](tir::SBlockRealize realize, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch( + "", [](tirx::SBlockRealize realize, AccessPath p, IRDocsifier d) -> Doc { Doc doc = PrintBlock(d, realize->block, p->Attr("block"), realize, p); // since we do not have d->AsDoc for realize->block, // we should add possible doc decoration manually. @@ -227,12 +227,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::SBlock block, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::SBlock block, AccessPath p, IRDocsifier d) -> Doc { return PrintBlock(d, block, p, std::nullopt, std::nullopt); }); -TVM_SCRIPT_REPR(tir::SBlockNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::SBlockRealizeNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::SBlockNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::SBlockRealizeNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tirx/buffer.cc similarity index 89% rename from src/script/printer/tir/buffer.cc rename to src/script/printer/tirx/buffer.cc index e9616d1dc6bf..7d08ab198485 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tirx/buffer.cc @@ -24,11 +24,11 @@ namespace tvm { namespace script { namespace printer { -ffi::Map BufferAttrs(tir::Buffer buffer, const AccessPath& buffer_p, +ffi::Map BufferAttrs(tirx::Buffer buffer, const AccessPath& buffer_p, const Frame& frame, const IRDocsifier& d, BufferVarDefinition var_definitions) { - using tvm::tir::Var; - using tvm::tir::VarNode; + using tvm::tirx::Var; + using tvm::tirx::VarNode; ffi::Map kwargs; ffi::Array var_def_lhs; ffi::Array var_def_rhs; @@ -36,7 +36,7 @@ ffi::Map BufferAttrs(tir::Buffer buffer, const AccessPath& // Step 0. Set up statistics std::unordered_map use_count; auto update_use_count = [&](const PrimExpr& e) { - tir::PostOrderVisit(e, [&](const ObjectRef& n) { + tirx::PostOrderVisit(e, [&](const ObjectRef& n) { if (const VarNode* var = n.as()) { ++use_count[var]; } @@ -167,7 +167,7 @@ ffi::Map BufferAttrs(tir::Buffer buffer, const AccessPath& LiteralDoc::Int(buffer->offset_factor, buffer_p->Attr("offset_factor"))); } // Step 9. Handle `buffer.buffer_type` - if (buffer->buffer_type != tir::BufferType::kDefault) { + if (buffer->buffer_type != tirx::BufferType::kDefault) { kwargs.Set("buffer_type", LiteralDoc::Str("auto", buffer_p->Attr("buffer_type"))); } // Step 10. Handle `buffer.axis_separator` @@ -202,7 +202,7 @@ ExprDoc BufferCall(const ExprDoc& prefix, const ffi::Map& return prefix->Call(args, kwargs_keys, kwargs_values); } -ExprDoc BufferDecl(const tir::Buffer& buffer, const ffi::String& method, +ExprDoc BufferDecl(const tirx::Buffer& buffer, const ffi::String& method, const ffi::Array& args, const AccessPath& p, const Frame& frame, const IRDocsifier& d, BufferVarDefinition var_definitions) { return BufferCall(/*prefix=*/TIR(d, method), @@ -210,7 +210,7 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const ffi::String& method, /*args=*/args); } -ExprDoc BufferAttn(const tir::Buffer& buffer, const AccessPath& p, const Frame& frame, +ExprDoc BufferAttn(const tirx::Buffer& buffer, const AccessPath& p, const Frame& frame, const IRDocsifier& d) { ffi::Map attrs = BufferAttrs(buffer, p, frame, d, BufferVarDefinition::DataPointer); @@ -226,7 +226,7 @@ ffi::Array BufferIndices(const ffi::Array& indices, const AccessP ffi::Array indices_doc; indices_doc.reserve(n); for (int i = 0; i < n; ++i) { - if (const auto* ramp = indices[i].as()) { + if (const auto* ramp = indices[i].as()) { if (const auto* stride = ramp->stride.as()) { AccessPath ramp_p = p->Attr("indices")->ArrayItem(i); AccessPath stride_p = ramp_p->Attr("stride"); @@ -256,7 +256,7 @@ ffi::Array BufferSlices(const ffi::Array& region, const AccessPath& Range range = region[i]; AccessPath range_p = p->ArrayItem(i); ExprDoc min = d->AsDoc(range->min, range_p->Attr("min")); - if (tir::is_one(range->extent)) { + if (tirx::is_one(range->extent)) { indices.push_back(min); } else { ExprDoc max = d->AsDoc(range->min + range->extent, range_p->Attr("extent")); @@ -267,15 +267,15 @@ ffi::Array BufferSlices(const ffi::Array& region, const AccessPath& } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( - "", [](tir::BufferRegion buffer_region, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch( + "", [](tirx::BufferRegion buffer_region, AccessPath p, IRDocsifier d) -> Doc { ExprDoc prefix = d->AsDoc(buffer_region->buffer, p->Attr("buffer")); return prefix[BufferSlices(buffer_region->region, p->Attr("region"), d)]; }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](tir::BufferStore store, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch( // + "", [](tirx::BufferStore store, AccessPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(store->buffer, p->Attr("buffer")); ExprDoc value = d->AsDoc(store->value, p->Attr("value")); @@ -293,8 +293,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](tir::BufferLoad load, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch( // + "", [](tirx::BufferLoad load, AccessPath p, IRDocsifier d) -> Doc { ExprDoc buffer = d->AsDoc(load->buffer, p->Attr("buffer")); // Use .vload(...) syntax when there is a predicate @@ -308,7 +308,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // - .set_dispatch("", [](tir::Buffer buffer, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Buffer buffer, AccessPath p, IRDocsifier d) -> Doc { if (!d->IsVarDefined(buffer)) { if (ffi::Optional opt_f = FindLowestVarDef(buffer, d)) { ExprDoc lhs = DefineBuffer(buffer, opt_f.value(), d); @@ -325,8 +325,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( - "", [](tir::MatchBufferRegion stmt, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch( + "", [](tirx::MatchBufferRegion stmt, AccessPath p, IRDocsifier d) -> Doc { Frame frame = d->frames.back(); ExprDoc lhs = DefineBuffer(stmt->buffer, frame, d); ExprDoc src_buffer = d->AsDoc(stmt->source, p->Attr("source")); @@ -336,18 +336,18 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](tir::ProducerLoad load, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch( // + "", [](tirx::ProducerLoad load, AccessPath p, IRDocsifier d) -> Doc { ExprDoc prefix = IdDoc(load->producer->GetNameHint()); return prefix[BufferIndices(load->indices, p->Attr("indices"), d)]; }); -TVM_SCRIPT_REPR(tir::BufferRegionNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::BufferLoadNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::BufferStoreNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::BufferNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::MatchBufferRegionNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::ProducerLoadNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::BufferRegionNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::BufferLoadNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::BufferStoreNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::BufferNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::MatchBufferRegionNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::ProducerLoadNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tirx/expr.cc similarity index 72% rename from src/script/printer/tir/expr.cc rename to src/script/printer/tirx/expr.cc index 69b047b0027e..364b2a40cedb 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tirx/expr.cc @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -#include +#include #include "./utils.h" @@ -24,14 +24,14 @@ namespace tvm { namespace script { namespace printer { -ExprDoc PrintVarCreation(const tir::Var& var, const AccessPath& var_p, const IRDocsifier& d) { +ExprDoc PrintVarCreation(const tirx::Var& var, const AccessPath& var_p, const IRDocsifier& d) { Type type = var->type_annotation; AccessPath type_p = var_p->Attr("type_annotation"); ExprDoc rhs{ffi::UnsafeInit()}; ffi::Array kwargs_keys; ffi::Array kwargs_values; - if (var->IsInstance()) { + if (var->IsInstance()) { kwargs_keys.push_back("is_size_var"); kwargs_values.push_back(LiteralDoc::Boolean(true, std::nullopt)); } @@ -64,7 +64,7 @@ ExprDoc PrintVarCreation(const tir::Var& var, const AccessPath& var_p, const IRD return rhs; } -Doc PrintVar(const tir::Var& var, const AccessPath& var_p, const IRDocsifier& d) { +Doc PrintVar(const tirx::Var& var, const AccessPath& var_p, const IRDocsifier& d) { if (!d->IsVarDefined(var)) { if (ffi::Optional opt_f = FindLowestVarDef(var, d)) { ExprDoc lhs = DefineVar(var, opt_f.value(), d); @@ -82,17 +82,17 @@ Doc PrintVar(const tir::Var& var, const AccessPath& var_p, const IRDocsifier& d) } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // - .set_dispatch("", [](tir::Var var, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Var var, AccessPath p, IRDocsifier d) -> Doc { return PrintVar(var, p, d); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // - .set_dispatch("", [](tir::SizeVar var, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::SizeVar var, AccessPath p, IRDocsifier d) -> Doc { return PrintVar(var, p, d); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::IterVar var, AccessPath var_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::IterVar var, AccessPath var_p, IRDocsifier d) -> Doc { return TIR(d, "iter_var") ->Call({ d->AsDoc(var->var, var_p->Attr("var")), @@ -103,7 +103,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Not node, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Not node, AccessPath p, IRDocsifier d) -> Doc { ExprDoc a = d->AsDoc(node->a, p->Attr("a")); if (a->IsInstance()) { return TIR(d, "Not")->Call({a}); @@ -112,7 +112,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::StringImm s, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::StringImm s, AccessPath p, IRDocsifier d) -> Doc { if (HasMultipleLines(s->value)) { return d->AddMetadata(s); } else { @@ -121,14 +121,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Cast cast, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Cast cast, AccessPath p, IRDocsifier d) -> Doc { ExprDoc dtype = LiteralDoc::DataType(cast->dtype, p->Attr("dtype")); ExprDoc value = d->AsDoc(cast->value, p->Attr("value")); return TIR(d, "Cast")->Call({dtype, value}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Select select, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Select select, AccessPath p, IRDocsifier d) -> Doc { return TIR(d, "Select") ->Call({ d->AsDoc(select->condition, p->Attr("condition")), @@ -138,7 +138,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Ramp ramp, AccessPath ramp_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Ramp ramp, AccessPath ramp_p, IRDocsifier d) -> Doc { return TIR(d, "Ramp")->Call({ d->AsDoc(ramp->base, ramp_p->Attr("base")), d->AsDoc(ramp->stride, ramp_p->Attr("stride")), @@ -147,17 +147,18 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Broadcast bc, AccessPath bc_p, IRDocsifier d) -> Doc { - return TIR(d, "Broadcast") - ->Call({ - d->AsDoc(bc->value, bc_p->Attr("value")), - d->AsDoc(bc->lanes, bc_p->Attr("lanes")), - }); - }); + .set_dispatch("", + [](tirx::Broadcast bc, AccessPath bc_p, IRDocsifier d) -> Doc { + return TIR(d, "Broadcast") + ->Call({ + d->AsDoc(bc->value, bc_p->Attr("value")), + d->AsDoc(bc->lanes, bc_p->Attr("lanes")), + }); + }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](tir::Shuffle shuffle, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch( // + "", [](tirx::Shuffle shuffle, AccessPath p, IRDocsifier d) -> Doc { return TIR(d, "Shuffle") ->Call({ d->AsDoc(shuffle->vectors, p->Attr("vectors")), @@ -166,8 +167,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](tir::CommReducer r, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch( // + "", [](tirx::CommReducer r, AccessPath p, IRDocsifier d) -> Doc { TVM_FFI_ICHECK_EQ(r->lhs.size(), r->rhs.size()); ffi::Optional lambda; { @@ -197,7 +198,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return TIR(d, "comm_reducer")->Call({lambda.value(), id}); }); -LambdaDoc PrintIndexMap(const ObjectRef& map, const ffi::Array& vs, +LambdaDoc PrintIndexMap(const ObjectRef& map, const ffi::Array& vs, const AccessPath& vs_p, const ffi::Array& es, const AccessPath& es_p, const IRDocsifier& d) { With f(d, map); @@ -213,12 +214,12 @@ LambdaDoc PrintIndexMap(const ObjectRef& map, const ffi::Array& vs, } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](tir::IndexMap m, AccessPath m_p, IRDocsifier d) -> Doc { + .set_dispatch( // + "", [](tirx::IndexMap m, AccessPath m_p, IRDocsifier d) -> Doc { LambdaDoc map = PrintIndexMap(m, m->initial_indices, m_p->Attr("initial_indices"), m->final_indices, m_p->Attr("final_indices"), d); if (m->inverse_index_map.defined()) { - tir::IndexMap inverse = Downcast(m->inverse_index_map); + tirx::IndexMap inverse = Downcast(m->inverse_index_map); LambdaDoc inv = PrintIndexMap(inverse, inverse->initial_indices, m_p->Attr("inverse_index_map")->Attr("initial_indices"), inverse->final_indices, @@ -230,7 +231,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Let let, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Let let, AccessPath p, IRDocsifier d) -> Doc { DictDoc where({d->AsDoc(let->var, p->Attr("var"))}, {d->AsDoc(let->value, p->Attr("value"))}); return TIR(d, "Let")->Call({d->AsDoc(let->body, p->Attr("body"))}, // @@ -238,12 +239,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Call call, AccessPath call_p, IRDocsifier d) -> Doc { - static const OpAttrMap& op_names = - Op::GetAttrMap("TScriptPrinterName"); - static const OpAttrMap dtype_locations = - Op::GetAttrMap("TScriptDtypePrintLocation"); - tir::ScriptDtypePrintLocation dtype_print_location = tir::ScriptDtypePrintLocation::kNone; + .set_dispatch("", [](tirx::Call call, AccessPath call_p, IRDocsifier d) -> Doc { + static const OpAttrMap& op_names = + Op::GetAttrMap("TScriptPrinterName"); + static const OpAttrMap dtype_locations = + Op::GetAttrMap("TScriptDtypePrintLocation"); + tirx::ScriptDtypePrintLocation dtype_print_location = tirx::ScriptDtypePrintLocation::kNone; ffi::Optional prefix; if (auto optional_op = call->op.as()) { auto op = optional_op.value(); @@ -254,7 +255,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) prefix = TIR(d, name); if (dtype_locations.count(op)) { dtype_print_location = - static_cast(dtype_locations[op].IntValue()); + static_cast(dtype_locations[op].IntValue()); } if (name == "call_llvm_pure_intrin" || name == "call_llvm_intrin") { int n_args = call->args.size(); @@ -264,7 +265,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ffi::Array args; args.reserve(n_args + 1); - if (dtype_print_location == tir::ScriptDtypePrintLocation::kFirst) { + if (dtype_print_location == tirx::ScriptDtypePrintLocation::kFirst) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } @@ -276,7 +277,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) args.push_back(d->AsDoc(call->args[i], call_p->Attr("args")->ArrayItem(i))); } } - if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) { + if (dtype_print_location == tirx::ScriptDtypePrintLocation::kLast) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } return prefix.value()->Call(args); @@ -289,21 +290,21 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ffi::Array args; int n_args = call->args.size(); args.reserve(n_args + 1); - if (dtype_print_location == tir::ScriptDtypePrintLocation::kFirst) { + if (dtype_print_location == tirx::ScriptDtypePrintLocation::kFirst) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } for (int i = 0; i < n_args; ++i) { args.push_back(d->AsDoc(call->args[i], call_p->Attr("args")->ArrayItem(i))); } - if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) { + if (dtype_print_location == tirx::ScriptDtypePrintLocation::kLast) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } return prefix.value()->Call(args); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Reduce r, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Reduce r, AccessPath p, IRDocsifier d) -> Doc { ExprDoc combiner = d->AsDoc(r->combiner, p->Attr("combiner")); ExprDoc source = d->AsDoc(r->source, p->Attr("source")); ExprDoc init = d->AsDoc(r->init, p->Attr("init")); @@ -316,14 +317,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_FFI_THROW(ValueError) << "Reduce should never exist in TIR: " << r; }); -#define TVM_SCRIPT_PRINTER_DEF_BINARY(NodeType, OpString) \ - TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ - .set_dispatch("", \ - [](tir::NodeType node, AccessPath p, IRDocsifier d) -> Doc { \ - ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ - ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ - return TIR(d, OpString)->Call({a, b}); \ - }); +#define TVM_SCRIPT_PRINTER_DEF_BINARY(NodeType, OpString) \ + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ + .set_dispatch("", \ + [](tirx::NodeType node, AccessPath p, IRDocsifier d) -> Doc { \ + ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ + ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ + return TIR(d, OpString)->Call({a, b}); \ + }); bool IsNumber(const ExprDoc& e) { if (const auto* n = e.as()) { @@ -335,11 +336,11 @@ bool IsNumber(const ExprDoc& e) { } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Div node, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Div node, AccessPath p, IRDocsifier d) -> Doc { ExprDoc a = d->AsDoc(node->a, p->Attr("a")); ExprDoc b = d->AsDoc(node->b, p->Attr("b")); PrimExpr ret = tvm::div(node->a, node->b); - if (!ret->IsInstance()) { + if (!ret->IsInstance()) { return TIR(d, "Div")->Call({a, b}); } if ((node->a->dtype.is_int() || node->a->dtype.is_uint()) && @@ -351,12 +352,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) #define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, NodeObj, NodeFunc, OpString, OpKind) \ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ - .set_dispatch( \ - "", [](tir::NodeType node, AccessPath p, IRDocsifier d) -> Doc { \ + .set_dispatch( \ + "", [](tirx::NodeType node, AccessPath p, IRDocsifier d) -> Doc { \ ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ PrimExpr ret = tvm::NodeFunc(node->a, node->b); \ - if (const auto* ret_node = ret.as()) { \ + if (const auto* ret_node = ret.as()) { \ if (ret_node->a.same_as(node->a) && ret_node->b.same_as(node->b)) { \ return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \ } \ @@ -385,38 +386,38 @@ TVM_SCRIPT_PRINTER_DEF_BINARY(Max, "max"); #undef TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR #undef TVM_SCRIPT_PRINTER_DEF_BINARY -TVM_SCRIPT_REPR(tir::VarNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::SizeVarNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::IterVarNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::StringImmNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::CastNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::AddNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::SubNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::MulNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::DivNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::ModNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::FloorDivNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::FloorModNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::MinNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::MaxNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::LTNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::LENode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::EQNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::NENode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::GTNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::GENode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::AndNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::OrNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::NotNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::SelectNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::RampNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::BroadcastNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::LetNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::CallNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::ShuffleNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::CommReducerNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::IndexMapNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::ReduceNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::VarNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::SizeVarNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::IterVarNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::StringImmNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::CastNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::AddNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::SubNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::MulNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::DivNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::ModNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::FloorDivNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::FloorModNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::MinNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::MaxNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::LTNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::LENode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::EQNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::NENode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::GTNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::GENode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::AndNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::OrNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::NotNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::SelectNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::RampNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::BroadcastNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::LetNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::CallNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::ShuffleNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::CommReducerNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::IndexMapNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::ReduceNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tirx/for_loop.cc similarity index 80% rename from src/script/printer/tir/for_loop.cc rename to src/script/printer/tirx/for_loop.cc index ec0a31e44ff2..a4d9d2c4b0f1 100644 --- a/src/script/printer/tir/for_loop.cc +++ b/src/script/printer/tirx/for_loop.cc @@ -23,22 +23,22 @@ namespace script { namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::For loop, AccessPath loop_p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::For loop, AccessPath loop_p, IRDocsifier d) -> Doc { // Step 1. Check syntactic sugar: `T.grid` - std::vector grid; - std::unordered_set grid_loop_vars; + std::vector grid; + std::unordered_set grid_loop_vars; auto f_var_dep = [&grid_loop_vars](const PrimExpr& e) -> bool { - return tir::UsesVar(e, [&grid_loop_vars](const tir::VarNode* v) -> bool { // + return tirx::UsesVar(e, [&grid_loop_vars](const tirx::VarNode* v) -> bool { // return grid_loop_vars.count(v); }); }; if (d->cfg->syntax_sugar) { - for (const tir::ForNode* l = loop.get(); l != nullptr; l = l->body.as()) { + for (const tirx::ForNode* l = loop.get(); l != nullptr; l = l->body.as()) { TVM_FFI_ICHECK(l->loop_var->dtype == l->min->dtype); TVM_FFI_ICHECK(l->loop_var->dtype == l->extent->dtype); - if (l->kind != tir::ForKind::kSerial || // - !tir::is_zero(l->min) || // - !l->annotations.empty() || // + if (l->kind != tirx::ForKind::kSerial || // + !tirx::is_zero(l->min) || // + !l->annotations.empty() || // !l->HasTrivialStep() || f_var_dep(l->extent)) { break; } @@ -55,7 +55,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) lhs.reserve(n); rhs.reserve(n); for (int i = 0; i < n; ++i) { - const tir::ForNode* loop = grid[i]; + const tirx::ForNode* loop = grid[i]; lhs.push_back(DefineVar(loop->loop_var, *f, d)); rhs.push_back(d->AsDoc(loop->extent, loop_p->Attr("extent"))); loop_p = loop_p->Attr("body"); @@ -69,7 +69,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ffi::Optional max = std::nullopt; ffi::Optional annotations = std::nullopt; ffi::Optional thread = std::nullopt; - if (tir::is_zero(loop->min) && loop->HasTrivialStep()) { + if (tirx::is_zero(loop->min) && loop->HasTrivialStep()) { max = d->AsDoc(loop->extent, loop_p->Attr("extent")); } else { min = d->AsDoc(loop->min, loop_p->Attr("min")); @@ -80,25 +80,25 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } bool use_range_sugar = false; ExprDoc prefix{ffi::UnsafeInit()}; - if (loop->kind == tir::ForKind::kSerial) { + if (loop->kind == tirx::ForKind::kSerial) { if (loop->annotations.empty()) { prefix = IdDoc("range"); use_range_sugar = true; } else { prefix = TIR(d, "serial"); } - } else if (loop->kind == tir::ForKind::kParallel) { + } else if (loop->kind == tirx::ForKind::kParallel) { prefix = TIR(d, "parallel"); - } else if (loop->kind == tir::ForKind::kUnrolled) { + } else if (loop->kind == tirx::ForKind::kUnrolled) { prefix = TIR(d, "unroll"); - } else if (loop->kind == tir::ForKind::kVectorized) { + } else if (loop->kind == tirx::ForKind::kVectorized) { prefix = TIR(d, "vectorized"); - } else if (loop->kind == tir::ForKind::kThreadBinding) { + } else if (loop->kind == tirx::ForKind::kThreadBinding) { prefix = TIR(d, "thread_binding"); thread = LiteralDoc::Str(loop->thread_binding.value()->thread_tag, loop_p->Attr("thread_binding")); } else { - TVM_FFI_THROW(ValueError) << "Unknown ForKind: " << tir::ForKind2String(loop->kind); + TVM_FFI_THROW(ValueError) << "Unknown ForKind: " << tirx::ForKind2String(loop->kind); } ffi::Array args; ffi::Array kwargs_keys; @@ -131,7 +131,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return ForDoc(lhs, rhs, (*f)->stmts); }); -TVM_SCRIPT_REPR(tir::ForNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::ForNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tirx/function.cc similarity index 80% rename from src/script/printer/tir/function.cc rename to src/script/printer/tirx/function.cc index bab991bc005f..6c05791bc610 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tirx/function.cc @@ -24,21 +24,21 @@ namespace tvm { namespace script { namespace printer { -bool IsSimpleBuffer(const tir::Buffer& buf) { +bool IsSimpleBuffer(const tirx::Buffer& buf) { if (!buf->strides.empty()) { return false; } for (const PrimExpr& shp_i : buf->shape) { - if (!tir::UndefinedVars(shp_i).empty()) { + if (!tirx::UndefinedVars(shp_i).empty()) { return false; } } for (const PrimExpr& stride_i : buf->strides) { - if (!tir::UndefinedVars(stride_i).empty()) { + if (!tirx::UndefinedVars(stride_i).empty()) { return false; } } - if (!tir::UndefinedVars(buf->elem_offset).empty()) { + if (!tirx::UndefinedVars(buf->elem_offset).empty()) { return false; } else if (buf->elem_offset->IsInstance()) { IntImm elem_offset = Downcast(buf->elem_offset); @@ -47,14 +47,14 @@ bool IsSimpleBuffer(const tir::Buffer& buf) { } } return buf.scope() == "global" && buf->data_alignment == runtime::kAllocAlignment && - buf->offset_factor == 1 && buf->buffer_type == tir::BufferType::kDefault && + buf->offset_factor == 1 && buf->buffer_type == tirx::BufferType::kDefault && !buf->axis_separators.size(); } -int CountVarOccurrence(const tir::PrimFunc& f, const tir::Var& v) { +int CountVarOccurrence(const tirx::PrimFunc& f, const tirx::Var& v) { OccurrenceCounter counter(v.get()); counter(f->body); - for (const tir::Var& v : f->params) { + for (const tirx::Var& v : f->params) { counter(v); } for (const auto& pair : f->buffer_map) { @@ -65,17 +65,17 @@ int CountVarOccurrence(const tir::PrimFunc& f, const tir::Var& v) { } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::PrimFunc func, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::PrimFunc func, AccessPath p, IRDocsifier d) -> Doc { With f(d, func); - (*f)->AddDispatchToken(d, "tir"); + (*f)->AddDispatchToken(d, "tirx"); IdDoc func_name = IdDoc(FindFunctionName(d, func).value_or("main")); d->SetCommonPrefix(func, [](const ObjectRef& obj) { - return obj->IsInstance() || obj->IsInstance(); + return obj->IsInstance() || obj->IsInstance(); }); int n_args = func->params.size(); - std::unordered_map buffer_data_counter; + std::unordered_map buffer_data_counter; for (const auto& pair : func->buffer_map) { - const tir::VarNode* data_var = pair.second->data.get(); + const tirx::VarNode* data_var = pair.second->data.get(); if (!buffer_data_counter.count(data_var)) { buffer_data_counter.insert({data_var, 0}); } @@ -84,13 +84,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Step 1. Handle `func->params` ffi::Array args; args.reserve(n_args); - std::unordered_set buffer_inlined; + std::unordered_set buffer_inlined; for (int i = 0; i < n_args; ++i) { - tir::Var var = func->params[i]; + tirx::Var var = func->params[i]; AccessPath var_p = p->Attr("params")->ArrayItem(i); if (d->cfg->syntax_sugar && CountVarOccurrence(func, var) == 2 && func->buffer_map.count(var)) { - tir::Buffer buffer = func->buffer_map[var]; + tirx::Buffer buffer = func->buffer_map[var]; if (IsSimpleBuffer(buffer) && buffer_data_counter.at(buffer->data.get()) == 1) { AccessPath buffer_p = p->Attr("buffer_map")->MapItem(var); IdDoc lhs = DefineBuffer(buffer, *f, d); @@ -128,9 +128,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } // Step 3. Handle `func->buffer_map` for (int i = 0; i < n_args; ++i) { - tir::Var param = func->params[i]; + tirx::Var param = func->params[i]; if (func->buffer_map.count(param)) { - tir::Buffer buffer = func->buffer_map[param]; + tirx::Buffer buffer = func->buffer_map[param]; if (buffer_inlined.count(buffer.get())) { continue; } @@ -143,19 +143,20 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } // Step 4. Handle `func->body` - ffi::Optional implicit_root_block = [&]() -> ffi::Optional { - const tir::SBlockRealizeNode* root_block_realize = func->body.as(); + ffi::Optional implicit_root_block = [&]() -> ffi::Optional { + const tirx::SBlockRealizeNode* root_block_realize = + func->body.as(); if (root_block_realize && !root_block_realize->iter_values.size() && - tir::is_one(root_block_realize->predicate)) { - tir::SBlock root_block = root_block_realize->block; + tirx::is_one(root_block_realize->predicate)) { + tirx::SBlock root_block = root_block_realize->block; if (!root_block->annotations.size() && !root_block->match_buffers.size() && !root_block->reads.size() && !root_block->writes.size() && !root_block->init.defined()) { - const tir::SBlockRealizeNode* block_realize = - root_block->body.as(); + const tirx::SBlockRealizeNode* block_realize = + root_block->body.as(); if (root_block->alloc_buffers.size() || (block_realize && block_realize->block->iter_vars.size()) || - (!block_realize && tir::ContainsNode(root_block->body))) { + (!block_realize && tirx::ContainsNode(root_block->body))) { return root_block; } } @@ -163,12 +164,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return std::nullopt; }(); if (d->cfg->syntax_sugar && implicit_root_block) { - tir::SBlock root_block = implicit_root_block.value(); + tirx::SBlock root_block = implicit_root_block.value(); AccessPath root_block_p = p->Attr("body")->Attr("block"); (*f)->stmts.push_back(CommentDoc("with T.sblock(\"root\"):")); // Handle root block `alloc_buffer` for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) { - tir::Buffer buffer = root_block->alloc_buffers[i]; + tirx::Buffer buffer = root_block->alloc_buffers[i]; AccessPath buffer_p = root_block_p->Attr("alloc_buffers")->ArrayItem(i); IdDoc lhs = DefineBuffer(buffer, *f, d); ExprDoc rhs = BufferDecl(buffer, "sblock_alloc_buffer", {}, buffer_p, *f, d, @@ -203,11 +204,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) /*body=*/(*f)->stmts)); }); -TVM_SCRIPT_REPR(tir::PrimFuncNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::PrimFuncNode, ReprPrintTIR); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "tir", [](tvm::GlobalVar n, AccessPath n_p, IRDocsifier d) -> Doc { // + .set_dispatch( // + "tirx", [](tvm::GlobalVar n, AccessPath n_p, IRDocsifier d) -> Doc { // if (ffi::Optional doc = d->GetVarDoc(n)) { return doc.value(); } else { @@ -218,8 +219,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "tir", [](tvm::IRModule mod, AccessPath n_p, IRDocsifier d) -> Doc { // + .set_dispatch( // + "tirx", [](tvm::IRModule mod, AccessPath n_p, IRDocsifier d) -> Doc { // ffi::Optional doc = d->GetVarDoc(mod); TVM_FFI_ICHECK(doc) << "Unable to print IRModule before definition in TIR."; return doc.value(); diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tirx/ir.cc similarity index 97% rename from src/script/printer/tir/ir.cc rename to src/script/printer/tirx/ir.cc index a2409865cbb1..4a7517599b10 100644 --- a/src/script/printer/tir/ir.cc +++ b/src/script/printer/tirx/ir.cc @@ -51,7 +51,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("tir", [](Range range, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("tirx", [](Range range, AccessPath p, IRDocsifier d) -> Doc { return TIR(d, "Range") ->Call({ d->AsDoc(range->min, p->Attr("min")), diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tirx/stmt.cc similarity index 80% rename from src/script/printer/tir/stmt.cc rename to src/script/printer/tirx/stmt.cc index 900bf6a2f24d..1c861f91dce2 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tirx/stmt.cc @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -#include "../../../tir/transform/ir_utils.h" // For `GetPtrStorageScope` +#include "../../../tirx/transform/ir_utils.h" // For `GetPtrStorageScope` #include "./utils.h" namespace tvm { @@ -52,7 +52,7 @@ bool AllowConciseScoping(const IRDocsifier& d, const ObjectRef& obj) { TVM_FFI_UNREACHABLE(); } -bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const IRDocsifier& d) { +bool IsAncestorOfAllVarUse(const tirx::Stmt& node, const ObjectRef& var, const IRDocsifier& d) { if (!d->common_prefix.count(var.get())) { return false; } @@ -65,14 +65,14 @@ bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const IR return false; } -ffi::Optional FindReturnValue(const tir::Stmt& node) { - auto eval = node.as(); +ffi::Optional FindReturnValue(const tirx::Stmt& node) { + auto eval = node.as(); if (!eval) return std::nullopt; - auto call = eval->value.as(); + auto call = eval->value.as(); if (!call) return std::nullopt; - if (!call->op.same_as(tir::builtin::ret())) return std::nullopt; + if (!call->op.same_as(tirx::builtin::ret())) return std::nullopt; if (call->args.size() != 1) return std::nullopt; @@ -80,7 +80,7 @@ ffi::Optional FindReturnValue(const tir::Stmt& node) { } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Evaluate eval, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Evaluate eval, AccessPath p, IRDocsifier d) -> Doc { if (d->cfg->syntax_sugar) { if (auto return_value = FindReturnValue(eval)) { ExprDoc value = @@ -90,14 +90,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } ExprDoc value = d->AsDoc(eval->value, p->Attr("value")); - if (eval->value->IsInstance()) { + if (eval->value->IsInstance()) { return ExprStmtDoc(value); } return ExprStmtDoc(TIR(d, "evaluate")->Call({value})); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::Bind stmt, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::Bind stmt, AccessPath p, IRDocsifier d) -> Doc { // Step 1. Type annotation ffi::Optional type_doc = d->AsDoc(stmt->var->type_annotation, // p->Attr("var")->Attr("type_annotation")); @@ -120,8 +120,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( - "", [](tir::AssertStmt stmt, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch( + "", [](tirx::AssertStmt stmt, AccessPath p, IRDocsifier d) -> Doc { ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); // Always emit the canonical tuple form: assert cond, ("Kind", ["part0", "part1", ...]) ffi::Array parts; @@ -134,7 +134,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::While stmt, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::While stmt, AccessPath p, IRDocsifier d) -> Doc { ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); With f(d, stmt); AsDocBody(stmt->body, p->Attr("body"), f->get(), d); @@ -142,7 +142,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); namespace { -Doc DeclBufferDoc(tir::DeclBuffer stmt, AccessPath p, IRDocsifier d, +Doc DeclBufferDoc(tirx::DeclBuffer stmt, AccessPath p, IRDocsifier d, BufferVarDefinition var_definitions) { ExprDoc rhs = BufferDecl(stmt->buffer, "decl_buffer", {}, p->Attr("buffer"), d->frames.back(), d, var_definitions); @@ -152,14 +152,14 @@ Doc DeclBufferDoc(tir::DeclBuffer stmt, AccessPath p, IRDocsifier d, } // namespace TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](tir::DeclBuffer stmt, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch( // + "", [](tirx::DeclBuffer stmt, AccessPath p, IRDocsifier d) -> Doc { return DeclBufferDoc(stmt, p, d, BufferVarDefinition::None); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](tir::IfThenElse stmt, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch( // + "", [](tirx::IfThenElse stmt, AccessPath p, IRDocsifier d) -> Doc { ExprDoc cond = d->AsDoc(stmt->condition, p->Attr("condition")); ffi::Array then_branch; ffi::Array else_branch; @@ -177,13 +177,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("", [](tir::SeqStmt stmt, AccessPath p, IRDocsifier d) -> Doc { + .set_dispatch("", [](tirx::SeqStmt stmt, AccessPath p, IRDocsifier d) -> Doc { With f(d, stmt); AsDocBody(stmt, p, f->get(), d); return StmtBlockDoc((*f)->stmts); }); -void InsertEnvThread(const tir::IterVar& iter_var, const AccessPath& iter_var_p, +void InsertEnvThread(const tirx::IterVar& iter_var, const AccessPath& iter_var_p, const IRDocsifier& d) { Frame f = FindLowestVarDef(iter_var->var, d).value(); DefineVar(iter_var->var, f, d); @@ -194,9 +194,9 @@ void InsertEnvThread(const tir::IterVar& iter_var, const AccessPath& iter_var_p, f->stmts.push_back(AssignDoc(lhs, rhs, std::nullopt)); } -ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const AccessPath& attr_stmt_p, - ffi::Optional* define_var, const IRDocsifier& d) { - tir::IterVar iter_var = Downcast(attr_stmt->node); +ExprDoc DocsifyLaunchThread(const tirx::AttrStmt& attr_stmt, const AccessPath& attr_stmt_p, + ffi::Optional* define_var, const IRDocsifier& d) { + tirx::IterVar iter_var = Downcast(attr_stmt->node); AccessPath iter_var_p = attr_stmt_p->Attr("node"); ExprDoc var_doc{ffi::UnsafeInit()}; @@ -217,16 +217,16 @@ ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const AccessPath& at } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](tir::AttrStmt stmt, AccessPath stmt_p, IRDocsifier d) -> Doc { + .set_dispatch( // + "", [](tirx::AttrStmt stmt, AccessPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d, stmt); ffi::Optional lhs = std::nullopt; ffi::Optional rhs = std::nullopt; - ffi::Optional define_var = std::nullopt; - tir::Stmt body = stmt->body; + ffi::Optional define_var = std::nullopt; + tirx::Stmt body = stmt->body; AccessPath body_p = stmt_p->Attr("body"); if (stmt->attr_key == "thread_extent" || stmt->attr_key == "virtual_thread") { - if (stmt->node.as()) { + if (stmt->node.as()) { rhs = DocsifyLaunchThread(stmt, stmt_p, &define_var, d); } } @@ -245,14 +245,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return DoConciseScoping(lhs, rhs.value(), &(*f)->stmts, concise); }); -TVM_SCRIPT_REPR(tir::BindNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::AttrStmtNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::AssertStmtNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::WhileNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::BindNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::AttrStmtNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::AssertStmtNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::WhileNode, ReprPrintTIR); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](tir::AllocBuffer stmt, AccessPath p, IRDocsifier d) -> Doc { - tir::Buffer buffer = stmt->buffer; + .set_dispatch( // + "", [](tirx::AllocBuffer stmt, AccessPath p, IRDocsifier d) -> Doc { + tirx::Buffer buffer = stmt->buffer; AccessPath buffer_p = p->Attr("buffer"); Frame frame = d->frames.back(); // Define buffer's data var inline as buffer.data @@ -275,11 +275,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) for (int i = 0; i < n; ++i) { PrimExpr e = buffer->shape[i]; AccessPath e_p = shape_p->ArrayItem(i); - if (!d->IsVarDefined(e) && e->IsInstance()) { - ExprDoc lhs = DefineVar(Downcast(e), frame, d); + if (!d->IsVarDefined(e) && e->IsInstance()) { + ExprDoc lhs = DefineVar(Downcast(e), frame, d); lhs->source_paths.push_back(e_p); frame->stmts.push_back( - AssignDoc(lhs, PrintVarCreation(Downcast(e), e_p, d), std::nullopt)); + AssignDoc(lhs, PrintVarCreation(Downcast(e), e_p, d), std::nullopt)); } shape_docs.push_back(d->AsDoc(e, e_p)); } @@ -308,11 +308,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return AssignDoc(lhs, rhs, std::nullopt); }); -TVM_SCRIPT_REPR(tir::AllocBufferNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::DeclBufferNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::SeqStmtNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::IfThenElseNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::EvaluateNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::AllocBufferNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::DeclBufferNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::SeqStmtNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::IfThenElseNode, ReprPrintTIR); +TVM_SCRIPT_REPR(tirx::EvaluateNode, ReprPrintTIR); } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tirx/utils.h similarity index 80% rename from src/script/printer/tir/utils.h rename to src/script/printer/tirx/utils.h index 736b3c62b56a..fb512769b0a5 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tirx/utils.h @@ -21,14 +21,14 @@ #include #include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include @@ -45,14 +45,14 @@ namespace printer { class TIRFrameNode : public FrameNode { public: /*! \brief The TIR fragment the frame corresponds to */ - ObjectRef tir; + ObjectRef tirx; /*! \brief Whether or not the frame allows concise scoping */ bool allow_concise_scoping{false}; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("tir", &TIRFrameNode::tir) + .def_ro("tirx", &TIRFrameNode::tirx) .def_ro("allow_concise_scoping", &TIRFrameNode::allow_concise_scoping); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.printer.TIRFrame", TIRFrameNode, FrameNode); @@ -62,11 +62,11 @@ class TIRFrameNode : public FrameNode { class TIRFrame : public Frame { public: /*! \brief Constructor */ - explicit TIRFrame(const IRDocsifier& d, const ObjectRef& tir) { + explicit TIRFrame(const IRDocsifier& d, const ObjectRef& tirx) { ObjectPtr n = ffi::make_object(); n->stmts.clear(); n->d = d.get(); - n->tir = tir; + n->tirx = tirx; data_ = std::move(n); } @@ -81,7 +81,7 @@ class TIRFrame : public Frame { * \param frame The frame to define the variable in * \return The IdDoc corresponding to the variable */ -inline ExprDoc DefineVar(const tir::Var& var, const Frame& frame, const IRDocsifier& d) { +inline ExprDoc DefineVar(const tirx::Var& var, const Frame& frame, const IRDocsifier& d) { if (ffi::Optional doc = d->GetVarDoc(var)) { return doc.value(); } @@ -96,7 +96,7 @@ inline ExprDoc DefineVar(const tir::Var& var, const Frame& frame, const IRDocsif * \param d The IRDocsifier * \return The IdDoc corresponding to the buffer */ -inline IdDoc DefineBuffer(const tir::Buffer& buffer, const Frame& frame, const IRDocsifier& d) { +inline IdDoc DefineBuffer(const tirx::Buffer& buffer, const Frame& frame, const IRDocsifier& d) { return d->Define(buffer, frame, buffer->name.empty() ? "buffer" : buffer->name); } @@ -107,9 +107,9 @@ inline IdDoc DefineBuffer(const tir::Buffer& buffer, const Frame& frame, const I * \param f The frame * \param d The IRDocsifier */ -inline void AsDocBody(const tir::Stmt& stmt, AccessPath p, TIRFrameNode* f, const IRDocsifier& d) { - if (const auto* seq_stmt = stmt.as()) { - ffi::Array body = seq_stmt->seq; +inline void AsDocBody(const tirx::Stmt& stmt, AccessPath p, TIRFrameNode* f, const IRDocsifier& d) { + if (const auto* seq_stmt = stmt.as()) { + ffi::Array body = seq_stmt->seq; for (int i = 0, n = body.size(); i < n; ++i) { f->allow_concise_scoping = (i == n - 1); Doc doc = d->AsDoc(body[i], p->Attr("seq")->ArrayItem(i)); @@ -147,8 +147,8 @@ inline ffi::Optional FindLowestVarDef(const ObjectRef& var, const IRDocsi tir_to_frame.reserve(n_frames); for (int i = n_frames - 1; i >= 0; --i) { if (const auto* f = d->frames[i].as()) { - if (f->tir.defined()) { - tir_to_frame[f->tir.get()] = f; + if (f->tirx.defined()) { + tir_to_frame[f->tirx.get()] = f; } else if (fallback_frame == nullptr) { fallback_frame = f; } @@ -170,10 +170,10 @@ inline ffi::Optional FindLowestVarDef(const ObjectRef& var, const IRDocsi inline std::string ReprPrintTIR(const ObjectRef& obj, const PrinterConfig& cfg) { IRDocsifier d(cfg); d->SetCommonPrefix(obj, [](const ObjectRef& obj) { - return obj->IsInstance() || obj->IsInstance(); + return obj->IsInstance() || obj->IsInstance(); }); With f(d, ObjectRef{nullptr}); - (*f)->AddDispatchToken(d, "tir"); + (*f)->AddDispatchToken(d, "tirx"); return Docsify(obj, d, *f, cfg); } @@ -212,7 +212,7 @@ enum class BufferVarDefinition { * the buffer. * \return The ExprDoc corresponding to the buffer declaration */ -ExprDoc BufferDecl(const tir::Buffer& buffer, const ffi::String& method, +ExprDoc BufferDecl(const tirx::Buffer& buffer, const ffi::String& method, const ffi::Array& args, const AccessPath& p, const Frame& frame, const IRDocsifier& d, BufferVarDefinition var_definitions); @@ -224,7 +224,7 @@ ExprDoc BufferDecl(const tir::Buffer& buffer, const ffi::String& method, * \param d The IRDocsifier * \return The ExprDoc corresponding to the buffer declaration */ -ExprDoc BufferAttn(const tir::Buffer& buffer, const AccessPath& p, const Frame& frame, +ExprDoc BufferAttn(const tirx::Buffer& buffer, const AccessPath& p, const Frame& frame, const IRDocsifier& d); /*! @@ -234,44 +234,44 @@ ExprDoc BufferAttn(const tir::Buffer& buffer, const AccessPath& p, const Frame& * \param d The IRDocsifier * \return The ExprDoc corresponding to the Var creation */ -ExprDoc PrintVarCreation(const tir::Var& var, const AccessPath& var_p, const IRDocsifier& d); +ExprDoc PrintVarCreation(const tirx::Var& var, const AccessPath& var_p, const IRDocsifier& d); /*! \brief A Var occurrence counter visitor */ -class OccurrenceCounter : public tir::StmtExprVisitor { +class OccurrenceCounter : public tirx::StmtExprVisitor { public: /*! \brief The occurrence counter */ int count = 0; /*! \brief The Var to count occurrence */ - const tir::VarNode* v = nullptr; + const tirx::VarNode* v = nullptr; - void VisitExpr_(const tir::VarNode* op) final { + void VisitExpr_(const tirx::VarNode* op) final { if (op == v) { ++count; } - tir::StmtExprVisitor::VisitExpr_(op); + tirx::StmtExprVisitor::VisitExpr_(op); } - void VisitStmt_(const tir::BufferStoreNode* op) final { + void VisitStmt_(const tirx::BufferStoreNode* op) final { VisitBuffer(op->buffer.get()); - tir::StmtExprVisitor::VisitStmt_(op); + tirx::StmtExprVisitor::VisitStmt_(op); } - void VisitExpr_(const tir::BufferLoadNode* op) final { + void VisitExpr_(const tirx::BufferLoadNode* op) final { VisitBuffer(op->buffer.get()); - tir::StmtExprVisitor::VisitExpr_(op); + tirx::StmtExprVisitor::VisitExpr_(op); } - void VisitStmt_(const tir::AllocBufferNode* op) final { + void VisitStmt_(const tirx::AllocBufferNode* op) final { VisitBuffer(op->buffer.get()); - tir::StmtExprVisitor::VisitStmt_(op); + tirx::StmtExprVisitor::VisitStmt_(op); } - void VisitStmt_(const tir::DeclBufferNode* op) final { + void VisitStmt_(const tirx::DeclBufferNode* op) final { VisitBuffer(op->buffer.get()); - tir::StmtExprVisitor::VisitStmt_(op); + tirx::StmtExprVisitor::VisitStmt_(op); } - void VisitBuffer(const tir::BufferNode* buffer) { + void VisitBuffer(const tirx::BufferNode* buffer) { VisitExpr(buffer->data); for (const PrimExpr& shape_i : buffer->shape) { VisitExpr(shape_i); @@ -282,7 +282,7 @@ class OccurrenceCounter : public tir::StmtExprVisitor { VisitExpr(buffer->elem_offset); } - explicit OccurrenceCounter(const tir::VarNode* var) { v = var; } + explicit OccurrenceCounter(const tirx::VarNode* var) { v = var; } }; } // namespace printer diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index ddeab5e754d2..78f12d4983d6 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -104,7 +104,7 @@ inline ExprDoc IR(const IRDocsifier& d, const ffi::String& attr) { /*! \brief Creates the TIR common prefix, which is by default `T` */ inline ExprDoc TIR(const IRDocsifier& d, const ffi::String& attr) { - d->ir_usage.insert("tir"); + d->ir_usage.insert("tirx"); return IdDoc(d->cfg->tir_prefix)->Attr(attr); } @@ -125,8 +125,8 @@ inline Doc HeaderWrapper(const IRDocsifier& d, const Doc& doc) { if (d->ir_usage.count("ir")) { stmts.push_back(CommentDoc("from tvm.script import ir as " + d->cfg->ir_prefix)); } - if (d->ir_usage.count("tir")) { - stmts.push_back(CommentDoc("from tvm.script import tir as " + d->cfg->tir_prefix)); + if (d->ir_usage.count("tirx")) { + stmts.push_back(CommentDoc("from tvm.script import tirx as " + d->cfg->tir_prefix)); } if (d->ir_usage.count("relax")) { stmts.push_back(CommentDoc("from tvm.script import relax as " + d->cfg->relax_prefix)); diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index d82c96f2636e..29d671108d82 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/support/nd_int_set.h b/src/support/nd_int_set.h index 01a9ee8ae398..012aefd18b4e 100644 --- a/src/support/nd_int_set.h +++ b/src/support/nd_int_set.h @@ -36,7 +36,7 @@ using NDIntSet = std::vector; * \param region The region. * \return The constructed set. */ -inline NDIntSet NDIntSetFromRegion(const tir::Region& region) { +inline NDIntSet NDIntSetFromRegion(const tirx::Region& region) { NDIntSet result; result.reserve(region.size()); for (const Range& range : region) { @@ -135,7 +135,7 @@ inline NDIntSet NDIntSetEmpty(int ndim) { */ inline NDIntSet NDIntSetEval( const NDIntSet& nd_int_set, - const std::unordered_map& dom_map) { + const std::unordered_map& dom_map) { NDIntSet ret; ret.reserve(nd_int_set.size()); for (const arith::IntSet& s : nd_int_set) { diff --git a/src/target/build_common.h b/src/target/build_common.h index b1192eeca8e0..aebdc992a566 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -28,9 +28,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include @@ -43,15 +43,15 @@ inline ffi::Map ExtractFuncInfo(const IRModu ffi::Map fmap; for (auto kv : mod->functions) { - TVM_FFI_ICHECK(kv.second->IsInstance()) + TVM_FFI_ICHECK(kv.second->IsInstance()) << "Can only lower IR Module with PrimFuncs"; - auto f = Downcast(kv.second); + auto f = Downcast(kv.second); ffi::Array arg_types; ffi::Array arg_extra_tags; for (size_t i = 0; i < f->params.size(); ++i) { arg_types.push_back(f->params[i].dtype()); - auto is_tensormap = [](const tir::Var& var) -> bool { + auto is_tensormap = [](const tirx::Var& var) -> bool { const auto* type = var->type_annotation.as(); if (type == nullptr) { return false; @@ -62,7 +62,7 @@ inline ffi::Map ExtractFuncInfo(const IRModu : runtime::ArgExtraTags::kNone); } ffi::Array launch_param_tags; - if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { + if (auto opt = f->GetAttr>(tirx::attr::kKernelLaunchParams)) { for (const auto& tag : opt.value()) { launch_param_tags.push_back(tag); } diff --git a/src/target/codegen.cc b/src/target/codegen.cc index b1bb9ae3e2dc..714b0d7256ba 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -30,8 +30,8 @@ #include #include #include -#include -#include +#include +#include #include #include @@ -46,9 +46,9 @@ namespace codegen { ffi::Module Build(IRModule mod, Target target) { if (transform::PassContext::Current() - ->GetConfig("tir.disable_assert", Bool(false)) + ->GetConfig("tirx.disable_assert", Bool(false)) .value()) { - mod = tir::transform::SkipAssert()(mod); + mod = tirx::transform::SkipAssert()(mod); } // the build function. diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 91701b067b47..31e8b6a83290 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -23,106 +23,106 @@ */ #include "intrin_rule.h" -#include -#include -#include +#include +#include +#include namespace tvm { namespace codegen { namespace intrin { -using tir::FLowerIntrinsic; +using tirx::FLowerIntrinsic; -TVM_REGISTER_OP("tir.exp").set_attr("default.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.exp") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.erf").set_attr("default.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.erf") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.log").set_attr("default.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.log") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.log2") +TVM_REGISTER_OP("tirx.log2") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.log10") +TVM_REGISTER_OP("tirx.log10") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.log1p") +TVM_REGISTER_OP("tirx.log1p") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.tanh") +TVM_REGISTER_OP("tirx.tanh") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.tan").set_attr("default.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.tan") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.trunc") +TVM_REGISTER_OP("tirx.trunc") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.atan") +TVM_REGISTER_OP("tirx.atan") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.atanh") +TVM_REGISTER_OP("tirx.atanh") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.atan2") +TVM_REGISTER_OP("tirx.atan2") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.cos").set_attr("default.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.cos") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.acos") +TVM_REGISTER_OP("tirx.acos") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.cosh") +TVM_REGISTER_OP("tirx.cosh") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.acosh") +TVM_REGISTER_OP("tirx.acosh") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.sin").set_attr("default.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.sin") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.asin") +TVM_REGISTER_OP("tirx.asin") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.sinh") +TVM_REGISTER_OP("tirx.sinh") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.asinh") +TVM_REGISTER_OP("tirx.asinh") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.hypot") +TVM_REGISTER_OP("tirx.hypot") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.nextafter") +TVM_REGISTER_OP("tirx.nextafter") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.copysign") +TVM_REGISTER_OP("tirx.copysign") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.ldexp") +TVM_REGISTER_OP("tirx.ldexp") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.sqrt") +TVM_REGISTER_OP("tirx.sqrt") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.floor") +TVM_REGISTER_OP("tirx.floor") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.ceil") +TVM_REGISTER_OP("tirx.ceil") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.round") +TVM_REGISTER_OP("tirx.round") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.nearbyint") +TVM_REGISTER_OP("tirx.nearbyint") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.pow").set_attr("default.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.pow") + .set_attr("default.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.tvm_access_ptr") +TVM_REGISTER_OP("tirx.tvm_access_ptr") .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); @@ -158,9 +158,9 @@ PrimExpr DispatchFastErf(const PrimExpr& e) { } PrimExpr DispatchNumericalStableTanh(const PrimExpr& e) { - using tir::make_const; - using tir::make_zero; - const tir::CallNode* call = e.as(); + using tirx::make_const; + using tirx::make_zero; + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; PrimExpr one = make_const(x.dtype(), 1); @@ -172,16 +172,16 @@ PrimExpr DispatchNumericalStableTanh(const PrimExpr& e) { PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); - return tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg); + return tirx::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg); } } // namespace intrin namespace legalize { -using namespace tir; +using namespace tirx; -TVM_REGISTER_OP("tir.rsqrt") +TVM_REGISTER_OP("tirx.rsqrt") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); @@ -189,7 +189,7 @@ TVM_REGISTER_OP("tir.rsqrt") return one / sqrt(call->args[0]); }); -TVM_REGISTER_OP("tir.sigmoid") +TVM_REGISTER_OP("tirx.sigmoid") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); @@ -197,14 +197,14 @@ TVM_REGISTER_OP("tir.sigmoid") return one / (one + exp(-call->args[0])); }); -TVM_REGISTER_OP("tir.isfinite") +TVM_REGISTER_OP("tirx.isfinite") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); return isfinite(call->args[0]); }); -TVM_REGISTER_OP("tir.isinf") +TVM_REGISTER_OP("tirx.isinf") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); @@ -236,7 +236,7 @@ static PrimExpr QMultiplyShift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr left PrimExpr one = make_const(hp_dtype, 1); x = cast(hp_dtype, x); y = cast(hp_dtype, y); - x = tir::Select(is_left_shift_required, x << left_shift, x); + x = tirx::Select(is_left_shift_required, x << left_shift, x); // 2) Perform the multiplication in higher precision. x = x * y; @@ -253,11 +253,11 @@ static PrimExpr QMultiplyShift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr left return cast(lp_dtype, x); } -TVM_REGISTER_OP("tir.q_multiply_shift") +TVM_REGISTER_OP("tirx.q_multiply_shift") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tir::make_const; + using tirx::make_const; - const tir::CallNode* call = e.as(); + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); PrimExpr x = call->args[0]; @@ -300,17 +300,17 @@ TVM_REGISTER_OP("tir.q_multiply_shift") // Calculating integer shifts PrimExpr zero = make_const(s.dtype(), 0); - PrimExpr left_shift = tir::Select(s > zero, s, zero); - PrimExpr right_shift = tir::Select(s > zero, zero, -s); + PrimExpr left_shift = tirx::Select(s > zero, s, zero); + PrimExpr right_shift = tirx::Select(s > zero, zero, -s); PrimExpr is_left_shift_required = (left_shift != zero); return QMultiplyShift(x, y, q, left_shift, right_shift, is_left_shift_required); } }); -TVM_REGISTER_OP("tir.q_multiply_shift_per_axis") +TVM_REGISTER_OP("tirx.q_multiply_shift_per_axis") .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { - const tir::CallNode* call = e.as(); + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); PrimExpr x = call->args[0]; diff --git a/src/target/intrin_rule.h b/src/target/intrin_rule.h index 3f5ac43211ea..a5f5a8931283 100644 --- a/src/target/intrin_rule.h +++ b/src/target/intrin_rule.h @@ -25,15 +25,15 @@ #define TVM_TARGET_INTRIN_RULE_H_ #include -#include -#include +#include +#include #include namespace tvm { namespace codegen { namespace intrin { -using namespace tir; +using namespace tirx; // Add float suffix to the intrinsics struct FloatSuffix { @@ -68,7 +68,7 @@ inline PrimExpr DispatchPureExtern(const PrimExpr& e) { const OpNode* op = call->op.as(); TVM_FFI_ICHECK(op != nullptr); std::string name = op->name; - TVM_FFI_ICHECK_EQ(name.substr(0, 4), "tir."); + TVM_FFI_ICHECK_EQ(name.substr(0, 5), "tirx."); DataType dtype; if (dtype_from_arg) { TVM_FFI_ICHECK_EQ(call->args.size(), 1U); @@ -76,7 +76,7 @@ inline PrimExpr DispatchPureExtern(const PrimExpr& e) { } else { dtype = call->dtype; } - name = T()(dtype, name.substr(4)); + name = T()(dtype, name.substr(5)); if (name.length() != 0) { ffi::Array new_args = {StringImm(name)}; diff --git a/src/target/llvm/codegen_aarch64.cc b/src/target/llvm/codegen_aarch64.cc index ced6d8dc2d50..7c328f18ab12 100644 --- a/src/target/llvm/codegen_aarch64.cc +++ b/src/target/llvm/codegen_aarch64.cc @@ -74,7 +74,7 @@ void CodeGenAArch64::SetTargetAttributes(llvm::Function* func) { void CodeGenAArch64::VisitStmt_(const AttrStmtNode* op) { std::string attr_key = op->attr_key; - if (!tir::attr::IsPragmaKey(attr_key)) { + if (!tirx::attr::IsPragmaKey(attr_key)) { CodeGenCPU::VisitStmt_(op); return; } diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 22636bce1591..a524dcce6df1 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -140,7 +140,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { buf, llvmGetPointerTo(DTypeToLLVMType(dtype), buf->getType()->getPointerAddressSpace())); TVM_FFI_ICHECK(!var_map_.count(op->buffer->data.get())); var_map_[op->buffer->data.get()] = buf; - if (op->annotations.count(tir::attr::kVolatile)) { + if (op->annotations.count(tirx::attr::kVolatile)) { volatile_buf_.insert(op->buffer->data.get()); } } diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc index 5f41bf569947..8cbdf0b19012 100644 --- a/src/target/llvm/codegen_arm.cc +++ b/src/target/llvm/codegen_arm.cc @@ -64,7 +64,7 @@ llvm::Value* CodeGenARM::CreateIntrinsic(const CallNode* op) { } PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { - using namespace tir; + using namespace tirx; const PrimExpr& e = call->args[1]; llvm::Intrinsic::ID ctpop_id = llvm::Intrinsic::ctpop; llvm::Intrinsic::ID vpaddlu_id = llvm::Intrinsic::arm_neon_vpaddlu; @@ -76,7 +76,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { ffi::Array vcnt_args; vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt_args.push_back(e); - return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt_args); + return tirx::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt_args); } // Popcount lowering rule: @@ -99,13 +99,13 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { ffi::Array vcnt8_args; vcnt8_args.push_back(IntImm(DataType::UInt(32), ctpop_id)); vcnt8_args.push_back(input8); - PrimExpr vcnt8 = tir::Call(uint8_type, builtin_call_llvm_pure_intrin_, vcnt8_args); + PrimExpr vcnt8 = tirx::Call(uint8_type, builtin_call_llvm_pure_intrin_, vcnt8_args); // Accumulation 8->16bit ffi::Array vcnt16_args; vcnt16_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt16_args.push_back(vcnt8); - PrimExpr vcnt16 = tir::Call(uint16_type, builtin_call_llvm_pure_intrin_, vcnt16_args); + PrimExpr vcnt16 = tirx::Call(uint16_type, builtin_call_llvm_pure_intrin_, vcnt16_args); if (call->dtype.bits() == 16) { return vcnt16; } @@ -114,7 +114,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { ffi::Array vcnt32_args; vcnt32_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt32_args.push_back(vcnt16); - PrimExpr vcnt32 = tir::Call(uint32_type, builtin_call_llvm_pure_intrin_, vcnt32_args); + PrimExpr vcnt32 = tirx::Call(uint32_type, builtin_call_llvm_pure_intrin_, vcnt32_args); if (call->dtype.bits() == 32) { return vcnt32; } @@ -123,7 +123,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) { ffi::Array vcnt64_args; vcnt64_args.push_back(IntImm(DataType::UInt(32), vpaddlu_id)); vcnt64_args.push_back(vcnt32); - return tir::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args); + return tirx::Call(call->dtype, builtin_call_llvm_pure_intrin_, vcnt64_args); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 8f9ee62cc878..2536386f669f 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -50,7 +50,7 @@ #include #include #include -#include +#include #include #include @@ -526,7 +526,7 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { // - Make sure the generated compute function is clearly separately(though it can get inlined) // - Set noalias on all the pointer arguments, some of them are loaded from ffi::PackedArgs. // This is easier than set the alias scope manually. - ffi::Array vargs = tir::UndefinedVars(op->body, {}); + ffi::Array vargs = tirx::UndefinedVars(op->body, {}); std::vector arg_values; std::vector arg_types; for (Var v : vargs) { @@ -633,7 +633,7 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task, std::strin SetTargetAttributes(f); // allocate and setup the closure, call the closure. - ffi::Array vfields = tir::UndefinedVars(body, {}); + ffi::Array vfields = tirx::UndefinedVars(body, {}); uint64_t nbytes; TypedPointer cdata = PackClosureData(vfields, &nbytes, "closure_" + name); auto launch_callee = llvm::FunctionCallee(ftype_tvm_parallel_launch_, RuntimeTVMParallelLaunch()); @@ -702,7 +702,7 @@ void CodeGenCPU::CreateStaticInit(const std::string& init_fname, const Stmt& bod } // allocate and setup the closure, call the closure. uint64_t nbytes; - ffi::Array vfields = tir::UndefinedVars(body, {}); + ffi::Array vfields = tirx::UndefinedVars(body, {}); TypedPointer cdata = PackClosureData(vfields, &nbytes); llvm::BasicBlock* init_end = CheckCallSuccess(builder_->CreateCall( finit, {gv, f, builder_->CreatePointerCast(cdata.addr, t_void_p_), ConstInt32(nbytes)})); @@ -791,7 +791,7 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const ffi::Array(); - TVM_FFI_ICHECK(ptr) << "Expected first argument of tir::Call to be " + TVM_FFI_ICHECK(ptr) << "Expected first argument of tirx::Call to be " << "a string containing the callee's name, " << "but instead contained " << args[0]; return ptr->value; @@ -836,7 +836,7 @@ CodeGenCPU::PackedCall CodeGenCPU::MakeCallPackedLowered(const ffi::ArrayCreateInBoundsGEP(t_tvm_ffi_any_, result, {ConstInt32(0), ConstInt32(2)}); @@ -1127,9 +1127,9 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { EmitDebugLocation(op); - if (op->attr_key == tir::attr::compute_scope) { + if (op->attr_key == tirx::attr::compute_scope) { this->CreateComputeScope(op); - } else if (tir::attr::IsPragmaKey(op->attr_key)) { + } else if (tirx::attr::IsPragmaKey(op->attr_key)) { if (op->attr_key == "pragma_parallel_stride_pattern") { TVM_FFI_ICHECK(parallel_env_.penv != nullptr) << "Pragma parallel_stride_pattern only valid in parallel launch"; @@ -1147,7 +1147,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { auto bar_callee = llvm::FunctionCallee(ftype_tvm_parallel_barrier_, RuntimeTVMParallelBarrier()); builder_->CreateCall(bar_callee, {MakeValue(parallel_env_.task_id), parallel_env_.penv}); - } else if (op->attr_key == tir::attr::pragma_import_llvm) { + } else if (op->attr_key == tirx::attr::pragma_import_llvm) { const StringImmNode* value = op->value.as(); TVM_FFI_ICHECK(value != nullptr); this->HandleImport(value->value); diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 4babf4b733de..bb1227b4641b 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -44,7 +44,7 @@ #include #include #include -#include +#include #include #include @@ -476,7 +476,7 @@ ffi::Module BuildHexagon(IRModule mod, Target target) { continue; } auto f = Downcast(kv.second); - if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + if (f->HasNonzeroAttr(tirx::attr::kIsEntryFunc)) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); TVM_FFI_ICHECK(global_symbol.has_value()); entry_func = global_symbol.value(); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 9c8d021f2a36..3fcaaaf46d68 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -79,7 +79,7 @@ #include #include #include -#include +#include #include #include @@ -246,7 +246,7 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const GlobalVar& gvar, cons << "Cannot codegen function with buffer_map, please lower them first"; std::vector param_types; - is_restricted_ = func->HasNonzeroAttr(tir::attr::kNoAlias); + is_restricted_ = func->HasNonzeroAttr(tirx::attr::kNoAlias); for (Var param : func->params) { param_types.push_back(GetLLVMType(param)); if (!is_restricted_ && param.dtype().is_handle()) { @@ -645,7 +645,7 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer_va if (arith::ramp(pbase, pstride, planes).Match(index)) { base = pbase.Eval()->value; xwith = planes.Eval()->value * pstride.Eval()->value; - } else if (auto* ptr = index.as()) { + } else if (auto* ptr = index.as()) { base = ptr->value; xwith = 1; } @@ -1395,9 +1395,9 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return value; } else if (op->op.same_as(builtin::ret())) { auto const* val = op->args[0].as(); - TVM_FFI_ICHECK(val) << "the tir.ret should be transformed to return zero " + TVM_FFI_ICHECK(val) << "the tirx.ret should be transformed to return zero " << "before the llvm code generation."; - TVM_FFI_ICHECK_EQ(val->value, 0) << "the tir.ret should be transformed to " + TVM_FFI_ICHECK_EQ(val->value, 0) << "the tirx.ret should be transformed to " << "return zero before the llvm code generation."; builder_->CreateRet(ConstInt32(0)); // LLVM allows exactly one terminator in a single basic block @@ -1408,7 +1408,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return ret_dummy; } else if (op->op.same_as(builtin::continue_loop())) { TVM_FFI_ICHECK(!loop_frame_jump_tgts_.empty()) - << "the tir.continue_loop should be inserted under at least one For or While stmts."; + << "the tirx.continue_loop should be inserted under at least one For or While stmts."; builder_->CreateBr(loop_frame_jump_tgts_.back().first); // LLVM allows exactly one terminator in a single basic block // append a new dummy basic block to avoid error. @@ -1418,7 +1418,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { return post_dummy; } else if (op->op.same_as(builtin::break_loop())) { TVM_FFI_ICHECK(!loop_frame_jump_tgts_.empty()) - << "the tir.break_loop should be inserted under at least one For or While stmts."; + << "the tirx.break_loop should be inserted under at least one For or While stmts."; builder_->CreateBr(loop_frame_jump_tgts_.back().second); // LLVM allows exactly one terminator in a single basic block // append a new dummy basic block to avoid error. @@ -2015,14 +2015,14 @@ void CodeGenLLVM::VisitStmt_(const AllocBufferNode* op) { TVM_FFI_ICHECK(!var_map_.count(op->buffer->data.get())); var_map_[op->buffer->data.get()] = buf; - if (op->annotations.count(tir::attr::kVolatile)) { + if (op->annotations.count(tirx::attr::kVolatile)) { volatile_buf_.insert(op->buffer->data.get()); } } void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { EmitDebugLocation(op); - if (op->attr_key == tir::attr::thread_extent) { + if (op->attr_key == tirx::attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { if (!var_map_.count(iv->var.get())) { @@ -2030,7 +2030,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value)); } } - } else if (op->attr_key == tir::attr::storage_alignment) { + } else if (op->attr_key == tirx::attr::storage_alignment) { const VarNode* v = op->node.as(); TVM_FFI_ICHECK(v); alloc_storage_info_[v].alignment = static_cast(op->value.as()->value); @@ -2062,7 +2062,7 @@ void CodeGenLLVM::VisitStmt_(const BindNode* op) { // TIR has type-annotations on variables, but not on each PrimExpr. // Therefore, to have the correct LLVM type for pointers, we may // need to introduce a pointer-cast, even though pointer-to-pointer - // casts are not expressible with the `tir::CastNode`. + // casts are not expressible with the `tirx::CastNode`. if (v->dtype.is_handle() && v->type_annotation.defined()) { TVM_FFI_ICHECK(op->value->dtype.is_handle()) << "Variable " << op->var << " is a pointer with type " << op->value diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 3fdbbec86fa9..b57a1a446bcf 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -43,13 +43,13 @@ #include #include #include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include #include #include @@ -61,7 +61,7 @@ #include #include "../../runtime/thread_storage_scope.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" #include "codegen_params.h" #include "llvm_instance.h" @@ -86,7 +86,7 @@ class MDBuilder; namespace tvm { namespace codegen { -using namespace tir; +using namespace tirx; /*! * \brief A base class to generate a LLVM. diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 545d93caa27b..191874bd5739 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -124,7 +124,7 @@ class CodeGenNVPTX : public CodeGenLLVM { buf, llvmGetPointerTo(DTypeToLLVMType(dtype), buf->getType()->getPointerAddressSpace())); TVM_FFI_ICHECK(!var_map_.count(op->buffer->data.get())); var_map_[op->buffer->data.get()] = buf; - if (op->annotations.count(tir::attr::kVolatile)) { + if (op->annotations.count(tirx::attr::kVolatile)) { volatile_buf_.insert(op->buffer->data.get()); } } diff --git a/src/target/llvm/codegen_x86_64.cc b/src/target/llvm/codegen_x86_64.cc index dd54c511fb0d..490f191e4543 100644 --- a/src/target/llvm/codegen_x86_64.cc +++ b/src/target/llvm/codegen_x86_64.cc @@ -65,9 +65,9 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) { llvm::Intrinsic::x86_avx512_mask_vcvtph2ps_512, 16, DTypeToLLVMType(DataType::Float(32, from.lanes())), { - MakeValue(tir::Call(DataType::Int(16, from.lanes()), tir::builtin::reinterpret(), - {op->value})), - MakeValue(tir::Broadcast(FloatImm(DataType::Float(32), 0), from.lanes())), + MakeValue(tirx::Call(DataType::Int(16, from.lanes()), tirx::builtin::reinterpret(), + {op->value})), + MakeValue(tirx::Broadcast(FloatImm(DataType::Float(32), 0), from.lanes())), /*mask=*/MakeValue(IntImm(DataType::Int(16), -1)), /*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)), }); diff --git a/src/target/llvm/intrin_rule_hexagon.cc b/src/target/llvm/intrin_rule_hexagon.cc index dc71e69a2122..79e91c20a3b8 100644 --- a/src/target/llvm/intrin_rule_hexagon.cc +++ b/src/target/llvm/intrin_rule_hexagon.cc @@ -20,14 +20,14 @@ #ifdef TVM_LLVM_VERSION #include -#include -#include +#include +#include #include "intrin_rule_llvm.h" #define TVM_REGISTER_QHL_OP_FP16(INTRIN_FUNC, WRAPPER_FUNC, NUM_SIGN) \ std::string tvm_qhl_ahf_##INTRIN_FUNC = WRAPPER_FUNC; \ - TVM_REGISTER_OP("tir." #INTRIN_FUNC) \ + TVM_REGISTER_OP("tirx." #INTRIN_FUNC) \ .set_attr( \ "hexagon.FLowerIntrinsic", \ DispatchTVMQHLWrapperFp16 new_args = {tir::StringImm(fname)}; +inline PrimExpr TVMExternCall(const tirx::CallNode* call, const std::string& fname) { + ffi::Array new_args = {tirx::StringImm(fname)}; for (PrimExpr arg : call->args) { new_args.push_back(arg); } - return tir::Call(call->dtype, tir::builtin::call_pure_extern(), new_args); + return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(), new_args); } template inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { - using namespace tir; + using namespace tirx; const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); ffi::Array new_args; @@ -72,33 +72,35 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { new_args.push_back(IntImm(DataType::UInt(32), id)); new_args.push_back(IntImm(DataType::UInt(32), num_sign)); new_args.insert(new_args.end(), call->args.begin(), call->args.end()); - return tir::Call(call->dtype, tir::builtin::call_llvm_pure_intrin(), new_args); + return tirx::Call(call->dtype, tirx::builtin::call_llvm_pure_intrin(), new_args); } -TVM_REGISTER_OP("tir.fma").set_attr( - "hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); +TVM_REGISTER_OP("tirx.fma") + .set_attr("hexagon.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); -TVM_REGISTER_OP("tir.log").set_attr( - "hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); +TVM_REGISTER_OP("tirx.log") + .set_attr("hexagon.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); -TVM_REGISTER_OP("tir.trunc") +TVM_REGISTER_OP("tirx.trunc") .set_attr("hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); -TVM_REGISTER_OP("tir.fabs") +TVM_REGISTER_OP("tirx.fabs") .set_attr("hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); -TVM_REGISTER_OP("tir.round") +TVM_REGISTER_OP("tirx.round") .set_attr("hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); -TVM_REGISTER_OP("tir.ctpop") +TVM_REGISTER_OP("tirx.ctpop") .set_attr("hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); -TVM_REGISTER_OP("tir.tanh") +TVM_REGISTER_OP("tirx.tanh") .set_attr("hexagon.FLowerIntrinsic", [](const PrimExpr& e) { - const tir::CallNode* call = e.as(); + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; @@ -119,22 +121,22 @@ TVM_REGISTER_OP("tir.tanh") return TVMExternCall(call, tvm_wrapper); } #endif - PrimExpr one = tir::make_const(x.dtype(), 1); - PrimExpr two = tir::make_const(x.dtype(), 2); - PrimExpr neg_two = tir::make_const(x.dtype(), -2); + PrimExpr one = tirx::make_const(x.dtype(), 1); + PrimExpr two = tirx::make_const(x.dtype(), 2); + PrimExpr neg_two = tirx::make_const(x.dtype(), -2); PrimExpr exp_neg2x = exp(neg_two * x); PrimExpr exp_pos2x = exp(two * x); PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x); PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one); - PrimExpr tanh_x = tir::Select(x >= tir::make_zero(x.dtype()), tanh_pos, tanh_neg); + PrimExpr tanh_x = tirx::Select(x >= tirx::make_zero(x.dtype()), tanh_pos, tanh_neg); return tanh_x; }); -TVM_REGISTER_OP("tir.tan").set_attr( - "hexagon.FLowerIntrinsic", [](const PrimExpr& e) { - const tir::CallNode* call = e.as(); +TVM_REGISTER_OP("tirx.tan") + .set_attr("hexagon.FLowerIntrinsic", [](const PrimExpr& e) { + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; #if ENABLE_QHL @@ -158,13 +160,13 @@ TVM_REGISTER_OP("tir.tan").set_attr( return tan_x; }); -TVM_REGISTER_OP("tir.nearbyint") +TVM_REGISTER_OP("tirx.nearbyint") .set_attr("hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); -TVM_REGISTER_OP("tir.sigmoid") +TVM_REGISTER_OP("tirx.sigmoid") .set_attr("hexagon.FLowerIntrinsic", [](const PrimExpr& e) { - const tir::CallNode* call = e.as(); + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; #if ENABLE_QHL @@ -178,13 +180,13 @@ TVM_REGISTER_OP("tir.sigmoid") useqhl = tstring.find("+hvx-qfloat") != std::string::npos; } - PrimExpr MinBound = tir::make_const(x.dtype(), -8); - PrimExpr MaxBound = tir::make_const(x.dtype(), 8); - const PrimExpr v1 = tir::Max(x, MinBound); - const PrimExpr v2 = tir::Min(v1, MaxBound); + PrimExpr MinBound = tirx::make_const(x.dtype(), -8); + PrimExpr MaxBound = tirx::make_const(x.dtype(), 8); + const PrimExpr v1 = tirx::Max(x, MinBound); + const PrimExpr v2 = tirx::Min(v1, MaxBound); ffi::Array new_args = {v2}; - const tir::Call new_call = tir::Call(call->dtype, call->op, new_args); + const tirx::Call new_call = tirx::Call(call->dtype, call->op, new_args); // Enable QHL library for FP16 data type if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { @@ -192,7 +194,7 @@ TVM_REGISTER_OP("tir.sigmoid") return TVMExternCall(new_call.get(), tvm_wrapper); } #endif - PrimExpr one = tir::make_const(x.dtype(), 1); + PrimExpr one = tirx::make_const(x.dtype(), 1); return one / (one + exp(-x)); }); diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 4406a5949052..468f0fb7b59f 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -27,8 +27,8 @@ #include #define _USE_MATH_DEFINES #include -#include -#include +#include +#include #include @@ -38,87 +38,93 @@ namespace tvm { namespace codegen { namespace llvm { namespace intrin { -using tir::FLowerIntrinsic; +using tirx::FLowerIntrinsic; -TVM_REGISTER_OP("tir.prefetch") +TVM_REGISTER_OP("tirx.prefetch") .set_attr("llvm.FLowerIntrinsic", DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>); -TVM_REGISTER_OP("tir.exp").set_attr( - "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); +TVM_REGISTER_OP("tirx.exp") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); -TVM_REGISTER_OP("tir.exp2") +TVM_REGISTER_OP("tirx.exp2") .set_attr("llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); -TVM_REGISTER_OP("tir.fma").set_attr( - "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); +TVM_REGISTER_OP("tirx.fma") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); -TVM_REGISTER_OP("tir.log").set_attr( - "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); +TVM_REGISTER_OP("tirx.log") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); -TVM_REGISTER_OP("tir.log2") +TVM_REGISTER_OP("tirx.log2") .set_attr("llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>); -TVM_REGISTER_OP("tir.log10") +TVM_REGISTER_OP("tirx.log10") .set_attr("llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>); -TVM_REGISTER_OP("tir.sqrt") +TVM_REGISTER_OP("tirx.sqrt") .set_attr("llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); -TVM_REGISTER_OP("tir.floor") +TVM_REGISTER_OP("tirx.floor") .set_attr("llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); -TVM_REGISTER_OP("tir.ceil") +TVM_REGISTER_OP("tirx.ceil") .set_attr("llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); -TVM_REGISTER_OP("tir.trunc") +TVM_REGISTER_OP("tirx.trunc") .set_attr("llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); -TVM_REGISTER_OP("tir.fabs") +TVM_REGISTER_OP("tirx.fabs") .set_attr("llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); -TVM_REGISTER_OP("tir.round") +TVM_REGISTER_OP("tirx.round") .set_attr("llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); -TVM_REGISTER_OP("tir.nearbyint") +TVM_REGISTER_OP("tirx.nearbyint") .set_attr("llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); -TVM_REGISTER_OP("tir.pow").set_attr( - "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); +TVM_REGISTER_OP("tirx.pow") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); -TVM_REGISTER_OP("tir.popcount") +TVM_REGISTER_OP("tirx.popcount") .set_attr("llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); -TVM_REGISTER_OP("tir.cos").set_attr( - "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); +TVM_REGISTER_OP("tirx.cos") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); -TVM_REGISTER_OP("tir.sin").set_attr( - "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); +TVM_REGISTER_OP("tirx.sin") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); -TVM_REGISTER_OP("tir.tanh") +TVM_REGISTER_OP("tirx.tanh") .set_attr("llvm.FLowerIntrinsic", ::tvm::codegen::intrin::DispatchNumericalStableTanh); } // namespace intrin namespace legalize { -using tir::FLegalize; +using tirx::FLegalize; -TVM_REGISTER_OP("tir.exp10") +TVM_REGISTER_OP("tirx.exp10") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tir::make_const; - using tir::make_zero; - const tir::CallNode* call = e.as(); + using tirx::make_const; + using tirx::make_zero; + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; PrimExpr ln10 = make_const(x.dtype(), 2.302585093); @@ -126,19 +132,20 @@ TVM_REGISTER_OP("tir.exp10") return ret; }); -TVM_REGISTER_OP("tir.tan").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - const tir::CallNode* call = e.as(); - TVM_FFI_ICHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr tan_x = sin(x) / cos(x); - return tan_x; -}); +TVM_REGISTER_OP("tirx.tan") + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + const tirx::CallNode* call = e.as(); + TVM_FFI_ICHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr tan_x = sin(x) / cos(x); + return tan_x; + }); -TVM_REGISTER_OP("tir.cosh") +TVM_REGISTER_OP("tirx.cosh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tir::make_const; - using tir::make_zero; - const tir::CallNode* call = e.as(); + using tirx::make_const; + using tirx::make_zero; + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; PrimExpr two = make_const(x.dtype(), 2); @@ -149,11 +156,11 @@ TVM_REGISTER_OP("tir.cosh") return ret; }); -TVM_REGISTER_OP("tir.sinh") +TVM_REGISTER_OP("tirx.sinh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tir::make_const; - using tir::make_zero; - const tir::CallNode* call = e.as(); + using tirx::make_const; + using tirx::make_zero; + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; PrimExpr two = make_const(x.dtype(), 2); @@ -164,11 +171,11 @@ TVM_REGISTER_OP("tir.sinh") return ret; }); -TVM_REGISTER_OP("tir.asin") +TVM_REGISTER_OP("tirx.asin") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tir::make_const; + using tirx::make_const; using namespace intrin; - const tir::CallNode* call = e.as(); + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); const PrimExpr& x = call->args[0]; @@ -190,17 +197,17 @@ TVM_REGISTER_OP("tir.asin") PrimExpr lower = make_const(x.dtype(), -1.0); PrimExpr upper = make_const(x.dtype(), 1.0); - PrimExpr out_range = tir::Or(x upper); + PrimExpr out_range = tirx::Or(x upper); PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits::quiet_NaN()); - return tir::Select(out_range, nan_const, tir::Select(use_lib, lib_result, series)); + return tirx::Select(out_range, nan_const, tirx::Select(use_lib, lib_result, series)); }); -TVM_REGISTER_OP("tir.acos") +TVM_REGISTER_OP("tirx.acos") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tir::make_const; + using tirx::make_const; using namespace intrin; - const tir::CallNode* call = e.as(); + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in acos legalization"; const PrimExpr& x = call->args[0]; @@ -217,16 +224,16 @@ TVM_REGISTER_OP("tir.acos") PrimExpr lower = make_const(x.dtype(), -1.0); PrimExpr upper = make_const(x.dtype(), 1.0); - PrimExpr out_range = tir::Or(x upper); + PrimExpr out_range = tirx::Or(x upper); PrimExpr nan_const = make_const(x.dtype(), std::numeric_limits::quiet_NaN()); - return tir::Select(out_range, nan_const, tir::Select(use_lib, lib_result, formula_result)); + return tirx::Select(out_range, nan_const, tirx::Select(use_lib, lib_result, formula_result)); }); -TVM_REGISTER_OP("tir.atan") +TVM_REGISTER_OP("tirx.atan") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tir::make_const; - const tir::CallNode* call = e.as(); + using tirx::make_const; + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in atan legalization"; const PrimExpr& x = call->args[0]; PrimExpr one = make_const(x.dtype(), 1.0); @@ -234,10 +241,10 @@ TVM_REGISTER_OP("tir.atan") return asin(x / denom); }); -TVM_REGISTER_OP("tir.asinh") +TVM_REGISTER_OP("tirx.asinh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tir::make_const; - const tir::CallNode* call = e.as(); + using tirx::make_const; + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in asinh legalization"; const PrimExpr& x = call->args[0]; PrimExpr one = make_const(x.dtype(), 1.0); @@ -245,10 +252,10 @@ TVM_REGISTER_OP("tir.asinh") return log(x + sqrt_val); }); -TVM_REGISTER_OP("tir.acosh") +TVM_REGISTER_OP("tirx.acosh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tir::make_const; - const tir::CallNode* call = e.as(); + using tirx::make_const; + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in acosh legalization"; const PrimExpr& x = call->args[0]; PrimExpr one = make_const(x.dtype(), 1.0); @@ -256,46 +263,48 @@ TVM_REGISTER_OP("tir.acosh") return log(x + sqrt_val); }); -TVM_REGISTER_OP("tir.atanh") +TVM_REGISTER_OP("tirx.atanh") .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tir::make_const; - const tir::CallNode* call = e.as(); + using tirx::make_const; + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in atanh legalization"; const PrimExpr& x = call->args[0]; PrimExpr one = make_const(x.dtype(), 1.0); return (log(one + x) - log(one - x)) * make_const(x.dtype(), 0.5); }); -TVM_REGISTER_OP("tir.erf").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - using tir::make_const; - const tir::CallNode* call = e.as(); - TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in erf legalization"; - const PrimExpr& x = call->args[0]; - PrimExpr abs_x = tvm::abs(x); - PrimExpr t = make_const(x.dtype(), 1.0) / - (make_const(x.dtype(), 1.0) + make_const(x.dtype(), 0.3275911) * abs_x); - PrimExpr a1 = make_const(x.dtype(), 0.254829592); - PrimExpr a2 = make_const(x.dtype(), -0.284496736); - PrimExpr a3 = make_const(x.dtype(), 1.421413741); - PrimExpr a4 = make_const(x.dtype(), -1.453152027); - PrimExpr a5 = make_const(x.dtype(), 1.061405429); - PrimExpr poly = (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t); - PrimExpr approx = make_const(x.dtype(), 1.0) - poly * exp(-abs_x * abs_x); - return tvm::tir::Select(x < 0, -approx, approx); -}); - -TVM_REGISTER_OP("tir.clz").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { - const tir::CallNode* call = e.as(); - TVM_FFI_ICHECK(call != nullptr); - TVM_FFI_ICHECK_EQ(call->args.size(), 1); - ffi::Array cargs; - cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz)); - cargs.push_back(call->args[0]); - cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef - // LLVM requires that the return type must match the first argument type - auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs); - return cast(call->dtype, clz); -}); +TVM_REGISTER_OP("tirx.erf") + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + using tirx::make_const; + const tirx::CallNode* call = e.as(); + TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in erf legalization"; + const PrimExpr& x = call->args[0]; + PrimExpr abs_x = tvm::abs(x); + PrimExpr t = make_const(x.dtype(), 1.0) / + (make_const(x.dtype(), 1.0) + make_const(x.dtype(), 0.3275911) * abs_x); + PrimExpr a1 = make_const(x.dtype(), 0.254829592); + PrimExpr a2 = make_const(x.dtype(), -0.284496736); + PrimExpr a3 = make_const(x.dtype(), 1.421413741); + PrimExpr a4 = make_const(x.dtype(), -1.453152027); + PrimExpr a5 = make_const(x.dtype(), 1.061405429); + PrimExpr poly = (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t); + PrimExpr approx = make_const(x.dtype(), 1.0) - poly * exp(-abs_x * abs_x); + return tvm::tirx::Select(x < 0, -approx, approx); + }); + +TVM_REGISTER_OP("tirx.clz") + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + const tirx::CallNode* call = e.as(); + TVM_FFI_ICHECK(call != nullptr); + TVM_FFI_ICHECK_EQ(call->args.size(), 1); + ffi::Array cargs; + cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz)); + cargs.push_back(call->args[0]); + cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef + // LLVM requires that the return type must match the first argument type + auto clz = tirx::Call(call->args[0]->dtype, tirx::builtin::call_llvm_intrin(), cargs); + return cast(call->dtype, clz); + }); } // namespace legalize } // namespace llvm diff --git a/src/target/llvm/intrin_rule_llvm.h b/src/target/llvm/intrin_rule_llvm.h index f1bed6378060..b70d2b8001e0 100644 --- a/src/target/llvm/intrin_rule_llvm.h +++ b/src/target/llvm/intrin_rule_llvm.h @@ -29,8 +29,8 @@ #include #include #include -#include -#include +#include +#include #include "llvm_instance.h" @@ -39,7 +39,7 @@ namespace codegen { // num_signature means number of arguments used to query signature template inline PrimExpr DispatchLLVMPureIntrin(const PrimExpr& e) { - const tir::CallNode* call = e.as(); + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); ffi::Array cargs; // intrin id. @@ -51,12 +51,12 @@ inline PrimExpr DispatchLLVMPureIntrin(const PrimExpr& e) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - return tir::Call(call->dtype, tir::builtin::call_llvm_pure_intrin(), cargs); + return tirx::Call(call->dtype, tirx::builtin::call_llvm_pure_intrin(), cargs); } template inline PrimExpr DispatchLLVMIntrin(const PrimExpr& e) { - const tir::CallNode* call = e.as(); + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); ffi::Array cargs; // intrin id. @@ -67,7 +67,7 @@ inline PrimExpr DispatchLLVMIntrin(const PrimExpr& e) { for (PrimExpr arg : call->args) { cargs.push_back(arg); } - return tir::Call(call->dtype, tir::builtin::call_llvm_intrin(), cargs); + return tirx::Call(call->dtype, tirx::builtin::call_llvm_intrin(), cargs); } } // namespace codegen diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index 42f1352e36f3..4560205a6094 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -23,10 +23,10 @@ #ifdef TVM_LLVM_VERSION #include -#include -#include -#include -#include +#include +#include +#include +#include #include @@ -34,7 +34,7 @@ namespace tvm { namespace codegen { inline PrimExpr DispatchPureExternLibDevice(const PrimExpr& e) { - using namespace tir; + using namespace tirx; const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) @@ -43,10 +43,10 @@ inline PrimExpr DispatchPureExternLibDevice(const PrimExpr& e) { const OpNode* op = call->op.as(); TVM_FFI_ICHECK(op != nullptr); std::string name = op->name; - TVM_FFI_ICHECK_EQ(name.substr(0, 4), "tir."); + TVM_FFI_ICHECK_EQ(name.substr(0, 5), "tirx."); std::ostringstream intrinsic_name; - intrinsic_name << "__nv_" << name.substr(4); + intrinsic_name << "__nv_" << name.substr(5); if (call->dtype.bits() == 32) intrinsic_name << "f"; ffi::Array new_args = {StringImm(intrinsic_name.str())}; @@ -57,75 +57,75 @@ inline PrimExpr DispatchPureExternLibDevice(const PrimExpr& e) { } namespace llvm { -using tir::FLowerIntrinsic; +using tirx::FLowerIntrinsic; -TVM_REGISTER_OP("tir.floor") +TVM_REGISTER_OP("tirx.floor") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.ceil") +TVM_REGISTER_OP("tirx.ceil") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.round") +TVM_REGISTER_OP("tirx.round") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.nearbyint") +TVM_REGISTER_OP("tirx.nearbyint") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.trunc") +TVM_REGISTER_OP("tirx.trunc") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.fabs") +TVM_REGISTER_OP("tirx.fabs") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.exp").set_attr("nvptx.FLowerIntrinsic", - DispatchPureExternLibDevice); +TVM_REGISTER_OP("tirx.exp") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.exp2") +TVM_REGISTER_OP("tirx.exp2") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.exp10") +TVM_REGISTER_OP("tirx.exp10") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.erf").set_attr("nvptx.FLowerIntrinsic", - DispatchPureExternLibDevice); +TVM_REGISTER_OP("tirx.erf") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.fma").set_attr("nvptx.FLowerIntrinsic", - DispatchPureExternLibDevice); +TVM_REGISTER_OP("tirx.fma") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.log").set_attr("nvptx.FLowerIntrinsic", - DispatchPureExternLibDevice); +TVM_REGISTER_OP("tirx.log") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.log2") +TVM_REGISTER_OP("tirx.log2") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.log10") +TVM_REGISTER_OP("tirx.log10") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.sqrt") +TVM_REGISTER_OP("tirx.sqrt") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.pow").set_attr("nvptx.FLowerIntrinsic", - DispatchPureExternLibDevice); +TVM_REGISTER_OP("tirx.pow") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.tanh") +TVM_REGISTER_OP("tirx.tanh") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.tan").set_attr("nvptx.FLowerIntrinsic", - DispatchPureExternLibDevice); +TVM_REGISTER_OP("tirx.tan") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.cos").set_attr("nvptx.FLowerIntrinsic", - DispatchPureExternLibDevice); +TVM_REGISTER_OP("tirx.cos") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.cosh") +TVM_REGISTER_OP("tirx.cosh") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.sin").set_attr("nvptx.FLowerIntrinsic", - DispatchPureExternLibDevice); +TVM_REGISTER_OP("tirx.sin") + .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.sinh") +TVM_REGISTER_OP("tirx.sinh") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); -TVM_REGISTER_OP("tir.atan") +TVM_REGISTER_OP("tirx.atan") .set_attr("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice); } // namespace llvm diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index eec2cf2d1dc0..6d72c777834c 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -24,10 +24,10 @@ #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include @@ -40,17 +40,17 @@ namespace codegen { inline PrimExpr DispatchPureExternOCML(const PrimExpr& e) { // NOTE: OCML dispatch fails to work properly with vectorization, and thus should be used with // extreme caution. - using namespace tir; + using namespace tirx; const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); const OpNode* op = call->op.as(); TVM_FFI_ICHECK(op != nullptr); std::string name = op->name; - TVM_FFI_ICHECK_EQ(name.substr(0, 4), "tir."); + TVM_FFI_ICHECK_EQ(name.substr(0, 5), "tirx."); std::ostringstream intrinsic_name; - intrinsic_name << "__ocml_" << name.substr(4) << "_f" << call->dtype.bits(); + intrinsic_name << "__ocml_" << name.substr(5) << "_f" << call->dtype.bits(); ffi::Array new_args = {StringImm(intrinsic_name.str())}; for (auto arg : call->args) { @@ -61,7 +61,7 @@ inline PrimExpr DispatchPureExternOCML(const PrimExpr& e) { } inline PrimExpr DispatchShuffle(const PrimExpr& e) { - using namespace tir; + using namespace tirx; const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size @@ -69,8 +69,8 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { TVM_FFI_ICHECK_EQ(var.dtype().bits(), 32); // get own lane in self (__lane_id) - PrimExpr minus_one = tir::make_const(DataType::Int(32), -1); - PrimExpr zero = tir::make_zero(DataType::Int(32)); + PrimExpr minus_one = tirx::make_const(DataType::Int(32), -1); + PrimExpr zero = tirx::make_zero(DataType::Int(32)); PrimExpr lo = Call(DataType::Int(32), builtin::call_pure_extern(), {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero}); PrimExpr self = Call(DataType::Int(32), builtin::call_pure_extern(), @@ -104,102 +104,108 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { } namespace llvm { -using tir::FLowerIntrinsic; +using tirx::FLowerIntrinsic; // dummy because we don't have the activemask -TVM_REGISTER_OP("tir.tvm_warp_activemask") +TVM_REGISTER_OP("tirx.tvm_warp_activemask") .set_attr("rocm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { - PrimExpr zero = tir::make_zero(DataType::Int(32)); + PrimExpr zero = tirx::make_zero(DataType::Int(32)); return zero; }); -TVM_REGISTER_OP("tir.tvm_warp_shuffle") +TVM_REGISTER_OP("tirx.tvm_warp_shuffle") .set_attr("rocm.FLowerIntrinsic", DispatchShuffle); -TVM_REGISTER_OP("tir.tvm_warp_shuffle_up") +TVM_REGISTER_OP("tirx.tvm_warp_shuffle_up") .set_attr("rocm.FLowerIntrinsic", DispatchShuffle); -TVM_REGISTER_OP("tir.tvm_warp_shuffle_down") +TVM_REGISTER_OP("tirx.tvm_warp_shuffle_down") .set_attr("rocm.FLowerIntrinsic", DispatchShuffle); -TVM_REGISTER_OP("tir.floor") +TVM_REGISTER_OP("tirx.floor") .set_attr("rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>); -TVM_REGISTER_OP("tir.ceil") +TVM_REGISTER_OP("tirx.ceil") .set_attr("rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>); -TVM_REGISTER_OP("tir.round") +TVM_REGISTER_OP("tirx.round") .set_attr("rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>); -TVM_REGISTER_OP("tir.nearbyint") +TVM_REGISTER_OP("tirx.nearbyint") .set_attr("rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); -TVM_REGISTER_OP("tir.trunc") +TVM_REGISTER_OP("tirx.trunc") .set_attr("rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>); -TVM_REGISTER_OP("tir.fabs") +TVM_REGISTER_OP("tirx.fabs") .set_attr("rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>); -TVM_REGISTER_OP("tir.exp").set_attr( - "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); +TVM_REGISTER_OP("tirx.exp") + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>); -TVM_REGISTER_OP("tir.exp2") +TVM_REGISTER_OP("tirx.exp2") .set_attr("rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); -TVM_REGISTER_OP("tir.fma").set_attr( - "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); +TVM_REGISTER_OP("tirx.fma") + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); -TVM_REGISTER_OP("tir.log").set_attr( - "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); +TVM_REGISTER_OP("tirx.log") + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>); -TVM_REGISTER_OP("tir.log2") +TVM_REGISTER_OP("tirx.log2") .set_attr("rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>); -TVM_REGISTER_OP("tir.log10") +TVM_REGISTER_OP("tirx.log10") .set_attr("rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>); -TVM_REGISTER_OP("tir.sqrt") +TVM_REGISTER_OP("tirx.sqrt") .set_attr("rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>); -TVM_REGISTER_OP("tir.pow").set_attr( - "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); +TVM_REGISTER_OP("tirx.pow") + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); -TVM_REGISTER_OP("tir.cos").set_attr( - "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); +TVM_REGISTER_OP("tirx.cos") + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); -TVM_REGISTER_OP("tir.sin").set_attr( - "rocm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); +TVM_REGISTER_OP("tirx.sin") + .set_attr("rocm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); -TVM_REGISTER_OP("tir.tanh") +TVM_REGISTER_OP("tirx.tanh") .set_attr("rocm.FLowerIntrinsic", ::tvm::codegen::intrin::DispatchNumericalStableTanh); -TVM_REGISTER_OP("tir.erf").set_attr("rocm.FLowerIntrinsic", - ::tvm::codegen::intrin::DispatchFastErf); +TVM_REGISTER_OP("tirx.erf") + .set_attr("rocm.FLowerIntrinsic", ::tvm::codegen::intrin::DispatchFastErf); -// TVM_REGISTER_OP("tir.tan").set_attr("rocm.FLowerIntrinsic", +// TVM_REGISTER_OP("tirx.tan").set_attr("rocm.FLowerIntrinsic", // DispatchPureExternOCML); -// TVM_REGISTER_OP("tir.cosh") +// TVM_REGISTER_OP("tirx.cosh") // .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -// TVM_REGISTER_OP("tir.sinh") +// TVM_REGISTER_OP("tirx.sinh") // .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -// TVM_REGISTER_OP("tir.atan") +// TVM_REGISTER_OP("tirx.atan") // .set_attr("rocm.FLowerIntrinsic", DispatchPureExternOCML); -// TVM_REGISTER_OP("tir.exp10") +// TVM_REGISTER_OP("tirx.exp10") // .set_attr("rocm.FLowerIntrinsic", // DispatchLLVMPureIntrin<::llvm::Intrinsic::exp10, 1>); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index ef10daff6981..36cab28bc6c0 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -310,7 +310,7 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { } auto f = Downcast(kv.second); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - bool is_entry_func = f->HasNonzeroAttr(tir::attr::kIsEntryFunc); + bool is_entry_func = f->HasNonzeroAttr(tirx::attr::kIsEntryFunc); TVM_FFI_ICHECK(global_symbol || !is_entry_func) << "The entry func must be exposed externally."; diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 116aa20bda8f..45bc954c80aa 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -34,7 +34,7 @@ namespace tvm { namespace codegen { -using namespace tir; +using namespace tirx; void CodeGenC::Init(bool output_ssa) { print_ssa_form_ = output_ssa; } @@ -83,7 +83,7 @@ void CodeGenC::PrintFunctionSignature(const ffi::String& function_name, const Pr PrintExtraAttrs(func, os); os << " " << function_name << "("; for (size_t i = 0; i < func->params.size(); ++i) { - tir::Var v = func->params[i]; + tirx::Var v = func->params[i]; if (i > 0) { os << ", "; @@ -105,7 +105,7 @@ void CodeGenC::PrintFunctionSignature(const ffi::String& function_name, const Pr PrintType(GetType(v), os); } - bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias); + bool no_alias = func->HasNonzeroAttr(tirx::attr::kNoAlias); bool is_handle = v.dtype().is_handle(); auto* ptr = v->type_annotation.as(); if (ptr && ptr->element_type.as()) { @@ -1072,20 +1072,20 @@ void CodeGenC::VisitStmt_(const AllocBufferNode* op) { stream << ' ' << vid << '[' << constant_size << "];\n"; RegisterHandleType(op->buffer->data.get(), op->buffer->dtype); - if (op->annotations.count(tir::attr::kVolatile)) { + if (op->annotations.count(tirx::attr::kVolatile)) { MarkVolatile(op->buffer->data.get()); } } void CodeGenC::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == tir::attr::thread_extent) { + if (op->attr_key == tirx::attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { if (!var_idmap_.count(iv->var.get())) { BindThreadIndex(iv); } } - } else if (op->attr_key == tir::attr::pragma_import_c) { + } else if (op->attr_key == tirx::attr::pragma_import_c) { const StringImmNode* value = op->value.as(); TVM_FFI_ICHECK(value != nullptr); decl_stream << value->value; diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 94080c5a1d4d..29c5e420997e 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -26,26 +26,26 @@ #include #include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include #include #include #include #include -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" #include "codegen_source_base.h" namespace tvm { namespace codegen { -using namespace tir; +using namespace tirx; /*! * \brief A base class to generate C code. * diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 8e7fb11b640d..a4e90e897b59 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -74,7 +74,7 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, emit_fwd_func_decl_ = emit_fwd_func_decl; CodeGenC::AddFunction(gvar, func); - if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc) && !has_tvm_ffi_main_func_) { + if (func->HasNonzeroAttr(tirx::attr::kIsEntryFunc) && !has_tvm_ffi_main_func_) { TVM_FFI_ICHECK(global_symbol.has_value()) << "CodeGenCHost: The entry func must have the global_symbol attribute, " << "but function " << gvar << " only has attributes " << func->attrs; diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index feb0f715d847..edeebe7da1cc 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -32,7 +32,7 @@ #include "codegen_c.h" #include "tvm/target/codegen.h" -#include "tvm/tir/expr.h" +#include "tvm/tirx/expr.h" namespace tvm { namespace codegen { diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 402d1665678b..8584f05804fd 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -26,15 +26,15 @@ #include #include #include -#include -#include +#include +#include #include #include #include #include -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" #include "literal/cuda_half_t.h" #include "literal/cuda_int8_t.h" #include "ptx.h" @@ -157,10 +157,10 @@ void CodeGenCUDA::PrintFunctionSignature(const ffi::String& function_name, const CodeGenC::PrintFunctionSignature(function_name, func, os); } -class ThreadIdxExtractor : public tir::StmtVisitor { +class ThreadIdxExtractor : public tirx::StmtVisitor { private: void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tir::attr::thread_extent) { + if (op->attr_key == tirx::attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->var->name_hint == "threadIdx.x" || iv->thread_tag == "threadIdx.x") { threadIdx_x_ext = op->value; @@ -327,8 +327,8 @@ std::string CodeGenCUDA::Finish() { return CodeGenC::Finish(); } -void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) { - if (op->kind == tir::ForKind::kUnrolled) { +void CodeGenCUDA::VisitStmt_(const tirx::ForNode* op) { + if (op->kind == tirx::ForKind::kUnrolled) { PrintIndent(); stream << "#pragma unroll\n"; } @@ -1120,7 +1120,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { // to determine the output location for each 8 element. const auto index_map_func = - tvm::ffi::Function::GetGlobal("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout"); + tvm::ffi::Function::GetGlobal("tirx.index_map.shared_16x16_to_ldmatrix_32x8_layout"); TVM_FFI_ICHECK(index_map_func.has_value()); arith::Analyzer analyzer; @@ -1134,10 +1134,10 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { class LowerFloorDivMod : public ExprMutator { public: PrimExpr VisitExpr_(const FloorDivNode* op) { - return tir::Div(this->VisitExpr(op->a), this->VisitExpr(op->b)); + return tirx::Div(this->VisitExpr(op->a), this->VisitExpr(op->b)); } PrimExpr VisitExpr_(const FloorModNode* op) { - return tir::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b)); + return tirx::Mod(this->VisitExpr(op->a), this->VisitExpr(op->b)); } }; @@ -1300,44 +1300,44 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { if (tgt_dtype.is_float4_e2m1fn()) { // We view the source as an uint16, and then extract bits of two fp4 numbers, // and finally reinterpret the result as fp4x2. - value = tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value}); - tir::Var temp_var("temp_var", DataType::UInt(16)); - value = tir::Let( - temp_var, value, - tir::Cast(DataType::UInt(8), (temp_var & IntImm(DataType::UInt(16), 0xF)) | - ((temp_var >> 4) & IntImm(DataType::UInt(16), 0xF0)))); + value = tirx::Call(DataType::UInt(16), tirx::builtin::reinterpret(), {value}); + tirx::Var temp_var("temp_var", DataType::UInt(16)); + value = tirx::Let(temp_var, value, + tirx::Cast(DataType::UInt(8), + (temp_var & IntImm(DataType::UInt(16), 0xF)) | + ((temp_var >> 4) & IntImm(DataType::UInt(16), 0xF0)))); } else { - value = tir::Cast(DataType::UInt(16), - tir::Call(DataType::UInt(8), tir::builtin::reinterpret(), {value})); - tir::Var temp_var("temp_var", DataType::UInt(16)); - value = tir::Let(temp_var, value, - (temp_var & IntImm(DataType::UInt(16), 0xF)) | - ((temp_var & IntImm(DataType::UInt(16), 0xF0)) << 4)); + value = tirx::Cast(DataType::UInt(16), + tirx::Call(DataType::UInt(8), tirx::builtin::reinterpret(), {value})); + tirx::Var temp_var("temp_var", DataType::UInt(16)); + value = tirx::Let(temp_var, value, + (temp_var & IntImm(DataType::UInt(16), 0xF)) | + ((temp_var & IntImm(DataType::UInt(16), 0xF0)) << 4)); } - os << PrintExpr(tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value})); + os << PrintExpr(tirx::Call(tgt_dtype, tirx::builtin::reinterpret(), {value})); } else if (lanes == 4) { if (tgt_dtype.is_float4_e2m1fn()) { // We view the source as an uint32, and then extract bits of four fp4 numbers, // and finally reinterpret the result as fp4x4. - value = tir::Call(DataType::UInt(32), tir::builtin::reinterpret(), {value}); - tir::Var temp_var("temp_var", DataType::UInt(32)); - value = tir::Let(temp_var, value, - tir::Cast(DataType::UInt(16), - (temp_var & IntImm(DataType::UInt(32), 0xF)) | - ((temp_var >> 4) & IntImm(DataType::UInt(32), 0xF0)) | - ((temp_var >> 8) & IntImm(DataType::UInt(32), 0xF00)) | - ((temp_var >> 12) & IntImm(DataType::UInt(32), 0xF000)))); + value = tirx::Call(DataType::UInt(32), tirx::builtin::reinterpret(), {value}); + tirx::Var temp_var("temp_var", DataType::UInt(32)); + value = tirx::Let(temp_var, value, + tirx::Cast(DataType::UInt(16), + (temp_var & IntImm(DataType::UInt(32), 0xF)) | + ((temp_var >> 4) & IntImm(DataType::UInt(32), 0xF0)) | + ((temp_var >> 8) & IntImm(DataType::UInt(32), 0xF00)) | + ((temp_var >> 12) & IntImm(DataType::UInt(32), 0xF000)))); } else { - value = tir::Cast(DataType::UInt(32), - tir::Call(DataType::UInt(16), tir::builtin::reinterpret(), {value})); - tir::Var temp_var("temp_var", DataType::UInt(32)); - value = tir::Let(temp_var, value, - (temp_var & IntImm(DataType::UInt(32), 0xF)) | - ((temp_var & IntImm(DataType::UInt(32), 0xF0)) << 4) | - ((temp_var & IntImm(DataType::UInt(32), 0xF00)) << 8) | - ((temp_var & IntImm(DataType::UInt(32), 0xF000)) << 12)); + value = tirx::Cast(DataType::UInt(32), + tirx::Call(DataType::UInt(16), tirx::builtin::reinterpret(), {value})); + tirx::Var temp_var("temp_var", DataType::UInt(32)); + value = tirx::Let(temp_var, value, + (temp_var & IntImm(DataType::UInt(32), 0xF)) | + ((temp_var & IntImm(DataType::UInt(32), 0xF0)) << 4) | + ((temp_var & IntImm(DataType::UInt(32), 0xF00)) << 8) | + ((temp_var & IntImm(DataType::UInt(32), 0xF000)) << 12)); } - os << PrintExpr(tir::Call(tgt_dtype, tir::builtin::reinterpret(), {value})); + os << PrintExpr(tirx::Call(tgt_dtype, tirx::builtin::reinterpret(), {value})); } else { TVM_FFI_THROW(InternalError) << "Invalid number of lanes for float4_e2m1fn reinterpret: " << lanes; @@ -1434,7 +1434,7 @@ void CodeGenCUDA::VisitStmt_(const AllocBufferNode* op) { } RegisterHandleType(op->buffer->data.get(), dtype); - if (op->annotations.count(tir::attr::kVolatile)) { + if (op->annotations.count(tirx::attr::kVolatile)) { MarkVolatile(op->buffer->data.get()); } } diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index dcd85c14caf5..4a384ffe16a6 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -25,8 +25,8 @@ #define TVM_TARGET_SOURCE_CODEGEN_CUDA_H_ #include -#include -#include +#include +#include #include #include diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 7c4d3b926bee..6831596c810f 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -23,7 +23,7 @@ #include "codegen_metal.h" #include -#include +#include #include #include @@ -151,7 +151,7 @@ void CodeGenMetal::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { TVM_FFI_ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); int work_dim = 0; auto launch_params = - func->GetAttr>(tir::attr::kKernelLaunchParams).value(); + func->GetAttr>(tirx::attr::kKernelLaunchParams).value(); for (const auto& tag : launch_params) { if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) { runtime::ThreadScope scope = runtime::ThreadScope::Create(tag); @@ -209,7 +209,7 @@ void CodeGenMetal::PrintType(DataType t, std::ostream& os) { // NOLINT(*) } bool fail = false; if (t.is_float()) { - // Need to care about sizes and alignment of half3/float3 because tir representation might not + // Need to care about sizes and alignment of half3/float3 because tirx representation might not // be aware of Metal half3/float3 details and can treat them as just three elements, // while sizes and alignmnents of half3/float3 are one element more (half3-8 bytes/ // float13 - 16bytes). @@ -345,7 +345,7 @@ void CodeGenMetal::VisitStmt_(const AllocBufferNode* op) { } RegisterHandleType(op->buffer->data.get(), dtype); - if (op->annotations.count(tir::attr::kVolatile)) { + if (op->annotations.count(tirx::attr::kVolatile)) { MarkVolatile(op->buffer->data.get()); } } @@ -444,7 +444,7 @@ void CodeGenMetal::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NO ffi::Module BuildMetal(IRModule mod, Target target) { bool output_ssa = false; - mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); + mod = tirx::transform::PointerValueTypeRewrite()(std::move(mod)); std::ostringstream source_maker; std::unordered_map smap; diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 34c040b1b2ba..5d9135ef223b 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -495,12 +495,13 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op, std::ostream& os) { std::string rhs = SSAGetID(ss.str(), op->dtype.with_lanes(data_lanes)); if (auto ramp = op->args.back().as()) { - if (ramp->base.as() && *tir::as_const_int(ramp->base) == 0 && - *tir::as_const_int(ramp->lanes) == data_lanes && *tir::as_const_int(ramp->stride) == 1) { + if (ramp->base.as() && *tirx::as_const_int(ramp->base) == 0 && + *tirx::as_const_int(ramp->lanes) == data_lanes && + *tirx::as_const_int(ramp->stride) == 1) { os << rhs; - } else if (*tir::as_const_int(ramp->stride) == 1) { + } else if (*tirx::as_const_int(ramp->stride) == 1) { os << "(*("; - this->PrintType(op->dtype.with_lanes(*tir::as_const_int(ramp->lanes)), os); + this->PrintType(op->dtype.with_lanes(*tirx::as_const_int(ramp->lanes)), os); os << "*)"; os << "(("; this->PrintType(op->dtype.with_lanes(1), os); diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc index c0f906a34f62..5a07e3c7aa07 100644 --- a/src/target/source/codegen_source_base.cc +++ b/src/target/source/codegen_source_base.cc @@ -52,7 +52,7 @@ std::string CodeGenSourceBase::SSAGetID(std::string src, DataType t) { return e.vid; } -std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) { +std::string CodeGenSourceBase::AllocVarID(const tirx::VarNode* v) { TVM_FFI_ICHECK(!var_idmap_.count(v)) << "Need input to be in SSA form dup " << v->name_hint; std::string key = v->name_hint; std::string vid = name_supply_->FreshName(key); @@ -63,7 +63,7 @@ std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) { return vid; } -std::string CodeGenSourceBase::GetVarID(const tir::VarNode* v) const { +std::string CodeGenSourceBase::GetVarID(const tirx::VarNode* v) const { auto it = var_idmap_.find(v); TVM_FFI_ICHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint; return it->second; diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 4c593c460557..2f05c4ad2c09 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -27,8 +27,8 @@ #include #include -#include -#include +#include +#include #include #include @@ -83,13 +83,13 @@ class CodeGenSourceBase { * \param v The variable. * \return the variable name. */ - std::string AllocVarID(const tir::VarNode* v); + std::string AllocVarID(const tirx::VarNode* v); /*! * \brief Get a variable name. * \param v The variable. * \return the variable name. */ - std::string GetVarID(const tir::VarNode* v) const; + std::string GetVarID(const tirx::VarNode* v) const; /*! * \brief Get the SSA ID corresponds to src * If necessary, generate new assignment @@ -122,7 +122,7 @@ class CodeGenSourceBase { /*! \brief the forward declaration stream */ std::ostringstream fwd_decl_stream; /*! \brief name of each variable */ - std::unordered_map var_idmap_; + std::unordered_map var_idmap_; /*! \brief NameSupply for allocation */ NameSupply name_supply_; diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 6c2e5cffddb2..25a691659e6e 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -26,8 +26,8 @@ #include #include #include -#include -#include +#include +#include #include #include @@ -78,14 +78,14 @@ class WebGPUWorkgroupInfoCollector : public StmtExprVisitor { void VisitStmt_(const AttrStmtNode* op) final { // record workgroup size - if (op->attr_key == tir::attr::thread_extent) { + if (op->attr_key == tirx::attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag); if (ts.rank == 1) { TVM_FFI_ICHECK_GE(ts.dim_index, 0) << "vthread should have been optimized out by here"; TVM_FFI_ICHECK_LT(ts.dim_index, 3); - auto* sizeptr = op->value.as(); + auto* sizeptr = op->value.as(); TVM_FFI_ICHECK(sizeptr) << "CodeGenWebGPU: only allows constant thread group size " << " get " << op->value; info_.workgroup_size[ts.dim_index] = static_cast(sizeptr->value); @@ -238,7 +238,7 @@ runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_re << "var " << val_pod_args << " : " << type_pod_args << ";\n\n"; // setup thread tags and param access in launch param tags; - if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { + if (auto opt = f->GetAttr>(tirx::attr::kKernelLaunchParams)) { for (const auto& thread_tag : opt.value()) { func_launch_param_tags.push_back(thread_tag); } @@ -767,14 +767,14 @@ class WebGPUSourceModuleNode final : public ffi::ModuleObj { // Build logic. //------------------------------------------------- ffi::Module BuildWebGPU(IRModule mod, Target target) { - mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); + mod = tirx::transform::PointerValueTypeRewrite()(std::move(mod)); bool output_ssa = false; bool skip_readonly_decl = false; std::unordered_map smap; ffi::Map fmap; // narrow all i64 to i32 - mod = tir::transform::ForceNarrowIndexToInt32()(std::move(mod)); + mod = tirx::transform::ForceNarrowIndexToInt32()(std::move(mod)); for (auto kv : mod->functions) { CodeGenWebGPU cg(target); diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index 4f041ff96e67..bcd158432bcd 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -21,8 +21,8 @@ * \file intrin_rule_cuda.cc * \brief CUDA intrinsic rules. */ -#include -#include +#include +#include #include "../intrin_rule.h" @@ -30,7 +30,7 @@ namespace tvm { namespace codegen { namespace intrin { // Add float suffix to the intrinsics, CUDA fast math. -using tir::FLowerIntrinsic; +using tirx::FLowerIntrinsic; struct CUDAMath { std::string operator()(DataType t, std::string name) const { @@ -124,19 +124,19 @@ struct CUDAPopcount { struct CUDAWarpIntrinsic { const Op operator()(DataType t, const Op& orig_op) const { if (orig_op.same_as(builtin::tvm_warp_shuffle())) { - return Op::Get("tir.cuda.__shfl_sync"); + return Op::Get("tirx.cuda.__shfl_sync"); } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { - return Op::Get("tir.cuda.__shfl_up_sync"); + return Op::Get("tirx.cuda.__shfl_up_sync"); } else { TVM_FFI_ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); - return Op::Get("tir.cuda.__shfl_down_sync"); + return Op::Get("tirx.cuda.__shfl_down_sync"); } } }; static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr& e) { const CallNode* call = e.as(); - return Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args); + return Call(call->dtype, Op::Get("tirx.cuda.__activemask"), call->args); } template @@ -148,96 +148,97 @@ static PrimExpr DispatchCUDAShuffle(const PrimExpr& e) { return Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args); } -TVM_REGISTER_OP("tir.clz").set_attr( - "cuda.FLowerIntrinsic", DispatchPureExtern); +TVM_REGISTER_OP("tirx.clz") + .set_attr("cuda.FLowerIntrinsic", + DispatchPureExtern); -TVM_REGISTER_OP("tir.floor") +TVM_REGISTER_OP("tirx.floor") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.ceil") +TVM_REGISTER_OP("tirx.ceil") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.trunc") +TVM_REGISTER_OP("tirx.trunc") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.fabs") +TVM_REGISTER_OP("tirx.fabs") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.round") +TVM_REGISTER_OP("tirx.round") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.nearbyint") +TVM_REGISTER_OP("tirx.nearbyint") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.exp").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.exp") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.exp2") +TVM_REGISTER_OP("tirx.exp2") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.exp10") +TVM_REGISTER_OP("tirx.exp10") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.erf").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.erf") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.log").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.log") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.log2") +TVM_REGISTER_OP("tirx.log2") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.log10") +TVM_REGISTER_OP("tirx.log10") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.tan").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.tan") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.cos").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.cos") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.cosh") +TVM_REGISTER_OP("tirx.cosh") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.sin").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.sin") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.sinh") +TVM_REGISTER_OP("tirx.sinh") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.atan") +TVM_REGISTER_OP("tirx.atan") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.tanh") +TVM_REGISTER_OP("tirx.tanh") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.sqrt") +TVM_REGISTER_OP("tirx.sqrt") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.pow").set_attr("cuda.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.pow") + .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.popcount") +TVM_REGISTER_OP("tirx.popcount") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.tvm_warp_shuffle") +TVM_REGISTER_OP("tirx.tvm_warp_shuffle") .set_attr("cuda.FLowerIntrinsic", DispatchCUDAShuffle); -TVM_REGISTER_OP("tir.tvm_warp_shuffle_up") +TVM_REGISTER_OP("tirx.tvm_warp_shuffle_up") .set_attr("cuda.FLowerIntrinsic", DispatchCUDAShuffle); -TVM_REGISTER_OP("tir.tvm_warp_shuffle_down") +TVM_REGISTER_OP("tirx.tvm_warp_shuffle_down") .set_attr("cuda.FLowerIntrinsic", DispatchCUDAShuffle); -TVM_REGISTER_OP("tir.tvm_warp_activemask") +TVM_REGISTER_OP("tirx.tvm_warp_activemask") .set_attr("cuda.FLowerIntrinsic", DispatchCUDAWarpActiveMask); -TVM_REGISTER_OP("tir.fmod") +TVM_REGISTER_OP("tirx.fmod") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); // Register low-level builtin ops. // TODO(tvm-team): consider make CUDA its own subfolder and create a file for low-level builtins. -TVM_REGISTER_OP("tir.cuda.__shfl_sync") +TVM_REGISTER_OP("tirx.cuda.__shfl_sync") .set_num_inputs(4) .add_argument("mask", "Expr", "The thread mask.") .add_argument("var", "Expr", "The variable to sync.") @@ -247,7 +248,7 @@ TVM_REGISTER_OP("tir.cuda.__shfl_sync") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); -TVM_REGISTER_OP("tir.cuda.__shfl_up_sync") +TVM_REGISTER_OP("tirx.cuda.__shfl_up_sync") .set_num_inputs(4) .add_argument("mask", "Expr", "The thread mask.") .add_argument("var", "Expr", "The variable to sync.") @@ -257,7 +258,7 @@ TVM_REGISTER_OP("tir.cuda.__shfl_up_sync") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); -TVM_REGISTER_OP("tir.cuda.__shfl_down_sync") +TVM_REGISTER_OP("tirx.cuda.__shfl_down_sync") .set_num_inputs(4) .add_argument("mask", "Expr", "The thread mask.") .add_argument("var", "Expr", "The variable to sync.") @@ -267,7 +268,7 @@ TVM_REGISTER_OP("tir.cuda.__shfl_down_sync") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); -TVM_REGISTER_OP("tir.cuda.__activemask") +TVM_REGISTER_OP("tirx.cuda.__activemask") .set_num_inputs(0) .set_attr("TGlobalSymbol", "__activemask") .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index be888d47fb98..d61bf1256f64 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -21,24 +21,24 @@ * \file intrin_rule_metal.cc * \brief Metal intrinsic rules. */ -#include +#include #include "../intrin_rule.h" namespace tvm { namespace codegen { namespace intrin { -using tir::FLowerIntrinsic; +using tirx::FLowerIntrinsic; struct MetalWarpIntrinsic { const Op operator()(DataType t, const Op& orig_op) const { if (orig_op.same_as(builtin::tvm_warp_shuffle())) { - return Op::Get("tir.metal.simd_shuffle"); + return Op::Get("tirx.metal.simd_shuffle"); } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { - return Op::Get("tir.metal.simd_shuffle_up"); + return Op::Get("tirx.metal.simd_shuffle_up"); } else { TVM_FFI_ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); - return Op::Get("tir.metal.simd_shuffle_down"); + return Op::Get("tirx.metal.simd_shuffle_down"); } } }; @@ -52,99 +52,99 @@ static PrimExpr DispatchMetalShuffle(const PrimExpr& e) { return Call(call->dtype, T()(call->dtype, Downcast(call->op)), metal_args); } -TVM_REGISTER_OP("tir.clz").set_attr("metal.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.clz") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.floor") +TVM_REGISTER_OP("tirx.floor") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.ceil") +TVM_REGISTER_OP("tirx.ceil") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.trunc") +TVM_REGISTER_OP("tirx.trunc") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.fabs") +TVM_REGISTER_OP("tirx.fabs") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.round") +TVM_REGISTER_OP("tirx.round") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.nearbyint") +TVM_REGISTER_OP("tirx.nearbyint") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.exp").set_attr("metal.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.exp") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.exp2") +TVM_REGISTER_OP("tirx.exp2") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.exp10") +TVM_REGISTER_OP("tirx.exp10") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.log").set_attr("metal.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.log") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.log2") +TVM_REGISTER_OP("tirx.log2") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.log10") +TVM_REGISTER_OP("tirx.log10") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.tanh") +TVM_REGISTER_OP("tirx.tanh") .set_attr("metal.FLowerIntrinsic", DispatchNumericalStableTanh); -TVM_REGISTER_OP("tir.sqrt") +TVM_REGISTER_OP("tirx.sqrt") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.pow").set_attr("metal.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.pow") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.popcount") +TVM_REGISTER_OP("tirx.popcount") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.fmod") +TVM_REGISTER_OP("tirx.fmod") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.sin").set_attr("metal.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.sin") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.sinh") +TVM_REGISTER_OP("tirx.sinh") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.cos").set_attr("metal.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.cos") + .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.cosh") +TVM_REGISTER_OP("tirx.cosh") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.erf").set_attr("metal.FLowerIntrinsic", DispatchFastErf); +TVM_REGISTER_OP("tirx.erf").set_attr("metal.FLowerIntrinsic", DispatchFastErf); -TVM_REGISTER_OP("tir.tvm_warp_shuffle") +TVM_REGISTER_OP("tirx.tvm_warp_shuffle") .set_attr("metal.FLowerIntrinsic", DispatchMetalShuffle); -TVM_REGISTER_OP("tir.tvm_warp_shuffle_up") +TVM_REGISTER_OP("tirx.tvm_warp_shuffle_up") .set_attr("metal.FLowerIntrinsic", DispatchMetalShuffle); -TVM_REGISTER_OP("tir.tvm_warp_shuffle_down") +TVM_REGISTER_OP("tirx.tvm_warp_shuffle_down") .set_attr("metal.FLowerIntrinsic", DispatchMetalShuffle); // Register low-level builtin ops. -TVM_REGISTER_OP("tir.metal.simd_shuffle") +TVM_REGISTER_OP("tirx.metal.simd_shuffle") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("lane", "Expr", "The source thread id.") .set_attr("TGlobalSymbol", "simd_shuffle") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TVM_REGISTER_OP("tir.metal.simd_shuffle_up") +TVM_REGISTER_OP("tirx.metal.simd_shuffle_up") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be added.") .set_attr("TGlobalSymbol", "simd_shuffle_up") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TVM_REGISTER_OP("tir.metal.simd_shuffle_down") +TVM_REGISTER_OP("tirx.metal.simd_shuffle_down") .set_num_inputs(2) .add_argument("var", "Expr", "The variable to sync.") .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") diff --git a/src/target/source/intrin_rule_opencl.cc b/src/target/source/intrin_rule_opencl.cc index 01c1e038cd1f..85084b1a1649 100644 --- a/src/target/source/intrin_rule_opencl.cc +++ b/src/target/source/intrin_rule_opencl.cc @@ -22,82 +22,82 @@ * \brief OpenCL intrinsic rules. */ #include -#include +#include #include "../intrin_rule.h" namespace tvm { namespace codegen { namespace intrin { -using tir::FLowerIntrinsic; +using tirx::FLowerIntrinsic; -TVM_REGISTER_OP("tir.clz").set_attr("opencl.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.clz") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.floor") +TVM_REGISTER_OP("tirx.floor") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.ceil") +TVM_REGISTER_OP("tirx.ceil") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.trunc") +TVM_REGISTER_OP("tirx.trunc") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.fabs") +TVM_REGISTER_OP("tirx.fabs") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.round") +TVM_REGISTER_OP("tirx.round") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.nearbyint") +TVM_REGISTER_OP("tirx.nearbyint") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.exp").set_attr("opencl.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.exp") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.erf").set_attr("opencl.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.erf") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.exp2") +TVM_REGISTER_OP("tirx.exp2") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.exp10") +TVM_REGISTER_OP("tirx.exp10") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.log").set_attr("opencl.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.log") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.log2") +TVM_REGISTER_OP("tirx.log2") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.log10") +TVM_REGISTER_OP("tirx.log10") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.tanh") +TVM_REGISTER_OP("tirx.tanh") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.sqrt") +TVM_REGISTER_OP("tirx.sqrt") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.pow").set_attr("opencl.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.pow") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.popcount") +TVM_REGISTER_OP("tirx.popcount") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.fmod") +TVM_REGISTER_OP("tirx.fmod") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.sin").set_attr("opencl.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.sin") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.sinh") +TVM_REGISTER_OP("tirx.sinh") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.cos").set_attr("opencl.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.cos") + .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.cosh") +TVM_REGISTER_OP("tirx.cosh") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); // There is no warp shuffle instruction in standard OpenCL @@ -114,7 +114,7 @@ static PrimExpr DispatchIntelShuffle(const PrimExpr& e) { return Call(call->dtype, builtin::call_pure_extern(), opencl_args); } -TVM_REGISTER_OP("tir.tvm_warp_shuffle") +TVM_REGISTER_OP("tirx.tvm_warp_shuffle") .set_attr("opencl.FLowerIntrinsic", DispatchIntelShuffle); } // namespace intrin diff --git a/src/target/source/intrin_rule_webgpu.cc b/src/target/source/intrin_rule_webgpu.cc index f3e561f71477..968df9a579f4 100644 --- a/src/target/source/intrin_rule_webgpu.cc +++ b/src/target/source/intrin_rule_webgpu.cc @@ -22,7 +22,7 @@ * \brief WebGPU intrinsic rules. */ #include -#include +#include #include "../intrin_rule.h" @@ -30,7 +30,7 @@ namespace tvm { namespace codegen { namespace intrin { -using tir::FLowerIntrinsic; +using tirx::FLowerIntrinsic; // See full list of builtin: https://www.w3.org/TR/WGSL/#builtin-functions @@ -38,80 +38,80 @@ struct ReturnAbs { std::string operator()(DataType t, std::string name) const { return "abs"; } }; -TVM_REGISTER_OP("tir.fabs") +TVM_REGISTER_OP("tirx.fabs") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.acos") +TVM_REGISTER_OP("tirx.acos") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.acosh") +TVM_REGISTER_OP("tirx.acosh") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.asin") +TVM_REGISTER_OP("tirx.asin") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.asinh") +TVM_REGISTER_OP("tirx.asinh") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.atan") +TVM_REGISTER_OP("tirx.atan") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.atan2") +TVM_REGISTER_OP("tirx.atan2") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.ceil") +TVM_REGISTER_OP("tirx.ceil") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.cos").set_attr("webgpu.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.cos") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.cosh") +TVM_REGISTER_OP("tirx.cosh") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.exp").set_attr("webgpu.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.exp") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.exp2") +TVM_REGISTER_OP("tirx.exp2") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.floor") +TVM_REGISTER_OP("tirx.floor") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.fma").set_attr("webgpu.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.fma") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.log").set_attr("webgpu.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.log") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.log2") +TVM_REGISTER_OP("tirx.log2") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.pow").set_attr("webgpu.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.pow") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.round") +TVM_REGISTER_OP("tirx.round") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.sin").set_attr("webgpu.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.sin") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.sinh") +TVM_REGISTER_OP("tirx.sinh") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.sqrt") +TVM_REGISTER_OP("tirx.sqrt") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.tan").set_attr("webgpu.FLowerIntrinsic", - DispatchPureExtern); +TVM_REGISTER_OP("tirx.tan") + .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); -TVM_REGISTER_OP("tir.tanh") +TVM_REGISTER_OP("tirx.tanh") .set_attr("webgpu.FLowerIntrinsic", DispatchNumericalStableTanh); -TVM_REGISTER_OP("tir.trunc") +TVM_REGISTER_OP("tirx.trunc") .set_attr("webgpu.FLowerIntrinsic", DispatchPureExtern); // extra dispatch -TVM_REGISTER_OP("tir.erf").set_attr("webgpu.FLowerIntrinsic", DispatchFastErf); +TVM_REGISTER_OP("tirx.erf").set_attr("webgpu.FLowerIntrinsic", DispatchFastErf); } // namespace intrin } // namespace codegen diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index b68957619038..7a8abce7df8e 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -24,15 +24,15 @@ #include "codegen_spirv.h" #include -#include -#include -#include +#include +#include +#include #include #include "../../runtime/pack_args.h" #include "../../runtime/vulkan/vulkan_common.h" -#include "../../tir/transform/ir_utils.h" +#include "../../tirx/transform/ir_utils.h" namespace tvm { namespace codegen { @@ -41,7 +41,7 @@ CodeGenSPIRV::CodeGenSPIRV(Target target) : spirv_support_(target) {} runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) { this->InitFuncState(); - TVM_FFI_ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) + TVM_FFI_ICHECK(f->HasNonzeroAttr(tirx::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; std::vector pod_args; uint32_t i_buffer = 0; @@ -141,7 +141,7 @@ spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& ext spirv::Value v; if (ts.rank == 1) { v = builder_->GetLocalID(ts.dim_index); - auto* sizeptr = extent.as(); + auto* sizeptr = extent.as(); TVM_FFI_ICHECK(sizeptr) << "SPIRV only allows constant thread group size " << " get " << extent; TVM_FFI_ICHECK_GE(ts.dim_index, 0) << "vthread should have been optimized out by here"; @@ -165,7 +165,7 @@ spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) { // Synchronize control at the Subgroup level, but memory at the // Workgroup level. This is because different invocations in a // subgroup may have each modified memory that exists at the - // workgroup scope. This should be changed if/when tir exposes + // workgroup scope. This should be changed if/when tirx exposes // more information as to which memory access needs to be // synchronized. sync_scope = spv::ScopeSubgroup; @@ -452,7 +452,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { PrimExpr index_d = op->args[1]; PrimExpr index_a = op->args[3]; PrimExpr index_b = op->args[5]; - tvm::tir::ExprDeepEqual expr_equal; + tvm::tirx::ExprDeepEqual expr_equal; PrimExpr index_c = op->args[7]; bool is_equal = ((buffer_d == buffer_c) && expr_equal(index_d, index_c)); spirv::SType& fragment_type_d = fragment_info_[buffer_d].stype; @@ -857,7 +857,7 @@ void CodeGenSPIRV::VisitStmt_(const AllocBufferNode* op) { TVM_FFI_ICHECK(!var_map_.count(var_node)); var_map_[var_node] = buf; - if (op->annotations.count(tir::attr::kVolatile)) { + if (op->annotations.count(tirx::attr::kVolatile)) { storage_info_[var_node].is_volatile = true; } } @@ -867,7 +867,7 @@ void CodeGenSPIRV::VisitStmt_(const DeclBufferNode* op) { } void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == tir::attr::thread_extent) { + if (op->attr_key == tirx::attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { // Will throw error if rebinding same local variable to a different extent. diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 8daac154ecd9..a77ff82d3dde 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -26,10 +26,10 @@ #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include @@ -45,7 +45,7 @@ namespace tvm { namespace codegen { -using namespace tir; +using namespace tirx; /*! * \brief Code generator into SPIRV @@ -119,7 +119,7 @@ class CodeGenSPIRV : public ExprFunctor, protected: /*! \brief Storage information for a buffer */ struct StorageInfo { - /*! \brief The name of the tir::Var for the buffer + /*! \brief The name of the tirx::Var for the buffer * * Used for error messages. */ diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index 7415367df8b3..cde1e0165f82 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -22,10 +22,10 @@ */ #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include "../intrin_rule.h" @@ -35,7 +35,7 @@ namespace spirv { // num_signature means number of arguments used to query signature template PrimExpr CallGLSLIntrin(PrimExpr e, const ffi::Array& args) { - const tir::CallNode* call = e.as(); + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); ffi::Array cargs; // intrin id. @@ -44,12 +44,12 @@ PrimExpr CallGLSLIntrin(PrimExpr e, const ffi::Array& args) { for (PrimExpr arg : args) { cargs.push_back(arg); } - return tir::Call(call->dtype, tir::builtin::call_spirv_pure_glsl450(), cargs); + return tirx::Call(call->dtype, tirx::builtin::call_spirv_pure_glsl450(), cargs); } template PrimExpr CallGLSLIntrin(PrimExpr e) { - const tir::CallNode* call = e.as(); + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); return CallGLSLIntrin(e, call->args); } @@ -60,91 +60,91 @@ inline PrimExpr DispatchGLSLPureIntrin(const PrimExpr& e) { } namespace intrin { -using tir::FLowerIntrinsic; -TVM_REGISTER_OP("tir.floor") +using tirx::FLowerIntrinsic; +TVM_REGISTER_OP("tirx.floor") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.ceil") +TVM_REGISTER_OP("tirx.ceil") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.round") +TVM_REGISTER_OP("tirx.round") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.nearbyint") +TVM_REGISTER_OP("tirx.nearbyint") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.trunc") +TVM_REGISTER_OP("tirx.trunc") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.fabs") +TVM_REGISTER_OP("tirx.fabs") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.exp").set_attr("vulkan.FLowerIntrinsic", - DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tirx.exp") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.exp2") +TVM_REGISTER_OP("tirx.exp2") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.sin").set_attr("vulkan.FLowerIntrinsic", - DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tirx.sin") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.cos").set_attr("vulkan.FLowerIntrinsic", - DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tirx.cos") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.tan").set_attr("vulkan.FLowerIntrinsic", - DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tirx.tan") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.asin") +TVM_REGISTER_OP("tirx.asin") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.acos") +TVM_REGISTER_OP("tirx.acos") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.atan") +TVM_REGISTER_OP("tirx.atan") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.sinh") +TVM_REGISTER_OP("tirx.sinh") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.cosh") +TVM_REGISTER_OP("tirx.cosh") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.tanh") +TVM_REGISTER_OP("tirx.tanh") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.asinh") +TVM_REGISTER_OP("tirx.asinh") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.acosh") +TVM_REGISTER_OP("tirx.acosh") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.atanh") +TVM_REGISTER_OP("tirx.atanh") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.atan2") +TVM_REGISTER_OP("tirx.atan2") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.log").set_attr("vulkan.FLowerIntrinsic", - DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tirx.log") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.log2") +TVM_REGISTER_OP("tirx.log2") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.sqrt") +TVM_REGISTER_OP("tirx.sqrt") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.pow").set_attr("vulkan.FLowerIntrinsic", - DispatchGLSLPureIntrin); +TVM_REGISTER_OP("tirx.pow") + .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.erf").set_attr("vulkan.FLowerIntrinsic", - codegen::intrin ::DispatchFastErf); +TVM_REGISTER_OP("tirx.erf") + .set_attr("vulkan.FLowerIntrinsic", codegen::intrin ::DispatchFastErf); } // namespace intrin namespace legalize { -using tir::FLegalize; -TVM_REGISTER_OP("tir.clz").set_attr( - "vulkan.FLegalize", [](const PrimExpr& e) -> PrimExpr { - const tir::CallNode* call = e.as(); +using tirx::FLegalize; +TVM_REGISTER_OP("tirx.clz") + .set_attr("vulkan.FLegalize", [](const PrimExpr& e) -> PrimExpr { + const tirx::CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); TVM_FFI_ICHECK_EQ(call->args.size(), 1); PrimExpr arg = call->args[0]; @@ -152,8 +152,8 @@ TVM_REGISTER_OP("tir.clz").set_attr( if (arg.dtype().bits() == 64) { // SPIR-V FindUMsb intrinsic only supports 32 bit input auto int32 = DataType::Int(32); - PrimExpr arg_hi32 = tvm::tir::Cast(int32, arg >> 32); - PrimExpr arg_lo32 = tvm::tir::Cast(int32, arg); + PrimExpr arg_hi32 = tvm::tirx::Cast(int32, arg >> 32); + PrimExpr arg_lo32 = tvm::tirx::Cast(int32, arg); PrimExpr msb_hi = CallGLSLIntrin(e, {arg_hi32}); PrimExpr msb_lo = CallGLSLIntrin(e, {arg_lo32}); msb = tvm::if_then_else(arg_hi32 == 0, msb_lo, msb_hi + 32); diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 8be080406506..2fab0ae0884a 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -25,7 +25,7 @@ #define TVM_TARGET_SPIRV_IR_BUILDER_H_ #include -#include +#include // clang-format off #include diff --git a/src/target/spirv/spirv_utils.cc b/src/target/spirv/spirv_utils.cc index 1724abb52b7f..df420a8f1097 100644 --- a/src/target/spirv/spirv_utils.cc +++ b/src/target/spirv/spirv_utils.cc @@ -30,7 +30,7 @@ #include "codegen_spirv.h" #endif -#include +#include #include #include @@ -118,7 +118,7 @@ std::pair, std::string> Lo auto postproc = tvm::ffi::Function::GetGlobal("tvm_callback_vulkan_postproc"); - mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); + mod = tirx::transform::PointerValueTypeRewrite()(std::move(mod)); CodeGenSPIRV cg(target); @@ -143,7 +143,7 @@ std::pair, std::string> Lo ss << path << "/" << f_name << "_"; std::string prefix = ss.str(); - std::ofstream(prefix + "tir.txt") << f; + std::ofstream(prefix + "tirx.txt") << f; std::ofstream(prefix + "spv.txt") << spirv_tools.BinaryToText(shader.data); std::ofstream(prefix + "spv.spv", std::ios::binary) .write(reinterpret_cast(shader.data.data()), diff --git a/src/target/target.cc b/src/target/target.cc index 7d03d5e14c69..4f1b4e3af2cb 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 5ee7feb11608..40ac3232c331 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -26,10 +26,10 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include @@ -37,7 +37,7 @@ namespace tvm { namespace te { -using namespace tir; +using namespace tirx; TVM_FFI_STATIC_INIT_BLOCK() { OperationNode::RegisterReflection(); @@ -56,7 +56,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) /// Verify if ComputeOp is valid with respect to Reduce operations. static void VerifyComputeOp(const ComputeOpNode* op); -static inline void AssertReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { +static inline void AssertReduceEqual(const tirx::ReduceNode* a, const tirx::ReduceNode* b) { const char* shared_text = "When a TE compute node produces multiple outputs, " "each of which is a reduction, " @@ -147,8 +147,8 @@ ComputeOp::ComputeOp(std::string name, std::string tag, ffi::Mapattrs = std::move(attrs); n->axis = std::move(axis); n->body = std::move(body); - if (n->body[0]->IsInstance()) { - const tir::ReduceNode* reduce = n->body[0].as(); + if (n->body[0]->IsInstance()) { + const tirx::ReduceNode* reduce = n->body[0].as(); n->reduce_axis = reduce->axis; } VerifyComputeOp(n.get()); @@ -169,8 +169,8 @@ ffi::Array ComputeOpNode::InputTensors() const { ffi::Array ret; std::unordered_set visited; for (auto& e : body) { - tir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) { - if (auto* pload = n.as()) { + tirx::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) { + if (auto* pload = n.as()) { Tensor t = Downcast(pload->producer); if (!visited.count(t)) { ret.push_back(t); @@ -194,12 +194,12 @@ namespace { * must be Reduce as well; and their inputs should have the * same attribute except value_index. */ -class ComputeVerifier final : protected tir::ExprVisitor { +class ComputeVerifier final : protected tirx::ExprVisitor { public: /// Special member functions //@{ explicit ComputeVerifier(const ComputeOpNode* compute) - : compute_(compute), reduce_(compute->body[0].as()) {} + : compute_(compute), reduce_(compute->body[0].as()) {} virtual ~ComputeVerifier() = default; ComputeVerifier(const ComputeVerifier&) = delete; ComputeVerifier(ComputeVerifier&&) = delete; @@ -211,7 +211,7 @@ class ComputeVerifier final : protected tir::ExprVisitor { void Run() { for (const PrimExpr e : compute_->body) { // Check for consistency of top level reductions - const tir::ReduceNode* reduce = e.as(); + const tirx::ReduceNode* reduce = e.as(); TVM_FFI_ICHECK((reduce && reduce_) || (!reduce && !reduce_)) << "All ComputeOp should be consistent " << "with being Reduce operation or not."; @@ -234,7 +234,7 @@ class ComputeVerifier final : protected tir::ExprVisitor { --level_; } - void VisitExpr_(const tir::ReduceNode* op) final { + void VisitExpr_(const tirx::ReduceNode* op) final { // Check for non top level reductions TVM_FFI_ICHECK(0 == level_) << "Reductions are only allowed at the top level of compute. " << "Please create another tensor for further composition."; @@ -242,9 +242,9 @@ class ComputeVerifier final : protected tir::ExprVisitor { //@} private: - const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify - const tir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation - int level_{0}; ///< Level of op being processed + const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify + const tirx::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation + int level_{0}; ///< Level of op being processed }; } // namespace diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 14650efcb77d..d34222117f95 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -25,9 +25,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include @@ -37,12 +37,12 @@ #include #include "../../support/array.h" -#include "../../tir/ir/data_type_rewriter.h" -#include "../../tir/ir/functor_common.h" +#include "../../tirx/ir/data_type_rewriter.h" +#include "../../tirx/ir/functor_common.h" #include "graph.h" namespace tvm { -namespace tir { +namespace tirx { /*! \brief The helper mutator that transforms ProducerLoad to BufferLoad */ class ProducerToBufferTransformer : public StmtExprMutator { @@ -177,7 +177,7 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator { return block; } - std::unordered_map buffer2index_; + std::unordered_map buffer2index_; std::set layout_free_buffer_indices_; ffi::String topi_attr = "layout_free_placeholders"; std::vector blocklist = {"const_matrix", @@ -259,10 +259,10 @@ ffi::Array GenerateOutputBuffers(const te::ComputeOp& compute_op, Create }; PrimExpr expr_body = compute_op->body[0]; tensors.push_back(compute_op.output(0)); - const tir::ReduceNode* reduce = expr_body.as(); + const tirx::ReduceNode* reduce = expr_body.as(); // specially handle reduction inline for multiplre reductions. for (size_t k = 1; k < compute_op->body.size(); ++k) { - const tir::ReduceNode* reduce_ = compute_op->body[k].as(); + const tirx::ReduceNode* reduce_ = compute_op->body[k].as(); TVM_FFI_ICHECK(reduce_); TVM_FFI_ICHECK(f_reducer_equal(reduce_, reduce)) << "The Reduce inputs of ComputeOp should have the same attribute except value_index, " @@ -750,7 +750,7 @@ PrimFunc GenerateAndCompletePrimFunc(const ffi::Array& arg_list, /*body=*/SeqStmt::Flatten(root_stmts), /*ret_type=*/VoidType(), /*buffer_map=*/std::move(buffer_map)), - {{"global_symbol", ffi::String("main")}, {"tir.noalias", true}}); + {{"global_symbol", ffi::String("main")}, {"tirx.noalias", true}}); const auto fcomplete = tvm::ffi::Function::GetGlobal("script.Complete"); TVM_FFI_ICHECK(fcomplete.has_value()); func = (*fcomplete)(std::move(func), info->root_alloc).cast(); @@ -812,7 +812,7 @@ PrimFunc GenerateAndCompletePrimFunc(const ffi::Array& arg_tir_var_li auto it = info->tensor2buffers.find(tensor); TVM_FFI_ICHECK(it != info->tensor2buffers.end()); buffer_map.Set(arg, it->second); - } else if (auto var = arg.as()) { + } else if (auto var = arg.as()) { parameters.push_back(var.value()); } } @@ -820,7 +820,7 @@ PrimFunc GenerateAndCompletePrimFunc(const ffi::Array& arg_tir_var_li /*body=*/SeqStmt::Flatten(root_stmts), /*ret_type=*/VoidType(), /*buffer_map=*/std::move(buffer_map)), - {{"global_symbol", ffi::String("main")}, {"tir.noalias", true}}); + {{"global_symbol", ffi::String("main")}, {"tirx.noalias", true}}); const auto fcomplete = tvm::ffi::Function::GetGlobal("script.Complete"); TVM_FFI_ICHECK(fcomplete.has_value()); func = (*fcomplete)(std::move(func), info->root_alloc).cast(); @@ -861,5 +861,5 @@ PrimFunc CreatePrimFunc(const ffi::Array& arg_list, return result; } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index a5bd9b16ed10..4ebaa1a2bbc5 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -22,12 +22,12 @@ #include #include -#include +#include #include namespace tvm { -namespace tir { +namespace tirx { /*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ PrimFunc CreatePrimFunc(const ffi::Array& arg_list, @@ -37,7 +37,7 @@ PrimFunc CreatePrimFunc(const ffi::Array& arg_list, PrimFunc CreatePrimFunc(const ffi::Array& arg_list, std::optional index_dtype_override); -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TE_OPERATION_CREATE_PRIMFUNC_H_ diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index b15ae9f67624..714915609188 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -25,11 +25,11 @@ #include #include #include -#include +#include namespace tvm { namespace te { -using namespace tir; +using namespace tirx; TVM_FFI_STATIC_INIT_BLOCK() { ExternOpNode::RegisterReflection(); } diff --git a/src/te/operation/graph.cc b/src/te/operation/graph.cc index bddea5f7f2d4..34a885480df8 100644 --- a/src/te/operation/graph.cc +++ b/src/te/operation/graph.cc @@ -26,8 +26,8 @@ #include #include #include -#include -#include +#include +#include #include #include diff --git a/src/te/operation/graph.h b/src/te/operation/graph.h index dc2b211cf3cb..db3e5c53293b 100644 --- a/src/te/operation/graph.h +++ b/src/te/operation/graph.h @@ -25,7 +25,7 @@ #define TVM_TE_OPERATION_GRAPH_H_ #include -#include +#include namespace tvm { namespace te { diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 25d09c931a22..09f464f466d1 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -24,11 +24,11 @@ #include #include #include -#include +#include namespace tvm { namespace te { -using namespace tir; +using namespace tirx; TVM_FFI_STATIC_INIT_BLOCK() { ScanOpNode::RegisterReflection(); } diff --git a/src/tir/analysis/check_contains.cc b/src/tirx/analysis/check_contains.cc similarity index 98% rename from src/tir/analysis/check_contains.cc rename to src/tirx/analysis/check_contains.cc index 2ba752905339..4ae3535af78e 100644 --- a/src/tir/analysis/check_contains.cc +++ b/src/tirx/analysis/check_contains.cc @@ -25,12 +25,12 @@ #include "check_contains.h" -#include +#include #include namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Toplevel (static) function that tells if an expression contains a subexpression that @@ -94,5 +94,5 @@ void CheckContains::VisitStmt(const Stmt& stmt) { // As otherwise we already have our answer } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/analysis/check_contains.h b/src/tirx/analysis/check_contains.h similarity index 93% rename from src/tir/analysis/check_contains.h rename to src/tirx/analysis/check_contains.h index 8b1a9e21aee9..41411d4249e3 100644 --- a/src/tir/analysis/check_contains.h +++ b/src/tirx/analysis/check_contains.h @@ -26,11 +26,11 @@ #ifndef TVM_TIR_ANALYSIS_CHECK_CONTAINS_H_ #define TVM_TIR_ANALYSIS_CHECK_CONTAINS_H_ -#include -#include // For the class StmtExprVisitor +#include +#include // For the class StmtExprVisitor namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Visitor which tells if a given expression or statement contains a subexpression @@ -54,7 +54,7 @@ class CheckContains : public StmtExprVisitor { bool contains_it_ = false; }; -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_ANALYSIS_CHECK_CONTAINS_H_ diff --git a/src/tir/analysis/collect_call_map.cc b/src/tirx/analysis/collect_call_map.cc similarity index 82% rename from src/tir/analysis/collect_call_map.cc rename to src/tirx/analysis/collect_call_map.cc index 98f7585c6b79..4acb0859ec96 100644 --- a/src/tir/analysis/collect_call_map.cc +++ b/src/tirx/analysis/collect_call_map.cc @@ -19,17 +19,17 @@ /*! * - * \file src/tir/analysis/collect_call_map.cc + * \file src/tirx/analysis/collect_call_map.cc * * \brief Collect cross-IR call graph */ #include -#include -#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { namespace { using ir::CalleeCollector; @@ -48,10 +48,10 @@ struct Visitor : StmtExprVisitor { } // namespace TVM_STATIC_IR_FUNCTOR(CalleeCollector, vtable) - .set_dispatch([](const ObjectRef& func, CalleeCollector* collector) { + .set_dispatch([](const ObjectRef& func, CalleeCollector* collector) { Visitor visitor{collector}; - visitor(Downcast(func)->body); + visitor(Downcast(func)->body); }); -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tirx/analysis/control_flow_graph.cc similarity index 97% rename from src/tir/analysis/control_flow_graph.cc rename to src/tirx/analysis/control_flow_graph.cc index d5214f085cf6..eec474a8cb41 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tirx/analysis/control_flow_graph.cc @@ -25,11 +25,11 @@ #include "control_flow_graph.h" #include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include #include @@ -47,7 +47,7 @@ #include "../../arith/unwrap_vector_expr.h" namespace tvm { -namespace tir { +namespace tirx { using namespace arith; @@ -70,7 +70,7 @@ ffi::Optional SubstituteParamValues(const ffi::Array& param_vars, << "Expression was defined as having " << param_vars.size() << " parameters, but received " << param_values.size() << " arguments."; - ffi::Map var_map; + ffi::Map var_map; for (size_t i = 0; i < param_values.size(); i++) { var_map.Set(param_vars[i], param_values[i]); } @@ -253,8 +253,8 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { std::vector buffer_exprs; for (const auto& expr : ExtractComponents(assumption)) { - auto side_effect = tir::SideEffect(expr); - if (side_effect <= tir::CallEffectKind::kPure) { + auto side_effect = tirx::SideEffect(expr); + if (side_effect <= tirx::CallEffectKind::kPure) { // Pulling out portions of the assumption that do not depend // on a buffer value allows the following two forms to be // treated identically. @@ -262,7 +262,7 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { // Option 1: if i < 3: T.assume(buf[i] == value) // Option 2: T.assume(i>=3 or buf[i] == value) additional_predicate = additional_predicate && logical_not(expr); - } else if (side_effect == tir::CallEffectKind::kReadState) { + } else if (side_effect == tirx::CallEffectKind::kReadState) { buffer_exprs.push_back(expr); } else { TVM_FFI_THROW(InternalError) @@ -279,7 +279,7 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { TVM_FFI_ICHECK_EQ(buffer_exprs.size(), 1) << "T.assume must contain only a single buffer expression"; - auto* as_equal_node = buffer_exprs[0].as(); + auto* as_equal_node = buffer_exprs[0].as(); TVM_FFI_ICHECK(as_equal_node || !from_assume_statement) << "T.assume buffer constraint must be of the form 'buffer[indices] == " "value', but received " @@ -291,12 +291,12 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { return; } - tir::BufferLoad load; + tirx::BufferLoad load; PrimExpr value; - if (auto opt = as_equal_node->a.as()) { + if (auto opt = as_equal_node->a.as()) { load = opt.value(); value = as_equal_node->b; - } else if (auto opt = as_equal_node->b.as()) { + } else if (auto opt = as_equal_node->b.as()) { load = opt.value(); value = as_equal_node->a; } else if (!from_assume_statement) { @@ -306,7 +306,7 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { << "T.assume buffer constraint must be of the form 'buffer[indices] == value'"; } - auto has_side_effect = tir::SideEffect(value) > tir::CallEffectKind::kPure; + auto has_side_effect = tirx::SideEffect(value) > tirx::CallEffectKind::kPure; TVM_FFI_ICHECK(!has_side_effect || !from_assume_statement) << "Buffer value in constraint must be pure expression, but was " << value; if (has_side_effect) { @@ -537,10 +537,10 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { : self(self), analyzer_context(&self->analyzer_, constraint) { old_num_constraints = self->conditions_.size(); - auto side_effect = tir::SideEffect(constraint); - if (side_effect <= tir::CallEffectKind::kPure) { + auto side_effect = tirx::SideEffect(constraint); + if (side_effect <= tirx::CallEffectKind::kPure) { self->conditions_.push_back(constraint); - } else if (side_effect <= tir::CallEffectKind::kReadState) { + } else if (side_effect <= tirx::CallEffectKind::kReadState) { assume = constraint; } @@ -638,7 +638,7 @@ class ControlFlowGraphBuilder final : public IRVisitorWithAnalyzer { }; std::pair> ControlFlowGraph::ControlFlowBlock::MakeBufferTouch( - const tir::Buffer& buf, ffi::Array index_variables, ffi::Array indices, + const tirx::Buffer& buf, ffi::Array index_variables, ffi::Array indices, BufferTouch::AccessType touch_type, PrimExpr known_value_expr) const { const auto& current_block = *this; @@ -809,7 +809,7 @@ std::pair> ControlFlowGraph::ControlFlowBlock: } BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph* graph, - const tir::Buffer& buf, + const tirx::Buffer& buf, const ffi::Array& indices, BufferTouch::AccessType touch_type, PrimExpr known_value_expr) const { @@ -822,7 +822,7 @@ BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph return buffer_touch; } -ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, int64_t max_simplification_steps, +ControlFlowGraph::ControlFlowGraph(const tirx::Stmt& stmt, int64_t max_simplification_steps, size_t max_revisits) : max_revisits_(max_revisits), max_simplification_steps_(max_simplification_steps) { ControlFlowGraphBuilder::Build(this, stmt); @@ -830,7 +830,7 @@ ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, int64_t max_simplifica BackwardPropagateUnusedValues(); } -void ControlFlowGraph::RemoveStore(const tir::BufferStore& store) { +void ControlFlowGraph::RemoveStore(const tirx::BufferStore& store) { size_t context_index = [&]() { auto it = control_flow_lookup_.find(store.get()); TVM_FFI_ICHECK(it != control_flow_lookup_.end()) @@ -951,7 +951,7 @@ std::ostream& operator<<(std::ostream& os, const BufferState& state) { } PrimExpr BufferState::SubstituteKnownBufferValues( - PrimExpr expr, const ffi::Map>& axis_var_lookup, + PrimExpr expr, const ffi::Map>& axis_var_lookup, Analyzer* analyzer) const { BufferConstraintApply mutator(axis_var_lookup, constraints_, analyzer); return mutator(std::move(expr)); @@ -966,7 +966,7 @@ void BufferState::AddCondition(const PrimExpr& condition) { void BufferState::Substitute(const ffi::Map& var_remap, Analyzer* analyzer) { if (var_remap.size()) { for (auto& prior : constraints_) { - PrimExpr updated = tvm::tir::Substitute(prior.predicate, var_remap); + PrimExpr updated = tvm::tirx::Substitute(prior.predicate, var_remap); if (!updated.same_as(prior.predicate)) { prior.predicate = SimplifyAsAndOfOrs(updated, analyzer); } @@ -1287,7 +1287,7 @@ void BufferState::BackpropUnusedIndices(const ffi::Map>& // Otherwise, add new "touch" to represent the unused values for (auto [buffer, predicate] : regions_written) { constraints_.push_back( - BufferTouch{buffer, predicate, tir::Call(buffer->dtype, builtin::undef(), {})}); + BufferTouch{buffer, predicate, tirx::Call(buffer->dtype, builtin::undef(), {})}); } // If buffer is read out, narrow the predicate @@ -1621,7 +1621,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional flow_ } } -bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store, +bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tirx::BufferStore& store, const Stmt& context) const { ffi::Optional> index_variables = GetIndexVariables(store->buffer); if (!index_variables) { @@ -1662,7 +1662,7 @@ bool ControlFlowGraph::IsOverwrittenWithoutEffect(const tir::BufferStore& store, return false; } -PrimExpr ControlFlowGraph::SimplifyInContext(PrimExpr expr, const tir::Stmt& context, +PrimExpr ControlFlowGraph::SimplifyInContext(PrimExpr expr, const tirx::Stmt& context, Analyzer* analyzer) const { size_t context_index = [&]() { auto it = control_flow_lookup_.find(context.get()); @@ -1687,5 +1687,5 @@ PrimExpr ControlFlowGraph::SimplifyInContext(PrimExpr expr, const tir::Stmt& con return expr; } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/analysis/control_flow_graph.h b/src/tirx/analysis/control_flow_graph.h similarity index 98% rename from src/tir/analysis/control_flow_graph.h rename to src/tirx/analysis/control_flow_graph.h index 7bde341c38fa..8f97d06f384e 100644 --- a/src/tir/analysis/control_flow_graph.h +++ b/src/tirx/analysis/control_flow_graph.h @@ -25,9 +25,9 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include #include @@ -38,7 +38,7 @@ #define TVM_TIR_ANALYSIS_CONTROL_FLOW_GRAPH_H_ namespace tvm { -namespace tir { +namespace tirx { /*! \brief Represents an interaction with a buffer */ struct BufferTouch { @@ -49,7 +49,7 @@ struct BufferTouch { /*! \brief Buffer access occurs in BufferStore */ Write, - /*! \brief Buffer access occurs in tir::builtin::assume() */ + /*! \brief Buffer access occurs in tirx::builtin::assume() */ Assume, }; @@ -293,10 +293,10 @@ class BufferState { }; /*! - * \brief Represents the flow of control through a `tir::Stmt` + * \brief Represents the flow of control through a `tirx::Stmt` * * This class contains an internal representation of the possible - * control flow that may occur during execution of a `tir::Stmt`. It + * control flow that may occur during execution of a `tirx::Stmt`. It * consists of a collection of ControlFlowBlock objects, each of which * represents a subset of operations performed during execution, along * with edges that represent allowed transitions between @@ -445,7 +445,7 @@ class ControlFlowGraph { * * \param store The store to remove */ - void RemoveStore(const tir::BufferStore& store); + void RemoveStore(const tirx::BufferStore& store); friend std::ostream& operator<<(std::ostream& os, const ControlFlowGraph& pattern); @@ -662,6 +662,6 @@ class ControlFlowGraph { int64_t max_simplification_steps_; }; -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_ANALYSIS_CONTROL_FLOW_GRAPH_H_ diff --git a/src/tir/analysis/deep_equal.cc b/src/tirx/analysis/deep_equal.cc similarity index 97% rename from src/tir/analysis/deep_equal.cc rename to src/tirx/analysis/deep_equal.cc index 60a3e0d448d2..f164ba427ca8 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tirx/analysis/deep_equal.cc @@ -18,16 +18,16 @@ */ /*! - * \file tir/analysis/deep_equal.cc + * \file tirx/analysis/deep_equal.cc * \brief Deep equality checking. */ #include #include -#include -#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { #define DEFINE_DEEP_EQUAL_BIN_EXPR(OpNode) \ bool VisitExpr_(const OpNode* plhs, const PrimExpr& rhs) final { \ @@ -199,9 +199,9 @@ bool ExprDeepEqual::operator()(const PrimExpr& lhs, const PrimExpr& rhs) const { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "tir.analysis.expr_deep_equal", + "tirx.analysis.expr_deep_equal", [](const PrimExpr& lhs, const PrimExpr& rhs) { return ExprDeepEqual()(lhs, rhs); }); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/analysis/expr_complexity.cc b/src/tirx/analysis/expr_complexity.cc similarity index 90% rename from src/tir/analysis/expr_complexity.cc rename to src/tirx/analysis/expr_complexity.cc index e809668bb624..dc1d156ae3ac 100644 --- a/src/tir/analysis/expr_complexity.cc +++ b/src/tirx/analysis/expr_complexity.cc @@ -18,14 +18,14 @@ */ /*! - * \file tir/analysis/expr_complexity.cc + * \file tirx/analysis/expr_complexity.cc * \brief Calculate expr complexity. */ -#include -#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { /*! \brief Count the size of the PrimExpr. */ class PrimExprSizeCounter : public ExprVisitor { @@ -49,5 +49,5 @@ class PrimExprSizeCounter : public ExprVisitor { size_t CalculateExprComplexity(const PrimExpr& expr) { return PrimExprSizeCounter::Count(expr); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/analysis/side_effect.cc b/src/tirx/analysis/side_effect.cc similarity index 92% rename from src/tir/analysis/side_effect.cc rename to src/tirx/analysis/side_effect.cc index e20e60d24a66..b64ddccfeddf 100644 --- a/src/tir/analysis/side_effect.cc +++ b/src/tirx/analysis/side_effect.cc @@ -22,13 +22,13 @@ * \brief side effect analysis */ #include -#include -#include -#include -#include +#include +#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { class ExprSideEffect : public ExprVisitor { public: @@ -71,5 +71,5 @@ CallEffectKind SideEffect(const PrimExpr& e) { return visitor.kind_; } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/analysis/stmt_finding.cc b/src/tirx/analysis/stmt_finding.cc similarity index 81% rename from src/tir/analysis/stmt_finding.cc rename to src/tirx/analysis/stmt_finding.cc index 6093c5da9f5c..0ba6146213cc 100644 --- a/src/tir/analysis/stmt_finding.cc +++ b/src/tirx/analysis/stmt_finding.cc @@ -17,24 +17,24 @@ * under the License. */ #include -#include -#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var) { GlobalVar result = NullValue(); - // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc` + // Priority 1: PrimFunc marked as `tirx::attr::kIsEntryFunc` int num_prim_func = 0; - const tir::PrimFuncNode* main_func = nullptr; - const tir::PrimFuncNode* last_func = nullptr; + const tirx::PrimFuncNode* main_func = nullptr; + const tirx::PrimFuncNode* last_func = nullptr; for (const auto& kv : mod->functions) { GlobalVar gv = kv.first; BaseFunc base_func = kv.second; - if (const auto* func = base_func.as()) { + if (const auto* func = base_func.as()) { last_func = func; - if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + if (func->HasNonzeroAttr(tirx::attr::kIsEntryFunc)) { if (result_g_var != nullptr) { *result_g_var = gv; } @@ -64,5 +64,5 @@ const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var) return nullptr; } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/analysis/var_touch.cc b/src/tirx/analysis/var_touch.cc similarity index 95% rename from src/tir/analysis/var_touch.cc rename to src/tirx/analysis/var_touch.cc index 8c2ed6c43255..b8c1a4424ebb 100644 --- a/src/tir/analysis/var_touch.cc +++ b/src/tirx/analysis/var_touch.cc @@ -21,11 +21,11 @@ * \file var_touch.cc * \brief Implementation of simple passes */ -#include -#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { class VarTouchVisitor : public StmtExprVisitor { public: @@ -76,5 +76,5 @@ bool UsesVar(const PrimExpr& expr, std::function var_set) return visitor.use_var_; } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/analysis/var_use_def_analysis.cc b/src/tirx/analysis/var_use_def_analysis.cc similarity index 98% rename from src/tir/analysis/var_use_def_analysis.cc rename to src/tirx/analysis/var_use_def_analysis.cc index 6951e25f8c99..c7899d9a1984 100644 --- a/src/tir/analysis/var_use_def_analysis.cc +++ b/src/tirx/analysis/var_use_def_analysis.cc @@ -25,7 +25,7 @@ #include namespace tvm { -namespace tir { +namespace tirx { VarUseDefAnalyzer::VarUseDefAnalyzer(const ffi::Array& defined_vars, bool visit_thread_extent) : visit_thread_extent_(visit_thread_extent) { @@ -210,7 +210,7 @@ ffi::Array UndefinedVars(const PrimExpr& expr, const ffi::Array& args) TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def_packed( - "tir.analysis.UndefinedVars", [](ffi::PackedArgs args, ffi::Any* rv) { + "tirx.analysis.UndefinedVars", [](ffi::PackedArgs args, ffi::Any* rv) { if (auto opt_stmt = args[0].as()) { *rv = UndefinedVars(opt_stmt.value(), args[1].cast>()); } else if (auto opt_expr = args[0].as()) { @@ -221,5 +221,5 @@ TVM_FFI_STATIC_INIT_BLOCK() { } }); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/analysis/var_use_def_analysis.h b/src/tirx/analysis/var_use_def_analysis.h similarity index 94% rename from src/tir/analysis/var_use_def_analysis.h rename to src/tirx/analysis/var_use_def_analysis.h index a887acb1d3c4..b96887eb306f 100644 --- a/src/tir/analysis/var_use_def_analysis.h +++ b/src/tirx/analysis/var_use_def_analysis.h @@ -18,19 +18,19 @@ */ /*! - * \file tvm/src/tir/analysis/var_use_def_analyzer.h + * \file tvm/src/tirx/analysis/var_use_def_analyzer.h * \brief Variable definition and usage analysis class. */ #ifndef TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALYSIS_H_ #define TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALYSIS_H_ -#include -#include +#include +#include #include namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Visitor class to perform use/def analysis, also delete unreferenced lets. @@ -84,7 +84,7 @@ class VarUseDefAnalyzer : public StmtExprVisitor { void VisitBuffer(const Buffer& buffer); }; -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_ANALYSIS_VAR_USE_DEF_ANALYSIS_H_ diff --git a/src/tir/analysis/verify_memory.cc b/src/tirx/analysis/verify_memory.cc similarity index 93% rename from src/tir/analysis/verify_memory.cc rename to src/tirx/analysis/verify_memory.cc index 35f682519c2a..a6c3c0ef3552 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tirx/analysis/verify_memory.cc @@ -25,13 +25,13 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { namespace { /*! @@ -159,7 +159,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { bool in_thread_env_{false}; std::vector errs_; //@} - tir::PrimFunc func_{nullptr}; ///< Function to be verified. + tirx::PrimFunc func_{nullptr}; ///< Function to be verified. int dev_type_{kDLCPU}; ///< Device type std::unordered_map defs_; ///< Variable definitions }; @@ -188,7 +188,7 @@ bool VerifyMemory(const PrimFunc& func) { return VerifyMemory_(func).size() == 0 TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.analysis.verify_memory", VerifyMemory); + refl::GlobalDef().def("tirx.analysis.verify_memory", VerifyMemory); } namespace transform { @@ -211,14 +211,14 @@ Pass VerifyMemory() { } return mod; }; - return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "tirx.VerifyMemory", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.VerifyMemory", VerifyMemory); + refl::GlobalDef().def("tirx.transform.VerifyMemory", VerifyMemory); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/analysis/verify_ssa.cc b/src/tirx/analysis/verify_ssa.cc similarity index 92% rename from src/tir/analysis/verify_ssa.cc rename to src/tirx/analysis/verify_ssa.cc index b8fb99d701e8..f63e5a47c896 100644 --- a/src/tir/analysis/verify_ssa.cc +++ b/src/tirx/analysis/verify_ssa.cc @@ -25,16 +25,16 @@ */ #include #include -#include -#include -#include +#include +#include +#include #include #include #include namespace tvm { -namespace tir { +namespace tirx { class SSAVerifier final : public StmtExprVisitor { public: @@ -142,7 +142,7 @@ bool VerifySSA(const PrimFunc& func) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.analysis.verify_ssa", VerifySSA); + refl::GlobalDef().def("tirx.analysis.verify_ssa", VerifySSA); } namespace transform { @@ -157,15 +157,15 @@ Pass VerifySSA() { } return mod; }; - return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifySSA", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "tirx.VerifySSA", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.VerifySSA", VerifySSA); + refl::GlobalDef().def("tirx.transform.VerifySSA", VerifySSA); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tirx/analysis/verify_well_formed.cc similarity index 96% rename from src/tir/analysis/verify_well_formed.cc rename to src/tirx/analysis/verify_well_formed.cc index 5d8a2d577842..d857767257d3 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tirx/analysis/verify_well_formed.cc @@ -18,14 +18,14 @@ */ /*! - * \file tir/analysis/verify_well_formed.cc - * \brief Check if schedulable tir is well-formed. + * \file tirx/analysis/verify_well_formed.cc + * \brief Check if schedulable tirx is well-formed. */ #include #include -#include -#include +#include +#include #include #include @@ -37,7 +37,7 @@ #include "tvm/ir/module.h" namespace tvm { -namespace tir { +namespace tirx { using AccessPath = ffi::reflection::AccessPath; @@ -373,13 +373,13 @@ class UndefinedBufferVerifier : public Verifier { std::unordered_map previously_defined_; }; -/* \brief Verify unique tir::Var for each environment thread +/* \brief Verify unique tirx::Var for each environment thread * * Environment threads, such as CUDA's `threadIdx.x`, are defined in * TIR using an `AttrStmt` with the key `attr::thread_extent`. A * `PrimFunc` may contain multiple such attributes for the same * environment thread. However, all such attributes must use the same - * `tir::Var` for a given thread. + * `tirx::Var` for a given thread. */ class SingleEnvThreadVerifier : public Verifier { public: @@ -398,8 +398,8 @@ class SingleEnvThreadVerifier : public Verifier { Verify(prev_var.same_as(iter_var->var)) << "PrimFunc uses multiple distinct TIR variables " << " for the environment thread \"" << iter_var->thread_tag << "\". " - << "While multiple tir::AttrStmt may define the same environment thread, " - << "all definitions within a single PrimFunc must share the same tir::Var. " + << "While multiple tirx::AttrStmt may define the same environment thread, " + << "all definitions within a single PrimFunc must share the same tirx::Var. " << "Binding of environment thread \"" << iter_var->thread_tag << "\" to the TIR variable " << iter_var->var << " at " << path << " conflicts with the previous binding to the TIR variable " << prev_var << " at " @@ -446,7 +446,7 @@ bool VerifyWellFormed(const IRModule& mod, bool assert_mode) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "tir.analysis.VerifyWellFormed", [](const ObjectRef& obj, bool assert_mode) { + "tirx.analysis.VerifyWellFormed", [](const ObjectRef& obj, bool assert_mode) { if (auto opt = obj.as()) { return VerifyWellFormed(opt.value(), assert_mode); } else if (auto opt = obj.as()) { @@ -459,5 +459,5 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/ir/buffer.cc b/src/tirx/ir/buffer.cc similarity index 93% rename from src/tir/ir/buffer.cc rename to src/tirx/ir/buffer.cc index c1ea74f3d2ea..8a8f81068cdd 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tirx/ir/buffer.cc @@ -24,11 +24,11 @@ #include #include #include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include #include @@ -37,12 +37,12 @@ #include "../../arith/pattern_match.h" namespace tvm { -namespace tir { +namespace tirx { TVM_FFI_STATIC_INIT_BLOCK() { BufferNode::RegisterReflection(); } -using IndexMod = tir::FloorModNode; -using IndexDiv = tir::FloorDivNode; +using IndexMod = tirx::FloorModNode; +using IndexDiv = tirx::FloorDivNode; ffi::Array SimplifyArray(arith::Analyzer* ana, ffi::Array array) { for (size_t i = 0; i < array.size(); ++i) { @@ -62,7 +62,7 @@ Buffer decl_buffer(ffi::Array shape, DataType dtype, ffi::String name, // Split the given expression w.r.t the add operator inline std::vector ExprSplitAddition(const PrimExpr& expr) { - using namespace tir; + using namespace tirx; std::vector ret; std::stack split_buffer; split_buffer.push(&expr); @@ -92,7 +92,7 @@ inline std::pair MergeMulModInner(arith::Analyzer* analyzer, const PrimExpr& mult_expr, const PrimExpr& mod_l_expr, const PrimExpr& mod_r_expr) { - using namespace tir; + using namespace tirx; const MulNode* mult_ptr = mult_expr.as(); if (!mult_ptr) return std::make_pair(false, PrimExpr()); PrimExpr mult_outer = mult_ptr->b; @@ -116,7 +116,7 @@ inline std::pair MergeMulModInner(arith::Analyzer* analyzer, const PrimExpr* search_ptr = inner; PrimExpr mult_inner; // The inner multiplication factor PrimExpr no_opt_sum; // Sum of the exprs that cannot be optimized - tir::ExprDeepEqual expr_equal; + tirx::ExprDeepEqual expr_equal; while (true) { auto inner_div_ptr = search_ptr->as(); @@ -160,7 +160,7 @@ inline void MergeMulModInsertElements(const std::vector& eles, std::list* mult_exprs, std::list>* mod_exprs, PrimExpr* no_opt_sum, bool* has_mult, bool* has_mod) { - using namespace tir; + using namespace tirx; *has_mult = false; *has_mod = false; for (const PrimExpr* ele : eles) { @@ -186,7 +186,7 @@ inline void MergeMulModInsertElements(const std::vector& eles, // Return: a pair with (false, Expr()) if cannot be optimized. // a pair with (true, optimized_expr) if can be optimized inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) { - using namespace tir; + using namespace tirx; // 1. Prepare the lists. // We store two lists, a list that contain all the elements that match Mul and // a list that contain all the elements that match Mod. @@ -335,7 +335,7 @@ inline ffi::Array BufferOffset(const BufferNode* n, ffi::Arrayelem_offset + offset; if (content_lanes > 1) { - e_dtype = tir::TypeAnnotation(self->dtype.with_lanes(content_lanes)); + e_dtype = tirx::TypeAnnotation(self->dtype.with_lanes(content_lanes)); extent = extent / make_const(self->elem_offset.dtype(), content_lanes); elem_offset = self->elem_offset / make_const(self->elem_offset.dtype(), content_lanes); } else { - e_dtype = tir::TypeAnnotation(self->dtype); + e_dtype = tirx::TypeAnnotation(self->dtype); } if (input_extent.defined()) { @@ -556,7 +556,7 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane } ffi::Array acc_args{e_dtype, self->data, elem_offset, extent, make_const(DataType::Int(32), access_mask)}; - return tir::Call(ptr_type, tir::builtin::tvm_access_ptr(), acc_args); + return tirx::Call(ptr_type, tirx::builtin::tvm_access_ptr(), acc_args); } Buffer::Buffer(Var data, DataType dtype, ffi::Array shape, ffi::Array strides, @@ -615,37 +615,37 @@ Buffer::Buffer(Var data, DataType dtype, ffi::Array shape, ffi::Array< data_ = std::move(n); } -tir::Buffer BufferWithOffsetAlignment(ffi::Array shape, DataType dtype, std::string name, - int data_alignment, int offset_factor, bool compact, - std::string memory_scope) { +tirx::Buffer BufferWithOffsetAlignment(ffi::Array shape, DataType dtype, std::string name, + int data_alignment, int offset_factor, bool compact, + std::string memory_scope) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); - auto data = tir::Var(name, PointerType(PrimType(storage_dtype), memory_scope)); + auto data = tirx::Var(name, PointerType(PrimType(storage_dtype), memory_scope)); bool has_any = false; if (!compact) { for (const auto& it : shape) { - if (it.as()) { + if (it.as()) { has_any = true; break; } } } - tir::BufferType buffer_type = has_any ? tir::kAutoBroadcast : tir::kDefault; + tirx::BufferType buffer_type = has_any ? tirx::kAutoBroadcast : tirx::kDefault; PrimExpr elem_offset; if (offset_factor != 0) { - elem_offset = tir::Var(name + "_elem_offset", shape[0].dtype()); + elem_offset = tirx::Var(name + "_elem_offset", shape[0].dtype()); } else { elem_offset = PrimExpr(); } - return tir::Buffer(data, dtype, shape, ffi::Array(), elem_offset, name, data_alignment, - offset_factor, buffer_type); + return tirx::Buffer(data, dtype, shape, ffi::Array(), elem_offset, name, data_alignment, + offset_factor, buffer_type); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def_packed("tir.Buffer", + .def_packed("tirx.Buffer", [](ffi::PackedArgs args, ffi::Any* ret) { TVM_FFI_ICHECK_EQ(args.size(), 11); auto buffer_type = args[8].cast(); @@ -663,13 +663,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { *ret = Buffer(data, dtype, shape, strides, elem_offset, name, data_alignment, offset_factor, type, axis_separators, span); }) - .def_method("tir.BufferAccessPtr", &Buffer::access_ptr) - .def_method("tir.BufferGetFlattenedBuffer", &Buffer::GetFlattenedBuffer) - .def_method("tir.BufferOffsetOf", &Buffer::OffsetOf) - .def_method("tir.BufferVLoad", &Buffer::vload) - .def_method("tir.BufferVStore", &Buffer::vstore) - .def_method("tir.BufferStorageScope", &Buffer::scope); + .def_method("tirx.BufferAccessPtr", &Buffer::access_ptr) + .def_method("tirx.BufferGetFlattenedBuffer", &Buffer::GetFlattenedBuffer) + .def_method("tirx.BufferOffsetOf", &Buffer::OffsetOf) + .def_method("tirx.BufferVLoad", &Buffer::vload) + .def_method("tirx.BufferVStore", &Buffer::vstore) + .def_method("tirx.BufferStorageScope", &Buffer::scope); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/ir/buffer_common.h b/src/tirx/ir/buffer_common.h similarity index 95% rename from src/tir/ir/buffer_common.h rename to src/tirx/ir/buffer_common.h index 5921c54d985e..b6aebba2d327 100644 --- a/src/tir/ir/buffer_common.h +++ b/src/tirx/ir/buffer_common.h @@ -17,7 +17,7 @@ * under the License. */ /*! - * \file tir/ir/buffer_common.h + * \file tirx/ir/buffer_common.h * \brief Common utils for buffer access */ #ifndef TVM_TIR_IR_BUFFER_COMMON_H_ @@ -29,7 +29,7 @@ #include namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Returns the type of object pointed to. @@ -52,6 +52,6 @@ inline std::optional GetPointerType(const Type& type) { return std::nullopt; } -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_IR_BUFFER_COMMON_H_ diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tirx/ir/data_type_rewriter.cc similarity index 98% rename from src/tir/ir/data_type_rewriter.cc rename to src/tirx/ir/data_type_rewriter.cc index 37ae4f70b2cc..229baf6070be 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tirx/ir/data_type_rewriter.cc @@ -25,20 +25,20 @@ #include "data_type_rewriter.h" #include -#include -#include +#include +#include #include #include #include "./functor_common.h" #include "tvm/ir/expr.h" -#include "tvm/tir/expr.h" -#include "tvm/tir/stmt.h" -#include "tvm/tir/var.h" +#include "tvm/tirx/expr.h" +#include "tvm/tirx/stmt.h" +#include "tvm/tirx/var.h" namespace tvm { -namespace tir { +namespace tirx { Stmt DataTypeLegalizer::VisitStmt_(const ForNode* op) { Stmt s = StmtExprMutator::VisitStmt_(op); @@ -231,7 +231,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) { Call before = ffi::GetRef(op); PrimExpr e = StmtExprMutator::VisitExpr_(op); op = e.as(); - static const Op& builtin_pow_ = Op::Get("tir.pow"); + static const Op& builtin_pow_ = Op::Get("tirx.pow"); TVM_FFI_ICHECK(op != nullptr) << "Expected type to be CallNode" << ", but get " << e->GetTypeKey(); if (op->op.same_as(builtin::shift_right())) { @@ -248,7 +248,7 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) { return pow(op->args[0], op->args[1]); } else if (op->op.same_as(builtin::if_then_else())) { return if_then_else(op->args[0], op->args[1], op->args[2]); - } else if (op->op.same_as(Op::Get("tir.clz"))) { + } else if (op->op.same_as(Op::Get("tirx.clz"))) { DataType before_dtype = before->args[0]->dtype; DataType after_dtype = op->args[0]->dtype; TVM_FFI_ICHECK((before_dtype.is_int() || before_dtype.is_uint()) && @@ -653,5 +653,5 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const CastNode* op) { return IndexDataTypeRewriter::VisitExpr_(op); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/ir/data_type_rewriter.h b/src/tirx/ir/data_type_rewriter.h similarity index 98% rename from src/tir/ir/data_type_rewriter.h rename to src/tirx/ir/data_type_rewriter.h index e19c555c6ed0..1bea362f6283 100644 --- a/src/tir/ir/data_type_rewriter.h +++ b/src/tirx/ir/data_type_rewriter.h @@ -24,12 +24,12 @@ #ifndef TVM_TIR_IR_DATA_TYPE_REWRITER_H_ #define TVM_TIR_IR_DATA_TYPE_REWRITER_H_ -#include +#include #include namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Legalize the data types of expressions to make sure they are consistent with other @@ -158,7 +158,7 @@ class IndexDataTypeNormalizer : public IndexDataTypeRewriter { DataType target_data_type_ = DataType::Int(64); }; -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_IR_DATA_TYPE_REWRITER_H_ diff --git a/src/tir/ir/expr.cc b/src/tirx/ir/expr.cc similarity index 90% rename from src/tir/ir/expr.cc rename to src/tirx/ir/expr.cc index 77fccc040f43..8b841e18dc45 100644 --- a/src/tir/ir/expr.cc +++ b/src/tirx/ir/expr.cc @@ -22,10 +22,10 @@ */ #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include @@ -34,7 +34,7 @@ #include "buffer_common.h" namespace tvm { -namespace tir { +namespace tirx { TVM_FFI_STATIC_INIT_BLOCK() { VarNode::RegisterReflection(); @@ -82,7 +82,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { */ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.convert", + refl::GlobalDef().def("tirx.convert", [](ffi::Variant> expr) { return expr; }); } @@ -168,7 +168,7 @@ Var Var::copy_with_dtype(DataType dtype) const { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Var", [](ffi::String name_hint, ffi::AnyView type, Span span) { + refl::GlobalDef().def("tirx.Var", [](ffi::String name_hint, ffi::AnyView type, Span span) { if (type.as()) { return Var(name_hint, type.cast(), span); } else { @@ -198,7 +198,7 @@ SizeVar::SizeVar(ffi::String name_hint, Type type_annotation, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.SizeVar", + refl::GlobalDef().def("tirx.SizeVar", [](ffi::String s, DataType t, Span span) { return SizeVar(s, t, span); }); } @@ -225,7 +225,7 @@ IterVar::IterVar(Range dom, Var var, IterVarType t, ffi::String thread_tag, Span TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "tir.IterVar", [](Range dom, Var var, int iter_type, ffi::String thread_tag, Span span) { + "tirx.IterVar", [](Range dom, Var var, int iter_type, ffi::String thread_tag, Span span) { return IterVar(dom, var, static_cast(iter_type), thread_tag, span); }); } @@ -241,7 +241,7 @@ StringImm::StringImm(ffi::String value, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.StringImm", + refl::GlobalDef().def("tirx.StringImm", [](ffi::String value, Span span) { return StringImm(value, span); }); } @@ -259,7 +259,7 @@ Cast::Cast(DataType t, PrimExpr value, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Cast", [](DataType dtype, PrimExpr value, Span span) { + refl::GlobalDef().def("tirx.Cast", [](DataType dtype, PrimExpr value, Span span) { return Cast(dtype, value, span); }); } @@ -269,7 +269,7 @@ TVM_DEFINE_BINOP_CONSTRUCTOR(Add); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Add", + refl::GlobalDef().def("tirx.Add", [](PrimExpr a, PrimExpr b, Span span) { return Add(a, b, span); }); } @@ -278,7 +278,7 @@ TVM_DEFINE_BINOP_CONSTRUCTOR(Sub); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Sub", + refl::GlobalDef().def("tirx.Sub", [](PrimExpr a, PrimExpr b, Span span) { return Sub(a, b, span); }); } @@ -287,7 +287,7 @@ TVM_DEFINE_BINOP_CONSTRUCTOR(Mul); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Mul", + refl::GlobalDef().def("tirx.Mul", [](PrimExpr a, PrimExpr b, Span span) { return Mul(a, b, span); }); } @@ -296,7 +296,7 @@ TVM_DEFINE_BINOP_CONSTRUCTOR(Div); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Div", + refl::GlobalDef().def("tirx.Div", [](PrimExpr a, PrimExpr b, Span span) { return Div(a, b, span); }); } @@ -305,7 +305,7 @@ TVM_DEFINE_BINOP_CONSTRUCTOR(Mod); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Mod", + refl::GlobalDef().def("tirx.Mod", [](PrimExpr a, PrimExpr b, Span span) { return Mod(a, b, span); }); } @@ -314,7 +314,7 @@ TVM_DEFINE_BINOP_CONSTRUCTOR(FloorDiv); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.FloorDiv", + refl::GlobalDef().def("tirx.FloorDiv", [](PrimExpr a, PrimExpr b, Span span) { return FloorDiv(a, b, span); }); } @@ -323,7 +323,7 @@ TVM_DEFINE_BINOP_CONSTRUCTOR(FloorMod); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.FloorMod", + refl::GlobalDef().def("tirx.FloorMod", [](PrimExpr a, PrimExpr b, Span span) { return FloorMod(a, b, span); }); } @@ -332,7 +332,7 @@ TVM_DEFINE_BINOP_CONSTRUCTOR(Min); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Min", + refl::GlobalDef().def("tirx.Min", [](PrimExpr a, PrimExpr b, Span span) { return Min(a, b, span); }); } @@ -341,7 +341,7 @@ TVM_DEFINE_BINOP_CONSTRUCTOR(Max); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Max", + refl::GlobalDef().def("tirx.Max", [](PrimExpr a, PrimExpr b, Span span) { return Max(a, b, span); }); } @@ -350,7 +350,8 @@ TVM_DEFINE_CMPOP_CONSTRUCTOR(EQ); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.EQ", [](PrimExpr a, PrimExpr b, Span span) { return EQ(a, b, span); }); + refl::GlobalDef().def("tirx.EQ", + [](PrimExpr a, PrimExpr b, Span span) { return EQ(a, b, span); }); } // NE @@ -358,7 +359,8 @@ TVM_DEFINE_CMPOP_CONSTRUCTOR(NE); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.NE", [](PrimExpr a, PrimExpr b, Span span) { return NE(a, b, span); }); + refl::GlobalDef().def("tirx.NE", + [](PrimExpr a, PrimExpr b, Span span) { return NE(a, b, span); }); } // LT @@ -366,7 +368,8 @@ TVM_DEFINE_CMPOP_CONSTRUCTOR(LT); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.LT", [](PrimExpr a, PrimExpr b, Span span) { return LT(a, b, span); }); + refl::GlobalDef().def("tirx.LT", + [](PrimExpr a, PrimExpr b, Span span) { return LT(a, b, span); }); } // LE @@ -374,7 +377,8 @@ TVM_DEFINE_CMPOP_CONSTRUCTOR(LE); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.LE", [](PrimExpr a, PrimExpr b, Span span) { return LE(a, b, span); }); + refl::GlobalDef().def("tirx.LE", + [](PrimExpr a, PrimExpr b, Span span) { return LE(a, b, span); }); } // GT @@ -382,7 +386,8 @@ TVM_DEFINE_CMPOP_CONSTRUCTOR(GT); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.GT", [](PrimExpr a, PrimExpr b, Span span) { return GT(a, b, span); }); + refl::GlobalDef().def("tirx.GT", + [](PrimExpr a, PrimExpr b, Span span) { return GT(a, b, span); }); } // GE @@ -390,7 +395,8 @@ TVM_DEFINE_CMPOP_CONSTRUCTOR(GE); TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.GE", [](PrimExpr a, PrimExpr b, Span span) { return GE(a, b, span); }); + refl::GlobalDef().def("tirx.GE", + [](PrimExpr a, PrimExpr b, Span span) { return GE(a, b, span); }); } // And @@ -412,7 +418,7 @@ And::And(PrimExpr a, PrimExpr b, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.And", + refl::GlobalDef().def("tirx.And", [](PrimExpr a, PrimExpr b, Span span) { return And(a, b, span); }); } @@ -435,7 +441,8 @@ Or::Or(PrimExpr a, PrimExpr b, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Or", [](PrimExpr a, PrimExpr b, Span span) { return Or(a, b, span); }); + refl::GlobalDef().def("tirx.Or", + [](PrimExpr a, PrimExpr b, Span span) { return Or(a, b, span); }); } // Not @@ -453,7 +460,7 @@ Not::Not(PrimExpr a, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Not", [](PrimExpr a, Span span) { return Not(a, span); }); + refl::GlobalDef().def("tirx.Not", [](PrimExpr a, Span span) { return Not(a, span); }); } // Select @@ -481,7 +488,7 @@ Select::Select(PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Sp TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "tir.Select", [](PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) { + "tirx.Select", [](PrimExpr condition, PrimExpr true_value, PrimExpr false_value, Span span) { return Select(condition, true_value, false_value, span); }); } @@ -509,7 +516,7 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { TVM_FFI_ICHECK(vscale_factor) << "Invalid expression for scalable lanes " << lanes; node->dtype = base.dtype().with_scalable_vscale_factor(vscale_factor.value()); - lanes = Mul(Call(DataType::Int(32), tir::builtin::vscale(), {}), vscale_factor.value()); + lanes = Mul(Call(DataType::Int(32), tirx::builtin::vscale(), {}), vscale_factor.value()); node->lanes = lanes; } node->base = base; @@ -520,7 +527,7 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Ramp", [](PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { + refl::GlobalDef().def("tirx.Ramp", [](PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) { return Ramp(base, stride, lanes, span); }); } @@ -543,7 +550,7 @@ Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span span) { TVM_FFI_ICHECK(vscale_factor) << "Invalid expression for scalable lanes " << lanes; node->dtype = value.dtype().with_scalable_vscale_factor(vscale_factor.value()); - lanes = Mul(Call(DataType::Int(32), tir::builtin::vscale(), {}), vscale_factor.value()); + lanes = Mul(Call(DataType::Int(32), tirx::builtin::vscale(), {}), vscale_factor.value()); node->lanes = lanes; } node->value = std::move(value); @@ -553,7 +560,7 @@ Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Broadcast", [](PrimExpr value, PrimExpr lanes, Span span) { + refl::GlobalDef().def("tirx.Broadcast", [](PrimExpr value, PrimExpr lanes, Span span) { return Broadcast(value, lanes, span); }); } @@ -575,7 +582,7 @@ Let::Let(Var var, PrimExpr value, PrimExpr body, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Let", [](Var var, PrimExpr value, PrimExpr body, Span span) { + refl::GlobalDef().def("tirx.Let", [](Var var, PrimExpr value, PrimExpr body, Span span) { return Let(var, value, body, span); }); } @@ -597,7 +604,7 @@ Call::Call(DataType dtype, RelaxExpr op, ffi::Array args, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "tir.Call", + "tirx.Call", [](ffi::Optional dtype, RelaxExpr op, ffi::Array> args, Span span) { @@ -615,7 +622,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { if (is_one(r->extent)) { indices.push_back(r->min); } else if (r->extent.as()) { - indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype, 1), r->extent)); + indices.push_back(tirx::Ramp(r->min, make_const(r->min->dtype, 1), r->extent)); } else { TVM_FFI_THROW(ValueError) << "Cannot convert to BufferLoad: " << ffi::GetRef(br); @@ -673,7 +680,7 @@ PrimExpr Shuffle::ExtractElement(PrimExpr vector, int index, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Shuffle", + refl::GlobalDef().def("tirx.Shuffle", [](ffi::Array vectors, ffi::Array indices, Span span) { return Shuffle(vectors, indices, span); }); @@ -738,11 +745,11 @@ ffi::Array CommReducerNode::operator()(ffi::Array a, TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.CommReducer", + .def("tirx.CommReducer", [](ffi::Array lhs, ffi::Array rhs, ffi::Array result, ffi::Array identity_element, Span span) { return CommReducer(lhs, rhs, result, identity_element, span); }) - .def_method("tir.CommReducerCombine", &tir::CommReducerNode::operator()); + .def_method("tirx.CommReducerCombine", &tirx::CommReducerNode::operator()); } // Reduce @@ -784,8 +791,8 @@ Reduce::Reduce(CommReducer combiner, ffi::Array source, ffi::Array source, ffi::Array axis, - PrimExpr condition, int value_index, ffi::Array init, Span span) { + "tirx.Reduce", [](CommReducer combiner, ffi::Array source, ffi::Array axis, + PrimExpr condition, int value_index, ffi::Array init, Span span) { return Reduce(combiner, source, axis, condition, value_index, init, span); }); } @@ -859,8 +866,8 @@ BufferLoad::BufferLoad(Buffer buffer, ffi::Array indices, TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.BufferLoad", [](Buffer buffer, ffi::Array indices, - ffi::Optional predicate, Span span) { + refl::GlobalDef().def("tirx.BufferLoad", [](Buffer buffer, ffi::Array indices, + ffi::Optional predicate, Span span) { return BufferLoad(buffer, indices, predicate, span); }); } @@ -877,11 +884,11 @@ ProducerLoad::ProducerLoad(DataProducer producer, ffi::Array indices, TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.ProducerLoad", + refl::GlobalDef().def("tirx.ProducerLoad", [](DataProducer producer, ffi::Array indices, Span span) { return ProducerLoad(producer, indices, span); }); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/ir/expr_functor.cc b/src/tirx/ir/expr_functor.cc similarity index 99% rename from src/tir/ir/expr_functor.cc rename to src/tirx/ir/expr_functor.cc index 19277d1013c1..c6abfd021e42 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tirx/ir/expr_functor.cc @@ -19,12 +19,12 @@ /*! * \file expr_functor.cc */ -#include +#include #include "functor_common.h" namespace tvm { -namespace tir { +namespace tirx { void ExprVisitor::VisitExpr_(const VarNode* op) {} @@ -283,5 +283,5 @@ PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) { } } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/ir/function.cc b/src/tirx/ir/function.cc similarity index 89% rename from src/tir/ir/function.cc rename to src/tirx/ir/function.cc index 1516bd8aa00e..b817ccfd7328 100644 --- a/src/tir/ir/function.cc +++ b/src/tirx/ir/function.cc @@ -18,18 +18,18 @@ */ /*! - * \file src/tir/ir/function.cc + * \file src/tirx/ir/function.cc * \brief The function data structure. */ #include #include #include #include -#include -#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { TVM_FFI_STATIC_INIT_BLOCK() { PrimFuncNode::RegisterReflection(); @@ -75,8 +75,8 @@ relax::StructInfo InferStructInfo(const PrimFunc& prim_func) { } // namespace // Get the function type of a PrimFunc -PrimFunc::PrimFunc(ffi::Array params, Stmt body, Type ret_type, - ffi::Map buffer_map, DictAttrs attrs, Span span) { +PrimFunc::PrimFunc(ffi::Array params, Stmt body, Type ret_type, + ffi::Map buffer_map, DictAttrs attrs, Span span) { if (!attrs.defined()) { attrs = DictAttrs(); } @@ -108,7 +108,7 @@ FuncType PrimFuncNode::func_type_annotation() const { class TensorIntrinManager { public: - ffi::Map reg; + ffi::Map reg; static TensorIntrinManager* Global() { static TensorIntrinManager* inst = new TensorIntrinManager(); @@ -162,17 +162,17 @@ ffi::Optional TensorIntrin::Get(ffi::String name, bool allow_missi TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.PrimFunc", - [](ffi::Array params, Stmt body, Type ret_type, - ffi::Map buffer_map, DictAttrs attrs, + .def("tirx.PrimFunc", + [](ffi::Array params, Stmt body, Type ret_type, + ffi::Map buffer_map, DictAttrs attrs, Span span) { return PrimFunc(params, body, ret_type, buffer_map, attrs, span); }) - .def("tir.TensorIntrin", + .def("tirx.TensorIntrin", [](PrimFunc desc_func, PrimFunc intrin_func) { return TensorIntrin(desc_func, intrin_func); }) - .def("tir.TensorIntrinRegister", TensorIntrin::Register) - .def("tir.TensorIntrinGet", TensorIntrin::Get); + .def("tirx.TensorIntrinRegister", TensorIntrin::Register) + .def("tirx.TensorIntrinGet", TensorIntrin::Get); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/ir/functor_common.h b/src/tirx/ir/functor_common.h similarity index 95% rename from src/tir/ir/functor_common.h rename to src/tirx/ir/functor_common.h index c9f21b1b38ec..96bb8af2c36a 100644 --- a/src/tir/ir/functor_common.h +++ b/src/tirx/ir/functor_common.h @@ -19,14 +19,14 @@ #include /*! - * \file tir/ir/functor_common.h + * \file tirx/ir/functor_common.h * \brief Common utils for implementing functors */ #ifndef TVM_TIR_IR_FUNCTOR_COMMON_H_ #define TVM_TIR_IR_FUNCTOR_COMMON_H_ namespace tvm { -namespace tir { +namespace tirx { // Implementation of Visitors template @@ -41,6 +41,6 @@ inline ffi::Array MutateArray(ffi::Array arr, F fmutate) { return arr.Map(fmutate); } -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_IR_FUNCTOR_COMMON_H_ diff --git a/src/tir/ir/index_map.cc b/src/tirx/ir/index_map.cc similarity index 97% rename from src/tir/ir/index_map.cc rename to src/tirx/ir/index_map.cc index 1c48d5fd1cc3..506a149dca86 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tirx/ir/index_map.cc @@ -26,14 +26,14 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include namespace tvm { -namespace tir { +namespace tirx { TVM_FFI_STATIC_INIT_BLOCK() { IndexMapNode::RegisterReflection(); } @@ -408,7 +408,7 @@ ffi::String IndexMapNode::ToPythonString( std::string inverse_lambda_expr = IndexMap2PythonLambdaExpr(inverse->initial_indices, inverse->final_indices); std::ostringstream oss; - oss << "tvm.tir.IndexMap.from_func(" << lambda_expr + oss << "tvm.tirx.IndexMap.from_func(" << lambda_expr << ", inverse_index_map=" << inverse_lambda_expr << ")"; return ffi::String(oss.str()); } @@ -427,29 +427,29 @@ IndexMap Substitute(const IndexMap& index_map, TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.IndexMap", + .def("tirx.IndexMap", [](ffi::Array initial_indices, ffi::Array final_indices, ffi::Optional inverse_index_map) { return IndexMap(initial_indices, final_indices, inverse_index_map); }) - .def("tir.IndexMapMapIndices", + .def("tirx.IndexMapMapIndices", [](IndexMap map, ffi::Array indices) { arith::Analyzer analyzer; return map->MapIndices(indices, &analyzer); }) - .def("tir.IndexMapMapShape", + .def("tirx.IndexMapMapShape", [](IndexMap map, ffi::Array shape) { arith::Analyzer analyzer; return map->MapShape(shape, &analyzer); }) - .def("tir.IndexMapInverse", + .def("tirx.IndexMapInverse", [](IndexMap map, ffi::Array initial_ranges) { arith::Analyzer analyzer; return map.Inverse(initial_ranges, &analyzer); }) - .def("tir.IndexMapMapTensor", + .def("tirx.IndexMapMapTensor", [](IndexMap map, runtime::Tensor arr) { return map->MapTensor(arr); }) - .def("tir.IndexMapNonSurjectiveInverse", + .def("tirx.IndexMapNonSurjectiveInverse", [](IndexMap forward, ffi::Array initial_ranges) { arith::Analyzer analyzer; auto result = forward.NonSurjectiveInverse(initial_ranges, &analyzer); @@ -457,5 +457,5 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/ir/py_functor.cc b/src/tirx/ir/py_functor.cc similarity index 98% rename from src/tir/ir/py_functor.cc rename to src/tirx/ir/py_functor.cc index b385922cb950..c6835727d559 100644 --- a/src/tir/ir/py_functor.cc +++ b/src/tirx/ir/py_functor.cc @@ -18,17 +18,17 @@ */ /*! - * \file src/tir/ir/py_functor.cc + * \file src/tirx/ir/py_functor.cc * \brief The python interface of ExprVisitor/ExprMutator, StmtVisitor/StmtMutator, * StmtExprVisitor/StmtExprMutator. */ #include -#include -#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { // ================================================ // Helper Macros @@ -216,7 +216,7 @@ class PyStmtExprVisitorNode : public Object, public StmtExprVisitor { } static constexpr const bool _type_mutable = true; - TVM_FFI_DECLARE_OBJECT_INFO("tir.PyStmtExprVisitor", PyStmtExprVisitorNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("tirx.PyStmtExprVisitor", PyStmtExprVisitorNode, Object); private: // Statement functions @@ -571,7 +571,7 @@ class PyStmtExprMutatorNode : public Object, public StmtExprMutator { } static constexpr const bool _type_mutable = true; - TVM_FFI_DECLARE_OBJECT_INFO("tir.PyStmtExprMutator", PyStmtExprMutatorNode, Object); + TVM_FFI_DECLARE_OBJECT_INFO("tirx.PyStmtExprMutator", PyStmtExprMutatorNode, Object); private: // Statement functions @@ -812,21 +812,21 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.MakePyStmtExprVisitor", PyStmtExprVisitor::MakePyStmtExprVisitor) - .def("tir.MakePyStmtExprMutator", PyStmtExprMutator::MakePyStmtExprMutator); + .def("tirx.MakePyStmtExprVisitor", PyStmtExprVisitor::MakePyStmtExprVisitor) + .def("tirx.MakePyStmtExprMutator", PyStmtExprMutator::MakePyStmtExprMutator); } // StmtExprVisitor TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.PyStmtExprVisitorDefaultVisitExpr", + .def("tirx.PyStmtExprVisitorDefaultVisitExpr", [](PyStmtExprVisitor visitor, const PrimExpr& expr) { visitor->DefaultVisitExpr(expr); }) - .def("tir.PyStmtExprVisitorDefaultVisitStmt", + .def("tirx.PyStmtExprVisitorDefaultVisitStmt", [](PyStmtExprVisitor visitor, const Stmt& stmt) { visitor->DefaultVisitStmt(stmt); }) - .def("tir.PyStmtExprVisitorVisitStmt", + .def("tirx.PyStmtExprVisitorVisitStmt", [](PyStmtExprVisitor visitor, const Stmt& stmt) { visitor->VisitStmt(stmt); }) - .def("tir.PyStmtExprVisitorVisitExpr", + .def("tirx.PyStmtExprVisitorVisitExpr", [](PyStmtExprVisitor visitor, const PrimExpr& expr) { visitor->VisitExpr(expr); }); } @@ -834,19 +834,19 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.PyStmtExprMutatorDefaultVisitExpr", + .def("tirx.PyStmtExprMutatorDefaultVisitExpr", [](PyStmtExprMutator mutator, const PrimExpr& expr) { return mutator->DefaultVisitExpr(expr); }) - .def("tir.PyStmtExprMutatorDefaultVisitStmt", + .def("tirx.PyStmtExprMutatorDefaultVisitStmt", [](PyStmtExprMutator mutator, const Stmt& stmt) { return mutator->DefaultVisitStmt(stmt); }) - .def("tir.PyStmtExprMutatorVisitExpr", + .def("tirx.PyStmtExprMutatorVisitExpr", [](PyStmtExprMutator mutator, const PrimExpr& expr) { return mutator->VisitExpr(expr); }) - .def("tir.PyStmtExprMutatorVisitStmt", + .def("tirx.PyStmtExprMutatorVisitStmt", [](PyStmtExprMutator mutator, const Stmt& stmt) { return mutator->VisitStmt(stmt); }); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/ir/script/script_complete.cc b/src/tirx/ir/script/script_complete.cc similarity index 96% rename from src/tir/ir/script/script_complete.cc rename to src/tirx/ir/script/script_complete.cc index 1f516f8aa71a..9be44a07dbfb 100644 --- a/src/tir/ir/script/script_complete.cc +++ b/src/tirx/ir/script/script_complete.cc @@ -18,7 +18,7 @@ */ /*! - * \file tir/ir/script/script_complete.cc + * \file tirx/ir/script/script_complete.cc * \brief Used by TVM Script parser to expand incomplete TIR input */ @@ -27,12 +27,12 @@ #include #include #include -#include +#include #include namespace tvm { -namespace tir { +namespace tirx { /*! \brief Generate surrounding loops automatically */ class ScriptCompleter : public StmtMutator { @@ -87,7 +87,7 @@ class ScriptCompleter : public StmtMutator { const ffi::Array& writes = access_region[1]; const ffi::Array& opaque = access_region[2]; TVM_FFI_CHECK(opaque.empty(), ValueError) - << "Can not auto detect buffer access region from tir.Load, tir.Store or " + << "Can not auto detect buffer access region from tirx.Load, tirx.Store or " "direct access by buffer data. Please annotation the access region manually"; auto n = CopyOnWrite(block.operator->()); if (!is_root_block) { @@ -173,5 +173,5 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("script.Complete", ScriptComplete); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/ir/script/script_complete.h b/src/tirx/ir/script/script_complete.h similarity index 89% rename from src/tir/ir/script/script_complete.h rename to src/tirx/ir/script/script_complete.h index 1facab664346..d49d1f73750b 100644 --- a/src/tir/ir/script/script_complete.h +++ b/src/tirx/ir/script/script_complete.h @@ -18,20 +18,20 @@ */ /*! - * \file tir/ir/script/script_complete.h + * \file tirx/ir/script/script_complete.h * \brief Used by TVM Script parser to expand incomplete TIR input */ #ifndef TVM_TIR_IR_SCRIPT_SCRIPT_COMPLETE_H_ #define TVM_TIR_IR_SCRIPT_SCRIPT_COMPLETE_H_ #include -#include -#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { PrimFunc ScriptComplete(PrimFunc func, const ffi::Array& root_allocates); -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_IR_SCRIPT_SCRIPT_COMPLETE_H_ diff --git a/src/tir/ir/specialize.cc b/src/tirx/ir/specialize.cc similarity index 96% rename from src/tir/ir/specialize.cc rename to src/tirx/ir/specialize.cc index b3e9f795e4f9..3dd99cc81728 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tirx/ir/specialize.cc @@ -18,15 +18,15 @@ */ /*! - * \file src/tir/ir/specialize.cc + * \file src/tirx/ir/specialize.cc * \brief Specialize parameters of PrimFunc. */ #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include @@ -34,7 +34,7 @@ #include "functor_common.h" namespace tvm { -namespace tir { +namespace tirx { using VarMap = std::unordered_map; @@ -153,8 +153,8 @@ class PrimFuncSpecializer : public StmtExprMutator { } // If the buffer variable is being remapped to an expression, we - // still need a tir::Var to be used as a the buffer variable. - // Therefore, generate a Bind that will provide a tir::Var for + // still need a tirx::Var to be used as a the buffer variable. + // Therefore, generate a Bind that will provide a tirx::Var for // the buffer to use. // // This step is only required when a buffer definition is using a @@ -269,7 +269,7 @@ class PrimFuncSpecializer : public StmtExprMutator { << "(see discussion on https://github.com/apache/tvm/pull/14565 for more details). " << "Please add a definition for this buffer, " << "either in the PrimFunc's buffer_map, " - << "in a tir::SBlock's alloc_buffer, " + << "in a tirx::SBlock's alloc_buffer, " << "or in a DeclBuffer statement."; return old_buffer; @@ -304,7 +304,7 @@ class PrimFuncSpecializer : public StmtExprMutator { * For example, we define a buffer in PrimFunc: * A = T.match_buffer(a, [m, n]) * - * Then we match it with a buffer B = tir.decl_buffer((8, 16)) + * Then we match it with a buffer B = tirx.decl_buffer((8, 16)) * * It means we have two var mappings here: m = 8 and n = 16 * @@ -314,7 +314,7 @@ class PrimFuncSpecializer : public StmtExprMutator { void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer& specific_buf, VarMap* var_map) { // preliminaries - tir::ExprDeepEqual equal; + tirx::ExprDeepEqual equal; auto it = func->buffer_map.find(param); TVM_FFI_CHECK(it != func->buffer_map.end(), ValueError) @@ -410,8 +410,8 @@ PrimFunc Specialize(PrimFunc func, const ffi::Map #include #include -#include -#include -#include +#include +#include +#include #include "buffer_common.h" namespace tvm { -namespace tir { +namespace tirx { TVM_FFI_STATIC_INIT_BLOCK() { StmtNode::RegisterReflection(); @@ -72,7 +72,7 @@ Bind::Bind(Var var, PrimExpr value, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Bind", + refl::GlobalDef().def("tirx.Bind", [](Var var, PrimExpr value, Span span) { return Bind(var, value, span); }); } @@ -89,7 +89,7 @@ AttrStmt::AttrStmt(ffi::Any node, ffi::String attr_key, PrimExpr value, Stmt bod TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.AttrStmt", + refl::GlobalDef().def("tirx.AttrStmt", [](Any node, ffi::String attr_key, PrimExpr value, Stmt body, Span span) { // when node is a POD data type like int or bool, first convert to // primexpr. @@ -119,8 +119,8 @@ AssertStmt::AssertStmt(PrimExpr condition, StringImm error_kind, TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.AssertStmt", [](PrimExpr condition, StringImm error_kind, - ffi::Array message_parts, Span span) { + refl::GlobalDef().def("tirx.AssertStmt", [](PrimExpr condition, StringImm error_kind, + ffi::Array message_parts, Span span) { return AssertStmt(condition, error_kind, message_parts, span); }); } @@ -189,10 +189,10 @@ bool ForNode::HasTrivialStep() const { return !step.has_value() || is_one(*step) TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.For", [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, - Stmt body, ffi::Optional thread_binding, - ffi::Optional> annotations, - ffi::Optional step, Span span) { + refl::GlobalDef().def("tirx.For", [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, + Stmt body, ffi::Optional thread_binding, + ffi::Optional> annotations, + ffi::Optional step, Span span) { return For(loop_var, min, extent, static_cast(kind), body, thread_binding, annotations.value_or(ffi::Map()), step, span); }); @@ -234,7 +234,7 @@ While::While(PrimExpr condition, Stmt body, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.While", [](PrimExpr condition, Stmt body, Span span) { + refl::GlobalDef().def("tirx.While", [](PrimExpr condition, Stmt body, Span span) { return While(condition, body, span); }); } @@ -249,7 +249,7 @@ DeclBuffer::DeclBuffer(Buffer buffer, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.DeclBuffer", + refl::GlobalDef().def("tirx.DeclBuffer", [](Buffer buffer, Span span) { return DeclBuffer(buffer, span); }); } @@ -265,7 +265,7 @@ AllocBuffer::AllocBuffer(Buffer buffer, ffi::Map annotations, TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "tir.AllocBuffer", + "tirx.AllocBuffer", [](Buffer buffer, ffi::Optional> annotations, Span span) { return AllocBuffer(buffer, annotations.value_or(ffi::Map()), span); }); @@ -300,8 +300,9 @@ SeqStmt::SeqStmt(ffi::Array seq, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "tir.SeqStmt", [](ffi::Array seq, Span span) { return SeqStmt(std::move(seq), span); }); + refl::GlobalDef().def("tirx.SeqStmt", [](ffi::Array seq, Span span) { + return SeqStmt(std::move(seq), span); + }); } // IfThenElse @@ -320,7 +321,7 @@ IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, ffi::Optional e TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.IfThenElse", + refl::GlobalDef().def("tirx.IfThenElse", [](PrimExpr condition, Stmt then_case, Stmt else_case, Span span) { return IfThenElse(condition, then_case, else_case, span); }); @@ -338,7 +339,7 @@ Evaluate::Evaluate(PrimExpr value, Span span) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.Evaluate", + refl::GlobalDef().def("tirx.Evaluate", [](PrimExpr value, Span span) { return Evaluate(value, span); }); } @@ -420,7 +421,7 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, ffi::Array ind TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.BufferStore", + refl::GlobalDef().def("tirx.BufferStore", [](Buffer buffer, PrimExpr value, ffi::Array indices, ffi::Optional predicate, Span span) { return BufferStore(buffer, value, indices, predicate, span); @@ -433,16 +434,16 @@ PrimExpr BufferRegionNode::ToPrimExpr() const { ffi::Array indices; indices.reserve(this->region.size()); for (const Range& r : this->region) { - if (tvm::tir::is_one(r->extent)) { + if (tvm::tirx::is_one(r->extent)) { indices.push_back(r->min); } else if (r->extent.as()) { - indices.push_back(tir::Ramp(r->min, tvm::tir::make_const(r->min->dtype, 1), r->extent)); + indices.push_back(tirx::Ramp(r->min, tvm::tirx::make_const(r->min->dtype, 1), r->extent)); } else { TVM_FFI_THROW(ValueError) << "Cannot convert to BufferLoad: " << ffi::GetRef(this); } } - return tir::BufferLoad(this->buffer, indices); + return tirx::BufferLoad(this->buffer, indices); } BufferRegion::BufferRegion(Buffer buffer, ffi::Array region) { @@ -478,7 +479,7 @@ BufferRegion BufferRegion::FromPoint(Buffer buffer, ffi::Array indices TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.BufferRegion", [](Buffer buffer, ffi::Array region) { + refl::GlobalDef().def("tirx.BufferRegion", [](Buffer buffer, ffi::Array region) { return BufferRegion(buffer, region); }); } @@ -536,7 +537,7 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.MatchBufferRegion", [](Buffer buffer, BufferRegion source) { + refl::GlobalDef().def("tirx.MatchBufferRegion", [](Buffer buffer, BufferRegion source) { return MatchBufferRegion(buffer, source); }); } @@ -563,7 +564,7 @@ SBlock::SBlock(ffi::Array iter_vars, ffi::Array reads, TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.SBlock", + refl::GlobalDef().def("tirx.SBlock", [](ffi::Array iter_vars, ffi::Array reads, ffi::Array writes, ffi::String name_hint, Stmt body, ffi::Optional init, ffi::Array alloc_buffers, @@ -591,15 +592,15 @@ SBlockRealize::SBlockRealize(ffi::Array values, PrimExpr predicate, SB TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.SBlockRealize", [](ffi::Array iter_values, - PrimExpr predicate, SBlock block, Span span) { + refl::GlobalDef().def("tirx.SBlockRealize", [](ffi::Array iter_values, + PrimExpr predicate, SBlock block, Span span) { return SBlockRealize(iter_values, predicate, block, span); }); } PrimExpr TypeAnnotation(DataType dtype, Span span) { - static auto op = Op::Get("tir.type_annotation"); - return tir::Call(dtype, op, {}, span); + static auto op = Op::Get("tirx.type_annotation"); + return tirx::Call(dtype, op, {}, span); } TVM_TIR_REGISTER_OP("type_annotation") @@ -607,5 +608,5 @@ TVM_TIR_REGISTER_OP("type_annotation") .set_attr("TScriptDtypePrintLocation", Integer(ScriptDtypePrintLocation::kFirst)); -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/ir/stmt_functor.cc b/src/tirx/ir/stmt_functor.cc similarity index 98% rename from src/tir/ir/stmt_functor.cc rename to src/tirx/ir/stmt_functor.cc index 2db7cf5c55fc..33b89043281b 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tirx/ir/stmt_functor.cc @@ -22,8 +22,8 @@ #include #include #include -#include -#include +#include +#include #include @@ -31,7 +31,7 @@ #include "functor_common.h" namespace tvm { -namespace tir { +namespace tirx { void StmtVisitor::VisitStmt_(const BindNode* op) { // Bind has no body -- only visit the value expression. @@ -661,7 +661,7 @@ class IRSubstitute : public StmtExprMutator { TVM_FFI_ICHECK(new_data_expr->IsInstance()) << "Buffer " << new_buf << " uses backing allocation " << new_buf->data << ", which was substituted into the expression " << new_data_expr - << " and the backing allocation must be a tir::Var"; + << " and the backing allocation must be a tirx::Var"; Var data = Downcast(new_data_expr); if (!data.same_as(new_buf->data)) { auto* n = new_buf.CopyOnWrite(); @@ -786,16 +786,16 @@ PrimExpr SubstituteWithDataTypeLegalization( TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.IRTransform", IRTransform) - .def("tir.PostOrderVisit", + .def("tirx.IRTransform", IRTransform) + .def("tirx.PostOrderVisit", [](ObjectRef node, ffi::Function f) { - tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); + tirx::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); }) - .def("tir.PreOrderVisit", + .def("tirx.PreOrderVisit", [](ObjectRef node, ffi::Function f) { - tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n).cast(); }); + tirx::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n).cast(); }); }) - .def("tir.Substitute", [](ObjectRef node, ffi::Map vmap) -> ObjectRef { + .def("tirx.Substitute", [](ObjectRef node, ffi::Map vmap) -> ObjectRef { if (node->IsInstance()) { return Substitute(Downcast(node), vmap); } else { @@ -804,5 +804,5 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tirx/ir/tir_visitor_with_path.cc similarity index 99% rename from src/tir/ir/tir_visitor_with_path.cc rename to src/tirx/ir/tir_visitor_with_path.cc index 914f90c2c4dc..857ccca08eff 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tirx/ir/tir_visitor_with_path.cc @@ -18,7 +18,7 @@ */ /*! - * \file tir/ir/tir_visitor_with_path.cc + * \file tirx/ir/tir_visitor_with_path.cc * \brief Provide a TIR visitor that tracks the current location */ #include "tir_visitor_with_path.h" @@ -33,7 +33,7 @@ #include namespace tvm { -namespace tir { +namespace tirx { using AccessPath = ffi::reflection::AccessPath; @@ -187,7 +187,7 @@ void TIRVisitorWithPath::VisitStmt_(const AttrStmtNode* op, AccessPath path) { iter_var && (op->attr_key == attr::thread_extent || op->attr_key == s_tir::attr::virtual_thread)) { // Some attributes serve as a source of definition for the - // tir::Var they annotate. + // tirx::Var they annotate. context.push_back(WithDef(iter_var.value(), path->Attr("node"))); } else if (auto expr = op->node.as()) { @@ -401,5 +401,5 @@ void TIRVisitorWithPath::VisitExpr_(const BroadcastNode* op, AccessPath path) { Visit(op->lanes, path->Attr("lanes")); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tirx/ir/tir_visitor_with_path.h similarity index 97% rename from src/tir/ir/tir_visitor_with_path.h rename to src/tirx/ir/tir_visitor_with_path.h index f64a398d807f..fbaf7a95f91e 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tirx/ir/tir_visitor_with_path.h @@ -18,7 +18,7 @@ */ /*! - * \file tir/ir/tir_visitor_with_path.h + * \file tirx/ir/tir_visitor_with_path.h * \brief Provide a TIR visitor that tracks the current location */ #ifndef TVM_TIR_IR_TIR_VISITOR_WITH_PATH_H_ @@ -26,8 +26,8 @@ #include #include -#include -#include +#include +#include #include #include @@ -37,7 +37,7 @@ #include namespace tvm { -namespace tir { +namespace tirx { /*! \brief Visit TIR while tracking the ffi::reflection::AccessPath */ class TIRVisitorWithPath @@ -75,13 +75,13 @@ class TIRVisitorWithPath virtual void EnterDef(const GlobalVar& var, ffi::reflection::AccessPath path) {} virtual void ExitDef(const GlobalVar& var, ffi::reflection::AccessPath path) {} - // Called when entering/exiting the scope of a tir::Var definition. + // Called when entering/exiting the scope of a tirx::Var definition. virtual void EnterDef(const Var& var, ffi::reflection::AccessPath path) {} virtual void ExitDef(const Var& var, ffi::reflection::AccessPath path) {} // Called when entering/exiting the scope of an IterVar definition. // By default, visits the `Range IterVarNode::dom`, then enters the - // scope of the internal `tir::Var`. + // scope of the internal `tirx::Var`. virtual void EnterDef(const IterVar& var, ffi::reflection::AccessPath path); virtual void ExitDef(const IterVar& var, ffi::reflection::AccessPath path); @@ -262,6 +262,6 @@ class TIRVisitorWithPath ScopeStack> bind_scope_; }; -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_IR_TIR_VISITOR_WITH_PATH_H_ diff --git a/src/tir/ir/transform.cc b/src/tirx/ir/transform.cc similarity index 78% rename from src/tir/ir/transform.cc rename to src/tirx/ir/transform.cc index 065b4813c7a7..14b0cf6b1102 100644 --- a/src/tir/ir/transform.cc +++ b/src/tirx/ir/transform.cc @@ -18,37 +18,37 @@ */ /*! - * \file tir/ir/transform.cc + * \file tirx/ir/transform.cc * \brief TIR specific transformation passes. */ #include #include #include #include -#include +#include namespace tvm { -namespace tir { +namespace tirx { namespace transform { // Register build pipeline related options -TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", ffi::Array>); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_static_smem", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.noalias", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.detect_global_barrier", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.instrument_bound_checkers", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.disable_assert", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.disable_vectorize", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.enable_buffer_level_predication", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.disable_cse_tir", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.enable_debug", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.disable_storage_rewrite", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.is_entry_func", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.add_lower_pass", ffi::Array>); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.debug_keep_trivial_loop", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.use_async_copy", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.merge_static_smem", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.instrument_lwp", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.vtcm_capacity", Integer); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.ptx_ldg32", Bool); /*! * \brief Function level pass that applies transformations to all @@ -81,7 +81,7 @@ class PrimFuncPassNode : public PassNode { * \brief Get the pass information/meta data. */ PassInfo Info() const override { return pass_info; } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.PrimFuncPass", PrimFuncPassNode, PassNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.PrimFuncPass", PrimFuncPassNode, PassNode); }; class PrimFuncPass : public Pass { @@ -114,7 +114,7 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) auto* func_dict = mod_ptr->functions.CopyOnWrite(); // directly loop over the underlying dict for (auto& kv : *func_dict) { - // only picks up tir::PrimFunc + // only picks up tirx::PrimFunc if (auto opt_func = kv.second.as()) { // reset the original Any state so the value contains only copy // use move semantics as follows to avoid only copy. @@ -149,7 +149,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { PrimFuncPassNode::RegisterReflection(); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def( - "tir.transform.CreatePrimFuncPass", + "tirx.transform.CreatePrimFuncPass", [](ffi::TypedFunction, IRModule, PassContext)> pass_func, PassInfo pass_info) { auto wrapped_pass_func = [pass_func](PrimFunc func, IRModule mod, PassContext ctx) { @@ -167,5 +167,5 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tirx/op/builtin.cc similarity index 97% rename from src/tir/op/builtin.cc rename to src/tirx/op/builtin.cc index 5891c97d8cb0..68f9ce219bb7 100644 --- a/src/tir/op/builtin.cc +++ b/src/tirx/op/builtin.cc @@ -18,24 +18,24 @@ */ /*! - * \file tir/op/builtin.cc + * \file tirx/op/builtin.cc * * builtin intrinsic operators. */ #include -#include -#include -#include +#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { namespace builtin { -#define TIR_DEFINE_BUILTIN_FUNC(OpName) \ - const Op& OpName() { \ - static const Op& op = Op::Get("tir." #OpName); \ - return op; \ - } \ +#define TIR_DEFINE_BUILTIN_FUNC(OpName) \ + const Op& OpName() { \ + static const Op& op = Op::Get("tirx." #OpName); \ + return op; \ + } \ TVM_TIR_REGISTER_OP(#OpName) TIR_DEFINE_BUILTIN_FUNC(reinterpret) @@ -437,5 +437,5 @@ TIR_DEFINE_BUILTIN_FUNC(ignore_loop_partition) Integer(ScriptDtypePrintLocation::kNone)); } // namespace builtin -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/op/op.cc b/src/tirx/op/op.cc similarity index 82% rename from src/tir/op/op.cc rename to src/tirx/op/op.cc index 8c2b4bd85962..f1c9c8a9b507 100644 --- a/src/tir/op/op.cc +++ b/src/tirx/op/op.cc @@ -18,17 +18,17 @@ */ /*! - * \file tir/op/op.cc + * \file tirx/op/op.cc * - * Common operator definitions for ops in tir/op.h + * Common operator definitions for ops in tirx/op.h */ #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include // Centralized header for constant folders. @@ -38,7 +38,7 @@ namespace tvm { -using namespace tir; +using namespace tirx; // macro to register an unary op #define TVM_TIR_REGISTER_PURE_UNARY_OP(OpName) \ @@ -66,7 +66,7 @@ runtime::DataType GetRuntimeDataType(const Type& type) { Type GetType(const PrimExpr& expr) { // TODO(tqchen): add recursive type inference for Call here // once we introduced the corresponding fields to the IR. - if (auto* ptr = expr.as()) { + if (auto* ptr = expr.as()) { // If Var has a more refined type annotation, // return the type anotation if (ptr->type_annotation.defined()) { @@ -74,12 +74,12 @@ Type GetType(const PrimExpr& expr) { } } - if (auto* access = expr.as()) { + if (auto* access = expr.as()) { if (access->op.same_as(builtin::tvm_access_ptr())) { TVM_FFI_ICHECK(access->args.size()) << "Builtin tvm_access_ptr() may not have empty arguments"; auto type_annotation = Downcast(access->args[0]); - static auto builtin_op = Op::Get("tir.type_annotation"); + static auto builtin_op = Op::Get("tirx.type_annotation"); TVM_FFI_ICHECK(type_annotation->op.same_as(builtin_op)) << "Expected the first argument of builtin tvm_access_ptr() " << "to be a type annotation, but found " << type_annotation->op; @@ -87,7 +87,7 @@ Type GetType(const PrimExpr& expr) { } } - if (auto* address_of = expr.as()) { + if (auto* address_of = expr.as()) { if (address_of->op.same_as(builtin::address_of())) { TVM_FFI_ICHECK_EQ(address_of->args.size(), 1) << "Builtin address_of() expects a single argument, but received arguments " @@ -114,16 +114,16 @@ Type GetTypeFromRuntimeDataType(const DataType& dtype) { // LargeUIntImm PrimExpr LargeUIntImm(DataType t, int64_t low, int64_t high, Span span) { - return tir::Call( - t, tir::builtin::large_uint_imm(), + return tirx::Call( + t, tirx::builtin::large_uint_imm(), {make_const(DataType::UInt(32), low, span), make_const(DataType::UInt(32), high, span)}, span); } // Q-multiplication PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y, PrimExpr q, PrimExpr s, Span span) { - return tir::Call(DataType::Int(32, x.dtype().lanes()), tir::builtin::q_multiply_shift(), - {x, y, q, s}, span); + return tirx::Call(DataType::Int(32, x.dtype().lanes()), tirx::builtin::q_multiply_shift(), + {x, y, q, s}, span); } void BroadcastToMatchLanes(PrimExpr& op_a, PrimExpr& op_b) { // NOLINT(*) @@ -133,10 +133,10 @@ void BroadcastToMatchLanes(PrimExpr& op_a, PrimExpr& op_b) { // NOLINT(*) if (!dtype_a.is_scalable_or_fixed_length_vector() && dtype_b.is_scalable_or_fixed_length_vector()) { if (dtype_b.is_scalable_vector()) { - op_a = tir::Broadcast( - op_a, tir::Mul(dtype_b.vscale_factor(), Call(DataType::Int(32), builtin::vscale(), {}))); + op_a = tirx::Broadcast( + op_a, tirx::Mul(dtype_b.vscale_factor(), Call(DataType::Int(32), builtin::vscale(), {}))); } else { - op_a = tir::Broadcast(op_a, dtype_b.lanes()); + op_a = tirx::Broadcast(op_a, dtype_b.lanes()); } } } @@ -251,33 +251,33 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) PrimExpr ret(PrimExpr value, Span span) { TVM_FFI_ICHECK(value.defined()); - return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span); + return tirx::Call(value.dtype(), tirx::builtin::ret(), {value}, span); } PrimExpr thread_return(Span span) { - return tir::Call(DataType::Void(), tir::builtin::thread_return(), {}, span); + return tirx::Call(DataType::Void(), tirx::builtin::thread_return(), {}, span); } PrimExpr continue_loop(Span span) { - return tir::Call(DataType::Void(), tir::builtin::continue_loop(), {}, span); + return tirx::Call(DataType::Void(), tirx::builtin::continue_loop(), {}, span); } PrimExpr break_loop(Span span) { - return tir::Call(DataType::Void(), tir::builtin::break_loop(), {}, span); + return tirx::Call(DataType::Void(), tirx::builtin::break_loop(), {}, span); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.ret", ret) - .def("tir.thread_return", thread_return) - .def("tir.continue_loop", continue_loop) - .def("tir.break_loop", break_loop); + .def("tirx.ret", ret) + .def("tirx.thread_return", thread_return) + .def("tirx.continue_loop", continue_loop) + .def("tirx.break_loop", break_loop); }; // maximum and min limits PrimExpr max_value(const DataType& dtype, Span span) { - using namespace tir; + using namespace tirx; TVM_FFI_ICHECK_EQ(dtype.lanes(), 1); if (dtype.is_int()) { if (dtype.bits() == 64) { @@ -336,7 +336,7 @@ PrimExpr max_value(const DataType& dtype, Span span) { } PrimExpr min_value(const DataType& dtype, Span span) { - using namespace tir; + using namespace tirx; TVM_FFI_ICHECK_EQ(dtype.lanes(), 1); if (datatype::Registry::Global()->GetTypeRegistered(dtype.code())) { // TODO(tkonolige): need to convert all registered min functions to use the span. @@ -399,7 +399,7 @@ PrimExpr min_value(const DataType& dtype, Span span) { // infinity PrimExpr infinity(const DataType& dtype, Span span) { - using namespace tir; + using namespace tirx; TVM_FFI_ICHECK_EQ(dtype.lanes(), 1); if (dtype.is_float()) { if (dtype.bits() == 64) { @@ -411,7 +411,7 @@ PrimExpr infinity(const DataType& dtype, Span span) { TVM_FFI_THROW(InternalError) << "Cannot decide infinity for type " << dtype; } -namespace tir { +namespace tirx { template inline bool ConstPowerHelper(ValueType val, int* shift) { if (val <= 0) return false; @@ -427,16 +427,16 @@ inline bool ConstPowerHelper(ValueType val, int* shift) { } bool is_const_power_of_two_integer(const PrimExpr& x, int* shift) { - if (const auto* op = x.as()) { + if (const auto* op = x.as()) { return ConstPowerHelper(op->value, shift); } else { return false; } } -} // namespace tir +} // namespace tirx PrimExpr cast(const DataType& t, PrimExpr value, Span span) { - using tir::FloatImmNode; + using tirx::FloatImmNode; if (value.dtype() == t) return value; // const fold IntImm as they are used in index computations if (t.is_scalar()) { @@ -446,7 +446,7 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { return make_const(t, op->value, op->span); } TVM_FFI_ICHECK(!value.dtype().is_handle()) << "Can't cast a handle to other types."; - return tir::Cast(t, value, span); + return tirx::Cast(t, value, span); } else { DataType vtype = t.element_of(); if (!value.dtype().is_scalable_or_fixed_length_vector()) { @@ -457,15 +457,15 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { } else if (const FloatImmNode* op = value.as()) { value = make_const(vtype, op->value, op->span); } else { - value = tir::Cast(vtype, value, span); + value = tirx::Cast(vtype, value, span); } } if (t.is_scalable_vector()) { - return tir::Broadcast( - value, tir::Mul(t.vscale_factor(), Call(DataType::Int(32), builtin::vscale(), {})), + return tirx::Broadcast( + value, tirx::Mul(t.vscale_factor(), Call(DataType::Int(32), builtin::vscale(), {})), span); } else { - return tir::Broadcast(value, t.lanes(), span); + return tirx::Broadcast(value, t.lanes(), span); } } else { /* value is a vector */ TVM_FFI_ICHECK(value.dtype().is_scalable_vector() == t.is_scalable_vector()); @@ -477,16 +477,16 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span) { lanes_match = value.dtype().lanes() == t.lanes(); } TVM_FFI_ICHECK(lanes_match); - if (const auto* broadcast = value.as()) { - return tir::Broadcast(cast(vtype, broadcast->value, span), broadcast->lanes, span); - } else if (const auto* ramp = value.as()) { + if (const auto* broadcast = value.as()) { + return tirx::Broadcast(cast(vtype, broadcast->value, span), broadcast->lanes, span); + } else if (const auto* ramp = value.as()) { if (t.is_int() || t.is_uint()) { // only cast to index data type can be folded to ramp - return tir::Ramp(cast(vtype, ramp->base, span), cast(vtype, ramp->stride, span), - ramp->lanes, span); + return tirx::Ramp(cast(vtype, ramp->base, span), cast(vtype, ramp->stride, span), + ramp->lanes, span); } } - return tir::Cast(t, value, span); + return tirx::Cast(t, value, span); } } } @@ -500,7 +500,7 @@ PrimExpr reinterpret(const DataType& t, PrimExpr value, Span span) { value.dtype().bytes() * value.dtype().lanes() == t.bytes() * t.lanes())) << "Reinterpret requires size match " << t << " vs " << value.dtype(); } - return tir::Call(t, tir::builtin::reinterpret(), {value}, span); + return tirx::Call(t, tirx::builtin::reinterpret(), {value}, span); } // operator+ @@ -508,16 +508,16 @@ PrimExpr operator+(PrimExpr a, PrimExpr b) { return add(a, b); } PrimExpr add(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); - return tir::Add(a, b, span); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + return tirx::Add(a, b, span); } // negation PrimExpr operator-(PrimExpr a) { return neg(a); } PrimExpr neg(PrimExpr a, Span span) { - using tir::FloatImmNode; - using tir::IntImmNode; + using tirx::FloatImmNode; + using tirx::IntImmNode; const IntImmNode* pa = a.as(); const FloatImmNode* fa = a.as(); if (pa) return IntImm(a.dtype(), -pa->value, span); @@ -529,21 +529,21 @@ PrimExpr operator-(PrimExpr a, PrimExpr b) { return sub(a, b); } PrimExpr sub(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); - return tir::Sub(a, b, span); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + return tirx::Sub(a, b, span); } PrimExpr operator*(PrimExpr a, PrimExpr b) { return mul(a, b); } PrimExpr mul(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); - return tir::Mul(a, b, span); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + return tirx::Mul(a, b, span); } PrimExpr div(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); - return tir::Div(a, b, span); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + return tirx::Div(a, b, span); } PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span) { @@ -554,8 +554,8 @@ PrimExpr truncdiv(PrimExpr a, PrimExpr b, Span span) { PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); - return tir::Mod(a, b, span); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + return tirx::Mod(a, b, span); } PrimExpr operator/(PrimExpr a, PrimExpr b) { return div(a, b); } @@ -573,8 +573,8 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) { TVM_FFI_ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; TVM_FFI_ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b, span); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); - return tir::FloorDiv(a, b, span); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + return tirx::FloorDiv(a, b, span); } PrimExpr logaddexp(PrimExpr a, PrimExpr b, Span span) { @@ -590,16 +590,16 @@ PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span) { TVM_FFI_ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; TVM_FFI_ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b, span); - if (auto ret = arith::TryConstFold(a + b - 1, b)) return ret.value(); - return tir::FloorDiv(a + b - 1, b, span); + if (auto ret = arith::TryConstFold(a + b - 1, b)) return ret.value(); + return tirx::FloorDiv(a + b - 1, b, span); } PrimExpr floormod(PrimExpr a, PrimExpr b, Span span) { TVM_FFI_ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; TVM_FFI_ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b, span); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); - return tir::FloorMod(a, b, span); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + return tirx::FloorMod(a, b, span); } PrimExpr min(PrimExpr a, PrimExpr b, Span span) { @@ -611,8 +611,8 @@ PrimExpr min(PrimExpr a, PrimExpr b, Span span) { if (is_pos_inf(b)) return a; if (is_neg_inf(b)) return b; BinaryOpMatchTypes(a, b, span); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); - return tir::Min(a, b, span); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + return tirx::Min(a, b, span); } PrimExpr max(PrimExpr a, PrimExpr b, Span span) { @@ -624,8 +624,8 @@ PrimExpr max(PrimExpr a, PrimExpr b, Span span) { if (is_pos_inf(b)) return b; if (is_neg_inf(b)) return a; BinaryOpMatchTypes(a, b, span); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); - return tir::Max(a, b, span); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + return tirx::Max(a, b, span); } // if_then_else @@ -641,58 +641,58 @@ PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr false_value, } } - return tir::Call(true_value.dtype(), tir::builtin::if_then_else(), - {cond, true_value, false_value}, span); + return tirx::Call(true_value.dtype(), tirx::builtin::if_then_else(), + {cond, true_value, false_value}, span); } // likely PrimExpr likely(PrimExpr cond, Span span) { if (is_const_int(cond)) return cond; - return tir::Call(cond.dtype(), tir::builtin::likely(), {cond}, span); + return tirx::Call(cond.dtype(), tirx::builtin::likely(), {cond}, span); } // operator> PrimExpr operator>(PrimExpr a, PrimExpr b) { return greater(a, b); } PrimExpr greater(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); - return tir::GT(a, b, span); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + return tirx::GT(a, b, span); } PrimExpr operator>=(PrimExpr a, PrimExpr b) { return greater_equal(a, b); } PrimExpr greater_equal(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); - return tir::GE(a, b, span); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + return tirx::GE(a, b, span); } PrimExpr operator<(PrimExpr a, PrimExpr b) { return less(a, b); } PrimExpr less(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); - return tir::LT(a, b, span); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + return tirx::LT(a, b, span); } PrimExpr operator<=(PrimExpr a, PrimExpr b) { return less_equal(a, b); } PrimExpr less_equal(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); - return tir::LE(a, b, span); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + return tirx::LE(a, b, span); } PrimExpr operator==(PrimExpr a, PrimExpr b) { return equal(a, b); } PrimExpr equal(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); if (arith::IsVScaleCall(a) && arith::IsVScaleCall(b)) return true; - return tir::EQ(a, b, span); + return tirx::EQ(a, b, span); } PrimExpr operator!=(PrimExpr a, PrimExpr b) { return not_equal(a, b); } PrimExpr not_equal(PrimExpr a, PrimExpr b, Span span) { BinaryOpMatchTypes(a, b, span); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); - return tir::NE(a, b, span); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + return tirx::NE(a, b, span); } namespace { @@ -735,22 +735,22 @@ void type_check_int_or_bool_args(const PrimExpr& lhs, const PrimExpr& rhs, const PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); } PrimExpr logical_and(PrimExpr a, PrimExpr b, Span span) { type_check_boolean_args(a, b, "&& operator (logical AND)"); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); - return tir::And(a, b, span); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + return tirx::And(a, b, span); } PrimExpr operator||(PrimExpr a, PrimExpr b) { return logical_or(a, b); } PrimExpr logical_or(PrimExpr a, PrimExpr b, Span span) { type_check_boolean_args(a, b, "|| operator (logical OR)"); - if (auto ret = arith::TryConstFold(a, b)) return ret.value(); - return tir::Or(a, b, span); + if (auto ret = arith::TryConstFold(a, b)) return ret.value(); + return tirx::Or(a, b, span); } PrimExpr operator!(PrimExpr a) { return logical_not(a); } PrimExpr logical_not(PrimExpr a, Span span) { type_check_boolean_args(a, "! operator (logical NOT)"); - if (auto ret = arith::TryConstFold(a)) return ret.value(); - return tir::Not(a, span); + if (auto ret = arith::TryConstFold(a)) return ret.value(); + return tirx::Not(a, span); } // shift right @@ -774,7 +774,7 @@ PrimExpr right_shift(PrimExpr a, PrimExpr b, Span span) { } }); - return tir::Call(a.dtype(), tir::builtin::shift_right(), {a, b}, span); + return tirx::Call(a.dtype(), tirx::builtin::shift_right(), {a, b}, span); } // shift left @@ -793,7 +793,7 @@ PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) { if (pb->value == 0) return a; } }); - return tir::Call(a.dtype(), tir::builtin::shift_left(), {a, b}, span); + return tirx::Call(a.dtype(), tirx::builtin::shift_left(), {a, b}, span); } // bitwise and @@ -805,7 +805,7 @@ PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value & pb->value), span); }); - return tir::Call(a.dtype(), tir::builtin::bitwise_and(), {a, b}, span); + return tirx::Call(a.dtype(), tirx::builtin::bitwise_and(), {a, b}, span); } // bitwise_or @@ -817,7 +817,7 @@ PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value | pb->value), span); }); - return tir::Call(a.dtype(), tir::builtin::bitwise_or(), {a, b}, span); + return tirx::Call(a.dtype(), tirx::builtin::bitwise_or(), {a, b}, span); } // bitwise_xor @@ -829,7 +829,7 @@ PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) { const DataType& rtype = a.dtype(); if (pa && pb) return IntImm(rtype, (pa->value ^ pb->value), span); }); - return tir::Call(a.dtype(), tir::builtin::bitwise_xor(), {a, b}, span); + return tirx::Call(a.dtype(), tirx::builtin::bitwise_xor(), {a, b}, span); } // bitwise_not @@ -837,12 +837,12 @@ PrimExpr operator~(PrimExpr a) { return bitwise_neg(a); } PrimExpr bitwise_neg(PrimExpr a, Span span) { type_check_int_or_bool_args(a, "~ operator (bitwise NOT)"); - return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, span); + return tirx::Call(a.dtype(), tirx::builtin::bitwise_not(), {a}, span); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.bitwise_not", + refl::GlobalDef().def("tirx.bitwise_not", [](PrimExpr a, Span span) { return bitwise_neg(a, span); }); } @@ -853,7 +853,7 @@ PrimExpr pow(PrimExpr x, PrimExpr y, Span span) { // If we detect pow(x, 3), suggest using x * x * x if (y.dtype().is_int()) { - using tir::IntImmNode; + using tirx::IntImmNode; const IntImmNode* px = y.as(); if (px) { if (px->value >= 3) { @@ -864,7 +864,7 @@ PrimExpr pow(PrimExpr x, PrimExpr y, Span span) { } } } else if (y.dtype().is_float()) { - using tir::FloatImmNode; + using tirx::FloatImmNode; const FloatImmNode* fx = y.as(); if (fx) { if (fx->value >= 3.0) { @@ -876,8 +876,8 @@ PrimExpr pow(PrimExpr x, PrimExpr y, Span span) { } } - static auto op = Op::Get("tir.pow"); - return tir::Call(x.dtype(), op, {x, y}, span); + static auto op = Op::Get("tirx.pow"); + return tirx::Call(x.dtype(), op, {x, y}, span); } TVM_TIR_REGISTER_PURE_BINARY_OP("pow").set_attr("TVectorizable", true); @@ -885,20 +885,20 @@ TVM_TIR_REGISTER_PURE_BINARY_OP("pow").set_attr("TVectorizable", // abs PrimExpr abs(PrimExpr x, Span span) { if (x.dtype().is_int()) { - using tir::IntImmNode; + using tirx::IntImmNode; const IntImmNode* px = x.as(); if (px) { return IntImm(x.dtype(), std::abs(px->value), px->span); } - return tir::Select(x >= make_zero(x.dtype()), x, -x, span); + return tirx::Select(x >= make_zero(x.dtype()), x, -x, span); } else if (x.dtype().is_float() || x.dtype().is_bfloat()) { - using tir::FloatImmNode; + using tirx::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { return FloatImm(x.dtype(), std::fabs(fx->value), fx->span); } - static auto op = Op::Get("tir.fabs"); - return tir::Call(x.dtype(), op, {x}, span); + static auto op = Op::Get("tirx.fabs"); + return tirx::Call(x.dtype(), op, {x}, span); } else if (x.dtype().is_uint()) { return x; } else { @@ -916,16 +916,16 @@ PrimExpr isnan(PrimExpr x, Span span) { if (x.dtype().is_int() || x.dtype().is_uint()) { return make_const(t, false); } else if (x.dtype().is_float()) { - using tir::FloatImmNode; + using tirx::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { return make_const(t, std::isnan(fx->value), fx->span); } - static auto op = Op::Get("tir.isnan"); + static auto op = Op::Get("tirx.isnan"); if (x.dtype().bits() == 16) { - return tir::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x), span)}, span); + return tirx::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x), span)}, span); } else { - return tir::Call(t, op, {x}, span); + return tirx::Call(t, op, {x}, span); } } else { TVM_FFI_THROW(InternalError) << "Data type " << x.dtype() @@ -952,60 +952,60 @@ PrimExpr isfinite(PrimExpr x, Span span) { return !isinf(x, span) && !isnan(x, s PrimExpr sum(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { Var x("x", source.dtype(), span), y("y", source.dtype(), span); - PrimExpr result = tir::Add(x, y, span); + PrimExpr result = tirx::Add(x, y, span); PrimExpr identity_element = make_zero(source.dtype(), span); - tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); - return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); + tirx::CommReducer combiner = tirx::CommReducer({x}, {y}, {result}, {identity_element}, span); + return tirx::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); } PrimExpr all(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { type_check_boolean_args(source, "tvm::all"); Var x("x", source.dtype(), span), y("y", source.dtype()); - PrimExpr result = tir::And(x, y, span); + PrimExpr result = tirx::And(x, y, span); PrimExpr identity_element = make_const(source.dtype(), true, span); - tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); - return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); + tirx::CommReducer combiner = tirx::CommReducer({x}, {y}, {result}, {identity_element}, span); + return tirx::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); } PrimExpr any(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { type_check_boolean_args(source, "tvm::any"); Var x("x", source.dtype(), span), y("y", source.dtype(), span); - PrimExpr result = tir::Or(x, y, span); + PrimExpr result = tirx::Or(x, y, span); PrimExpr identity_element = make_const(source.dtype(), false, span); - tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); - return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); + tirx::CommReducer combiner = tirx::CommReducer({x}, {y}, {result}, {identity_element}, span); + return tirx::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); } PrimExpr max(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { Var x("x", source.dtype(), span), y("y", source.dtype(), span); - PrimExpr result = tir::Max(x, y, span); + PrimExpr result = tirx::Max(x, y, span); PrimExpr identity_element = min_value(source.dtype(), span); - tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); - return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); + tirx::CommReducer combiner = tirx::CommReducer({x}, {y}, {result}, {identity_element}, span); + return tirx::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); } PrimExpr min(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { Var x("x", source.dtype(), span), y("y", source.dtype(), span); - PrimExpr result = tir::Min(x, y, span); + PrimExpr result = tirx::Min(x, y, span); PrimExpr identity_element = max_value(source.dtype(), span); - tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); - return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); + tirx::CommReducer combiner = tirx::CommReducer({x}, {y}, {result}, {identity_element}, span); + return tirx::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); } PrimExpr prod(PrimExpr source, ffi::Array rdom, ffi::Array init, Span span) { Var x("x", source.dtype(), span), y("y", source.dtype(), span); - PrimExpr result = tir::Mul(x, y, span); + PrimExpr result = tirx::Mul(x, y, span); PrimExpr identity_element = make_const(source.dtype(), 1, span); - tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, {identity_element}, span); - return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); + tirx::CommReducer combiner = tirx::CommReducer({x}, {y}, {result}, {identity_element}, span); + return tirx::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), true), 0, init, span); } // fmod PrimExpr fmod(PrimExpr x, PrimExpr y, Span span) { BinaryOpMatchTypes(x, y, span); TVM_FFI_ICHECK(x.dtype().is_float()) << "fmod only applies to float"; - static auto op = Op::Get("tir.fmod"); - return tir::Call(x.dtype(), op, {x, y}, span); + static auto op = Op::Get("tirx.fmod"); + return tirx::Call(x.dtype(), op, {x, y}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("fmod"); @@ -1015,11 +1015,11 @@ PrimExpr floor(PrimExpr x, Span span) { if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) { return x; } - using tir::FloatImmNode; + using tirx::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::floor(fx->value), fx->span); - static auto op = Op::Get("tir.floor"); - return tir::Call(x.dtype(), op, {x}, span); + static auto op = Op::Get("tirx.floor"); + return tirx::Call(x.dtype(), op, {x}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("floor").set_attr("TVectorizable", true); @@ -1029,11 +1029,11 @@ PrimExpr ceil(PrimExpr x, Span span) { if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) { return x; } - using tir::FloatImmNode; + using tirx::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::ceil(fx->value), fx->span); - static auto op = Op::Get("tir.ceil"); - return tir::Call(x.dtype(), op, {x}, span); + static auto op = Op::Get("tirx.ceil"); + return tirx::Call(x.dtype(), op, {x}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("ceil").set_attr("TVectorizable", true); @@ -1043,11 +1043,11 @@ PrimExpr round(PrimExpr x, Span span) { if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) { return x; } - using tir::FloatImmNode; + using tirx::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value), fx->span); - static auto op = Op::Get("tir.round"); - return tir::Call(x.dtype(), op, {x}, span); + static auto op = Op::Get("tirx.round"); + return tirx::Call(x.dtype(), op, {x}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("round").set_attr("TVectorizable", true); @@ -1057,11 +1057,11 @@ PrimExpr nearbyint(PrimExpr x, Span span) { if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) { return x; } - using tir::FloatImmNode; + using tirx::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value), fx->span); - static auto op = Op::Get("tir.nearbyint"); - return tir::Call(x.dtype(), op, {x}, span); + static auto op = Op::Get("tirx.nearbyint"); + return tirx::Call(x.dtype(), op, {x}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("nearbyint"); @@ -1071,14 +1071,14 @@ PrimExpr trunc(PrimExpr x, Span span) { if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) { return x; } - using tir::FloatImmNode; + using tirx::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) { return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value)), fx->span); } - static auto op = Op::Get("tir.trunc"); - return tir::Call(x.dtype(), op, {x}, span); + static auto op = Op::Get("tirx.trunc"); + return tirx::Call(x.dtype(), op, {x}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("trunc").set_attr("TVectorizable", true); @@ -1160,40 +1160,40 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_packed("node._const", [](ffi::PackedArgs args, ffi::Any* ret) { if (auto opt = args[0].try_cast()) { - *ret = tir::make_const(args[1].cast(), *opt, args[2].cast()); + *ret = tirx::make_const(args[1].cast(), *opt, args[2].cast()); } else if (auto opt = args[0].try_cast()) { - *ret = tir::make_const(args[1].cast(), *opt, args[2].cast()); + *ret = tirx::make_const(args[1].cast(), *opt, args[2].cast()); } else { TVM_FFI_THROW(InternalError) - << "First argument to tvm.tir.const must be int, float, or bool, " + << "First argument to tvm.tirx.const must be int, float, or bool, " << "but instead received argument with type code " << args[0].GetTypeKey(); } }) .def("node.LargeUIntImm", LargeUIntImm) - .def("tir.min_value", min_value) - .def("tir.max_value", max_value) - .def("tir.infinity", infinity) - .def("tir.abs", tvm::abs) - .def("tir.likely", tvm::likely) - .def("tir.isnan", tvm::isnan) - .def("tir.isfinite", tvm::isfinite) - .def("tir.isinf", tvm::isinf) - .def("tir.floor", tvm::floor) - .def("tir.ceil", tvm::ceil) - .def("tir.round", tvm::round) - .def("tir.nearbyint", tvm::nearbyint) - .def("tir.trunc", tvm::trunc) - .def("tir._cast", tvm::cast) - .def("tir.reinterpret", tvm::reinterpret); + .def("tirx.min_value", min_value) + .def("tirx.max_value", max_value) + .def("tirx.infinity", infinity) + .def("tirx.abs", tvm::abs) + .def("tirx.likely", tvm::likely) + .def("tirx.isnan", tvm::isnan) + .def("tirx.isfinite", tvm::isfinite) + .def("tirx.isinf", tvm::isinf) + .def("tirx.floor", tvm::floor) + .def("tirx.ceil", tvm::ceil) + .def("tirx.round", tvm::round) + .def("tirx.nearbyint", tvm::nearbyint) + .def("tirx.trunc", tvm::trunc) + .def("tirx._cast", tvm::cast) + .def("tirx.reinterpret", tvm::reinterpret); } // operator overloading, smarter than make #define DEF_MAKE_BINARY_OP(Node, Func) \ - def("tir." #Node, [](PrimExpr a, PrimExpr b, Span span) { return (Func(a, b, span)); }) + def("tirx." #Node, [](PrimExpr a, PrimExpr b, Span span) { return (Func(a, b, span)); }) #define DEF_MAKE_BIT_OP(Node, Func) \ - def_packed("tir." #Node, [](ffi::PackedArgs args, ffi::Any* ret) { \ + def_packed("tirx." #Node, [](ffi::PackedArgs args, ffi::Any* ret) { \ bool lhs_is_int = args[0].type_index() == ffi::TypeIndex::kTVMFFIInt; \ bool rhs_is_int = args[1].type_index() == ffi::TypeIndex::kTVMFFIInt; \ if (lhs_is_int) { \ @@ -1208,11 +1208,11 @@ TVM_FFI_STATIC_INIT_BLOCK() { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir._OpIfThenElse", + .def("tirx._OpIfThenElse", [](PrimExpr cond, PrimExpr true_value, PrimExpr false_value, Span span) { return if_then_else(cond, true_value, false_value, span); }) - .def("tir.const_true", [](DataType t, Span span) { return const_true(t.lanes(), span); }) + .def("tirx.const_true", [](DataType t, Span span) { return const_true(t.lanes(), span); }) .DEF_MAKE_BINARY_OP(_OpAdd, add) .DEF_MAKE_BINARY_OP(_OpSub, sub) .DEF_MAKE_BINARY_OP(_OpMul, mul) diff --git a/src/tir/op/runtime.cc b/src/tirx/op/runtime.cc similarity index 85% rename from src/tir/op/runtime.cc rename to src/tirx/op/runtime.cc index 9ee6c67ec96b..e013b21d6676 100644 --- a/src/tir/op/runtime.cc +++ b/src/tirx/op/runtime.cc @@ -18,24 +18,24 @@ */ /*! - * \file tir/op/runtime.cc + * \file tirx/op/runtime.cc * \brief TIR ops for runtime functions. */ #include -#include +#include namespace tvm { -namespace tir { +namespace tirx { -TVM_REGISTER_OP("tir.TVMBackendAnyListSetPackedArg") +TVM_REGISTER_OP("tirx.TVMBackendAnyListSetPackedArg") .set_num_inputs(5) .set_attr("TGlobalSymbol", "TVMBackendAnyListSetPackedArg") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TVM_REGISTER_OP("tir.TVMBackendAnyListMoveFromPackedReturn") +TVM_REGISTER_OP("tirx.TVMBackendAnyListMoveFromPackedReturn") .set_num_inputs(3) .set_attr("TGlobalSymbol", "TVMBackendAnyListMoveFromPackedReturn") .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/annotate_device_regions.cc b/src/tirx/transform/annotate_device_regions.cc similarity index 88% rename from src/tir/transform/annotate_device_regions.cc rename to src/tirx/transform/annotate_device_regions.cc index 22ddd074959e..2ecea55400b0 100644 --- a/src/tir/transform/annotate_device_regions.cc +++ b/src/tirx/transform/annotate_device_regions.cc @@ -25,13 +25,13 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { class DeviceRegionAnnotater : public StmtMutator { public: @@ -71,14 +71,14 @@ Pass AnnotateDeviceRegions() { return func; }; - return CreatePrimFuncPass(pass_func, 0, "tir.AnnotateDeviceRegions", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.AnnotateDeviceRegions", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.AnnotateDeviceRegions", AnnotateDeviceRegions); + refl::GlobalDef().def("tirx.transform.AnnotateDeviceRegions", AnnotateDeviceRegions); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/bind_target.cc b/src/tirx/transform/bind_target.cc similarity index 97% rename from src/tir/transform/bind_target.cc rename to src/tirx/transform/bind_target.cc index eaf13c44bd7d..d19a5cf24419 100644 --- a/src/tir/transform/bind_target.cc +++ b/src/tirx/transform/bind_target.cc @@ -38,15 +38,15 @@ #include #include #include -#include -#include +#include +#include #include #include "tvm/ir/attrs.h" namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Visitor class to classify function calls as host or device calls. @@ -286,11 +286,11 @@ IRModule BindTarget(IRModule mod, const Target& target) { continue; } - if (prim_func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) { + if (prim_func->HasNonzeroAttr(tvm::tirx::attr::kIsHostFunc)) { // Rule 2: If the function is marked as host function, bind the host target to the function prim_func = WithAttr(std::move(prim_func), tvm::attr::kTarget, Target::WithHost(target_host, target_host)); - new_mod->Update(gvar, WithoutAttr(std::move(prim_func), tvm::tir::attr::kIsHostFunc)); + new_mod->Update(gvar, WithoutAttr(std::move(prim_func), tvm::tirx::attr::kIsHostFunc)); continue; } @@ -370,16 +370,16 @@ namespace transform { */ transform::Pass BindTarget(Target target) { auto fpass = [target](IRModule mod, transform::PassContext ctx) { - return tvm::tir::BindTarget(mod, target); + return tvm::tirx::BindTarget(mod, target); }; - return tir::transform::CreateModulePass(fpass, 0, "tir.BindTarget", {}); + return tirx::transform::CreateModulePass(fpass, 0, "tirx.BindTarget", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.BindTarget", BindTarget); + refl::GlobalDef().def("tirx.transform.BindTarget", BindTarget); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/common_subexpr_elim.cc b/src/tirx/transform/common_subexpr_elim.cc similarity index 98% rename from src/tir/transform/common_subexpr_elim.cc rename to src/tirx/transform/common_subexpr_elim.cc index fb4e01a1371c..215f225b5126 100644 --- a/src/tir/transform/common_subexpr_elim.cc +++ b/src/tirx/transform/common_subexpr_elim.cc @@ -64,13 +64,13 @@ #include #include #include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include #include #include @@ -81,7 +81,7 @@ #include "../analysis/check_contains.h" namespace tvm { -namespace tir { +namespace tirx { // ============================================================================ // Plan interface types (internal, C++ only) @@ -777,14 +777,14 @@ Pass CommonSubexprElim() { } return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.CommonSubexprElim", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.CommonSubexprElim", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.CommonSubexprElim", CommonSubexprElim); + refl::GlobalDef().def("tirx.transform.CommonSubexprElim", CommonSubexprElim); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/dtype_conversion.cc b/src/tirx/transform/dtype_conversion.cc similarity index 99% rename from src/tir/transform/dtype_conversion.cc rename to src/tirx/transform/dtype_conversion.cc index 84530a778d5c..af5c63f231f6 100644 --- a/src/tir/transform/dtype_conversion.cc +++ b/src/tirx/transform/dtype_conversion.cc @@ -24,7 +24,7 @@ #include "dtype_conversion.h" namespace tvm { -namespace tir { +namespace tirx { PrimExpr ReinterpretAsUInt(PrimExpr value) { return reinterpret(GetStorageUIntDType(value.dtype()), value); @@ -98,5 +98,5 @@ PrimExpr DTypeConversion(PrimExpr src_value, DataType tgt_dtype, RoundingMode ro } } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/dtype_conversion.h b/src/tirx/transform/dtype_conversion.h similarity index 98% rename from src/tir/transform/dtype_conversion.h rename to src/tirx/transform/dtype_conversion.h index bc258301fa6c..21bd5bf355bd 100644 --- a/src/tir/transform/dtype_conversion.h +++ b/src/tirx/transform/dtype_conversion.h @@ -24,13 +24,13 @@ #ifndef TVM_TIR_TRANSFORM_DTYPE_CONVERSION_H_ #define TVM_TIR_TRANSFORM_DTYPE_CONVERSION_H_ -#include -#include -#include +#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Rounding mode: https://en.wikipedia.org/wiki/Rounding @@ -196,6 +196,6 @@ DataType GetStorageUIntDType(DataType dtype); PrimExpr DTypeConversion(PrimExpr src_value, DataType tgt_dtype, RoundingMode round_mode = RoundingMode::kHalfToEven); -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_TRANSFORM_DTYPE_CONVERSION_H_ diff --git a/src/tir/transform/flatten_buffer.cc b/src/tirx/transform/flatten_buffer.cc similarity index 95% rename from src/tir/transform/flatten_buffer.cc rename to src/tirx/transform/flatten_buffer.cc index 65d0a1f9afd0..0218b55a97ac 100644 --- a/src/tir/transform/flatten_buffer.cc +++ b/src/tirx/transform/flatten_buffer.cc @@ -23,9 +23,9 @@ #include #include -#include -#include -#include +#include +#include +#include #include @@ -33,7 +33,7 @@ #include "ir_utils.h" namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Transform multi-dimension BufferLoad/BufferStore into device-supported dimension @@ -80,8 +80,8 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer { Stmt VisitStmt_(const SBlockNode* op) final { TVM_FFI_ICHECK_EQ(op->match_buffers.size(), 0) - << "Unexpected MatchBufferRegion found during tir.transform.FlattenBuffer. " - << "All MatchBufferRegion should be removed in tir.transform.LowerMatchBuffer."; + << "Unexpected MatchBufferRegion found during tirx.transform.FlattenBuffer. " + << "All MatchBufferRegion should be removed in tirx.transform.LowerMatchBuffer."; SBlock block = ffi::GetRef(op); @@ -255,14 +255,14 @@ Pass FlattenBuffer() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { return FlattenBuffer(std::move(f)); }; - return CreatePrimFuncPass(pass_func, 0, "tir.FlattenBuffer", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.FlattenBuffer", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.FlattenBuffer", FlattenBuffer); + refl::GlobalDef().def("tirx.transform.FlattenBuffer", FlattenBuffer); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/force_narrow_index_to_i32.cc b/src/tirx/transform/force_narrow_index_to_i32.cc similarity index 92% rename from src/tir/transform/force_narrow_index_to_i32.cc rename to src/tirx/transform/force_narrow_index_to_i32.cc index 46ff7739d2ad..2f417a3fffd0 100644 --- a/src/tir/transform/force_narrow_index_to_i32.cc +++ b/src/tirx/transform/force_narrow_index_to_i32.cc @@ -24,13 +24,13 @@ */ #include -#include -#include +#include +#include #include "../ir/data_type_rewriter.h" namespace tvm { -namespace tir { +namespace tirx { class Int32DTypeNarrower : public IndexDataTypeNormalizer { public: @@ -87,14 +87,14 @@ Pass ForceNarrowIndexToInt32() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { return ForceNarrowIndexToInt32(f); }; - return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.NarrowDataType", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.ForceNarrowIndexToInt32", ForceNarrowIndexToInt32); + refl::GlobalDef().def("tirx.transform.ForceNarrowIndexToInt32", ForceNarrowIndexToInt32); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/inline_private_functions.cc b/src/tirx/transform/inline_private_functions.cc similarity index 92% rename from src/tir/transform/inline_private_functions.cc rename to src/tirx/transform/inline_private_functions.cc index a44cdb37add2..16f211a53d82 100644 --- a/src/tir/transform/inline_private_functions.cc +++ b/src/tirx/transform/inline_private_functions.cc @@ -23,15 +23,15 @@ */ #include #include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { namespace transform { namespace { @@ -120,7 +120,7 @@ bool IsInlinablePrimFunc(const GlobalVar& gvar, const PrimFunc& prim_func, // We do not currently support inlining of schedulable TIR // functions. To support this use case, repeated names in - // `tir::SBlock` nodes resulting from multiple calls to the same + // `tirx::SBlock` nodes resulting from multiple calls to the same // inlined function will need to be de-duplicated. bool has_block_node = prim_func->body.as(); if (has_block_node) return false; @@ -203,12 +203,12 @@ class PrimFuncInliner : StmtExprMutator { PrimExpr VisitExpr_(const CallNode* call) override { // Because the current implementation inlines a subroutine inserts - // the `tir::Stmt` body at the point of use, replacement must - // occur in a context where a `tir::Stmt` can be returned. Support + // the `tirx::Stmt` body at the point of use, replacement must + // occur in a context where a `tirx::Stmt` can be returned. Support // of subroutines that are called within an expression // (e.g. Replacing func in `Buf[0] = func(1) + func(2)`) would // require hoisting preprocessing done in the subroutine to the - // parent `tir::Stmt`. + // parent `tirx::Stmt`. // // See `TestInlineCallOccurringInExpression` in // `test_tir_inline_private_functions.py` for a test of this @@ -233,7 +233,7 @@ class PrimFuncInliner : StmtExprMutator { << "Inlining of PrimFuncs with buffer arguments is not yet supported, " << "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map; - ffi::Map> param_map; + ffi::Map> param_map; for (size_t i = 0; i < callee->params.size(); i++) { param_map.Set(callee->params[i], args[i]); } @@ -291,15 +291,15 @@ Pass InlinePrivateFunctions() { return mod; }; - return tvm::transform::CreateModulePass(pass_func, 0, "tir.InlinePrivateFunctions", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "tirx.InlinePrivateFunctions", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.InlinePrivateFunctions", InlinePrivateFunctions); + refl::GlobalDef().def("tirx.transform.InlinePrivateFunctions", InlinePrivateFunctions); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/ir_utils.cc b/src/tirx/transform/ir_utils.cc similarity index 98% rename from src/tir/transform/ir_utils.cc rename to src/tirx/transform/ir_utils.cc index 5e781d634476..dc1ed5b8b860 100644 --- a/src/tir/transform/ir_utils.cc +++ b/src/tirx/transform/ir_utils.cc @@ -28,15 +28,15 @@ #include #include #include -#include -#include +#include +#include #include #include #include namespace tvm { -namespace tir { +namespace tirx { Stmt MergeNest(const std::vector& nest, Stmt body) { // use reverse iteration @@ -90,7 +90,7 @@ class IRConvertSSA final : public StmtExprMutator { // Remap parameters, if they were used in another function. // Function-scope remaps use function_scope_var_remap_ (not the scope stack), // because they persist across the entire function body. - auto params = func->params.Map([&](const tir::Var& var) -> tir::Var { + auto params = func->params.Map([&](const tirx::Var& var) -> tirx::Var { if (defined_.count(var.get())) { Var new_var = MakeNewVar(var); PushVarRemap(var, new_var); @@ -927,7 +927,7 @@ std::pair GetWmmaFragmentDimSize(const std::string& shape_str, } std::optional IsHostFunc(const PrimFunc& func) { - if (func->HasNonzeroAttr(tvm::tir::attr::kIsHostFunc)) { + if (func->HasNonzeroAttr(tvm::tirx::attr::kIsHostFunc)) { return true; } else if (auto target = func->GetAttr(tvm::attr::kTarget)) { return target.value()->HasKey("cpu"); @@ -939,12 +939,12 @@ std::optional IsHostFunc(const PrimFunc& func) { namespace transform { Pass ConvertSSA() { auto pass_func = [](IRModule mod, PassContext ctx) { - tir::IRConvertSSA converter; + tirx::IRConvertSSA converter; ffi::Map functions; bool made_change = false; for (auto [gvar, base_func] : mod->functions) { - if (auto* ptr = base_func.as()) { - auto updated = converter.VisitPrimFunc(ffi::GetRef(ptr)); + if (auto* ptr = base_func.as()) { + auto updated = converter.VisitPrimFunc(ffi::GetRef(ptr)); if (!updated.same_as(base_func)) { made_change = true; base_func = updated; @@ -957,14 +957,14 @@ Pass ConvertSSA() { } return mod; }; - return tvm::transform::CreateModulePass(pass_func, 0, "tir.ConvertSSA", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "tirx.ConvertSSA", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.ConvertSSA", ConvertSSA); + refl::GlobalDef().def("tirx.transform.ConvertSSA", ConvertSSA); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/ir_utils.h b/src/tirx/transform/ir_utils.h similarity index 98% rename from src/tir/transform/ir_utils.h rename to src/tirx/transform/ir_utils.h index 1282778ae09d..dad1d92eafce 100644 --- a/src/tir/transform/ir_utils.h +++ b/src/tirx/transform/ir_utils.h @@ -30,10 +30,10 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include @@ -43,7 +43,7 @@ #include namespace tvm { -namespace tir { +namespace tirx { /*! * \brief combine the nest stmt, whose body is not defined. * \param nest A list of For and Bind, whose body is not defined. @@ -345,6 +345,6 @@ std::pair GetWmmaFragmentDimSize(const std::string& shape_str, */ std::optional IsHostFunc(const PrimFunc& func); -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_TRANSFORM_IR_UTILS_H_ diff --git a/src/tir/transform/lower_custom_datatypes.cc b/src/tirx/transform/lower_custom_datatypes.cc similarity index 97% rename from src/tir/transform/lower_custom_datatypes.cc rename to src/tirx/transform/lower_custom_datatypes.cc index 8f2bf539a8c9..b82a2e5a07e9 100644 --- a/src/tir/transform/lower_custom_datatypes.cc +++ b/src/tirx/transform/lower_custom_datatypes.cc @@ -24,14 +24,14 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include "../../target/datatype/registry.h" namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Helper mutator to implement lowering of custom datatypes. @@ -251,15 +251,15 @@ Pass LowerCustomDatatypes() { n->body = CustomDatatypesLowerer(target.value()->kind->name)(std::move(n->body)); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.LowerCustomDatatypes", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.LowerCustomDatatypes", LowerCustomDatatypes); + refl::GlobalDef().def("tirx.transform.LowerCustomDatatypes", LowerCustomDatatypes); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/lower_device_kernel_launch.cc b/src/tirx/transform/lower_device_kernel_launch.cc similarity index 95% rename from src/tir/transform/lower_device_kernel_launch.cc rename to src/tirx/transform/lower_device_kernel_launch.cc index 6df115652216..3ff4cf17c585 100644 --- a/src/tir/transform/lower_device_kernel_launch.cc +++ b/src/tirx/transform/lower_device_kernel_launch.cc @@ -25,16 +25,16 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" namespace tvm { -namespace tir { +namespace tirx { namespace { struct KernelInfo { @@ -212,7 +212,7 @@ class DeviceKernelMutator : public StmtExprMutator { << "This case is not yet supported."; if (is_kernel_launch || is_call_extern) { - func = WithAttr(std::move(func), tvm::tir::attr::kIsGlobalFunc, true); + func = WithAttr(std::move(func), tvm::tirx::attr::kIsGlobalFunc, true); } if (is_kernel_launch) { @@ -228,7 +228,7 @@ class DeviceKernelMutator : public StmtExprMutator { func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, static_cast(tvm::CallingConv::kDeviceKernelLaunch)}, - {tvm::tir::attr::kKernelLaunchParams, info.launch_params}, + {tvm::tirx::attr::kKernelLaunchParams, info.launch_params}, {tvm::attr::kGlobalSymbol, info.global_symbol}}); } else if (is_call_extern && !func->GetAttr(tvm::attr::kGlobalSymbol)) { @@ -279,7 +279,7 @@ class DeviceKernelMutator : public StmtExprMutator { TVM_FFI_ICHECK(dev_info.launch_params.defined()) << "CallNode attempted kernel launch to " << gvar->name_hint << " on target " << dev_info.target << ", but subroutine " << gvar->name_hint - << " did not have the tir::attr::kKernelLaunchParams attribute " + << " did not have the tirx::attr::kKernelLaunchParams attribute " << "required for cross-target kernel launch"; // Collected kernel information may be in terms of the callee's @@ -369,14 +369,14 @@ Pass LowerDeviceKernelLaunch() { return mod; }; - return tvm::transform::CreateModulePass(pass_func, 0, "tir.LowerDeviceKernelLaunch", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "tirx.LowerDeviceKernelLaunch", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.LowerDeviceKernelLaunch", LowerDeviceKernelLaunch); + refl::GlobalDef().def("tirx.transform.LowerDeviceKernelLaunch", LowerDeviceKernelLaunch); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/lower_intrin.cc b/src/tirx/transform/lower_intrin.cc similarity index 94% rename from src/tir/transform/lower_intrin.cc rename to src/tirx/transform/lower_intrin.cc index fc85e541038f..103a47e888fc 100644 --- a/src/tir/transform/lower_intrin.cc +++ b/src/tirx/transform/lower_intrin.cc @@ -24,10 +24,10 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include @@ -36,7 +36,7 @@ #include "../../arith/pattern_match.h" namespace tvm { -namespace tir { +namespace tirx { class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { public: @@ -61,7 +61,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { if (Op::HasAttrMap(pattern)) { attr_maps_.push_back(Op::GetAttrMap(pattern)); if (fma_ == nullptr) { - fma_ = (*attr_maps_.rbegin()).get(Op::Get("tir.fma"), nullptr); + fma_ = (*attr_maps_.rbegin()).get(Op::Get("tirx.fma"), nullptr); } } } @@ -135,7 +135,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // equivalent to rdiv + (rmod >= 0 ? 0: -1); return rdiv + (rmod >> make_const(dtype, dtype.bits() - 1)); } else { - return tir::Select(rmod >= 0, rdiv, rdiv - make_const(dtype, 1)); + return tirx::Select(rmod >= 0, rdiv, rdiv - make_const(dtype, 1)); } } else { @@ -145,14 +145,14 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } else { // uncommon case DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divisor"; - auto rmod = tir::Var("rmod", dtype); - auto rdiv = tir::Var("rdiv", dtype); + auto rmod = tirx::Var("rmod", dtype); + auto rdiv = tirx::Var("rdiv", dtype); // b >= 0 => (rmod >=0 ? rdiv : rdiv - 1) // b < 0 => (rmod <= 0 ? rdiv : rdiv - 1) PrimExpr let_rdiv = - tir::Let(rdiv, truncdiv(op->a, op->b), - tir::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, - rdiv - make_const(dtype, 1))); + tirx::Let(rdiv, truncdiv(op->a, op->b), + tirx::Select((op->b >= 0 && rmod >= 0) || (op->b < 0 && rmod <= 0), rdiv, + rdiv - make_const(dtype, 1))); return Let(rmod, truncmod(op->a, op->b), let_rdiv); } } @@ -197,7 +197,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { // -> rmod >= 0 ? 0 : b return rmod + (op->b & (rmod >> make_const(dtype, dtype.bits() - 1))); } else { - return tir::Select(rmod >= 0, rmod, rmod + op->b); + return tirx::Select(rmod >= 0, rmod, rmod + op->b); } } else { @@ -207,7 +207,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { } else { // uncommon case DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divsor and divident"; - auto rmod = tir::Var("rmod", dtype); + auto rmod = tirx::Var("rmod", dtype); // b > 0 && rmod >= 0 -> rmod // b > 0 && rmod < 0 -> rmod + b // b < 0 && rmod < 0 -> rmod @@ -368,15 +368,15 @@ Pass LowerIntrin() { IntrinInjecter(&analyzer, target.value()->kind->name, mtriple.value())(std::move(n->body)); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.LowerIntrin", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.LowerIntrin", LowerIntrin); + refl::GlobalDef().def("tirx.transform.LowerIntrin", LowerIntrin); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/lower_tvm_builtin.cc b/src/tirx/transform/lower_tvm_builtin.cc similarity index 97% rename from src/tir/transform/lower_tvm_builtin.cc rename to src/tirx/transform/lower_tvm_builtin.cc index 496302abbf4b..401194b9a18a 100644 --- a/src/tir/transform/lower_tvm_builtin.cc +++ b/src/tirx/transform/lower_tvm_builtin.cc @@ -19,22 +19,22 @@ /*! * Lower TVM related builtin intrinsics such as packed call. - * \file tir/transform/lower_tvm_builtin.cc + * \file tirx/transform/lower_tvm_builtin.cc */ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include "ir_utils.h" namespace tvm { -namespace tir { +namespace tirx { // Calculate the statistics of packed function. // These information are needed during codegen. @@ -266,7 +266,7 @@ class BuiltinLower : public StmtExprMutator { Stmt alloc_nullptr_check = IfThenElse( Call(DataType::Bool(), builtin::isnullptr(), {op->buffer->data}), throw_last_error); - PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"), + PrimExpr free_op = Call(DataType::Int(32), Op::Get("tirx.TVMBackendFreeWorkspace"), {cast(DataType::Int(32), device_type_.value()), cast(DataType::Int(32), device_id_.value()), op->buffer->data}); Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); @@ -275,7 +275,7 @@ class BuiltinLower : public StmtExprMutator { scope_.Current().pending_frees.push_back(free_stmt); Stmt alloc_bind = Bind(op->buffer->data, - Call(op->buffer->data.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"), + Call(op->buffer->data.dtype(), Op::Get("tirx.TVMBackendAllocWorkspace"), {cast(DataType::Int(32), device_type_.value()), cast(DataType::Int(32), device_id_.value()), total_bytes, IntImm(DataType::Int(32), op->buffer->dtype.code()), @@ -520,12 +520,12 @@ class BuiltinLower : public StmtExprMutator { } void SetPackedArg(PrimExpr arg, const Var& args_stack, size_t stack_offset, - std::vector* prep_seq) { + std::vector* prep_seq) { auto* call_pattern = arg.as(); if (call_pattern && call_pattern->op.same_as(builtin::anylist_getitem())) { // call runtime function to set anylist prep_seq->emplace_back(Evaluate(Call( - DataType::Int(32), Op::Get("tir.TVMBackendAnyListSetPackedArg"), + DataType::Int(32), Op::Get("tirx.TVMBackendAnyListSetPackedArg"), {call_pattern->args[0], call_pattern->args[1], args_stack, ConstInt32(stack_offset)}))); } else { DataType api_dtype = APIType(arg.dtype()); @@ -582,7 +582,7 @@ class BuiltinLower : public StmtExprMutator { PrimExpr ret_offset = call->args[3]; auto& prep_seq = prep_seq_stack_.back(); prep_seq.emplace_back(Evaluate(call)); - return Call(DataType::Int(32), Op::Get("tir.TVMBackendAnyListMoveFromPackedReturn"), + return Call(DataType::Int(32), Op::Get("tirx.TVMBackendAnyListMoveFromPackedReturn"), {list_handle, list_index, args_stack, ret_offset}); } /*! @@ -745,14 +745,14 @@ Pass LowerTVMBuiltin() { } return func; }; - return CreatePrimFuncPass(pass_func, 0, "tir.LowerTVMBuiltin", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.LowerTVMBuiltin", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.LowerTVMBuiltin", LowerTVMBuiltin); + refl::GlobalDef().def("tirx.transform.LowerTVMBuiltin", LowerTVMBuiltin); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/lower_warp_memory.cc b/src/tirx/transform/lower_warp_memory.cc similarity index 98% rename from src/tir/transform/lower_warp_memory.cc rename to src/tirx/transform/lower_warp_memory.cc index b63fb77fc695..17b525ac320d 100644 --- a/src/tir/transform/lower_warp_memory.cc +++ b/src/tirx/transform/lower_warp_memory.cc @@ -31,12 +31,12 @@ #include #include #include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include #include @@ -46,7 +46,7 @@ #include "update_pointer_storage_scope.h" namespace tvm { -namespace tir { +namespace tirx { // Rewrite Rule // @@ -496,15 +496,15 @@ Pass LowerWarpMemory() { n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.LowerWarpMemory", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.LowerWarpMemory", LowerWarpMemory); + refl::GlobalDef().def("tirx.transform.LowerWarpMemory", LowerWarpMemory); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/make_packed_api.cc b/src/tirx/transform/make_packed_api.cc similarity index 85% rename from src/tir/transform/make_packed_api.cc rename to src/tirx/transform/make_packed_api.cc index e94c355c7a69..bc568589e5dc 100644 --- a/src/tir/transform/make_packed_api.cc +++ b/src/tirx/transform/make_packed_api.cc @@ -27,12 +27,12 @@ #include #include #include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include #include #include @@ -42,7 +42,7 @@ #include "tvm_ffi_binder.h" namespace tvm { -namespace tir { +namespace tirx { namespace { class ReturnRewriter : public StmtMutator { @@ -62,8 +62,8 @@ class ReturnRewriter : public StmtMutator { TVM_FFI_ICHECK(eval); if (const CallNode* call = eval->value.as()) { if (call->op.same_as(builtin::ret())) { - TVM_FFI_ICHECK_EQ(in_parallel_, 0) << "tir.ret cannot be used in parallel scope."; - TVM_FFI_ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument."; + TVM_FFI_ICHECK_EQ(in_parallel_, 0) << "tirx.ret cannot be used in parallel scope."; + TVM_FFI_ICHECK_EQ(call->args.size(), 1) << "tirx.ret expect a single argument."; ret = WriteToOut(call->args[0]); } } @@ -103,19 +103,19 @@ class ReturnRewriter : public StmtMutator { Stmt WriteToOut(PrimExpr val) { auto info = ConvertForFFI(val); Stmt store_tindex = - tir::Evaluate(tir::Call(DataType::Int(32), tir::builtin::tvm_struct_set(), - {ret_var_, IntImm(DataType::Int(32), 0), - IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyTypeIndex), - IntImm(DataType::Int(32), info.type_index)})); + tirx::Evaluate(tirx::Call(DataType::Int(32), tirx::builtin::tvm_struct_set(), + {ret_var_, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), tirx::builtin::kTVMFFIAnyTypeIndex), + IntImm(DataType::Int(32), info.type_index)})); Stmt store_zero_padding = - tir::Evaluate(tir::Call(DataType::Int(32), tir::builtin::tvm_struct_set(), - {ret_var_, IntImm(DataType::Int(32), 0), - IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyZeroPadding), - IntImm(DataType::Int(32), 0)})); - Stmt store_val = tir::Evaluate( - tir::Call(DataType::Int(32), tir::builtin::tvm_struct_set(), - {ret_var_, IntImm(DataType::Int(32), 0), - IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyUnionValue), info.expr})); + tirx::Evaluate(tirx::Call(DataType::Int(32), tirx::builtin::tvm_struct_set(), + {ret_var_, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), tirx::builtin::kTVMFFIAnyZeroPadding), + IntImm(DataType::Int(32), 0)})); + Stmt store_val = tirx::Evaluate( + tirx::Call(DataType::Int(32), tirx::builtin::tvm_struct_set(), + {ret_var_, IntImm(DataType::Int(32), 0), + IntImm(DataType::Int(32), tirx::builtin::kTVMFFIAnyUnionValue), info.expr})); Stmt ret_zero = Evaluate(tvm::ret(0)); return SeqStmt({store_tindex, store_zero_padding, store_val, ret_zero}); } @@ -148,15 +148,15 @@ class SubroutineCallRewriter : public StmtExprMutator { auto gvar = ffi::GetRef(gvar_ptr); if (auto symbol = packed_func_methods.Get(gvar)) { ffi::Array cpacked_args; - cpacked_args.push_back(tir::StringImm(symbol.value())); + cpacked_args.push_back(tirx::StringImm(symbol.value())); for (auto arg : node->args) { cpacked_args.push_back(arg); } // push an empty handle to be compatible with current cpacked convention - cpacked_args.push_back(tir::make_zero(DataType::Handle())); + cpacked_args.push_back(tirx::make_zero(DataType::Handle())); made_change_ = true; - return tir::Call(node->dtype, tir::builtin::tvm_call_cpacked(), cpacked_args); + return tirx::Call(node->dtype, tirx::builtin::tvm_call_cpacked(), cpacked_args); } } @@ -328,13 +328,13 @@ Pass MakePackedAPI() { return mod; }; - return tvm::transform::CreateModulePass(pass_func, 0, "tir.MakePackedAPI", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "tirx.MakePackedAPI", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.MakePackedAPI", []() { return MakePackedAPI(); }); + refl::GlobalDef().def("tirx.transform.MakePackedAPI", []() { return MakePackedAPI(); }); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/narrow_datatype.cc b/src/tirx/transform/narrow_datatype.cc similarity index 97% rename from src/tir/transform/narrow_datatype.cc rename to src/tirx/transform/narrow_datatype.cc index 0ca492c0036d..9a677f530dfb 100644 --- a/src/tir/transform/narrow_datatype.cc +++ b/src/tirx/transform/narrow_datatype.cc @@ -25,16 +25,16 @@ #include #include #include -#include -#include -#include +#include +#include +#include #include "../../arith/ir_mutator_with_analyzer.h" #include "../../arith/ir_visitor_with_analyzer.h" #include "../ir/data_type_rewriter.h" namespace tvm { -namespace tir { +namespace tirx { // This pass narrows indexing expressions (like BufferStoreNode::indices) // that trivially fit into i32/i16 (denoted by `target_bits_`) to @@ -319,14 +319,14 @@ Pass NarrowDataType(int target_bits) { n->body = NarrowDataTypeRewriter(target_bits)(std::move(n->body)); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.NarrowDataType", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.NarrowDataType", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.NarrowDataType", NarrowDataType); + refl::GlobalDef().def("tirx.transform.NarrowDataType", NarrowDataType); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/primfunc_utils.cc b/src/tirx/transform/primfunc_utils.cc similarity index 79% rename from src/tir/transform/primfunc_utils.cc rename to src/tirx/transform/primfunc_utils.cc index 136a92abde31..d790c02b2f6c 100644 --- a/src/tir/transform/primfunc_utils.cc +++ b/src/tirx/transform/primfunc_utils.cc @@ -23,10 +23,10 @@ */ #include -#include +#include namespace tvm { -namespace tir { +namespace tirx { namespace transform { transform::Pass AnnotateEntryFunc() { @@ -34,9 +34,9 @@ transform::Pass AnnotateEntryFunc() { // If only a single function exists, that function must be the entry if (mod->functions.size() == 1) { auto [gvar, base_func] = *mod->functions.begin(); - if (!base_func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + if (!base_func->HasNonzeroAttr(tirx::attr::kIsEntryFunc)) { if (auto ptr = base_func.as()) { - mod->Update(gvar, WithAttr(ffi::GetRef(ptr), tir::attr::kIsEntryFunc, true)); + mod->Update(gvar, WithAttr(ffi::GetRef(ptr), tirx::attr::kIsEntryFunc, true)); } } return mod; @@ -51,7 +51,7 @@ transform::Pass AnnotateEntryFunc() { if (is_external) { if (auto ptr = base_func.as()) { with_annotations->Add( - gvar, WithAttr(ffi::GetRef(ptr), tir::attr::kIsEntryFunc, true)); + gvar, WithAttr(ffi::GetRef(ptr), tirx::attr::kIsEntryFunc, true)); } else { has_external_non_primfuncs = true; } @@ -65,27 +65,27 @@ transform::Pass AnnotateEntryFunc() { // Default fallback, no annotations may be inferred. return mod; }; - return tvm::transform::CreateModulePass(fpass, 0, "tir.AnnotateEntryFunc", {}); + return tvm::transform::CreateModulePass(fpass, 0, "tirx.AnnotateEntryFunc", {}); } transform::Pass Filter(ffi::TypedFunction fcond) { - auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { + auto fpass = [fcond](tirx::PrimFunc f, IRModule m, transform::PassContext ctx) { if (fcond(f)) { return f; } else { - return tir::PrimFunc(nullptr); + return tirx::PrimFunc(nullptr); } }; - return tir::transform::CreatePrimFuncPass(fpass, 0, "tir.Filter", {}); + return tirx::transform::CreatePrimFuncPass(fpass, 0, "tirx.Filter", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() - .def("tir.transform.AnnotateEntryFunc", AnnotateEntryFunc) - .def("tir.transform.Filter", Filter); + .def("tirx.transform.AnnotateEntryFunc", AnnotateEntryFunc) + .def("tirx.transform.Filter", Filter); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/remap_thread_axis.cc b/src/tirx/transform/remap_thread_axis.cc similarity index 87% rename from src/tir/transform/remap_thread_axis.cc rename to src/tirx/transform/remap_thread_axis.cc index 4a47e43d06d1..4ed69713fd2f 100644 --- a/src/tir/transform/remap_thread_axis.cc +++ b/src/tirx/transform/remap_thread_axis.cc @@ -22,14 +22,14 @@ */ #include #include -#include -#include -#include +#include +#include +#include #include namespace tvm { -namespace tir { +namespace tirx { // Mutator to change the read pattern class ThreadAxisRewriter : private StmtExprMutator { @@ -76,8 +76,8 @@ PrimFunc RemapThreadAxis(PrimFunc func, ffi::Map thread_ma tmap[kv.first] = kv.second; } - if (auto opt = func->GetAttr>(tir::attr::kKernelLaunchParams)) { - TVM_FFI_ICHECK(opt != nullptr) << "Require attribute " << tir::attr::kKernelLaunchParams; + if (auto opt = func->GetAttr>(tirx::attr::kKernelLaunchParams)) { + TVM_FFI_ICHECK(opt != nullptr) << "Require attribute " << tirx::attr::kKernelLaunchParams; auto launch_params = opt.value(); // replace the thread axis attribute for (size_t i = 0; i < launch_params.size(); ++i) { @@ -87,7 +87,7 @@ PrimFunc RemapThreadAxis(PrimFunc func, ffi::Map thread_ma } } - func = WithAttr(std::move(func), tir::attr::kKernelLaunchParams, launch_params); + func = WithAttr(std::move(func), tirx::attr::kKernelLaunchParams, launch_params); } auto* n = func.CopyOnWrite(); @@ -101,14 +101,14 @@ Pass RemapThreadAxis(ffi::Map thread_map) { auto pass_func = [thread_map](PrimFunc f, IRModule m, PassContext ctx) { return RemapThreadAxis(std::move(f), thread_map); }; - return CreatePrimFuncPass(pass_func, 0, "tir.RemapThreadAxis", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.RemapThreadAxis", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.RemapThreadAxis", RemapThreadAxis); + refl::GlobalDef().def("tirx.transform.RemapThreadAxis", RemapThreadAxis); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/remove_assume.cc b/src/tirx/transform/remove_assume.cc similarity index 80% rename from src/tir/transform/remove_assume.cc rename to src/tirx/transform/remove_assume.cc index 6475befa1cf8..a2adc794d42e 100644 --- a/src/tir/transform/remove_assume.cc +++ b/src/tirx/transform/remove_assume.cc @@ -19,19 +19,19 @@ /*! * \file remove_store_undef.cc - * \brief Remove stores of tir::builtin::undef + * \brief Remove stores of tirx::builtin::undef */ #include #include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { // Remove any builtin::assume calls class AssumeRemover : public StmtExprMutator { @@ -55,19 +55,19 @@ Pass RemoveAssumeInternal() { n->body = AssumeRemover()(std::move(n->body)); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.RemoveAssumeInternal", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.RemoveAssumeInternal", {}); } Pass RemoveAssume() { - return Sequential({RemoveAssumeInternal(), RemoveNoOp()}, "tir.RemoveAssume"); + return Sequential({RemoveAssumeInternal(), RemoveNoOp()}, "tirx.RemoveAssume"); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.RemoveAssume", RemoveAssume); + refl::GlobalDef().def("tirx.transform.RemoveAssume", RemoveAssume); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/remove_no_op.cc b/src/tirx/transform/remove_no_op.cc similarity index 95% rename from src/tir/transform/remove_no_op.cc rename to src/tirx/transform/remove_no_op.cc index 60b6a679d418..a70c2e037a13 100644 --- a/src/tir/transform/remove_no_op.cc +++ b/src/tirx/transform/remove_no_op.cc @@ -25,11 +25,11 @@ #include #include #include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include #include @@ -41,7 +41,7 @@ #include "ir_utils.h" namespace tvm { -namespace tir { +namespace tirx { struct RemoveNoOpConfigNode : public AttrsNodeReflAdapter { bool use_dataflow_analysis; @@ -60,7 +60,7 @@ struct RemoveNoOpConfigNode : public AttrsNodeReflAdapter "For use in debug and testing purposes.", refl::DefaultValue(0)); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.transform.RemoveNoOpConfig", RemoveNoOpConfigNode, + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.transform.RemoveNoOpConfig", RemoveNoOpConfigNode, BaseAttrsNode); }; @@ -71,7 +71,7 @@ class RemoveNoOpConfig : public Attrs { TVM_FFI_STATIC_INIT_BLOCK() { RemoveNoOpConfigNode::RegisterReflection(); } -TVM_REGISTER_PASS_CONFIG_OPTION("tir.RemoveNoOp", RemoveNoOpConfig); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.RemoveNoOp", RemoveNoOpConfig); // Mark the statement of each stage. class NoOpRemover : public arith::IRMutatorWithAnalyzer { @@ -285,7 +285,7 @@ Pass RemoveNoOp() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { std::optional touch_pattern = std::nullopt; - RemoveNoOpConfig config = ctx->GetConfig("tir.RemoveNoOp") + RemoveNoOpConfig config = ctx->GetConfig("tirx.RemoveNoOp") .value_or(AttrsWithDefaultValues()); if (config->use_dataflow_analysis) { @@ -302,15 +302,15 @@ Pass RemoveNoOp() { } return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.RemoveNoOp", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.RemoveNoOp", RemoveNoOp); + refl::GlobalDef().def("tirx.transform.RemoveNoOp", RemoveNoOp); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/remove_no_op.h b/src/tirx/transform/remove_no_op.h similarity index 93% rename from src/tir/transform/remove_no_op.h rename to src/tirx/transform/remove_no_op.h index 493b98c90d4d..3f6d1c112470 100644 --- a/src/tir/transform/remove_no_op.h +++ b/src/tirx/transform/remove_no_op.h @@ -25,18 +25,18 @@ #define TVM_TIR_TRANSFORM_REMOVE_NO_OP_H_ #include -#include +#include #include #include "../analysis/control_flow_graph.h" namespace tvm { -namespace tir { +namespace tirx { /* \brief Remove no-ops from the statement * - * Applies the same behavior as the tir.transform.RemoveNoOp pass, but + * Applies the same behavior as the tirx.transform.RemoveNoOp pass, but * on a single statement, usable as a subroutine in other passes. * * \param stmt The TIR statement from which to remove no-ops @@ -55,6 +55,6 @@ Stmt RemoveNoOp(Stmt stmt, arith::Analyzer* analyzer, std::optional touch_pattern = std::nullopt, const StmtNode* context = nullptr); -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_TRANSFORM_REMOVE_NO_OP_H_ diff --git a/src/tir/transform/replace_global_vars.cc b/src/tirx/transform/replace_global_vars.cc similarity index 88% rename from src/tir/transform/replace_global_vars.cc rename to src/tirx/transform/replace_global_vars.cc index 8bb7dee37882..1723ff75c146 100644 --- a/src/tir/transform/replace_global_vars.cc +++ b/src/tirx/transform/replace_global_vars.cc @@ -19,17 +19,17 @@ /*! * - * \file src/tir/transform/replace_global_vars.cc + * \file src/tirx/transform/replace_global_vars.cc * * \brief GlobalVar replacement across IR types */ #include -#include -#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { namespace { using tvm::transform::GlobalVarReplacer; @@ -52,8 +52,8 @@ struct Mutator : StmtExprMutator { } // namespace TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) - .set_dispatch([](const ObjectRef& obj, - ffi::Map replacements) -> BaseFunc { + .set_dispatch([](const ObjectRef& obj, + ffi::Map replacements) -> BaseFunc { Mutator mutator(replacements); auto func = Downcast(obj); auto new_body = mutator(func->body); @@ -80,5 +80,5 @@ TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable) return func; }); -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/replace_selected_expr.cc b/src/tirx/transform/replace_selected_expr.cc similarity index 94% rename from src/tir/transform/replace_selected_expr.cc rename to src/tirx/transform/replace_selected_expr.cc index ce133b2f5d6a..ed62f5863484 100644 --- a/src/tir/transform/replace_selected_expr.cc +++ b/src/tirx/transform/replace_selected_expr.cc @@ -27,15 +27,15 @@ #include "replace_selected_expr.h" #include // For the class Pass and the class PassContext -#include -#include -#include // For the class PrimFunc -#include -#include -#include // For the declaration of the pass +#include +#include +#include // For the class PrimFunc +#include +#include +#include // For the declaration of the pass namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Toplevel (static) function that replace in an expression @@ -105,5 +105,5 @@ PrimExpr ReplaceSelectedExpr::VisitExpr(const PrimExpr& expr) { } } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/replace_selected_expr.h b/src/tirx/transform/replace_selected_expr.h similarity index 93% rename from src/tir/transform/replace_selected_expr.h rename to src/tirx/transform/replace_selected_expr.h index e3ee62fb8d07..a3bc0c0168a5 100644 --- a/src/tir/transform/replace_selected_expr.h +++ b/src/tirx/transform/replace_selected_expr.h @@ -27,13 +27,13 @@ #ifndef TVM_TIR_TRANSFORM_REPLACE_SELECTED_EXPR_H_ #define TVM_TIR_TRANSFORM_REPLACE_SELECTED_EXPR_H_ -#include -#include -#include -#include // For the class StmtExprMutator +#include +#include +#include +#include // For the class StmtExprMutator namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Mutator for replacing the expressions selected by a predicate in a statement and/or @@ -69,7 +69,7 @@ class ReplaceSelectedExpr : public StmtExprMutator { std::function can_replace_inside_; }; -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_TRANSFORM_REPLACE_SELECTED_EXPR_H_ diff --git a/src/tir/transform/simplify.cc b/src/tirx/transform/simplify.cc similarity index 93% rename from src/tir/transform/simplify.cc rename to src/tirx/transform/simplify.cc index 3d295e1764be..fed46f59a83d 100644 --- a/src/tir/transform/simplify.cc +++ b/src/tirx/transform/simplify.cc @@ -22,26 +22,26 @@ * \brief Statement simplifier based on analyzer */ -#include "../../tir/transform/simplify.h" +#include "../../tirx/transform/simplify.h" #include #include #include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include #include "../../arith/ir_mutator_with_analyzer.h" -#include "../../tir/analysis/control_flow_graph.h" +#include "../../tirx/analysis/control_flow_graph.h" namespace tvm { namespace arith { -using namespace tir; +using namespace tirx; struct SimplifyConfigNode : public AttrsNodeReflAdapter { bool transitively_prove_inequalities; @@ -76,7 +76,7 @@ struct SimplifyConfigNode : public AttrsNodeReflAdapter { "branch", refl::DefaultValue(false)); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.transform.SimplifyConfig", SimplifyConfigNode, + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.transform.SimplifyConfig", SimplifyConfigNode, BaseAttrsNode); RewriteSimplifier::Extension GetEnabledExtensions() const { @@ -103,7 +103,7 @@ class SimplifyConfig : public Attrs { TVM_FFI_STATIC_INIT_BLOCK() { SimplifyConfigNode::RegisterReflection(); } -TVM_REGISTER_PASS_CONFIG_OPTION("tir.Simplify", SimplifyConfig); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.Simplify", SimplifyConfig); class StmtSimplifier : public IRMutatorWithAnalyzer { public: @@ -233,7 +233,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { if (const BufferLoadNode* load = store->value.as()) { if (load->buffer->data.same_as(store->buffer->data) && ArrayDeepEqual(load->indices, store->indices) && - tir::ExprDeepEqual()(load->buffer->elem_offset, store->buffer->elem_offset) && + tirx::ExprDeepEqual()(load->buffer->elem_offset, store->buffer->elem_offset) && ArrayDeepEqual(load->buffer->shape, store->buffer->shape) && ArrayDeepEqual(load->buffer->strides, store->buffer->strides)) { return Evaluate(0); @@ -248,7 +248,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return false; } for (size_t i = 0; i < lhs.size(); i++) { - if (!tir::ExprDeepEqual()(lhs[i], rhs[i])) { + if (!tirx::ExprDeepEqual()(lhs[i], rhs[i])) { return false; } } @@ -286,7 +286,7 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } // namespace arith -namespace tir { +namespace tirx { PrimFunc Simplify(PrimFunc func, arith::Analyzer* analyzer) { return arith::StmtSimplifier::Apply(std::move(func), analyzer); @@ -297,18 +297,18 @@ namespace transform { Pass Simplify() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { arith::Analyzer analyzer; - auto cfg = ctx->GetConfig("tir.Simplify"); + auto cfg = ctx->GetConfig("tirx.Simplify"); return arith::StmtSimplifier::Apply(f, &analyzer, cfg); }; - return CreatePrimFuncPass(pass_func, 0, "tir.Simplify", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.Simplify", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.Simplify", Simplify); + refl::GlobalDef().def("tirx.transform.Simplify", Simplify); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/simplify.h b/src/tirx/transform/simplify.h similarity index 89% rename from src/tir/transform/simplify.h rename to src/tirx/transform/simplify.h index 03a493d51bc0..c59797fcff95 100644 --- a/src/tir/transform/simplify.h +++ b/src/tirx/transform/simplify.h @@ -25,17 +25,17 @@ #define TVM_TIR_TRANSFORM_SIMPLIFY_H_ #include -#include +#include namespace tvm { -namespace tir { +namespace tirx { /* \brief Simplifies the prim func * - * Applies the same behavior as the tir.transform.Simplify pass. + * Applies the same behavior as the tirx.transform.Simplify pass. */ PrimFunc Simplify(PrimFunc stmt, arith::Analyzer* analyzer); -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_TRANSFORM_SIMPLIFY_H_ diff --git a/src/tir/transform/skip_assert.cc b/src/tirx/transform/skip_assert.cc similarity index 84% rename from src/tir/transform/skip_assert.cc rename to src/tirx/transform/skip_assert.cc index 8e997bc9eeb4..d1f9dff9d7c7 100644 --- a/src/tir/transform/skip_assert.cc +++ b/src/tirx/transform/skip_assert.cc @@ -19,12 +19,12 @@ #include #include -#include -#include -#include +#include +#include +#include namespace tvm { -namespace tir { +namespace tirx { class AssertSkipper : public StmtMutator { public: @@ -44,15 +44,15 @@ Pass SkipAssert() { n->body = AssertSkipper()(std::move(n->body)); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.SkipAssert", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.SkipAssert", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.SkipAssert", SkipAssert); + refl::GlobalDef().def("tirx.transform.SkipAssert", SkipAssert); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/split_host_device.cc b/src/tirx/transform/split_host_device.cc similarity index 91% rename from src/tir/transform/split_host_device.cc rename to src/tirx/transform/split_host_device.cc index 59ca8610700e..f41ca8eed8b0 100644 --- a/src/tir/transform/split_host_device.cc +++ b/src/tirx/transform/split_host_device.cc @@ -26,17 +26,17 @@ #include #include #include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include #include "../analysis/var_use_def_analysis.h" namespace tvm { -namespace tir { +namespace tirx { class HostDeviceSplitter : public StmtMutator { public: @@ -93,8 +93,8 @@ class HostDeviceSplitter : public StmtMutator { } PrimFunc device_func(params, body, kernel_ret_type); device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, - {tir::attr::kNoAlias, true}, - {tir::attr::kIsGlobalFunc, true}}); + {tirx::attr::kNoAlias, true}, + {tirx::attr::kIsGlobalFunc, true}}); GlobalVar kernel_symbol_global = var_supply_(); (*device_mod_)->Add(kernel_symbol_global, device_func); @@ -161,14 +161,14 @@ Pass SplitHostDevice() { return ConvertSSA()(mod); }; - return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice", {}); + return tvm::transform::CreateModulePass(pass_func, 0, "tirx.SplitHostDevice", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.SplitHostDevice", SplitHostDevice); + refl::GlobalDef().def("tirx.transform.SplitHostDevice", SplitHostDevice); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/storage_rewrite.cc b/src/tirx/transform/storage_rewrite.cc similarity index 98% rename from src/tir/transform/storage_rewrite.cc rename to src/tirx/transform/storage_rewrite.cc index 2967ac3d387b..41244ba9b0c2 100644 --- a/src/tir/transform/storage_rewrite.cc +++ b/src/tirx/transform/storage_rewrite.cc @@ -27,11 +27,11 @@ #include #include #include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include #include @@ -44,7 +44,7 @@ #include "ir_utils.h" namespace tvm { -namespace tir { +namespace tirx { using runtime::StorageRank; using runtime::StorageScope; @@ -361,7 +361,7 @@ class InplaceOpVerifier : public StmtExprVisitor { << "Store/Load occur to the same buffer " << buf->name_hint << " with differing number of indices"; for (size_t i = 0; i < store_->indices.size(); i++) { - if (!tir::ExprDeepEqual()(store_->indices[i], op->indices[i])) { + if (!tirx::ExprDeepEqual()(store_->indices[i], op->indices[i])) { result_ = false; return; } @@ -1085,7 +1085,7 @@ struct BufferVarInfo { kLetNode = (1 << 3), }; - // The tir::Var that represents this buffer. + // The tirx::Var that represents this buffer. Var var; // The data type of an element of the buffer. @@ -1173,7 +1173,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { * missing a type annotation, assume that it has the same underlying * type as it is later accessed, with scalar element types. */ - VectorTypeAccessChecker(const ffi::Array& params, + VectorTypeAccessChecker(const ffi::Array& params, const ffi::Map& buffer_map, bool allow_untyped_pointers = false, bool detect_scalar_read_patterns = true) @@ -1341,7 +1341,7 @@ class VectorTypeAccessChecker : public StmtExprVisitor { } // TODO(Lunderberg): Uncomment this check once it can be applied. - // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tir-buffers/10615 + // See https://discuss.tvm.apache.org/t/pre-rfc-vectorized-tirx-buffers/10615 // for discussion. // TVM_FFI_ICHECK_EQ(index_lanes * var_info.element_dtype.lanes(), value_dtype.lanes()) @@ -1615,7 +1615,7 @@ class VectorTypeRewriter : public StmtExprMutator { PrimExpr extent = op->args[3]; PrimExpr flag = op->args[4]; - PrimExpr e_dtype = tir::TypeAnnotation(info.new_element_dtype); + PrimExpr e_dtype = tirx::TypeAnnotation(info.new_element_dtype); int factor = info.factor(); extent = extent / make_const(extent.dtype(), factor); index = index / make_const(index.dtype(), factor); @@ -1721,7 +1721,7 @@ Pass StorageRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { bool enable_reuse = true; bool reuse_require_exact_matched_dtype = false; - bool merge_static_smem = ctx->GetConfig("tir.merge_static_smem", Bool(false)).value(); + bool merge_static_smem = ctx->GetConfig("tirx.merge_static_smem", Bool(false)).value(); if (merge_static_smem) { // When `merge_static_smem` is true, we will reuse and merge shared // memory in a dedicated pass `MergeSharedMemoryAllocations`. @@ -1741,27 +1741,27 @@ Pass StorageRewrite() { // Parameters may not be rewritten, but internal allocations may. return PointerValueTypeRewrite(std::move(f), true, false, false, true, true, true, false); }; - return CreatePrimFuncPass(pass_func, 0, "tir.StorageRewrite", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.StorageRewrite", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.StorageRewrite", StorageRewrite); + refl::GlobalDef().def("tirx.transform.StorageRewrite", StorageRewrite); } Pass PointerValueTypeRewrite() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { return PointerValueTypeRewrite(std::move(f)); }; - return CreatePrimFuncPass(pass_func, 0, "tir.PointerValueTypeRewrite", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.PointerValueTypeRewrite", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.PointerValueTypeRewrite", PointerValueTypeRewrite); + refl::GlobalDef().def("tirx.transform.PointerValueTypeRewrite", PointerValueTypeRewrite); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/tvm_ffi_binder.cc b/src/tirx/transform/tvm_ffi_binder.cc similarity index 98% rename from src/tir/transform/tvm_ffi_binder.cc rename to src/tirx/transform/tvm_ffi_binder.cc index c119e22e6631..40f3dc665477 100644 --- a/src/tir/transform/tvm_ffi_binder.cc +++ b/src/tirx/transform/tvm_ffi_binder.cc @@ -24,15 +24,15 @@ #include "tvm_ffi_binder.h" #include -#include -#include -#include -#include +#include +#include +#include +#include #include "ir_utils.h" namespace tvm { -namespace tir { +namespace tirx { using ffi::reflection::AccessPath; using ffi::reflection::AccessStep; @@ -445,7 +445,7 @@ PrimExpr TVMFFIABIBuilder::DecodeParamOpaqueHandle(int param_index, const Var& t static_assert(sizeof(TVMFFIObject) == 24); PrimExpr arg_value = LoadTVMFFIAnyUnionValue(v_packed_args_, param_index, params_[param_index].dtype()); - PrimExpr handle_from_tensor = Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(), + PrimExpr handle_from_tensor = Call(DataType::Handle(), tirx::builtin::handle_add_byte_offset(), {arg_value, IntImm(DataType::Int(32), object_cell_offset)}); return Select(type_index == ffi::TypeIndex::kTVMFFITensor, handle_from_tensor, arg_value); } @@ -476,7 +476,7 @@ PrimExpr TVMFFIABIBuilder::DecodeParamFloat(int param_index, const Var& type_ind type_index == ffi::TypeIndex::kTVMFFIInt || type_index == ffi::TypeIndex::kTVMFFIBool, "float"); - return tir::Select( + return tirx::Select( type_index == ffi::TypeIndex::kTVMFFIFloat, /* true_value = */ LoadTVMFFIAnyUnionValue(v_packed_args_, param_index, dtype), /* false_value = */ @@ -494,9 +494,9 @@ void TVMFFIABIBuilder::DecodeParam(int param_index) { // Extract type_index from packed_args Var type_index(param->name_hint + ".type_index", DataType::Int(32)); init_nest_.push_back( - Bind(type_index, tir::Call(DataType::Int(32), builtin::tvm_struct_get(), - {v_packed_args_, IntImm(DataType::Int(32), param_index), - IntImm(DataType::Int(32), builtin::kTVMFFIAnyTypeIndex)}))); + Bind(type_index, tirx::Call(DataType::Int(32), builtin::tvm_struct_get(), + {v_packed_args_, IntImm(DataType::Int(32), param_index), + IntImm(DataType::Int(32), builtin::kTVMFFIAnyTypeIndex)}))); // Type-check and load value via per-dtype dispatch PrimExpr arg_value; @@ -789,12 +789,12 @@ void TVMFFIABIBuilder::DecodeParamDLTensor(const Buffer& buffer, const PrimExpr& } // mark alignment of external bufs — must be after the alignment assertion // so the compiler does not emit aligned loads before the check fires. - asserts_.emplace_back(AttrStmt(vptr, tir::attr::storage_alignment, + asserts_.emplace_back(AttrStmt(vptr, tirx::attr::storage_alignment, IntImm(DataType::Int(32), buffer->data_alignment), Evaluate(0))); } else { // Even without alignment check, mark alignment for the compiler. - init_nest_.emplace_back(AttrStmt(vptr, tir::attr::storage_alignment, + init_nest_.emplace_back(AttrStmt(vptr, tirx::attr::storage_alignment, IntImm(DataType::Int(32), buffer->data_alignment), Evaluate(0))); } @@ -802,5 +802,5 @@ void TVMFFIABIBuilder::DecodeParamDLTensor(const Buffer& buffer, const PrimExpr& } } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/tvm_ffi_binder.h b/src/tirx/transform/tvm_ffi_binder.h similarity index 99% rename from src/tir/transform/tvm_ffi_binder.h rename to src/tirx/transform/tvm_ffi_binder.h index 0ae8338a9b38..5f17d970dd8f 100644 --- a/src/tir/transform/tvm_ffi_binder.h +++ b/src/tirx/transform/tvm_ffi_binder.h @@ -30,9 +30,9 @@ #include #include -#include -#include -#include +#include +#include +#include #include #include @@ -40,7 +40,7 @@ #include namespace tvm { -namespace tir { +namespace tirx { /*! * \brief Helper utility to generate match and bind of packed function arguments. @@ -426,6 +426,6 @@ class TVMFFIABIBuilder { bool check_alignment_ = false; }; -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_TRANSFORM_TVM_FFI_BINDER_H_ diff --git a/src/tir/transform/unroll_loop.cc b/src/tirx/transform/unroll_loop.cc similarity index 95% rename from src/tir/transform/unroll_loop.cc rename to src/tirx/transform/unroll_loop.cc index 87a3e2be363c..3a3a4d429f96 100644 --- a/src/tir/transform/unroll_loop.cc +++ b/src/tirx/transform/unroll_loop.cc @@ -25,10 +25,10 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include @@ -36,7 +36,7 @@ #include "ir_utils.h" namespace tvm { -namespace tir { +namespace tirx { struct UnrollLoopConfigNode : public AttrsNodeReflAdapter { int auto_max_step; @@ -62,7 +62,7 @@ struct UnrollLoopConfigNode : public AttrsNodeReflAdapter .def_ro("unroll_local_access", &UnrollLoopConfigNode::unroll_local_access, "Whether to always unroll local access", refl::DefaultValue(false)); } - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.transform.UnrollLoopConfig", UnrollLoopConfigNode, + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tirx.transform.UnrollLoopConfig", UnrollLoopConfigNode, BaseAttrsNode); }; @@ -73,7 +73,7 @@ class UnrollLoopConfig : public Attrs { TVM_FFI_STATIC_INIT_BLOCK() { UnrollLoopConfigNode::RegisterReflection(); } -TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig); +TVM_REGISTER_PASS_CONFIG_OPTION("tirx.UnrollLoop", UnrollLoopConfig); class VarLocalAccessMarker : public ExprVisitor { public: @@ -281,22 +281,22 @@ namespace transform { Pass UnrollLoop() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - auto cfg = ctx->GetConfig("tir.UnrollLoop"); + auto cfg = ctx->GetConfig("tirx.UnrollLoop"); if (!cfg.defined()) { cfg = AttrsWithDefaultValues(); } n->body = UnrollLoop(std::move(f->body), cfg.value()); return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.UnrollLoop", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.UnrollLoop", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.UnrollLoop", UnrollLoop); + refl::GlobalDef().def("tirx.transform.UnrollLoop", UnrollLoop); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/unsupported_dtype_legalize.cc b/src/tirx/transform/unsupported_dtype_legalize.cc similarity index 97% rename from src/tir/transform/unsupported_dtype_legalize.cc rename to src/tirx/transform/unsupported_dtype_legalize.cc index 0655018db087..555f2bbbcc35 100644 --- a/src/tir/transform/unsupported_dtype_legalize.cc +++ b/src/tirx/transform/unsupported_dtype_legalize.cc @@ -23,10 +23,10 @@ */ #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include #include @@ -34,7 +34,7 @@ #include "dtype_conversion.h" namespace tvm { -namespace tir { +namespace tirx { // NOTE: do not touch buffer on function boundary // remap internal fp8/bf16 buffer to f32 if they meet the following condition @@ -751,12 +751,12 @@ Pass BF16ComputeLegalize() { } return BF16ComputeLegalizer().Legalize(f); }; - return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.BF16ComputeLegalize", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.BF16ComputeLegalize", BF16ComputeLegalize); + refl::GlobalDef().def("tirx.transform.BF16ComputeLegalize", BF16ComputeLegalize); } Pass BF16StorageLegalize() { @@ -767,12 +767,12 @@ Pass BF16StorageLegalize() { } return BF16StorageLegalizer().Legalize(f); }; - return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.BF16StorageLegalize", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.BF16StorageLegalize", BF16StorageLegalize); + refl::GlobalDef().def("tirx.transform.BF16StorageLegalize", BF16StorageLegalize); } Pass FP8ComputeLegalize(ffi::String promote_dtype) { @@ -783,12 +783,12 @@ Pass FP8ComputeLegalize(ffi::String promote_dtype) { } return FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype))).Legalize(f); }; - return CreatePrimFuncPass(pass_func, 0, "tir.FP8ComputeLegalize", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.FP8ComputeLegalize", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.FP8ComputeLegalize", FP8ComputeLegalize); + refl::GlobalDef().def("tirx.transform.FP8ComputeLegalize", FP8ComputeLegalize); } Pass FP8StorageLegalize() { @@ -799,14 +799,14 @@ Pass FP8StorageLegalize() { } return FP8StorageLegalizer().Legalize(f); }; - return CreatePrimFuncPass(pass_func, 0, "tir.FP8StorageLegalize", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.FP8StorageLegalize", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.FP8StorageLegalize", FP8StorageLegalize); + refl::GlobalDef().def("tirx.transform.FP8StorageLegalize", FP8StorageLegalize); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/update_pointer_storage_scope.cc b/src/tirx/transform/update_pointer_storage_scope.cc similarity index 95% rename from src/tir/transform/update_pointer_storage_scope.cc rename to src/tirx/transform/update_pointer_storage_scope.cc index f0c14413f554..0dfd7e2542af 100644 --- a/src/tir/transform/update_pointer_storage_scope.cc +++ b/src/tirx/transform/update_pointer_storage_scope.cc @@ -23,10 +23,10 @@ */ #include "update_pointer_storage_scope.h" -#include -#include -#include -#include +#include +#include +#include +#include #include #include @@ -35,7 +35,7 @@ #include "ir_utils.h" namespace tvm { -namespace tir { +namespace tirx { Var WithStorageScope(const VarNode* buffer_var, ffi::String storage_scope) { auto* ptr_type = buffer_var->type_annotation.as(); @@ -109,5 +109,5 @@ Stmt UpdatePointerStorageScope::VisitStmt_(const BufferStoreNode* op) { return UpdateBufferAccess(node); } -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/src/tir/transform/update_pointer_storage_scope.h b/src/tirx/transform/update_pointer_storage_scope.h similarity index 93% rename from src/tir/transform/update_pointer_storage_scope.h rename to src/tirx/transform/update_pointer_storage_scope.h index 1075a1ff6732..cacf667e49a2 100644 --- a/src/tir/transform/update_pointer_storage_scope.h +++ b/src/tirx/transform/update_pointer_storage_scope.h @@ -24,14 +24,14 @@ #ifndef TVM_TIR_TRANSFORM_UPDATE_POINTER_STORAGE_SCOPE_H_ #define TVM_TIR_TRANSFORM_UPDATE_POINTER_STORAGE_SCOPE_H_ -#include -#include -#include +#include +#include +#include #include namespace tvm { -namespace tir { +namespace tirx { class UpdatePointerStorageScope : public StmtExprMutator { public: @@ -54,6 +54,6 @@ class UpdatePointerStorageScope : public StmtExprMutator { std::unordered_map new_buffer_remap_; }; -} // namespace tir +} // namespace tirx } // namespace tvm #endif // TVM_TIR_TRANSFORM_UPDATE_POINTER_STORAGE_SCOPE_H_ diff --git a/src/tir/transform/vectorize_loop.cc b/src/tirx/transform/vectorize_loop.cc similarity index 98% rename from src/tir/transform/vectorize_loop.cc rename to src/tirx/transform/vectorize_loop.cc index 1862ceb1d480..735029cf1294 100644 --- a/src/tir/transform/vectorize_loop.cc +++ b/src/tirx/transform/vectorize_loop.cc @@ -25,24 +25,24 @@ #include #include #include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include #include #include #include "../../src/arith/scalable_expression.h" -#include "../../tir/analysis/check_contains.h" +#include "../../tirx/analysis/check_contains.h" #include "tvm/runtime/data_type.h" -#include "tvm/tir/buffer.h" +#include "tvm/tirx/buffer.h" namespace tvm { -namespace tir { +namespace tirx { inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) { if (is_scalable) { @@ -78,7 +78,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { bool EnableBufferLevelPredication(Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); ffi::Optional enable_buffer_predication = - pass_ctx->GetConfig("tir.enable_buffer_level_predication"); + pass_ctx->GetConfig("tirx.enable_buffer_level_predication"); if (enable_buffer_predication.defined()) { return enable_buffer_predication.value(); } @@ -692,7 +692,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorvectors[0], {{var_, tvm::IntImm(var_->dtype, 0)}}); + return tirx::Substitute(op->vectors[0], {{var_, tvm::IntImm(var_->dtype, 0)}}); } else { PrimExpr prev_ramp = ramp_; PrimExpr prev_var_lanes = var_lanes_; @@ -1006,15 +1006,15 @@ Pass VectorizeLoop(bool enable_vectorize) { } return f; }; - return CreatePrimFuncPass(pass_func, 0, "tir.VectorizeLoop", {}); + return CreatePrimFuncPass(pass_func, 0, "tirx.VectorizeLoop", {}); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tir.transform.VectorizeLoop", VectorizeLoop); + refl::GlobalDef().def("tirx.transform.VectorizeLoop", VectorizeLoop); } } // namespace transform -} // namespace tir +} // namespace tirx } // namespace tvm diff --git a/tests/cpp/arith_integer_set_test.cc b/tests/cpp/arith_integer_set_test.cc index 4454598c2049..03a7b5f893f3 100644 --- a/tests/cpp/arith_integer_set_test.cc +++ b/tests/cpp/arith_integer_set_test.cc @@ -26,8 +26,8 @@ #include "../src/arith/presburger_set.h" TEST(PresburgerSet, eval) { - auto x = tvm::tir::Var("x"); - auto y = tvm::tir::Var("y"); + auto x = tvm::tirx::Var("x"); + auto y = tvm::tirx::Var("y"); auto sub_constraint0 = (x + y < 20) && (x - y <= 0); auto sub_constraint1 = x >= 0 && x < 20 && y >= 0 && y < 20; auto constraint = sub_constraint0 && sub_constraint1; @@ -35,7 +35,7 @@ TEST(PresburgerSet, eval) { auto target = x + 2 * y; auto result = EvalSet(target, set); - ASSERT_TRUE(tvm::tir::is_zero(result.min())); - ASSERT_TRUE(tvm::tir::is_const_int(result.max(), 38)); + ASSERT_TRUE(tvm::tirx::is_zero(result.min())); + ASSERT_TRUE(tvm::tirx::is_const_int(result.max(), 38)); } #endif // TVM_MLIR_VERSION diff --git a/tests/cpp/arith_simplify_test.cc b/tests/cpp/arith_simplify_test.cc index 9f6108617696..703e83c5312a 100644 --- a/tests/cpp/arith_simplify_test.cc +++ b/tests/cpp/arith_simplify_test.cc @@ -28,11 +28,11 @@ TEST(Simplify, MinMax) { auto x = tvm::te::var("x"); auto e1 = (tvm::max(x, 1) - tvm::max(x, 1)); auto e1s = ana.canonical_simplify(e1); - TVM_FFI_ICHECK(tvm::tir::is_zero(e1s)); + TVM_FFI_ICHECK(tvm::tirx::is_zero(e1s)); auto e2 = (x * tvm::min(x, 1)) - (x * tvm::min(x, 1)); auto e2s = ana.canonical_simplify(e2); - TVM_FFI_ICHECK(tvm::tir::is_zero(e2s)); + TVM_FFI_ICHECK(tvm::tirx::is_zero(e2s)); } TEST(Simplify, Mul) { @@ -40,7 +40,7 @@ TEST(Simplify, Mul) { auto x = tvm::te::var("x"); auto e = (x * x) - (x * x); auto es = ana.canonical_simplify(e); - TVM_FFI_ICHECK(tvm::tir::is_zero(es)); + TVM_FFI_ICHECK(tvm::tirx::is_zero(es)); } TEST(Simplify, Mod) { @@ -50,29 +50,29 @@ TEST(Simplify, Mod) { // Mod::make is used instead of % to avoid constant folding during // calling operator%(x,y). Mod::make doesn't try constant folding, // and therefore, the constant folding will be attempted in CanonicalSimplify - auto mod = ana.canonical_simplify(tvm::tir::Mod(x, y)); + auto mod = ana.canonical_simplify(tvm::tirx::Mod(x, y)); auto es = ana.canonical_simplify(mod - x); - TVM_FFI_ICHECK(tvm::tir::is_zero(es)); + TVM_FFI_ICHECK(tvm::tirx::is_zero(es)); } TEST(ConstantFold, Broadcast) { tvm::ffi::StructuralEqual checker; - auto i32x4 = tvm::tir::Broadcast(tvm::IntImm(tvm::DataType::Int(32), 10), 4); + auto i32x4 = tvm::tirx::Broadcast(tvm::IntImm(tvm::DataType::Int(32), 10), 4); auto i64x4 = tvm::cast(i32x4->dtype.with_bits(64), i32x4); - auto i64x4_expected = tvm::tir::Broadcast(tvm::IntImm(tvm::DataType::Int(64), 10), 4); + auto i64x4_expected = tvm::tirx::Broadcast(tvm::IntImm(tvm::DataType::Int(64), 10), 4); ASSERT_TRUE(checker(i64x4, i64x4_expected)); } TEST(ConstantFold, Ramp) { tvm::ffi::StructuralEqual checker; - auto i32x4 = tvm::tir::Ramp(tvm::IntImm(tvm::DataType::Int(32), 10), - tvm::IntImm(tvm::DataType::Int(32), 1), 4); + auto i32x4 = tvm::tirx::Ramp(tvm::IntImm(tvm::DataType::Int(32), 10), + tvm::IntImm(tvm::DataType::Int(32), 1), 4); auto i64x4 = tvm::cast(i32x4->dtype.with_bits(64), i32x4); - auto i64x4_expected = tvm::tir::Ramp(tvm::IntImm(tvm::DataType::Int(64), 10), - tvm::IntImm(tvm::DataType::Int(64), 1), 4); + auto i64x4_expected = tvm::tirx::Ramp(tvm::IntImm(tvm::DataType::Int(64), 10), + tvm::IntImm(tvm::DataType::Int(64), 1), 4); ASSERT_TRUE(checker(i64x4, i64x4_expected)); auto f32x4 = tvm::cast(tvm::DataType::Float(32, 4), i32x4); - auto f32x4_expected = tvm::tir::Cast(tvm::DataType::Float(32, 4), i32x4); + auto f32x4_expected = tvm::tirx::Cast(tvm::DataType::Float(32, 4), i32x4); ASSERT_TRUE(checker(f32x4, f32x4_expected)); } diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc index 1d3aa62f6629..16fe4d8dfd85 100644 --- a/tests/cpp/expr_test.cc +++ b/tests/cpp/expr_test.cc @@ -24,7 +24,7 @@ TEST(Expr, Basic) { using namespace tvm; - using namespace tvm::tir; + using namespace tvm::tirx; Var x("x"); auto z = max(x + 1 + 2, 100); ObjectRef tmp = z; @@ -37,7 +37,7 @@ TEST(Expr, Basic) { TEST(Expr, VarTypeAnnotation) { using namespace tvm; - using namespace tvm::tir; + using namespace tvm::tirx; Var x("x", DataType::Float(32)); Var y("y", PrimType(DataType::Float(32))); tvm::ffi::StructuralEqual checker; @@ -47,9 +47,9 @@ TEST(Expr, VarTypeAnnotation) { TEST(ExprNodeRef, Basic) { using namespace tvm; - using namespace tvm::tir; + using namespace tvm::tirx; Var x("x"); PrimExpr z = max(x + 1 + 2, 100); - const tir::MaxNode* op = z.as(); + const tirx::MaxNode* op = z.as(); TVM_FFI_ICHECK(ffi::GetRef(op).same_as(z)); } diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index e7a27086d35b..cc0efbea2082 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -21,17 +21,17 @@ #include #include #include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include TEST(IRF, Basic) { using namespace tvm; - using namespace tvm::tir; + using namespace tvm::tirx; Var x("x"); auto z = x + 1; @@ -44,12 +44,12 @@ TEST(IRF, Basic) { TEST(IRF, CountVar) { using namespace tvm; - using namespace tvm::tir; + using namespace tvm::tirx; int n_var = 0; Var x("x"), y; auto z = x + 1 + y + y; - tir::PostOrderVisit(z, [&n_var](const ObjectRef& n) { + tirx::PostOrderVisit(z, [&n_var](const ObjectRef& n) { if (n.as()) ++n_var; }); TVM_FFI_ICHECK_EQ(n_var, 2); @@ -57,7 +57,7 @@ TEST(IRF, CountVar) { TEST(IRF, PreOrderVisit) { using namespace tvm; - using namespace tvm::tir; + using namespace tvm::tirx; Stmt init = IfThenElse(const_true(), Evaluate(Integer(0)), Evaluate(Integer(0))); Stmt body = Evaluate(Integer(1)); SBlock block(/*iter_vars=*/{}, /*reads=*/{}, @@ -91,11 +91,11 @@ TEST(IRF, PreOrderVisit) { TEST(IRF, ExprTransform) { using namespace tvm; - using namespace tvm::tir; + using namespace tvm::tirx; Var x("x"); auto z = x + 1; - class MyExprFunctor : public tir::ExprFunctor { + class MyExprFunctor : public tirx::ExprFunctor { public: int VisitExpr_(const VarNode* op, int b) final { return b; } int VisitExpr_(const IntImmNode* op, int b) final { return op->value; } @@ -115,12 +115,12 @@ TEST(IRF, ExprTransform) { TEST(IRF, ExprVisit) { using namespace tvm; - using namespace tvm::tir; + using namespace tvm::tirx; Var x("x"); auto z = x + 1; - class MyVisitor : public tir::ExprFunctor, - public tir::StmtFunctor { + class MyVisitor : public tirx::ExprFunctor, + public tirx::StmtFunctor { public: int count = 0; // implementation @@ -139,7 +139,7 @@ TEST(IRF, ExprVisit) { TEST(IRF, StmtVisitor) { using namespace tvm; - using namespace tvm::tir; + using namespace tvm::tirx; Var x("x"); class MyVisitor : public StmtExprVisitor { public: @@ -189,10 +189,10 @@ TEST(IRF, StmtVisitor) { TEST(IRF, StmtMutator) { using namespace tvm; - using namespace tvm::tir; + using namespace tvm::tirx; Var x("x"); - class MyVisitor : public tir::StmtMutator, public tir::ExprMutator { + class MyVisitor : public tirx::StmtMutator, public tirx::ExprMutator { public: using StmtMutator::operator(); using ExprMutator::operator(); @@ -328,7 +328,7 @@ TEST(IRF, StmtMutator) { TEST(IRF, Substitute) { using namespace tvm; - using namespace tvm::tir; + using namespace tvm::tirx; DataType dtype = DataType::Float(32); Var x("x", PointerType(PrimType(dtype), "")); Var n("n", DataType::Int(32)); diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc index 02b662875c62..a23be7ebdd91 100644 --- a/tests/cpp/nested_msg_test.cc +++ b/tests/cpp/nested_msg_test.cc @@ -23,7 +23,7 @@ #include #include #include -#include +#include #include #include diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc index a763b0b7002b..ab668e9a4204 100644 --- a/tests/cpp/pattern_match_test.cc +++ b/tests/cpp/pattern_match_test.cc @@ -20,13 +20,13 @@ #include "../src/arith/pattern_match.h" #include -#include +#include TEST(Pattern, Basic) { using namespace tvm; - using namespace tvm::tir; + using namespace tvm::tirx; using namespace tvm::arith; - tvm::tir::Var x("x"), y("y"), z("z"); + tvm::tirx::Var x("x"), y("y"), z("z"); PrimExpr scalable_lanes = Mul(Call(DataType::Int(32), builtin::vscale(), {}), 4); arith::PVar px, py, pz; arith::PVar pt; @@ -44,12 +44,12 @@ TEST(Pattern, Basic) { TVM_FFI_ICHECK((px + (py + px)).Match(r)); auto rr = (px + py).Eval(); - TVM_FFI_ICHECK(tir::ExprDeepEqual()(rr, 1 + y)); - TVM_FFI_ICHECK(tir::ExprDeepEqual()(px.Eval() + py.Eval(), 1 + y)); + TVM_FFI_ICHECK(tirx::ExprDeepEqual()(rr, 1 + y)); + TVM_FFI_ICHECK(tirx::ExprDeepEqual()(px.Eval() + py.Eval(), 1 + y)); } { TVM_FFI_ICHECK((px + max(py, px)).Match((x + 1) + max(y, (x + 1)))); - TVM_FFI_ICHECK(tir::ExprDeepEqual()(px.Eval(), x + 1)); + TVM_FFI_ICHECK(tirx::ExprDeepEqual()(px.Eval(), x + 1)); } TVM_FFI_ICHECK(!(px + min(py, px)).Match((x + 1) + max(y, (x + 1)))); @@ -68,8 +68,8 @@ TEST(Pattern, Basic) { TVM_FFI_ICHECK((px >= py && px < pz).Match(x >= y && x < z)); TVM_FFI_ICHECK((!(px > py || px != py)).Match(!(x > y || x != y))); { - TVM_FFI_ICHECK(select(px >= pz, py, py + pz).Match(tir::Select((x + 1) >= 1, y, y + 1))); - TVM_FFI_ICHECK(tir::ExprDeepEqual()(px.Eval(), x + 1)); + TVM_FFI_ICHECK(select(px >= pz, py, py + pz).Match(tirx::Select((x + 1) >= 1, y, y + 1))); + TVM_FFI_ICHECK(tirx::ExprDeepEqual()(px.Eval(), x + 1)); } // bit intrinsics { @@ -84,14 +84,14 @@ TEST(Pattern, Basic) { TVM_FFI_ICHECK((px - (~(py | (px * pz)))).Match(x - (~(2 | (x * 2))))); // select { - TVM_FFI_ICHECK(select(px > pz, py, py + pz).Match(tir::Select(x > 1, y, y + 1))); + TVM_FFI_ICHECK(select(px > pz, py, py + pz).Match(tirx::Select(x > 1, y, y + 1))); TVM_FFI_ICHECK(is_const_int(pz.Eval(), 1)); } - TVM_FFI_ICHECK(!select(px > pz, py, py + pz).Match(tir::Select(x > 2, y, y + 1))); - TVM_FFI_ICHECK(!select(px > pz, py, py).Match(tir::Select(x > 2, y, y + 1))); + TVM_FFI_ICHECK(!select(px > pz, py, py + pz).Match(tirx::Select(x > 2, y, y + 1))); + TVM_FFI_ICHECK(!select(px > pz, py, py).Match(tirx::Select(x > 2, y, y + 1))); { - TVM_FFI_ICHECK(select(px, py, pz).Match(tir::Select(x > 2, y, y + 1))); - TVM_FFI_ICHECK(tir::ExprDeepEqual()(pz.Eval(), y + 1)); + TVM_FFI_ICHECK(select(px, py, pz).Match(tirx::Select(x > 2, y, y + 1))); + TVM_FFI_ICHECK(tirx::ExprDeepEqual()(pz.Eval(), y + 1)); } // if_then_else { @@ -101,38 +101,39 @@ TEST(Pattern, Basic) { // cast pattern { TVM_FFI_ICHECK( - !cast(PConst(DataType::Int(32)), px).Match(tir::Cast(DataType::Float(64), x))); - TVM_FFI_ICHECK(cast(pt, px).Match(tir::Cast(DataType::Float(64), x))); + !cast(PConst(DataType::Int(32)), px).Match(tirx::Cast(DataType::Float(64), x))); + TVM_FFI_ICHECK(cast(pt, px).Match(tirx::Cast(DataType::Float(64), x))); TVM_FFI_ICHECK(pt.Eval() == DataType::Float(64)); auto zz = cast(pt, px).Eval(); - TVM_FFI_ICHECK((cast(pt, px) - cast(pt, py)) - .Match(tir::Cast(DataType::Float(64), x) - tir::Cast(DataType::Int(64), x))); - auto expr = tir::Cast(DataType::Int(32), tir::Cast(DataType::Float(64), x)); + TVM_FFI_ICHECK( + (cast(pt, px) - cast(pt, py)) + .Match(tirx::Cast(DataType::Float(64), x) - tirx::Cast(DataType::Int(64), x))); + auto expr = tirx::Cast(DataType::Int(32), tirx::Cast(DataType::Float(64), x)); TVM_FFI_ICHECK(!(cast(pt, cast(pt, px))).Match(expr)); } // ramp pattern { - TVM_FFI_ICHECK(ramp(px, PConst(1), planes).Match(tir::Ramp(x, 1, 10))); + TVM_FFI_ICHECK(ramp(px, PConst(1), planes).Match(tirx::Ramp(x, 1, 10))); TVM_FFI_ICHECK(planes.Eval().as()->value == 10); - TVM_FFI_ICHECK(ramp(px, PConst(1), planes).Match(tir::Ramp(x, 1, scalable_lanes))); + TVM_FFI_ICHECK(ramp(px, PConst(1), planes).Match(tirx::Ramp(x, 1, scalable_lanes))); TVM_FFI_ICHECK((vscale * PConst(4)).Match(planes.Eval())); - TVM_FFI_ICHECK(!ramp(px, PConst(1), planes).Match(tir::Ramp(x, 2, 10))); + TVM_FFI_ICHECK(!ramp(px, PConst(1), planes).Match(tirx::Ramp(x, 2, 10))); } // broadcast pattern { - TVM_FFI_ICHECK(broadcast(px, planes).Match(tir::Broadcast(x, 10))); + TVM_FFI_ICHECK(broadcast(px, planes).Match(tirx::Broadcast(x, 10))); TVM_FFI_ICHECK(planes.Eval().as()->value == 10); - TVM_FFI_ICHECK(broadcast(px * py, planes).Match(tir::Broadcast(x * 10, 10))); - TVM_FFI_ICHECK(broadcast(px, planes).Match(tir::Broadcast(x, scalable_lanes))); + TVM_FFI_ICHECK(broadcast(px * py, planes).Match(tirx::Broadcast(x * 10, 10))); + TVM_FFI_ICHECK(broadcast(px, planes).Match(tirx::Broadcast(x, scalable_lanes))); TVM_FFI_ICHECK((vscale * PConst(4)).Match(planes.Eval())); } } TEST(Pattern, IntImm) { using namespace tvm; - tir::Var tx, ty; + tirx::Var tx, ty; arith::PVar c; - arith::PVar v; + arith::PVar v; { // We can match integer and Var, both of which are // special case container of Expr @@ -150,20 +151,20 @@ TEST(Pattern, MatchWithType) { using namespace tvm; // match expr with specified dtype arith::PVarWithDataType> pat(DataType::Float(32)); - tir::Var x("x", DataType::Float(32)); - tir::Var y("y", DataType::Float(32)); - tir::Var x_int("x", DataType::Int(32)); - tir::Var y_int("y", DataType::Int(32)); + tirx::Var x("x", DataType::Float(32)); + tirx::Var y("y", DataType::Float(32)); + tirx::Var x_int("x", DataType::Int(32)); + tirx::Var y_int("y", DataType::Int(32)); TVM_FFI_ICHECK(pat.Match(x + y * 2.0f)); TVM_FFI_ICHECK(!pat.Match(x_int + y_int * 2)); // match vectorized expr with specified element dtype arith::PVecDataType vec_ty(DataType::Float(32)); arith::PVarWithDataType vpat(vec_ty); - tir::Var vx = tir::Var("x", DataType::Float(32, 8)); - tir::Var vy("y", DataType::Float(32, 8)); - tir::Var vx_int("x", DataType::Int(32, 8)); - tir::Var vy_int("y", DataType::Int(32, 8)); - TVM_FFI_ICHECK(vpat.Match(vx + vy * tir::Broadcast(2.0f, 8))); - TVM_FFI_ICHECK(!vpat.Match(vx_int + vy_int * tir::Broadcast(2, 8))); + tirx::Var vx = tirx::Var("x", DataType::Float(32, 8)); + tirx::Var vy("y", DataType::Float(32, 8)); + tirx::Var vx_int("x", DataType::Int(32, 8)); + tirx::Var vy_int("y", DataType::Int(32, 8)); + TVM_FFI_ICHECK(vpat.Match(vx + vy * tirx::Broadcast(2.0f, 8))); + TVM_FFI_ICHECK(!vpat.Match(vx_int + vy_int * tirx::Broadcast(2, 8))); } diff --git a/tests/cpp/tir_analysis_side_effect.cc b/tests/cpp/tir_analysis_side_effect.cc index 7ac19c28f198..bcc7128647b4 100644 --- a/tests/cpp/tir_analysis_side_effect.cc +++ b/tests/cpp/tir_analysis_side_effect.cc @@ -20,16 +20,16 @@ #include #include #include -#include -#include +#include +#include TEST(SimplePasses, SideEffect) { using namespace tvm; - auto buf = tir::decl_buffer({16}, DataType::Float(32)); - auto i = tir::Var("i", DataType::Int(32)); - TVM_FFI_ICHECK(tir::SideEffect(tir::BufferLoad(buf, {i})) == tir::CallEffectKind::kReadState); - TVM_FFI_ICHECK(tir::SideEffect(exp(tir::Cast(DataType::Float(32), i + 1))) == - tir::CallEffectKind::kPure); - TVM_FFI_ICHECK(tir::SideEffect(tir::Call(DataType::Handle(), tir::builtin::tvm_storage_sync(), - {})) == tir::CallEffectKind::kUpdateState); + auto buf = tirx::decl_buffer({16}, DataType::Float(32)); + auto i = tirx::Var("i", DataType::Int(32)); + TVM_FFI_ICHECK(tirx::SideEffect(tirx::BufferLoad(buf, {i})) == tirx::CallEffectKind::kReadState); + TVM_FFI_ICHECK(tirx::SideEffect(exp(tirx::Cast(DataType::Float(32), i + 1))) == + tirx::CallEffectKind::kPure); + TVM_FFI_ICHECK(tirx::SideEffect(tirx::Call(DataType::Handle(), tirx::builtin::tvm_storage_sync(), + {})) == tirx::CallEffectKind::kUpdateState); } diff --git a/tests/cpp/tir_scalable_datatype.cc b/tests/cpp/tir_scalable_datatype.cc index 19e94e267723..9be9e8552e83 100644 --- a/tests/cpp/tir_scalable_datatype.cc +++ b/tests/cpp/tir_scalable_datatype.cc @@ -20,8 +20,8 @@ #include #include #include -#include -#include +#include +#include #ifdef TVM_LLVM_VERSION #include @@ -186,13 +186,13 @@ TEST(ScalableDataType, TestScalableUInt) { // ----------- TEST(ScalableDataType, TestScalableIntrinCall) { tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true); - tvm::tir::Call call = - tvm::tir::Call(scalable_type, tvm::tir::builtin::call_llvm_intrin(), + tvm::tirx::Call call = + tvm::tirx::Call(scalable_type, tvm::tirx::builtin::call_llvm_intrin(), #if TVM_LLVM_VERSION >= 200 - {tvm::IntImm(tvm::DataType::Int(32), ::llvm::Intrinsic::stepvector)}); + {tvm::IntImm(tvm::DataType::Int(32), ::llvm::Intrinsic::stepvector)}); #else - {tvm::IntImm(tvm::DataType::Int(32), - ::llvm::Intrinsic::experimental_stepvector)}); + {tvm::IntImm(tvm::DataType::Int(32), + ::llvm::Intrinsic::experimental_stepvector)}); #endif ASSERT_EQ(call->dtype, scalable_type); ASSERT_EQ(call->Script(), diff --git a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py index 8e2406b85155..f2dd4c2b053c 100644 --- a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py +++ b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py @@ -23,7 +23,7 @@ import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T def test_get_global(): @@ -133,15 +133,15 @@ def check(arr): def test_dict_function_value_type(): - from tvm import tir # pylint: disable=import-outside-toplevel + from tvm import tirx # pylint: disable=import-outside-toplevel te_func_dict = {"add": lambda a, b: a + b} converted_dict = tvm.runtime.convert(te_func_dict) f = converted_dict["add"] - a = tir.Var("a", "float32") - b = tir.Var("b", "float32") - tvm.ir.assert_structural_equal(f(a, b), tir.Add(a, b)) + a = tirx.Var("a", "float32") + b = tirx.Var("b", "float32") + tvm.ir.assert_structural_equal(f(a, b), tirx.Add(a, b)) if __name__ == "__main__": diff --git a/tests/python/arith/test_arith_canonical_simplify.py b/tests/python/arith/test_arith_canonical_simplify.py index 4ff8af6b8b72..79d3d0dfc41d 100644 --- a/tests/python/arith/test_arith_canonical_simplify.py +++ b/tests/python/arith/test_arith_canonical_simplify.py @@ -17,8 +17,8 @@ # ruff: noqa: E731, F841 import tvm import tvm.testing -from tvm import te, tir -from tvm.script import tir as T +from tvm import te, tirx +from tvm.script import tirx as T class CanonicalChecker: @@ -26,7 +26,7 @@ def __init__(self): self.analyzer = tvm.arith.Analyzer() def _convert(self, expr): - # TODO(Lunderberg): Make utility functions `tir.convert` and + # TODO(Lunderberg): Make utility functions `tirx.convert` and # `relax.convert` that convert to their respective IR types. # Implementation should be in C++, and should only consist of # conversions that are applied automatically through FFI. @@ -45,20 +45,20 @@ def verify(self, data, expected): def test_mul_sum_simplify(): ck = CanonicalChecker() - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") ck.verify(2 + (3 * x + z + y + 1) * 4 + x, x * 13 + z * 4 + y * 4 + 6) ck.verify(x * 3 - 4 * x + 1, 1 - x) ck.verify(y + x * 3 - 5 * x + 1 + y, y * 2 + 1 - x * 2) - tdiv = tvm.tir.truncdiv - tmod = tvm.tir.truncmod + tdiv = tvm.tirx.truncdiv + tmod = tvm.tirx.truncmod # trucdiv ck.verify(tdiv(x + y + x + y * 3, 2), y * 2 + x) ck.verify(tmod(x + y + x + y * 3, 2), 0) # floordiv - fld = tvm.tir.floordiv - flm = tvm.tir.floormod + fld = tvm.tirx.floordiv + flm = tvm.tirx.floormod ck.verify(flm(x + x + y * 3, 2), flm(y * 3, 2)) ck.verify(fld(x + y + x + y * 3, 2), y * 2 + x) ck.verify(flm(x + y + x + y * 3, 2), 0) @@ -67,11 +67,11 @@ def test_mul_sum_simplify(): def test_split_index_simplify(): ck = CanonicalChecker() - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") # trucdiv - tdiv = tvm.tir.truncdiv - tmod = tvm.tir.truncmod + tdiv = tvm.tirx.truncdiv + tmod = tvm.tirx.truncmod # split div const ck.verify(tdiv(x, 3) * 3 + tmod(x, 3), x) @@ -97,8 +97,8 @@ def test_split_index_simplify(): ck.verify(tdiv(x * 4 + y, 2) * 2 + tmod(x * 4 + y, 2), x * 4 + y) # floordiv - fld = tvm.tir.floordiv - flm = tvm.tir.floormod + fld = tvm.tirx.floordiv + flm = tvm.tirx.floormod ck.verify(fld(x * 5, 2), fld(x * 5, 2)) ck.verify(fld(x, 3) * 3 + flm(x, 3), x) ck.verify(fld(x, 6) * 6 + flm(fld(x, 3), 2) * 3 + flm(x, 3), x) @@ -115,8 +115,8 @@ def test_split_index_simplify(): def test_div_simplify(): ck = CanonicalChecker() - x = tvm.tir.Var("x", "int32") - tdiv = tvm.tir.truncdiv + x = tvm.tirx.Var("x", "int32") + tdiv = tvm.tirx.truncdiv # truc div ck.verify(tdiv(16 + 48 * x, 16), x * 3 + 1) @@ -130,7 +130,7 @@ def test_div_simplify(): ck.verify(tdiv(17 + 47 * x, 16), tdiv(x * 47 + 17, 16)) # floordiv - fld = tvm.tir.floordiv + fld = tvm.tirx.floordiv ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 10000), True) ck.verify(fld(16 + 48 * x, 16), x * 3 + 1) ck.verify(fld(17 + 48 * x, 16), x * 3 + 1) @@ -139,9 +139,9 @@ def test_div_simplify(): def test_fp16_const_fold(): ck = CanonicalChecker() - zero = tvm.tir.const(0, "float16") - one = tvm.tir.const(1, "float16") - half = tvm.tir.const(0.5, "float16") + zero = tvm.tirx.const(0, "float16") + one = tvm.tirx.const(1, "float16") + half = tvm.tirx.const(0.5, "float16") ck.verify(zero + half, half) ck.verify(half - zero, half) @@ -155,8 +155,8 @@ def test_fp16_const_fold(): def test_floormod_simplify(): ck = CanonicalChecker() - flm = tvm.tir.floormod - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + flm = tvm.tirx.floormod + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") ck.verify(flm(flm((x * 4) + y - 466036, 24528) - 24512, 16), flm((x * 4) + y + 12, 16)) ck.verify(flm(flm((x * 4), 16), 8), flm(x, 2) * 4) @@ -165,36 +165,36 @@ def test_floormod_simplify(): def test_canonical_mixed(): ck = CanonicalChecker() - x = tvm.tir.Var("x", "int32") - z = tvm.tir.const(3, "int32") - tdiv = tvm.tir.truncdiv - tmod = tvm.tir.truncmod + x = tvm.tirx.Var("x", "int32") + z = tvm.tirx.const(3, "int32") + tdiv = tvm.tirx.truncdiv + tmod = tvm.tirx.truncmod ck.verify(tdiv(x, (z * z)) - tdiv(x, (z * z)), 0) ck.verify(tdiv(x, (z + z)) - tdiv(x, (z + z)), 0) ck.verify(x - 2 < 3, x < 5) - ck.verify(tvm.tir.max(x, 1) - tvm.tir.max(x, 1), 0) - ck.verify(tvm.tir.min(x, 1) - tvm.tir.min(x, 1), 0) + ck.verify(tvm.tirx.max(x, 1) - tvm.tirx.max(x, 1), 0) + ck.verify(tvm.tirx.min(x, 1) - tvm.tirx.min(x, 1), 0) ck.verify(x * x - x * x, 0) ck.verify(tmod(tdiv(tmod(x, 20), 2) * 2, 4), tdiv(tmod(x, 4), 2) * 2) - fld = tvm.tir.floordiv + fld = tvm.tirx.floordiv ck.verify(fld(x, (z * z)) - fld(x, (z * z)), 0) ck.verify(fld(x, (z + z)) - fld(x, (z + z)), 0) def test_reduce_combiner_simplify(): ck = CanonicalChecker() - dummy = tvm.tir.Var("dummy", "int32") + dummy = tvm.tirx.Var("dummy", "int32") comm_reducer = te.comm_reducer - prod = comm_reducer(lambda x, y: x * y, lambda t0: tvm.tir.const(1, t0)) + prod = comm_reducer(lambda x, y: x * y, lambda t0: tvm.tirx.const(1, t0)) sum_or_prod = comm_reducer( - lambda x, y: tvm.tir.Select(dummy < 0, x + y, x * y), - lambda t0: tvm.tir.Select(dummy < 0, tvm.tir.const(0, t0), tvm.tir.const(1, t0)), + lambda x, y: tvm.tirx.Select(dummy < 0, x + y, x * y), + lambda t0: tvm.tirx.Select(dummy < 0, tvm.tirx.const(0, t0), tvm.tirx.const(1, t0)), ) sum_and_prod = comm_reducer( lambda x, y: (x[0] + y[0], x[1] * y[1]), - lambda t0, t1: (tvm.tir.const(0, t0), tvm.tir.const(5, t1) - tvm.tir.const(4, t1)), + lambda t0, t1: (tvm.tirx.const(0, t0), tvm.tirx.const(5, t1) - tvm.tirx.const(4, t1)), ) some_reducer1 = comm_reducer( lambda x, y: ( @@ -205,11 +205,11 @@ def test_reduce_combiner_simplify(): 4.0, ), lambda t0, t1, t2, t3, t4: ( - tvm.tir.const(0, t0), - tvm.tir.const(1, t1), - tvm.tir.const(2, t2), - tvm.tir.const(3, t3), - tvm.tir.const(4, t4), + tvm.tirx.const(0, t0), + tvm.tirx.const(1, t1), + tvm.tirx.const(2, t2), + tvm.tirx.const(3, t3), + tvm.tirx.const(4, t4), ), ) @@ -246,7 +246,7 @@ def test_reduce_combiner_simplify(): # Test that components with side effects are not removed dummy = tvm.ir.GlobalVar("dummy") - side_effect = lambda *xs: tvm.tir.Call("int32", dummy, xs) + side_effect = lambda *xs: tvm.tirx.Call("int32", dummy, xs) ck.verify( sum_and_prod((A[k], side_effect(A[10 - k])), k)[0], sum_and_prod((A[k], side_effect(A[10 - k])), k)[0], @@ -259,23 +259,23 @@ def test_reduce_simplify(): k = te.reduce_axis((0, 10), name="k") j = te.reduce_axis((-5, 3), name="j") A = te.placeholder((10,), name="A") - ck.verify(te.sum(tvm.tir.Select(k + j < 12, k + j, 0), [k, j]), te.sum(k + j, [k, j])) + ck.verify(te.sum(tvm.tirx.Select(k + j < 12, k + j, 0), [k, j]), te.sum(k + j, [k, j])) ck.verify(te.sum(A[3], []), A[3]) - ck.verify(te.sum(A[3], [], where=k > 12, init=1.0), tvm.tir.const(1.0, dtype="float32")) + ck.verify(te.sum(A[3], [], where=k > 12, init=1.0), tvm.tirx.const(1.0, dtype="float32")) # The rule below is not typical, removed for now - ck.verify(te.sum(tvm.tir.div(k, 10), k), te.sum(tvm.tir.const(0, "int32"), k)) + ck.verify(te.sum(tvm.tirx.div(k, 10), k), te.sum(tvm.tirx.const(0, "int32"), k)) def test_simplify_if_then_else(): ck = CanonicalChecker() - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - tdiv = tvm.tir.truncdiv - tmod = tvm.tir.truncmod + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + tdiv = tvm.tirx.truncdiv + tmod = tvm.tirx.truncmod # simplification that takes condition into account. - res = tvm.tir.if_then_else( + res = tvm.tirx.if_then_else( (x * 4 + y) >= 466036, - tvm.tir.if_then_else( + tvm.tirx.if_then_else( 24512 <= tmod(((x * 4) + y) - 466036, 24528), tmod(tmod(((x * 4) + y) - 466036, 24528) - 24512, 16), x, @@ -283,43 +283,43 @@ def test_simplify_if_then_else(): y, ) - res2 = tvm.tir.if_then_else( + res2 = tvm.tirx.if_then_else( (x * 4) >= 466036 - y, - tvm.tir.if_then_else( + tvm.tirx.if_then_else( 24512 <= tmod(((x * 4) + y) - 466036, 24528), tmod(tmod(((x * 4) + y) - 466036, 24528) - 24512, 16), x, ), y, ) - expected = tvm.tir.if_then_else( - tvm.tir.LE(466036, (x * 4 + y)), - tvm.tir.if_then_else( - tvm.tir.LE(24512, tmod(((x * 4) + y) - 4, 24528)), tmod(((x * 4) + y) - 4, 16), x + expected = tvm.tirx.if_then_else( + tvm.tirx.LE(466036, (x * 4 + y)), + tvm.tirx.if_then_else( + tvm.tirx.LE(24512, tmod(((x * 4) + y) - 4, 24528)), tmod(((x * 4) + y) - 4, 16), x ), y, ) ck.verify(res, expected) ck.verify(res2, expected) # can only simplify if condition - res = tvm.tir.Select(tvm.tir.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3)) - expected = tvm.tir.Select(tvm.tir.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3)) + res = tvm.tirx.Select(tvm.tirx.all(x >= -1, y >= 0), tmod(x + y + 100, 3), tmod(x + 100, 3)) + expected = tvm.tirx.Select(tvm.tirx.all(x >= -1, y >= 0), tmod(x + y + 1, 3), tmod(x + 100, 3)) ck.verify(res, ck.analyzer.canonical_simplify(expected)) - res = tvm.tir.Select(x >= 10, tvm.tir.if_then_else(tdiv(x, 3) > 2, x, 0), 0) - expected = tvm.tir.Select(x >= 10, x, 0) + res = tvm.tirx.Select(x >= 10, tvm.tirx.if_then_else(tdiv(x, 3) > 2, x, 0), 0) + expected = tvm.tirx.Select(x >= 10, x, 0) ck.verify(res, ck.analyzer.canonical_simplify(expected)) - res = tvm.tir.Select(x >= 10, tvm.tir.if_then_else(tdiv(x, 3) < 2, x, 0), 0) + res = tvm.tirx.Select(x >= 10, tvm.tirx.if_then_else(tdiv(x, 3) < 2, x, 0), 0) ck.verify(res, 0) def test_complex_cases(): ck = CanonicalChecker() - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - tdiv = tvm.tir.truncdiv - tmod = tvm.tir.truncmod + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + tdiv = tvm.tirx.truncdiv + tmod = tvm.tirx.truncmod res2 = ( tdiv(tdiv(tmod(x * 128 + y, 1296), 36) * 2 + 1, 2) * 36 + tdiv(tmod((x * 128) + y, 36) * 2 + 1, 2) @@ -346,66 +346,66 @@ def test_complex_cases(): def test_simplify_cast(): ck = CanonicalChecker() - tcast = tvm.tir.Cast - fld = tvm.tir.floordiv - flm = tvm.tir.floormod + tcast = tvm.tirx.Cast + fld = tvm.tirx.floordiv + flm = tvm.tirx.floormod # cast(i64, i + j + 1) - cast(i64, i) - i = tvm.tir.Var("i", "int32") - j = tvm.tir.Var("j", "int32") + i = tvm.tirx.Var("i", "int32") + j = tvm.tirx.Var("j", "int32") res = tcast("int64", i + j + 1) - tcast("int64", i) - ck.verify(res, tcast("int64", j) + tvm.tir.const(1, "int64")) + ck.verify(res, tcast("int64", j) + tvm.tirx.const(1, "int64")) # cast(i32, i + j + 1) - cast(i32, i) - i = tvm.tir.Var("i", "int64") - j = tvm.tir.Var("j", "int64") + i = tvm.tirx.Var("i", "int64") + j = tvm.tirx.Var("j", "int64") ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 10)) ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10)) res = tcast("int32", i + j + 1) - tcast("int32", i) ck.verify(res, tcast("int32", j) + 1) # cast(i32, i + j - 100) - i = tvm.tir.Var("i", "int64") - j = tvm.tir.Var("j", "int64") + i = tvm.tirx.Var("i", "int64") + j = tvm.tirx.Var("j", "int64") ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 2**31 - 1)) ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10)) res = tcast("int32", i + j - 100) ck.verify(res, res) # cast(i32, flm(axis, 7i64) * 2i64 + 1i64) + 1i32 # - cast(i32, flm(axis, 7i64) * 2i64) - axis = tvm.tir.Var("axis", "int64") + axis = tvm.tirx.Var("axis", "int64") ck.analyzer.update(axis, tvm.arith.ConstIntBound(0, 42)) res = ( tcast( "int32", - flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64") - + tvm.tir.const(1, "int64"), + flm(axis, tvm.tirx.const(7, "int64")) * tvm.tirx.const(2, "int64") + + tvm.tirx.const(1, "int64"), ) - + tvm.tir.const(1, "int32") - - tcast("int32", flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64")) + + tvm.tirx.const(1, "int32") + - tcast("int32", flm(axis, tvm.tirx.const(7, "int64")) * tvm.tirx.const(2, "int64")) ) ck.verify(res, 2) def test_simplify_normalize_min_value_expr(): ck = CanonicalChecker() - x = tvm.tir.Var("x", "int32") + x = tvm.tirx.Var("x", "int32") - ck.verify(tvm.tir.min_value("int32") - x == 0, x == tvm.tir.min_value("int32")) - ck.verify(tvm.tir.min_value("int32") + x == 0, tir.const(False)) - ck.verify(0 == tvm.tir.min_value("int32") - x, x == tvm.tir.min_value("int32")) - ck.verify(0 == tvm.tir.min_value("int32") + x, tir.const(False)) - ck.verify(-x + tvm.tir.min_value("int32") == 0, x == tvm.tir.min_value("int32")) - ck.verify(x + tvm.tir.min_value("int32") == 0, tir.const(False)) - ck.verify(0 == -x + tvm.tir.min_value("int32"), x == tvm.tir.min_value("int32")) - ck.verify(0 == x + tvm.tir.min_value("int32"), tir.const(False)) + ck.verify(tvm.tirx.min_value("int32") - x == 0, x == tvm.tirx.min_value("int32")) + ck.verify(tvm.tirx.min_value("int32") + x == 0, tirx.const(False)) + ck.verify(0 == tvm.tirx.min_value("int32") - x, x == tvm.tirx.min_value("int32")) + ck.verify(0 == tvm.tirx.min_value("int32") + x, tirx.const(False)) + ck.verify(-x + tvm.tirx.min_value("int32") == 0, x == tvm.tirx.min_value("int32")) + ck.verify(x + tvm.tirx.min_value("int32") == 0, tirx.const(False)) + ck.verify(0 == -x + tvm.tirx.min_value("int32"), x == tvm.tirx.min_value("int32")) + ck.verify(0 == x + tvm.tirx.min_value("int32"), tirx.const(False)) def test_proddiv_simplify(): ck = CanonicalChecker() - flm = tvm.tir.floormod - fld = tvm.tir.floordiv - tdiv = tvm.tir.truncdiv - tmod = tvm.tir.truncmod + flm = tvm.tirx.floormod + fld = tvm.tirx.floordiv + tdiv = tvm.tirx.truncdiv + tmod = tvm.tirx.truncmod - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("y", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("y", "int32") ck.verify(flm(x * 32 * x, x), 0) ck.verify(flm(z * x * 32 * x * y, x * z), 0) @@ -429,15 +429,15 @@ def test_proddiv_simplify(): def test_floormod_two(): ck = CanonicalChecker() - flm = tvm.tir.floormod - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + flm = tvm.tirx.floormod + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1) def test_simplify_le(): ck = CanonicalChecker() # Case 1. Ignore the extra expr if it's small than the division number - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") ck.analyzer.bind(y, tvm.ir.Range(0, 8)) ck.analyzer.bind(z, tvm.ir.Range(0, 2)) ck.verify(x * 8 + y < 16, x < 2) @@ -450,16 +450,16 @@ def test_simplify_le(): ck.verify(x * 8 + y + z < 16, x * 8 + y + z < 16) - n = tvm.tir.SizeVar("n", "int32") + n = tvm.tirx.SizeVar("n", "int32") ck.verify(x * 8 + y < n, x * 8 + y < n) # Case 2. Simplify the extra expr x1, x2, ty, tx, vec = ( - tvm.tir.Var("x1", "int32"), - tvm.tir.Var("x2", "int32"), - tvm.tir.Var("ty", "int32"), - tvm.tir.Var("tx", "int32"), - tvm.tir.Var("vec", "int32"), + tvm.tirx.Var("x1", "int32"), + tvm.tirx.Var("x2", "int32"), + tvm.tirx.Var("ty", "int32"), + tvm.tirx.Var("tx", "int32"), + tvm.tirx.Var("vec", "int32"), ) ck.analyzer.bind(x1, tvm.ir.Range(0, 2)) ck.analyzer.bind(x2, tvm.ir.Range(0, 3)) @@ -473,7 +473,7 @@ def test_simplify_le(): ck.verify(tx // 2 % 8 + vec < 8, tx % 16 // 2 + vec < 8) # Case 3. No failure - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") ck.analyzer.bind(y, tvm.ir.Range(0, 1024)) ck.verify(x * 1024 + y < z * 7168, x - z * 7 < 0) diff --git a/tests/python/arith/test_arith_const_int_bound.py b/tests/python/arith/test_arith_const_int_bound.py index 7ac4983d07a4..d15453226771 100644 --- a/tests/python/arith/test_arith_const_int_bound.py +++ b/tests/python/arith/test_arith_const_int_bound.py @@ -68,16 +68,16 @@ def test_const_bounds(self, test_case): class TestDataType(BaseCompare): test_case = tvm.testing.parameter( - TestCase(tvm.tir.Var("x", "int64"), (NEG_INF, POS_INF)), - TestCase(tvm.tir.Var("x", "int8"), (-128, 127)), - TestCase(tvm.tir.Var("x", "uint8"), (0, 255)), - TestCase(tvm.tir.SizeVar("x", "int32"), (0, POS_INF)), + TestCase(tvm.tirx.Var("x", "int64"), (NEG_INF, POS_INF)), + TestCase(tvm.tirx.Var("x", "int8"), (-128, 127)), + TestCase(tvm.tirx.Var("x", "uint8"), (0, 255)), + TestCase(tvm.tirx.SizeVar("x", "int32"), (0, POS_INF)), ) class TestCastBound(BaseCompare): - x = tvm.tir.Var("x", "int8") - tmod = tvm.tir.truncmod + x = tvm.tirx.Var("x", "int8") + tmod = tvm.tirx.truncmod test_case = tvm.testing.parameter( TestCase(tmod(x, 3).astype("uint32"), (0, 2)), @@ -86,8 +86,8 @@ class TestCastBound(BaseCompare): class TestAddSubBound(BaseCompare): - x = tvm.tir.Var("x", "int64") - y = tvm.tir.Var("y", "int64") + x = tvm.tirx.Var("x", "int64") + y = tvm.tirx.Var("y", "int64") test_case = tvm.testing.parameter( TestCase(x + y, (NEG_INF, POS_INF)), @@ -118,7 +118,7 @@ class TestBoundsUsingReciprocals(BaseCompare): achieve its minimum while `A*B` simultaneously achieves its maximum. """ - A, B, C = [tvm.tir.Var(letter, "int64") for letter in "ABC"] + A, B, C = [tvm.tirx.Var(letter, "int64") for letter in "ABC"] symmetric_bounds = {A: (1, 4095), B: (1, 4095), C: (2048, 2048)} asymmetric_bounds = {A: (1, 1024), B: (1, POS_INF), C: (2048, 2048)} @@ -136,7 +136,7 @@ class TestBoundsUsingReciprocals(BaseCompare): class TestMulBound(BaseCompare): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") test_case = tvm.testing.parameter( TestCase(x * y + 20, (0, 60), {x: (-2, 4), y: (4, 10)}), @@ -146,9 +146,9 @@ class TestMulBound(BaseCompare): class TestTruncDivBound(BaseCompare): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") - expr = tvm.tir.truncdiv(x, y) + expr = tvm.tirx.truncdiv(x, y) test_case = tvm.testing.parameter( TestCase(expr, (-2, None), {x: (-9, 4), y: (4, 10)}), @@ -159,9 +159,9 @@ class TestTruncDivBound(BaseCompare): class TestTruncModBound(BaseCompare): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") - expr = tvm.tir.truncmod(x, y) + expr = tvm.tirx.truncmod(x, y) test_case = tvm.testing.parameter( TestCase(expr, (-9, 4), {x: (-9, 4), y: (4, 10)}), @@ -171,9 +171,9 @@ class TestTruncModBound(BaseCompare): class TestFloorDivBound(BaseCompare): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") - ux = tvm.tir.Var("x", "uint32") - uy = tvm.tir.Var("y", "uint32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") + ux = tvm.tirx.Var("x", "uint32") + uy = tvm.tirx.Var("y", "uint32") test_case = tvm.testing.parameter( TestCase(x // y, (-9 // 4, None), {x: (-9, 4), y: (4, 10)}), @@ -185,7 +185,7 @@ class TestFloorDivBound(BaseCompare): class TestFloorModBound(BaseCompare): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") test_case = tvm.testing.parameter( TestCase(x % y, (0, 9), {x: (-9, 4), y: (4, 10)}), @@ -195,22 +195,22 @@ class TestFloorModBound(BaseCompare): class TestMinMaxBound(BaseCompare): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") test_case = tvm.testing.parameter( - TestCase(tvm.tir.min(x, y), (-9, 10), {x: (-9, 11), y: (4, 10)}), - TestCase(tvm.tir.min(x, y), (NEG_INF, 10), {x: (NEG_INF, POS_INF), y: (4, 10)}), - TestCase(tvm.tir.max(x, y), (4, POS_INF), {x: (NEG_INF, POS_INF), y: (4, 10)}), - TestCase(tvm.tir.max(x, y), (4, POS_INF), {x: (1, POS_INF), y: (4, 10)}), + TestCase(tvm.tirx.min(x, y), (-9, 10), {x: (-9, 11), y: (4, 10)}), + TestCase(tvm.tirx.min(x, y), (NEG_INF, 10), {x: (NEG_INF, POS_INF), y: (4, 10)}), + TestCase(tvm.tirx.max(x, y), (4, POS_INF), {x: (NEG_INF, POS_INF), y: (4, 10)}), + TestCase(tvm.tirx.max(x, y), (4, POS_INF), {x: (1, POS_INF), y: (4, 10)}), ) class TestSelectBound(BaseCompare): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") test_case = tvm.testing.parameter( TestCase( - tvm.tir.Select(x > 1, (y < 0).astype("int32"), y + 1), + tvm.tirx.Select(x > 1, (y < 0).astype("int32"), y + 1), (0, 11), {x: (-9, 11), y: (4, 10)}, ), @@ -218,7 +218,7 @@ class TestSelectBound(BaseCompare): class TestShiftAndBound(BaseCompare): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") test_case = tvm.testing.parameter( TestCase(x >> y, (-3, 2), {x: (-9, 11), y: (2, 10)}), @@ -228,9 +228,9 @@ class TestShiftAndBound(BaseCompare): class TestMixIndexBound(BaseCompare): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") - tdiv = tvm.tir.truncdiv - tmod = tvm.tir.truncmod + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") + tdiv = tvm.tirx.truncdiv + tmod = tvm.tirx.truncmod test_case = tvm.testing.parameter( TestCase(tmod(x, 8) + tdiv(x, 8) * 8, (0, 24 - 1), {x: (0, 24 - 1), y: (0, 3 - 1)}), @@ -242,15 +242,15 @@ class TestMixIndexBound(BaseCompare): class TestLetBound(BaseCompare): - x = tvm.tir.Var("x", "int32") + x = tvm.tirx.Var("x", "int32") test_case = tvm.testing.parameter( - TestCase(tvm.tir.Let(x, 1, x + 1), (2, 2)), + TestCase(tvm.tirx.Let(x, 1, x + 1), (2, 2)), ) class TestFloorModNegativeDivisor(BaseCompare): - flm, fld = tvm.tir.floormod, tvm.tir.floordiv - a, b = tvm.tir.Var("a", "int32"), tvm.tir.Var("b", "int32") + flm, fld = tvm.tirx.floormod, tvm.tirx.floordiv + a, b = tvm.tirx.Var("a", "int32"), tvm.tirx.Var("b", "int32") test_case = tvm.testing.parameter( TestCase(a % b, (-4, 6), {a: (0, 6), b: (-5, 7)}), @@ -263,7 +263,7 @@ class TestDivModAssumeNoZeroDivisor(BaseCompare): from symbolic shape programs """ - a, b = tvm.tir.Var("a", "int32"), tvm.tir.Var("b", "int32") + a, b = tvm.tirx.Var("a", "int32"), tvm.tirx.Var("b", "int32") test_case = tvm.testing.parameter( TestCase(a // b, (0, 6), {a: (0, 6), b: (0, POS_INF)}), @@ -272,35 +272,35 @@ class TestDivModAssumeNoZeroDivisor(BaseCompare): class TestMultipleCondition(BaseCompare): - a = tvm.tir.Var("a", "int32") + a = tvm.tirx.Var("a", "int32") test_case = tvm.testing.parameter( TestCase( a % 58 - 1, (0, None), known_bounds={a: (0, 128)}, - constraint=tvm.tir.all(1 <= a % 58, a % 58 < 57), + constraint=tvm.tirx.all(1 <= a % 58, a % 58 < 57), ), ) class TestBroadcastBound(BaseCompare): - a = tvm.tir.Var("a", "int32") + a = tvm.tirx.Var("a", "int32") test_case = tvm.testing.parameter( - TestCase(tvm.tir.Broadcast(a, 4), (0, 128), {a: (0, 128)}), + TestCase(tvm.tirx.Broadcast(a, 4), (0, 128), {a: (0, 128)}), ) class TestRampBound(BaseCompare): - a = tvm.tir.Var("a", "int32") + a = tvm.tirx.Var("a", "int32") test_case = tvm.testing.parameter( - TestCase(tvm.tir.Ramp(a, 2, 4) + 2, (2, 128 + 2 * 3 + 2), {a: (0, 128)}), + TestCase(tvm.tirx.Ramp(a, 2, 4) + 2, (2, 128 + 2 * 3 + 2), {a: (0, 128)}), ) class TestModularSetBound(BaseCompare): analyzer = tvm.arith.Analyzer() - tx = tvm.tir.Var("tx", "int32") - bx = tvm.tir.Var("bx", "int32") + tx = tvm.tirx.Var("tx", "int32") + bx = tvm.tirx.Var("bx", "int32") expr = (bx * 2048 + tx * 16) % 7168 diff --git a/tests/python/arith/test_arith_deduce_bound.py b/tests/python/arith/test_arith_deduce_bound.py index 43dd04485529..00aa8d94984e 100644 --- a/tests/python/arith/test_arith_deduce_bound.py +++ b/tests/python/arith/test_arith_deduce_bound.py @@ -19,21 +19,21 @@ import tvm import tvm.testing -from tvm.tir.buffer import decl_buffer +from tvm.tirx.buffer import decl_buffer def test_deduce(): - a = tvm.tir.Var("a", "int32") - b = tvm.tir.Var("b", "int32") - c = tvm.tir.Var("c", "int32") - d = tvm.tir.Var("d", "int32") + a = tvm.tirx.Var("a", "int32") + b = tvm.tirx.Var("b", "int32") + c = tvm.tirx.Var("c", "int32") + d = tvm.tirx.Var("d", "int32") b_s = tvm.arith.IntervalSet(2, 3) c_s = tvm.arith.IntervalSet(10, 15) d_s = tvm.arith.IntervalSet(-3, -1) - zero = tvm.tir.const(0, "int32") + zero = tvm.tirx.const(0, "int32") - fdiv = tvm.tir.floordiv + fdiv = tvm.tirx.floordiv e0 = (-b) * a + c - d res0 = tvm.arith.deduce_bound(a, e0 >= 0, {b: b_s, c: c_s, d: d_s}, {}) @@ -63,13 +63,13 @@ def test_deduce(): res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) tvm.testing.assert_prim_expr_equal(res1.max_value, ans1) - e2 = tvm.tir.max(5, a * 4) < 0 + e2 = tvm.tirx.max(5, a * 4) < 0 res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) assert str(res2.max_value) == "neg_inf" assert str(res2.min_value) == "pos_inf" # expression containing variable a is on rhs - e2 = zero < tvm.tir.max(5, a * 4) + e2 = zero < tvm.tirx.max(5, a * 4) res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {}) assert str(res2.max_value) == "neg_inf" assert str(res2.min_value) == "pos_inf" @@ -122,10 +122,10 @@ def test_deduce(): def test_check(): - a = tvm.tir.Var("a", "int32") - b = tvm.tir.Var("b", "int32") - c = tvm.tir.Var("c", "int32") - d = tvm.tir.Var("d", "int32") + a = tvm.tirx.Var("a", "int32") + b = tvm.tirx.Var("b", "int32") + c = tvm.tirx.Var("c", "int32") + d = tvm.tirx.Var("d", "int32") b_s = tvm.arith.IntervalSet(2, 3) c_s = tvm.arith.IntervalSet(5, 7) @@ -146,8 +146,8 @@ def test_check(): def test_deduce_basic(): def test_basic(a1, a2, coff): - a = tvm.tir.Var("a", "int32") - b = tvm.tir.Var("b", "int32") + a = tvm.tirx.Var("a", "int32") + b = tvm.tirx.Var("b", "int32") b_s = tvm.arith.IntervalSet(a1, a2) e0 = b + a * coff + 3 @@ -156,12 +156,12 @@ def test_basic(a1, a2, coff): tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) < 17, True) # expression containing variable a is on rhs - res1 = tvm.arith.deduce_bound(a, tvm.tir.const(17, "int32") < e0, {b: b_s}, {b: b_s}) + res1 = tvm.arith.deduce_bound(a, tvm.tirx.const(17, "int32") < e0, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) > 17, True) # expression containing variable a is on rhs - res1 = tvm.arith.deduce_bound(a, tvm.tir.const(17, "int32") >= e0, {b: b_s}, {b: b_s}) + res1 = tvm.arith.deduce_bound(a, tvm.tirx.const(17, "int32") >= e0, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) <= 17, True) @@ -180,8 +180,8 @@ def test_basic(a1, a2, coff): def test_deduce_complex(): def test_complex(a1, a2, coff): - a = tvm.tir.Var("a", "int32") - b = tvm.tir.Var("b", "int32") + a = tvm.tirx.Var("a", "int32") + b = tvm.tirx.Var("b", "int32") b_s = tvm.arith.IntervalSet(a1, a2) e0 = (b * 3 + a * coff) * 4 @@ -190,7 +190,7 @@ def test_complex(a1, a2, coff): tvm.testing.assert_prim_expr_equal(((x * 3 + t * coff) * 4) < 63, True) # expression containing variable a is on rhs - res1 = tvm.arith.deduce_bound(a, tvm.tir.const(63, "int32") >= e0, {b: b_s}, {b: b_s}) + res1 = tvm.arith.deduce_bound(a, tvm.tirx.const(63, "int32") >= e0, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] tvm.testing.assert_prim_expr_equal(((x * 3 + t * coff) * 4) <= 63, True) @@ -199,7 +199,7 @@ def test_complex(a1, a2, coff): tvm.testing.assert_prim_expr_equal(((x * 3 + t * coff) * 4) > 63, True) # expression containing variable a is on rhs - res1 = tvm.arith.deduce_bound(a, tvm.tir.const(63, "int32") <= e0, {b: b_s}, {b: b_s}) + res1 = tvm.arith.deduce_bound(a, tvm.tirx.const(63, "int32") <= e0, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] tvm.testing.assert_prim_expr_equal(((x * 3 + t * coff) * 4) >= 63, True) @@ -212,28 +212,28 @@ def test_complex(a1, a2, coff): def test_deduce_non_support(): - a = tvm.tir.Var("a", "int32") + a = tvm.tirx.Var("a", "int32") def test_non_support(lhs): res = tvm.arith.deduce_bound(a, lhs < 10, {}, {}) assert res.is_nothing() - test_non_support(tvm.tir.floormod(a, 16)) - test_non_support(tvm.tir.Min(a, 16)) - test_non_support(tvm.tir.Max(a, 16)) - test_non_support(tvm.tir.LE(a, 16)) - test_non_support(tvm.tir.LT(a, 16)) - test_non_support(tvm.tir.GE(a, 16)) - test_non_support(tvm.tir.GT(a, 16)) - test_non_support(tvm.tir.EQ(a, 16)) - test_non_support(tvm.tir.NE(a, 16)) - test_non_support(tvm.tir.log(a)) - test_non_support(tvm.tir.BufferLoad(decl_buffer([16], "int32"), [a])) + test_non_support(tvm.tirx.floormod(a, 16)) + test_non_support(tvm.tirx.Min(a, 16)) + test_non_support(tvm.tirx.Max(a, 16)) + test_non_support(tvm.tirx.LE(a, 16)) + test_non_support(tvm.tirx.LT(a, 16)) + test_non_support(tvm.tirx.GE(a, 16)) + test_non_support(tvm.tirx.GT(a, 16)) + test_non_support(tvm.tirx.EQ(a, 16)) + test_non_support(tvm.tirx.NE(a, 16)) + test_non_support(tvm.tirx.log(a)) + test_non_support(tvm.tirx.BufferLoad(decl_buffer([16], "int32"), [a])) def test_deduce_floordiv(): def do_test(gen_expr, dom_map, expect_min, expect_max): - a = tvm.tir.Var("a", "int32") + a = tvm.tirx.Var("a", "int32") expr = gen_expr(a) res = tvm.arith.deduce_bound(a, expr, dom_map, dom_map) if isinstance(expect_min, str): @@ -260,7 +260,7 @@ def do_test(gen_expr, dom_map, expect_min, expect_max): do_test(lambda a: 8 // a >= 2, {}, "pos_inf", "neg_inf") # test nested cases - b = tvm.tir.Var("b", "int32") + b = tvm.tirx.Var("b", "int32") bs = {b: tvm.arith.IntervalSet(2, 6)} do_test(lambda a: b * 3 + a // 8 < 63, bs, "neg_inf", 359) do_test(lambda a: b * 3 + a // 8 <= 63, bs, "neg_inf", 367) diff --git a/tests/python/arith/test_arith_detect_clip_bound.py b/tests/python/arith/test_arith_detect_clip_bound.py index 830c6d48112e..c93130491cbe 100644 --- a/tests/python/arith/test_arith_detect_clip_bound.py +++ b/tests/python/arith/test_arith_detect_clip_bound.py @@ -19,32 +19,32 @@ def test_basic(): - a = tvm.tir.Var("a", "int32") - b = tvm.tir.Var("b", "int32") - c = tvm.tir.Var("c", "int32") - m = tvm.arith.detect_clip_bound(tvm.tir.all(a * 1 < b * 6, a - 1 > 0), [a]) + a = tvm.tirx.Var("a", "int32") + b = tvm.tirx.Var("b", "int32") + c = tvm.tirx.Var("c", "int32") + m = tvm.arith.detect_clip_bound(tvm.tirx.all(a * 1 < b * 6, a - 1 > 0), [a]) tvm.testing.assert_prim_expr_equal(m[1], b * 6 - 1) assert m[0].value == 2 - m = tvm.arith.detect_clip_bound(tvm.tir.all(a * 1 < b * 6, a - 1 > 0), [a, b]) + m = tvm.arith.detect_clip_bound(tvm.tirx.all(a * 1 < b * 6, a - 1 > 0), [a, b]) assert len(m) == 0 - m = tvm.arith.detect_clip_bound(tvm.tir.all(a + 10 * c <= 20, b - 1 > 0), [a, b]) + m = tvm.arith.detect_clip_bound(tvm.tirx.all(a + 10 * c <= 20, b - 1 > 0), [a, b]) tvm.testing.assert_prim_expr_equal(m[1], 20 - 10 * c) tvm.testing.assert_prim_expr_equal(m[2], 2) - m = tvm.arith.detect_clip_bound(tvm.tir.all(tvm.tir.Not(a * 1 > b * 6), a - 1 > 0), [a]) + m = tvm.arith.detect_clip_bound(tvm.tirx.all(tvm.tirx.Not(a * 1 > b * 6), a - 1 > 0), [a]) tvm.testing.assert_prim_expr_equal(m[1], b * 6) - m = tvm.arith.detect_clip_bound(tvm.tir.all(tvm.tir.Min(a, b) > 3, a - 10 < 0), [a, b]) + m = tvm.arith.detect_clip_bound(tvm.tirx.all(tvm.tirx.Min(a, b) > 3, a - 10 < 0), [a, b]) tvm.testing.assert_prim_expr_equal(m[0], 4) tvm.testing.assert_prim_expr_equal(m[1], 9) tvm.testing.assert_prim_expr_equal(m[2], 4) def test_trivial_eq(): - a = tvm.tir.Var("a", "int32") - b = tvm.tir.Var("b", "int32") + a = tvm.tirx.Var("a", "int32") + b = tvm.tirx.Var("b", "int32") m = tvm.arith.detect_clip_bound(b == 3, [a, b]) tvm.testing.assert_prim_expr_equal(m[2], 3) tvm.testing.assert_prim_expr_equal(m[3], 3) - m = tvm.arith.detect_clip_bound(tvm.tir.all(a == 4, b == 3), [a, b]) + m = tvm.arith.detect_clip_bound(tvm.tirx.all(a == 4, b == 3), [a, b]) tvm.testing.assert_prim_expr_equal(m[0], 4) tvm.testing.assert_prim_expr_equal(m[1], 4) tvm.testing.assert_prim_expr_equal(m[2], 3) diff --git a/tests/python/arith/test_arith_detect_linear_equation.py b/tests/python/arith/test_arith_detect_linear_equation.py index e20ab5156906..08332cec9760 100644 --- a/tests/python/arith/test_arith_detect_linear_equation.py +++ b/tests/python/arith/test_arith_detect_linear_equation.py @@ -19,8 +19,8 @@ def test_basic(): - a = tvm.tir.Var("a", "int32") - b = tvm.tir.Var("b", "int32") + a = tvm.tirx.Var("a", "int32") + b = tvm.tirx.Var("b", "int32") m = tvm.arith.detect_linear_equation(a * 4 + b * 6 + 7, [a]) assert m[0].value == 4 tvm.testing.assert_prim_expr_equal(m[1], b * 6 + 7) @@ -42,14 +42,14 @@ def test_basic(): assert len(m) == 1 tvm.testing.assert_prim_expr_equal(m[0], b * 7) - c = tvm.tir.Var("c", "uint32") + c = tvm.tirx.Var("c", "uint32") m = tvm.arith.detect_linear_equation(128 - c, [c]) assert m[0].value == -1 def test_multivariate(): - v = [tvm.tir.Var(f"v{i}", "int32") for i in range(4)] - b = tvm.tir.Var("b", "int32") + v = [tvm.tirx.Var(f"v{i}", "int32") for i in range(4)] + b = tvm.tirx.Var("b", "int32") m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8, v) tvm.testing.assert_prim_expr_equal(m[0], b + 5) diff --git a/tests/python/arith/test_arith_domain_touched.py b/tests/python/arith/test_arith_domain_touched.py index 284a715bb2c1..c1791ac184f9 100644 --- a/tests/python/arith/test_arith_domain_touched.py +++ b/tests/python/arith/test_arith_domain_touched.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func @@ -51,7 +51,7 @@ def test_domain_touched(): assert a_domain_rw[0].min.value == -1 assert a_domain_rw[0].extent.value == 101 assert a_domain_rw[1].min.value == -1 - assert isinstance(a_domain_rw[1].extent, tvm.tir.Add) + assert isinstance(a_domain_rw[1].extent, tvm.tirx.Add) assert a_domain_rw[1].extent.a.name == "m" assert a_domain_rw[1].extent.b.value == 1 diff --git a/tests/python/arith/test_arith_intset.py b/tests/python/arith/test_arith_intset.py index 78871d065cb4..49e09191d62b 100644 --- a/tests/python/arith/test_arith_intset.py +++ b/tests/python/arith/test_arith_intset.py @@ -17,7 +17,7 @@ # ruff: noqa: F841 import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.arith.analyzer import Analyzer @@ -49,14 +49,14 @@ def test_vector(): base = 10 stride = 3 lanes = 2 - s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, stride, lanes)) + s = tvm.arith.IntSet.vector(tvm.tirx.Ramp(base, stride, lanes)) assert s.min_value.value == base assert s.max_value.value == base + stride * (lanes - 1) def test_scalable_vector(): base = 5 - s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, 2, tvm.tir.vscale() * 4)) + s = tvm.arith.IntSet.vector(tvm.tirx.Ramp(base, 2, tvm.tirx.vscale() * 4)) assert s.min_value.value == base assert s.max_value.same_as(tvm.arith.int_set.pos_inf()) @@ -64,7 +64,7 @@ def test_scalable_vector(): def test_add_sub(): ck = IntSetChecker() - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") ck.verify(x + y, {x: tvm.arith.IntervalSet(0, 10)}, (y, 10 + y)) ck.verify(x + y, {x: tvm.arith.IntervalSet(0, 10), y: tvm.arith.IntervalSet(1, 11)}, (1, 21)) ck.verify(x - y, {x: tvm.arith.IntervalSet(0, 10), y: tvm.arith.IntervalSet(1, 11)}, (-11, 9)) @@ -72,9 +72,9 @@ def test_add_sub(): def test_mul_div(): ck = IntSetChecker() - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") - tdiv = tvm.tir.truncdiv + tdiv = tvm.tirx.truncdiv ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True) ck.verify(x * y, {x: tvm.arith.IntervalSet(0, 10)}, (0, 10 * y)) ck.verify(x * 2, {x: tvm.arith.IntervalSet(1, 10)}, (2, 20)) @@ -83,20 +83,20 @@ def test_mul_div(): ck.verify(tdiv(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, tdiv(10, y))) ck.verify(tdiv(x, 2), {x: tvm.arith.IntervalSet(1, 10)}, (0, 5)) - fld = tvm.tir.floordiv + fld = tvm.tirx.floordiv ck.verify(fld(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, fld(10, y))) ck.verify(fld(x, 2), {x: tvm.arith.IntervalSet(-1, 10)}, (-1, 5)) def test_mod(): ck = IntSetChecker() - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") - tmod = tvm.tir.truncmod + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") + tmod = tvm.tirx.truncmod ck.analyzer.update(y, tvm.arith.ConstIntBound(1, 100), override=True) ck.verify(tmod(x, y), {x: tvm.arith.IntervalSet(0, 10)}, (0, y - 1)) ck.verify(tmod(x, 10), {x: tvm.arith.IntervalSet(1, 10)}, (0, 9)) - flm = tvm.tir.floormod + flm = tvm.tirx.floormod ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(-10, 10)}, (0, 9)) ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 5)}, (3, 5)) ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(13, 15)}, (3, 5)) @@ -104,8 +104,8 @@ def test_mod(): ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(3, 11)}, (0, 9)) ck.verify(flm(x, 10), {x: tvm.arith.IntervalSet(1, 21)}, (0, 9)) - fld = tvm.tir.floordiv - z = tvm.tir.Var("z", "int32") + fld = tvm.tirx.floordiv + z = tvm.tirx.Var("z", "int32") ck.analyzer.bind(x, tvm.ir.Range.from_min_extent(0, 3)) ck.verify( flm(y, 8), @@ -124,17 +124,17 @@ def test_mod(): def test_max_min(): ck = IntSetChecker() - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") - ck.verify(tvm.tir.max(x, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (1, 11)) - ck.verify(tvm.tir.min(x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 9)) - ck.verify(tvm.tir.min(x, y), {}, (tvm.tir.min(x, y), tvm.tir.min(x, y))) - ck.verify(tvm.tir.max(x, y), {}, (tvm.tir.max(x, y), tvm.tir.max(x, y))) + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") + ck.verify(tvm.tirx.max(x, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (1, 11)) + ck.verify(tvm.tirx.min(x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 9)) + ck.verify(tvm.tirx.min(x, y), {}, (tvm.tirx.min(x, y), tvm.tirx.min(x, y))) + ck.verify(tvm.tirx.max(x, y), {}, (tvm.tirx.max(x, y), tvm.tirx.max(x, y))) def test_select(): ck = IntSetChecker() - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") - ck.verify(tvm.tir.Select(x > 0, x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 11)) + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") + ck.verify(tvm.tirx.Select(x > 0, x - 1, x + 1), {x: tvm.arith.IntervalSet(0, 10)}, (-1, 11)) def check_region_bound(expect_region, var_dom, mode, predicate=None): @@ -156,7 +156,7 @@ def check_region_bound(expect_region, var_dom, mode, predicate=None): Extra predicate, defaults to True. """ if predicate is None: - predicate = tvm.tir.IntImm("bool", 1) + predicate = tvm.tirx.IntImm("bool", 1) region = [] expect = [] for k, v in expect_region.items(): @@ -210,7 +210,7 @@ def check_region_bound(expect_region, var_dom, mode, predicate=None): def test_region_bound_not_independent(): # (i, i+2) and (i+2, i+4) are dependent, this the lowerbound is not available - i = tvm.tir.Var("i", "int32") + i = tvm.tirx.Var("i", "int32") var_dom = { i: tvm.ir.Range(begin=0, end=64), } @@ -218,20 +218,20 @@ def test_region_bound_not_independent(): check_region_bound({(i, i + 2): (0, 65), (i + 2, i + 4): (2, 67)}, var_dom, mode="upperbound") # when only a subset of access indices are affine - i, j, k = tvm.tir.Var("i", "int32"), tvm.tir.Var("j", "int32"), tvm.tir.Var("k", "int32") + i, j, k = tvm.tirx.Var("i", "int32"), tvm.tirx.Var("j", "int32"), tvm.tirx.Var("k", "int32") var_dom = { i: tvm.ir.Range(begin=0, end=16), j: tvm.ir.Range(begin=0, end=16), k: tvm.ir.Range(begin=0, end=16), } check_region_bound( - {i // 4: None, j * 4 + i % 4: None, tir.truncdiv(k, 2): None}, + {i // 4: None, j * 4 + i % 4: None, tirx.truncdiv(k, 2): None}, var_dom, predicate=j * 4 + i % 4 > 3, mode="lowerbound", ) check_region_bound( - {i // 4: (0, 4), j * 4 + i % 4: (4, 64), tir.truncdiv(k, 2): (0, 8)}, + {i // 4: (0, 4), j * 4 + i % 4: (4, 64), tirx.truncdiv(k, 2): (0, 8)}, var_dom, predicate=j * 4 + i % 4 > 3, mode="upperbound", @@ -239,14 +239,14 @@ def test_region_bound_not_independent(): def test_region_bound_stride_too_wide(): - i = tvm.tir.Var("i", "int32") + i = tvm.tirx.Var("i", "int32") var_dom = {i: tvm.ir.Range(begin=0, end=64)} check_region_bound({(i * 4, i * 4 + 2): None}, var_dom, mode="lowerbound") check_region_bound({(i * 4, i * 4 + 2): (0, 254)}, var_dom, mode="upperbound") def test_region_bound_small_stride(): - i = tvm.tir.Var("i", "int32") + i = tvm.tirx.Var("i", "int32") var_dom = { i: tvm.ir.Range(begin=0, end=64), } @@ -254,8 +254,8 @@ def test_region_bound_small_stride(): def test_region_lower_bound_split_predicate(): - x_o = tvm.tir.Var("xo", "int32") - x_i = tvm.tir.Var("xi", "int32") + x_o = tvm.tirx.Var("xo", "int32") + x_i = tvm.tirx.Var("xi", "int32") x = x_o * 4 + x_i var_dom = { x_o: tvm.ir.Range(begin=0, end=16), @@ -272,10 +272,10 @@ def test_region_lower_bound_split_predicate(): def test_region_lower_bound_multiple_variables(): - div = tvm.tir.floordiv - mod = tvm.tir.floormod - x = tvm.tir.Var("x", "int32") - wid = tvm.tir.Var("wid", "int32") + div = tvm.tirx.floordiv + mod = tvm.tirx.floormod + x = tvm.tirx.Var("x", "int32") + wid = tvm.tirx.Var("wid", "int32") i = div(x, 16) j = div(mod(x, 16), 4) * 8 + mod(x, 4) + div(wid, 32) * 4 k = wid % 32 @@ -287,8 +287,8 @@ def test_region_lower_bound_multiple_variables(): def test_region_lower_bound_negative_scale(): - i = tvm.tir.Var("i", "int32") - j = tvm.tir.Var("j", "int32") + i = tvm.tirx.Var("i", "int32") + j = tvm.tirx.Var("j", "int32") var_dom = { i: tvm.ir.Range(begin=0, end=4), j: tvm.ir.Range(begin=0, end=4), @@ -299,9 +299,9 @@ def test_region_lower_bound_negative_scale(): def test_region_lower_bound_for_non_perfect_tile(): - h1 = tvm.tir.Var("h1", "int32") - h2 = tvm.tir.Var("h2", "int32") - h3 = tvm.tir.Var("h3", "int32") + h1 = tvm.tirx.Var("h1", "int32") + h2 = tvm.tirx.Var("h2", "int32") + h3 = tvm.tirx.Var("h3", "int32") # non-uniform tiling, single inner variable var_dom = { @@ -311,8 +311,8 @@ def test_region_lower_bound_for_non_perfect_tile(): { h3 * 8 + h2: { (): ( - tvm.tir.max(h3 * 8, 1), - tvm.tir.min(0, h3 * 8 - 214) + 224, + tvm.tirx.max(h3 * 8, 1), + tvm.tirx.min(0, h3 * 8 - 214) + 224, ), ((h3, 0),): (1, 10), # h3 == 0: region is [1, 10) ((h3, 10),): (h3 * 8, h3 * 8 + 10), # 0 < h3 <= 26: region is [h3 * 8, h3 * 8 + 10) @@ -320,7 +320,7 @@ def test_region_lower_bound_for_non_perfect_tile(): } }, var_dom, - predicate=tvm.tir.all(1 <= h3 * 8 + h2, h3 * 8 + h2 < 224), + predicate=tvm.tirx.all(1 <= h3 * 8 + h2, h3 * 8 + h2 < 224), mode="lowerbound", ) @@ -333,8 +333,8 @@ def test_region_lower_bound_for_non_perfect_tile(): { h3 * 8 + h2 * 5 + h1: { (): ( - tvm.tir.max(h3 * 8, 1), - tvm.tir.min(0, h3 * 8 - 214) + 224, + tvm.tirx.max(h3 * 8, 1), + tvm.tirx.min(0, h3 * 8 - 214) + 224, ), ((h3, 0),): (1, 10), ((h3, 10),): (h3 * 8, h3 * 8 + 10), @@ -342,7 +342,7 @@ def test_region_lower_bound_for_non_perfect_tile(): } }, var_dom, - predicate=tvm.tir.all(1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h2 * 5 + h1 < 224), + predicate=tvm.tirx.all(1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h2 * 5 + h1 < 224), mode="lowerbound", ) @@ -350,21 +350,21 @@ def test_region_lower_bound_for_non_perfect_tile(): check_region_bound( {h3 * 8 + h2 * 5 + h1: None}, var_dom, - predicate=tvm.tir.all(1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h1 * 2 + h2 < 224), + predicate=tvm.tirx.all(1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h1 * 2 + h2 < 224), mode="lowerbound", ) check_region_bound( {h3 * 8 + h2 * 5 + h1: (h3 * 8, h3 * 8 + 10)}, var_dom, - predicate=tvm.tir.all(1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h1 * 2 + h2 < 224), + predicate=tvm.tirx.all(1 <= h3 * 8 + h2 * 5 + h1, h3 * 8 + h1 * 2 + h2 < 224), mode="upperbound", ) def test_region_lower_bound_unfusable(): var_dom = { - tvm.tir.Var("i", "int32"): tvm.ir.Range(8), - tvm.tir.Var("j", "int32"): tvm.ir.Range(4), + tvm.tirx.Var("i", "int32"): tvm.ir.Range(8), + tvm.tirx.Var("j", "int32"): tvm.ir.Range(4), } i, j = var_dom check_region_bound({(i + j) // 2: (0, 6)}, var_dom, mode="lowerbound") @@ -386,8 +386,8 @@ def test_union_lower_bound(): def test_modular_set(): ck = IntSetChecker() - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") expr = (x * 2048 + y * 16) % 7168 ck.verify( expr, {x: tvm.arith.IntervalSet(0, 128), y: tvm.arith.IntervalSet(0, 3584)}, (0, 7152) diff --git a/tests/python/arith/test_arith_iter_affine_map.py b/tests/python/arith/test_arith_iter_affine_map.py index 28bc94db3b2f..9cb4f790db08 100644 --- a/tests/python/arith/test_arith_iter_affine_map.py +++ b/tests/python/arith/test_arith_iter_affine_map.py @@ -17,8 +17,8 @@ # ruff: noqa: F841 import tvm import tvm.testing -from tvm.script import tir as T -from tvm.tir import floordiv, floormod +from tvm.script import tirx as T +from tvm.tirx import floordiv, floormod def ifuse(inputs, pred_extent=None): @@ -32,8 +32,8 @@ def ifuse(inputs, pred_extent=None): def isplit(axis, factor): """Split iterators""" - fld = tvm.tir.floordiv - flm = tvm.tir.floormod + fld = tvm.tirx.floordiv + flm = tvm.tirx.floormod return [ (fld(axis[0], factor), fld(axis[1] + (factor - 1), factor)), (flm(axis[0], factor), factor), @@ -116,9 +116,9 @@ def assert_iter_sum_failure(iters, dom_map, predicate=True, check_level="surject def test_trivial(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - z = tvm.tir.Var("z", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + z = tvm.tirx.Var("z", "int32") dom_map = var_dom([(x, 3), (y, 4), (z, 1)]) assert_iter_sum_pattern({x: (3, 0), y: (4, 0), 3: (1, 3)}, dom_map) @@ -137,10 +137,10 @@ def test_trivial(): def test_fuse(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - c = tvm.tir.SizeVar("c", "int32") - c0 = tvm.tir.SizeVar("c0", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + c = tvm.tirx.SizeVar("c", "int32") + c0 = tvm.tirx.SizeVar("c0", "int32") assert_iter_sum_pattern({y * 3 + 1 + c + x: (12, 1 + c)}, var_dom([(x, 3), (y, 4)])) @@ -166,12 +166,12 @@ def test_fuse(): def test_split(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - c0 = tvm.tir.SizeVar("c0", "int32") - c1 = tvm.tir.SizeVar("c1", "int32") - fld = tvm.tir.floordiv - flm = tvm.tir.floormod + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + c0 = tvm.tirx.SizeVar("c0", "int32") + c1 = tvm.tirx.SizeVar("c1", "int32") + fld = tvm.tirx.floordiv + flm = tvm.tirx.floormod assert_iter_sum_pattern({fld(x, 3): (8, 0), flm(x, 3) * 2 + c1: (3, c1, 2)}, var_dom([(x, 24)])) @@ -202,8 +202,8 @@ def test_split(): def test_compound(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") xo, xi = isplit((x, 10), 5) yo, yi = isplit((y, 9), 3) @@ -224,9 +224,9 @@ def test_compound(): def test_compound_floormod_two_regression(): - x = tvm.tir.Var("x", "int32") - fld = tvm.tir.floordiv - flm = tvm.tir.floormod + x = tvm.tirx.Var("x", "int32") + fld = tvm.tirx.floordiv + flm = tvm.tirx.floormod # regression # extent of 2 of negative scale cannot be normalized assert_iter_sum_failure( @@ -236,9 +236,9 @@ def test_compound_floormod_two_regression(): def test_predicate(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - z = tvm.tir.Var("z", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + z = tvm.tirx.Var("z", "int32") # available contraints # upper bound only @@ -263,29 +263,29 @@ def test_predicate(): assert_iter_sum_pattern( {x * 10 + y: (122, 6)}, var_dom([(x, 13), (y, 10)]), - predicate=tvm.tir.And(x * 10 + y > 5, x * 10 + y < 128), + predicate=tvm.tirx.And(x * 10 + y > 5, x * 10 + y < 128), ) assert_iter_sum_pattern( {x * 10 + y: (122, 6)}, var_dom([(x, 13), (y, 10)]), - predicate=tvm.tir.And(x * 10 + y >= 6, x * 10 + y <= 127), + predicate=tvm.tirx.And(x * 10 + y >= 6, x * 10 + y <= 127), ) assert_iter_sum_pattern( {x * 64 + y * 4 + z: (16, 16)}, var_dom([(x, 16), (y, 16), (z, 4)]), - predicate=tvm.tir.And(x * 64 + y * 4 + z < 32, 4 <= x * 16 + y), + predicate=tvm.tirx.And(x * 64 + y * 4 + z < 32, 4 <= x * 16 + y), ) # constraint on one fused iter - i = tvm.tir.Var("i", "int32") - j = tvm.tir.Var("j", "int32") - k = tvm.tir.Var("k", "int32") + i = tvm.tirx.Var("i", "int32") + j = tvm.tirx.Var("j", "int32") + k = tvm.tirx.Var("k", "int32") assert_iter_sum_pattern( {i * 8 + j * 2 + k: (88, 1)}, var_dom([(i, 11), (j, 5), (k, 2)]), - predicate=tvm.tir.all(1 <= j * 2 + k, j * 2 + k < 9), + predicate=tvm.tirx.all(1 <= j * 2 + k, j * 2 + k < 9), ) # constraint on single var @@ -295,7 +295,7 @@ def test_predicate(): assert_iter_sum_failure( [i, j, k], var_dom([(i, 128), (j, 128), (k, 128)]), - predicate=tvm.tir.all(i * 16384 + j * 128 + k < 100), + predicate=tvm.tirx.all(i * 16384 + j * 128 + k < 100), ) # iterations are subparts of constraint, invalid case 2 @@ -312,7 +312,7 @@ def test_predicate(): assert_iter_sum_pattern( {i * 8 + j * 2 + k: (22, 3)}, var_dom([(i, 11), (j, 5), (k, 2)]), - predicate=tvm.tir.all( + predicate=tvm.tirx.all( 1 <= j * 2 + k, j * 2 + k < 9, 3 <= i * 8 + j * 2 + k, i * 8 + j * 2 + k < 25 ), ) @@ -321,14 +321,14 @@ def test_predicate(): assert_iter_sum_pattern( {i * 6 + j * 2 + k: (66, 2)}, var_dom([(i, 11), (j, 5), (k, 2)]), - predicate=tvm.tir.all(1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, j * 2 + k < 9), + predicate=tvm.tirx.all(1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, j * 2 + k < 9), ) # duplicate constraint on nested fused iters assert_iter_sum_pattern( {i * 6 + j * 2 + k: (15, 3)}, var_dom([(i, 11), (j, 5), (k, 2)]), - predicate=tvm.tir.all( + predicate=tvm.tirx.all( 1 <= j * 2 + k, 2 <= j * 2 + k, j * 2 + k < 8, @@ -344,7 +344,7 @@ def test_predicate(): assert_iter_sum_failure( [i * 8 + j * 2 + k], var_dom([(i, 11), (j, 5), (k, 2)]), - predicate=tvm.tir.all(2 <= j * 2 + k, 0 <= i * 4 + j), + predicate=tvm.tirx.all(2 <= j * 2 + k, 0 <= i * 4 + j), ) # constraint with differnent lower bound @@ -354,15 +354,15 @@ def test_predicate(): 64, 0, 1, - (i * 16 + j) // 23 * 8 + ((i * 16 + j) % 23 + tvm.tir.IntImm("int32", -15)), + (i * 16 + j) // 23 * 8 + ((i * 16 + j) % 23 + tvm.tirx.IntImm("int32", -15)), ) }, var_dom([(i, 12), (j, 16)]), - predicate=tvm.tir.And( - tvm.tir.And( - i * 16 + j < 184, tvm.tir.LE(tvm.tir.IntImm("int32", 8), (i * 16 + j) % 23) + predicate=tvm.tirx.And( + tvm.tirx.And( + i * 16 + j < 184, tvm.tirx.LE(tvm.tirx.IntImm("int32", 8), (i * 16 + j) % 23) ), - tvm.tir.LE(tvm.tir.IntImm("int32", 15), (i * 16 + j) % 23), + tvm.tirx.LE(tvm.tirx.IntImm("int32", 15), (i * 16 + j) % 23), ), ) @@ -370,23 +370,23 @@ def test_predicate(): # i4 * 6 + i5 in [3, 9), extent=6 (= scale of i2) # i2 * 30 + i3 * 15 in [30, 90), extent=60 (= scale of i1) # i1 * 60 in [60, 240), extent=180 (= scale of i0) - i0 = tvm.tir.Var("i0", "int32") - i1 = tvm.tir.Var("i1", "int32") - i2 = tvm.tir.Var("i2", "int32") - i3 = tvm.tir.Var("i3", "int32") - i4 = tvm.tir.Var("i4", "int32") - i5 = tvm.tir.Var("i5", "int32") + i0 = tvm.tirx.Var("i0", "int32") + i1 = tvm.tirx.Var("i1", "int32") + i2 = tvm.tirx.Var("i2", "int32") + i3 = tvm.tirx.Var("i3", "int32") + i4 = tvm.tirx.Var("i4", "int32") + i5 = tvm.tirx.Var("i5", "int32") assert_iter_sum_pattern( {i0 * 180 + i1 * 60 + i2 * 30 + i3 * 15 + i4 * 6 + i5: (540, 93)}, var_dom([(i0, 3), (i1, 4), (i2, 3), (i3, 2), (i4, 3), (i5, 6)]), - predicate=tvm.tir.all(1 <= i1, 2 <= i2 * 2 + i3, 3 <= i4 * 6 + i5), + predicate=tvm.tirx.all(1 <= i1, 2 <= i2 * 2 + i3, 3 <= i4 * 6 + i5), ) # constraint on many disjoint fused iters, case 2 assert_iter_sum_pattern( {i0 * 45 + i1 * 45 + i2 * 9 + i3 * 4 + i4: (135, 28)}, var_dom([(i0, 3), (i1, 2), (i2, 5), (i3, 3), (i4, 4)]), - predicate=tvm.tir.all( + predicate=tvm.tirx.all( 3 <= i1 * 5 + i2, i1 * 5 + i2 < 8, 1 <= i3 * 4 + i4, i3 * 4 + i4 < 10 ), ) @@ -395,7 +395,7 @@ def test_predicate(): assert_iter_sum_pattern( {i % 16: (7, 3), i // 16: (8, 4)}, var_dom([(i, 1024)]), - predicate=tvm.tir.all(3 <= i % 16, i % 16 < 10, 4 <= i // 16, i // 16 < 12), + predicate=tvm.tirx.all(3 <= i % 16, i % 16 < 10, 4 <= i // 16, i // 16 < 12), check_level="bijective", ) @@ -403,7 +403,7 @@ def test_predicate(): assert_iter_sum_pattern( {(i * 32 + j) % 16: (7, 3)}, var_dom([(i, 5), (j, 32)]), - predicate=tvm.tir.all(3 <= (i * 32 + j) % 16, (i * 32 + j) % 16 < 10), + predicate=tvm.tirx.all(3 <= (i * 32 + j) % 16, (i * 32 + j) % 16 < 10), ) # constraint on split iters, nested case 2 @@ -412,18 +412,18 @@ def test_predicate(): (i * 32 + j) % 16, ], var_dom([(i, 5), (j, 32)]), - predicate=tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32), + predicate=tvm.tirx.all(1 <= i * 32 + j, i * 32 + j <= 32), check_level="bijective", ) assert_iter_sum_pattern( {(i * 32 + j) % 16: (16, 0)}, var_dom([(i, 5), (j, 32)]), - predicate=tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 32), + predicate=tvm.tirx.all(1 <= i * 32 + j, i * 32 + j <= 32), ) assert_iter_sum_pattern( {(i * 32 + j - 1) % 16: (16, 0), (i * 32 + j - 1) // 16: (4, 0)}, var_dom([(i, 5), (j, 32)]), - predicate=tvm.tir.all(1 <= i * 32 + j, i * 32 + j <= 64), + predicate=tvm.tirx.all(1 <= i * 32 + j, i * 32 + j <= 64), ) # non-standard form of predicate @@ -435,7 +435,7 @@ def test_predicate(): assert_iter_sum_pattern( {x * 10 + y: (64, 0)}, var_dom([(x, 13), (y, 10)]), - predicate=tvm.tir.all(x * 10 + y < 128, x * 10 + y < 64), + predicate=tvm.tirx.all(x * 10 + y < 128, x * 10 + y < 64), ) # useless constraint @@ -443,15 +443,15 @@ def test_predicate(): {x * 10 + y: (130, 0)}, var_dom([(x, 13), (y, 10)]), predicate=x * 10 + y < 140 ) - i1 = tvm.tir.Var("i1", "int32") - i2 = tvm.tir.Var("i2", "int32") - i3 = tvm.tir.Var("i3", "int32") - i4 = tvm.tir.Var("i4", "int32") + i1 = tvm.tirx.Var("i1", "int32") + i2 = tvm.tirx.Var("i2", "int32") + i3 = tvm.tirx.Var("i3", "int32") + i4 = tvm.tirx.Var("i4", "int32") assert_iter_sum_pattern( {i1 * 20 + i2 * 10 + i3 * 3 + i4: (128, 0)}, var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), predicate=( - tvm.tir.all( + tvm.tirx.all( i1 * 2 + i2 < 13, i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, i3 * 3 + i4 < 10, @@ -464,7 +464,7 @@ def test_predicate(): [i1 * 20 + i2 * 10 + i3 * 3 + i4], var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), predicate=( - tvm.tir.all( + tvm.tirx.all( i1 * 2 + i2 < 13, i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, i3 * 3 + i4 < 7, @@ -477,7 +477,7 @@ def test_predicate(): [i1 * 20 + i2 * 10 + i3 * 3 + i4], var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), predicate=( - tvm.tir.all( + tvm.tirx.all( i1 * 2 + i2 < 13, i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, i3 * 3 + i4 < 10, @@ -489,7 +489,7 @@ def test_predicate(): [i1 * 20 + i2 * 10 + i3 * 3 + i4], var_dom([(i1, 7), (i2, 2), (i3, 4), (i4, 3)]), predicate=( - tvm.tir.all( + tvm.tirx.all( i1 * 2 + i2 < 13, i1 * 20 + i2 * 10 + i3 * 3 + i4 < 128, i1 * 4 + i3 < 20, @@ -498,9 +498,9 @@ def test_predicate(): ) # zero iter - xo = tvm.tir.Var("xo", "int32") - xi = tvm.tir.Var("xi", "int32") - y = tvm.tir.Var("y", "int32") + xo = tvm.tirx.Var("xo", "int32") + xi = tvm.tirx.Var("xi", "int32") + y = tvm.tirx.Var("y", "int32") assert_iter_sum_pattern( {xo * 129 + xi: (128, 0), y: (128, 0)}, var_dom([(xo, 1), (xi, 129), (y, 128)]), @@ -531,14 +531,14 @@ def convert_division(divisions): def create_iter(name, extent): - return tvm.tir.Var(name, "int32"), extent + return tvm.tirx.Var(name, "int32"), extent def test_subspace_division(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - z = tvm.tir.Var("z", "int32") - c = tvm.tir.SizeVar("c", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + z = tvm.tirx.Var("z", "int32") + c = tvm.tirx.SizeVar("c", "int32") # simple 1.1 res = tvm.arith.subspace_divide( @@ -639,7 +639,7 @@ def test_subspace_division(): # compound 1.6 res = tvm.arith.subspace_divide( - [k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]], tvm.tir.all(k0[0] < 7, k1[0] < 7) + [k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]], tvm.tirx.all(k0[0] < 7, k1[0] < 7) ) res = convert_division(res) assert len(res) == 0 @@ -706,7 +706,7 @@ def test_subspace_division(): [i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l1[0], j3[0]], - tvm.tir.all(i0[0] < 7, i2[0] < 8), + tvm.tirx.all(i0[0] < 7, i2[0] < 8), ) res = convert_division(res) assert len(res) == 4 @@ -733,9 +733,9 @@ def test_subspace_division(): def test_subspace_divide_trivial_iters(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - z = tvm.tir.Var("z", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + z = tvm.tirx.Var("z", "int32") # trivial 1.1 res = tvm.arith.subspace_divide( @@ -836,7 +836,7 @@ def test_complex(): assert_iter_sum_pattern( {i0[0]: (200, 0, 1, i0_final), i1[0]: (50, 0, 1, i1_final)}, var_dom([l0, l1, n0, n1, m1, l3]), - predicate=tvm.tir.all( + predicate=tvm.tirx.all( i0[0] < 200, i1[0] < 50, m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15 ), ) @@ -845,7 +845,7 @@ def test_complex(): assert_iter_sum_failure( [i0[0], i1[0]], var_dom([l0, l1, n0, n1, m1, l3]), - tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 9, l2[0] < 16, j0[0] < 7, j3[0] < 14), + tvm.tirx.all(i0[0] < 200, i1[0] < 50, m0[0] < 9, l2[0] < 16, j0[0] < 7, j3[0] < 14), ) # subspace_division @@ -853,7 +853,7 @@ def test_complex(): [i0[0], i1[0]], var_dom([l0, l1, n0, n1, m1, l3]), [n0[0], n1[0], m1[0], l3[0]], - tvm.tir.all(m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15), + tvm.tirx.all(m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15), ) res = convert_division(res) assert len(res) == 3 @@ -868,7 +868,7 @@ def test_complex(): tvm.ir.assert_structural_equal(res[2][0], (floordiv(l0[0], 2) * 4) + floordiv(l1[0], 2) < 7) tvm.ir.assert_structural_equal( res[2][1], - tvm.tir.all( + tvm.tirx.all( n0[0] * 4 + n1[0] < 6, (n0[0] * 4 + n1[0]) * 3 + m1[0] < 16, floormod(((n0[0] * 4 + n1[0]) * 3 + m1[0]), 4) * 4 + floormod(l3[0], 4) < 15, @@ -882,11 +882,11 @@ def test_complex(): def test_normalize_iter_map_to_expr(): - fld = tvm.tir.floordiv - flm = tvm.tir.floormod + fld = tvm.tirx.floordiv + flm = tvm.tirx.floormod - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") xo, xi = isplit((x, 10), 5) yo, yi = isplit((y, 9), 3) @@ -918,7 +918,7 @@ def test_inverse_affine_iter_map(): iter_map = tvm.arith.detect_iter_map( [l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1]) ).indices - outputs = [tvm.tir.Var(f"output_{i}", "int32") for i in range(len(iter_map))] + outputs = [tvm.tirx.Var(f"output_{i}", "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 2 l0_inverse = floordiv(outputs[0], 4) + outputs[1] * 16 @@ -937,7 +937,7 @@ def test_inverse_affine_iter_map(): iter_map = tvm.arith.detect_iter_map( [l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], var_dom([l0, l1, l2]) ).indices - outputs = [tvm.tir.Var(f"output_{i}", "int32") for i in range(len(iter_map))] + outputs = [tvm.tirx.Var(f"output_{i}", "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 3 l0_inverse = floordiv(outputs[0], 64) + outputs[1] * 16 @@ -957,7 +957,7 @@ def test_inverse_affine_iter_map(): l2 = ifuse([l1_1, l1_0]) iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0])).indices - outputs = [tvm.tir.Var(f"output_{i}", "int32") for i in range(len(iter_map))] + outputs = [tvm.tirx.Var(f"output_{i}", "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 1 l1_inverse = floormod(outputs[0], 8) * 8 + floordiv(outputs[0], 8) @@ -971,7 +971,7 @@ def test_inverse_affine_map_trivial_iter(): l0 = create_iter("l0", 64) l1 = create_iter("l1", 64) iter_map = tvm.arith.detect_iter_map([0, l0[0], l1[0]], var_dom([l0, l1])).indices - outputs = [tvm.tir.Var(f"output_{i}", "int32") for i in range(len(iter_map))] + outputs = [tvm.tirx.Var(f"output_{i}", "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) # output_0 is expected to be constant and it is not included in the inverse map assert len(res) == 2 @@ -980,9 +980,9 @@ def test_inverse_affine_map_trivial_iter(): def test_free_variables(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - z = tvm.tir.Var("z", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + z = tvm.tirx.Var("z", "int32") # illegal iter if z is within dom assert_iter_sum_failure([z * 19 + y * 3 + x], var_dom([(x, 3), (y, 3), (z, 3)])) @@ -1009,10 +1009,10 @@ def test_free_variables(): class TestPadding: - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - fld = tvm.tir.floordiv - flm = tvm.tir.floormod + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + fld = tvm.tirx.floordiv + flm = tvm.tirx.floormod positive_test_case = tvm.testing.parameter( # left padding only, offset divisible @@ -1090,11 +1090,11 @@ def test_padding_error(self, negative_test_case): def test_overlapped_fuse(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - z = tvm.tir.Var("z", "int32") - a = tvm.tir.Var("x", "int32") - b = tvm.tir.Var("y", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + z = tvm.tirx.Var("z", "int32") + a = tvm.tirx.Var("x", "int32") + b = tvm.tirx.Var("y", "int32") # non-bijective fuse of two assert_iter_sum_pattern( @@ -1126,7 +1126,7 @@ def test_overlapped_fuse(): a * 40 + b * 20 + x * 18 + y * 3 + z: (125, 6, 1), }, var_dom([(a, 3), (b, 2), (x, 2), (y, 6), (z, 8)]), - predicate=tvm.tir.all(z < 4, 1 < x * 6 + y, x * 6 + y < 10), + predicate=tvm.tirx.all(z < 4, 1 < x * 6 + y, x * 6 + y < 10), check_level="surjective", ) @@ -1141,11 +1141,11 @@ def test_overlapped_fuse(): def test_iter_map_simplify_symbolic_case(): """Test itermap simplify""" - x = tvm.tir.Var("x", "int64") - y = tvm.tir.Var("y", "int64") + x = tvm.tirx.Var("x", "int64") + y = tvm.tirx.Var("y", "int64") z = x * 32 + y - n = tvm.tir.SizeVar("n", "int64") + n = tvm.tirx.SizeVar("n", "int64") def simple_fuse0(x): return (x // n) * n + x % n @@ -1176,10 +1176,10 @@ def fsymbolic_fuse2(i): def test_iter_map_simplify_symbolic_predicate(): """Test itermap simplify""" - x = tvm.tir.Var("x", "int64") - y = tvm.tir.Var("y", "int64") + x = tvm.tirx.Var("x", "int64") + y = tvm.tirx.Var("y", "int64") - n = tvm.tir.SizeVar("n", "int64") + n = tvm.tirx.SizeVar("n", "int64") def simple_fuse0(x): return (x // n) * n + x % n @@ -1201,8 +1201,8 @@ def fsymbolic_fuse2(i): def test_iter_map_simplify_symbolic_reshape(): - n = tvm.tir.Var("n", "int64") - fused = tvm.tir.Var("fused", "int64") + n = tvm.tirx.Var("n", "int64") + fused = tvm.tirx.Var("fused", "int64") ax0 = (fused // 4096) // n ax1 = (fused // 4096) % n @@ -1215,9 +1215,9 @@ def test_iter_map_simplify_symbolic_reshape(): def test_iter_map_simplify_unit_loop_order(): """Test itermap simplify""" - x = tvm.tir.Var("x", "int64") - y = tvm.tir.Var("y", "int64") - z = tvm.tir.Var("z", "int64") + x = tvm.tirx.Var("x", "int64") + y = tvm.tirx.Var("y", "int64") + z = tvm.tirx.Var("z", "int64") # trivial iterators can be found at any when comparing via scale # ensure order unchange @@ -1252,7 +1252,7 @@ def assert_normalize_to_iter_sum(index, input_iters, args, base): Parameters ---------- - index : tvm.tir.PrimExpr + index : tvm.tirx.PrimExpr The index to be normalized input_iters : Mapping[Var, Range] The input iterators @@ -1260,7 +1260,7 @@ def assert_normalize_to_iter_sum(index, input_iters, args, base): The expected result. Ordered list of args of the expected IterSumExpr. Each arg can be either IterSplitExpr or a tuple of (PrimExpr, PrimExpr) where the first element is the iterator normalized to PrimExpr and the second element is the scale. - base : tvm.tir.PrimExpr + base : tvm.tirx.PrimExpr The expected base """ res = tvm.arith.normalize_to_iter_sum(index, input_iters) @@ -1279,12 +1279,12 @@ def assert_normalize_to_iter_sum(index, input_iters, args, base): def test_normalize_to_iter_sum(): - x = tvm.tir.Var("x", "int64") - y = tvm.tir.Var("y", "int64") - z = tvm.tir.Var("z", "int64") - a = tvm.tir.Var("a", "int64") - n = tvm.tir.Var("n", "int64") - flm = tvm.tir.floormod + x = tvm.tirx.Var("x", "int64") + y = tvm.tirx.Var("y", "int64") + z = tvm.tirx.Var("z", "int64") + a = tvm.tirx.Var("a", "int64") + n = tvm.tirx.Var("n", "int64") + flm = tvm.tirx.floormod assert_normalize_to_iter_sum( z + ((y + x * 4 + 2) * n) + 3, @@ -1295,10 +1295,10 @@ def test_normalize_to_iter_sum(): # max cannot detected so it goes into base assert_normalize_to_iter_sum( - tvm.tir.max(z, a) + ((y + x * 4 + 2) * n) + 3, + tvm.tirx.max(z, a) + ((y + x * 4 + 2) * n) + 3, var_dom([(x, 9), (y, 4), (z, 3)]), [(x, n * 4), (y, n)], - tvm.tir.max(z, a) + 2 * n + 3, + tvm.tirx.max(z, a) + 2 * n + 3, ) # order by symbolc prod @@ -1332,9 +1332,9 @@ def test_normalize_to_iter_sum(): [ tvm.arith.IterSplitExpr( tvm.arith.IterMark(x, 4096), - lower_factor=tvm.tir.const(5, "int64"), - extent=tvm.tir.const(820, "int64"), - scale=tvm.tir.const(1, "int64"), + lower_factor=tvm.tirx.const(5, "int64"), + extent=tvm.tirx.const(820, "int64"), + scale=tvm.tirx.const(1, "int64"), ) ], 0, @@ -1350,19 +1350,19 @@ def test_normalize_to_iter_sum(): def test_detect_iter_map_with_bufferload_recursion(): - n = tvm.tir.Var("n", "int32") - m = tvm.tir.Var("m", "int32") - divisor = tvm.tir.Var("divisor", "int32") + n = tvm.tirx.Var("n", "int32") + m = tvm.tirx.Var("m", "int32") + divisor = tvm.tirx.Var("divisor", "int32") - i = tvm.tir.Var("i", "int32") - j = tvm.tir.Var("j", "int32") + i = tvm.tirx.Var("i", "int32") + j = tvm.tirx.Var("j", "int32") - buffer = tvm.tir.decl_buffer((n,), "int32", name="seqlen") + buffer = tvm.tirx.decl_buffer((n,), "int32", name="seqlen") indices = [(buffer[i] + j) // divisor] iter_vars = { - i: tvm.ir.Range(tvm.tir.const(0, "int32"), n), - j: tvm.ir.Range(tvm.tir.const(0, "int32"), m), + i: tvm.ir.Range(tvm.tirx.const(0, "int32"), n), + j: tvm.ir.Range(tvm.tirx.const(0, "int32"), m), } result = tvm.arith.detect_iter_map(indices, iter_vars) diff --git a/tests/python/arith/test_arith_modular_set.py b/tests/python/arith/test_arith_modular_set.py index a103c2faee54..142a1b0d615d 100644 --- a/tests/python/arith/test_arith_modular_set.py +++ b/tests/python/arith/test_arith_modular_set.py @@ -21,7 +21,7 @@ def test_cast(): analyzer = tvm.arith.Analyzer() - x = tvm.tir.Var("x", "int8") + x = tvm.tirx.Var("x", "int8") m = analyzer.modular_set((x * 3).astype("uint32")) assert m.coeff == 3 assert m.base == 0 @@ -32,7 +32,7 @@ def test_cast(): def test_add_sub(): analyzer = tvm.arith.Analyzer() - x, y = tvm.tir.Var("x", "int64"), tvm.tir.Var("y", "int64") + x, y = tvm.tirx.Var("x", "int64"), tvm.tirx.Var("y", "int64") m = analyzer.modular_set(x * 6 + y * 4) assert m.coeff == 2 assert m.base == 0 @@ -45,7 +45,7 @@ def test_add_sub(): def test_mul(): analyzer = tvm.arith.Analyzer() - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") m = analyzer.modular_set((x * 4 + 2) * (y * 6 + 1)) assert m.coeff == 4 assert m.base == 2 @@ -53,17 +53,17 @@ def test_mul(): def test_floormod(): analyzer = tvm.arith.Analyzer() - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") - m = analyzer.modular_set(tvm.tir.floormod(x * 128 + y * 4, 256)) + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") + m = analyzer.modular_set(tvm.tirx.floormod(x * 128 + y * 4, 256)) assert m.coeff == 4 assert m.base == 0 def test_div_shift(): analyzer = tvm.arith.Analyzer() - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") # not sure if x is non-negative - tdiv = tvm.tir.truncdiv + tdiv = tvm.tirx.truncdiv m = analyzer.modular_set(tdiv(x * 4 + 2, 2)) assert m.coeff == 1 assert m.base == 0 @@ -71,7 +71,7 @@ def test_div_shift(): m = analyzer.modular_set((x * 4 + 2) >> 1) assert m.coeff == 2 assert m.base == 1 - fld = tvm.tir.floordiv + fld = tvm.tirx.floordiv m = analyzer.modular_set(fld(x * 4 + 2, 2)) assert m.coeff == 2 assert m.base == 1 @@ -84,9 +84,9 @@ def test_div_shift(): def test_mod(): analyzer = tvm.arith.Analyzer() - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") - tmod = tvm.tir.truncmod - fmod = tvm.tir.floormod + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") + tmod = tvm.tirx.truncmod + fmod = tvm.tirx.floormod # not sure if x is non-negative m = analyzer.modular_set(tmod(x * 4 + 1, 4)) assert m.coeff == 1 @@ -111,25 +111,25 @@ def test_mod(): def test_min_max_select(): analyzer = tvm.arith.Analyzer() - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") - m = analyzer.modular_set(tvm.tir.min(x * 3, y * 9)) + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") + m = analyzer.modular_set(tvm.tirx.min(x * 3, y * 9)) assert m.coeff == 3 assert m.base == 0 - m = analyzer.modular_set(tvm.tir.max(x * 3 + 1, y * 9 + 4)) + m = analyzer.modular_set(tvm.tirx.max(x * 3 + 1, y * 9 + 4)) assert m.coeff == 3 assert m.base == 1 - m = analyzer.modular_set(tvm.tir.Select(x > 0, x * 3 + 1, y * 9 + 2)) + m = analyzer.modular_set(tvm.tirx.Select(x > 0, x * 3 + 1, y * 9 + 2)) assert m.coeff == 1 assert m.base == 0 def test_mix_index(): - a = tvm.tir.Var("a", "int32") - b = tvm.tir.Var("b", "int32") + a = tvm.tirx.Var("a", "int32") + b = tvm.tirx.Var("b", "int32") analyzer = tvm.arith.Analyzer() - tdiv = tvm.tir.truncdiv + tdiv = tvm.tirx.truncdiv m = analyzer.modular_set(a * 4 + b * 6 + 7) assert m.coeff == 2 assert m.base == 1 @@ -150,16 +150,16 @@ def test_mix_index(): assert m.coeff == 3 assert m.base == 2 - m = analyzer.modular_set(a * 12 + tvm.tir.min(b * 3 * 7, 2)) + m = analyzer.modular_set(a * 12 + tvm.tirx.min(b * 3 * 7, 2)) assert m.coeff == 1 assert m.base == 0 def test_constraint_scope(): - a = tvm.tir.Var("a", "int32") - b = tvm.tir.Var("b", "int32") + a = tvm.tirx.Var("a", "int32") + b = tvm.tirx.Var("b", "int32") analyzer = tvm.arith.Analyzer() - tmod = tvm.tir.truncmod + tmod = tvm.tirx.truncmod with analyzer.constraint_scope(tmod(b, 4) == 2): m = analyzer.modular_set(b + 1) @@ -179,9 +179,9 @@ def test_constraint_scope(): def test_intersect(): - a = tvm.tir.Var("a", "int32") + a = tvm.tirx.Var("a", "int32") analyzer = tvm.arith.Analyzer() - tmod = tvm.tir.truncmod + tmod = tvm.tirx.truncmod with analyzer.constraint_scope(tmod(a, 4) == 1): with analyzer.constraint_scope(tmod(a, 3) == 1): m = analyzer.modular_set(a) @@ -198,17 +198,17 @@ def test_intersect(): def test_let(): analyzer = tvm.arith.Analyzer() - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - m = analyzer.modular_set(tvm.tir.Let(x, y * 10, x + 1)) + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + m = analyzer.modular_set(tvm.tirx.Let(x, y * 10, x + 1)) assert m.coeff == 10 assert m.base == 1 def test_bitwise_and(): analyzer = tvm.arith.Analyzer() - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") # RHS of bitwise_and is 2^p - 1 m = analyzer.modular_set((x * 16 + y * 4) & 31) diff --git a/tests/python/arith/test_arith_narrow_predicate_expression.py b/tests/python/arith/test_arith_narrow_predicate_expression.py index 16a0b85489d8..ea54d87dab92 100644 --- a/tests/python/arith/test_arith_narrow_predicate_expression.py +++ b/tests/python/arith/test_arith_narrow_predicate_expression.py @@ -18,19 +18,19 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.runtime import convert -from tvm.script import tir as T +from tvm.script import tirx as T -i = tir.Var("i", "int32") -j = tir.Var("j", "int32") -n = tir.Var("n", "int32") -m = tir.Var("m", "int32") -b = tir.Var("b", "bool") -buf = tir.decl_buffer(16, "int32", "buf") +i = tirx.Var("i", "int32") +j = tirx.Var("j", "int32") +n = tirx.Var("n", "int32") +m = tirx.Var("m", "int32") +b = tirx.Var("b", "bool") +buf = tirx.decl_buffer(16, "int32", "buf") -tir_false = tir.IntImm("bool", False) -tir_true = tir.IntImm("bool", True) +tir_false = tirx.IntImm("bool", False) +tir_true = tirx.IntImm("bool", True) before, expected = tvm.testing.parameters( # General arithmatic @@ -46,17 +46,17 @@ [n < i, T.int32(7) < i], [n <= i, T.int32(7) <= i], [n >= i, T.int32(0) >= i], - [i == n, tir.all(i <= 0, T.int32(7) <= i)], - [n == i, tir.all(T.int32(7) <= i, i <= 0)], - [i != n, tir.any(i < 0, T.int32(7) < i)], - [n != i, tir.any(T.int32(7) < i, i < 0)], + [i == n, tirx.all(i <= 0, T.int32(7) <= i)], + [n == i, tirx.all(T.int32(7) <= i, i <= 0)], + [i != n, tirx.any(i < 0, T.int32(7) < i)], + [n != i, tirx.any(T.int32(7) < i, i < 0)], [i // 4 > n, i // 4 > 7], [n < i // 4, T.int32(7) < i // 4], - [(i + n) // 4 > 0, tir.Add(i, 0) // 4 > 0], - [(i + n) // 4 == 0, tir.all(tir.Add(i, 7) // 4 <= 0, T.int32(0) <= tir.Add(i, 0) // 4)], + [(i + n) // 4 > 0, tirx.Add(i, 0) // 4 > 0], + [(i + n) // 4 == 0, tirx.all(tirx.Add(i, 7) // 4 <= 0, T.int32(0) <= tirx.Add(i, 0) // 4)], [i + n < 10, i + 7 < 10], - [i - n < 10, tir.Sub(i, 0) < 10], - [tir.Not(i < n), tir.Not(i < 7)], + [i - n < 10, tirx.Sub(i, 0) < 10], + [tirx.Not(i < n), tirx.Not(i < 7)], # Use of FloorMod should make the narrowing strategy bail out, as # it is non-monotonic. [i % 8 == n, tir_false], @@ -67,9 +67,9 @@ [buf.vload(0) > 0, tir_false], [buf.vload(0) > i, tir_false], [buf.vload(i) > 0, tir_false], - [tir.And(buf.vload(i) > 0, i <= 0), tir.And(tir_false, i <= 0)], - [tir.Or(buf.vload(i) > 0, i <= n), tir.Or(tir_false, i <= 0)], - [tir.Or(tir.Not(buf.vload(i) > 0), i <= n), tir.Or(tir_false, i <= 0)], + [tirx.And(buf.vload(i) > 0, i <= 0), tirx.And(tir_false, i <= 0)], + [tirx.Or(buf.vload(i) > 0, i <= n), tirx.Or(tir_false, i <= 0)], + [tirx.Or(tirx.Not(buf.vload(i) > 0), i <= n), tirx.Or(tir_false, i <= 0)], ) diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 2dc27e578520..ee3f67d60fe9 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -22,19 +22,19 @@ import tvm import tvm.testing -from tvm import tir -from tvm.script import tir as T -from tvm.tir import floordiv as fld -from tvm.tir import floormod as flm -from tvm.tir import truncdiv as tdiv -from tvm.tir import truncmod as tmod +from tvm import tirx +from tvm.script import tirx as T +from tvm.tirx import floordiv as fld +from tvm.tirx import floormod as flm +from tvm.tirx import truncdiv as tdiv +from tvm.tirx import truncmod as tmod class TestCase: def __init__(self, before, expected, preconditions=None): - if isinstance(before, tir.expr.EqualOp): + if isinstance(before, tirx.expr.EqualOp): before = before.asobject() - if isinstance(expected, tir.expr.EqualOp): + if isinstance(expected, tirx.expr.EqualOp): expected = expected.asobject() self.before = self._convert(before) @@ -43,7 +43,7 @@ def __init__(self, before, expected, preconditions=None): @staticmethod def _convert(expr): - if isinstance(expr, tir.expr.EqualOp): + if isinstance(expr, tirx.expr.EqualOp): return expr.asobject() elif isinstance(expr, int): return T.int32(expr) @@ -59,7 +59,7 @@ def constraint(self): elif isinstance(self.preconditions, tvm.ir.PrimExpr): return self.preconditions else: - return tvm.tir.all(*self.preconditions) + return tvm.tirx.all(*self.preconditions) @property def __name__(self): @@ -90,314 +90,327 @@ def test_simplify(self, test_case): class TestVector(BaseCompare): - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") - x64 = tvm.tir.Var("x", "int64") - vx = tvm.tir.Var("vx", "int32x2") - vc = tvm.tir.Var("vc", "bool") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") + x64 = tvm.tirx.Var("x", "int64") + vx = tvm.tirx.Var("vx", "int32x2") + vc = tvm.tirx.Var("vc", "bool") test_case = tvm.testing.parameter( # Add rules - TestCase(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x + y, 3, 4)), - TestCase(tvm.tir.Ramp(x, 1, 2) + y, tvm.tir.Ramp(x + y, 1, 2)), - TestCase(y + tvm.tir.Ramp(x, 1, 2), tvm.tir.Ramp(y + x, 1, 2)), + TestCase(tvm.tirx.Ramp(x, 1, 4) + tvm.tirx.Ramp(y, 2, 4), tvm.tirx.Ramp(x + y, 3, 4)), + TestCase(tvm.tirx.Ramp(x, 1, 2) + y, tvm.tirx.Ramp(x + y, 1, 2)), + TestCase(y + tvm.tirx.Ramp(x, 1, 2), tvm.tirx.Ramp(y + x, 1, 2)), TestCase( - tvm.tir.Ramp(x, 1, tir.vscale() * 4) + tvm.tir.Ramp(y, 2, tir.vscale() * 4), - tvm.tir.Ramp(x + y, 3, tir.vscale() * 4), + tvm.tirx.Ramp(x, 1, tirx.vscale() * 4) + tvm.tirx.Ramp(y, 2, tirx.vscale() * 4), + tvm.tirx.Ramp(x + y, 3, tirx.vscale() * 4), ), TestCase(y.astype("int32x2") + x.astype("int32x2"), (y + x).astype("int32x2")), - TestCase(tvm.tir.Broadcast(0, 4) + y, tvm.tir.Broadcast(y, 4)), + TestCase(tvm.tirx.Broadcast(0, 4) + y, tvm.tirx.Broadcast(y, 4)), # int64 lanes TestCase( - tvm.tir.Broadcast(x, 4) + tvm.tir.Ramp(0, 1, tvm.tir.IntImm(dtype="int64", value=4)), - tvm.tir.Ramp(x, 1, 4), + tvm.tirx.Broadcast(x, 4) + tvm.tirx.Ramp(0, 1, tvm.tirx.IntImm(dtype="int64", value=4)), + tvm.tirx.Ramp(x, 1, 4), ), TestCase( - tvm.tir.Broadcast(x, tvm.tir.IntImm(dtype="int64", value=4)) + tvm.tir.Ramp(0, 1, 4), - tvm.tir.Ramp(x, 1, 4), + tvm.tirx.Broadcast(x, tvm.tirx.IntImm(dtype="int64", value=4)) + tvm.tirx.Ramp(0, 1, 4), + tvm.tirx.Ramp(x, 1, 4), ), # int64 iterators with int32 lanes TestCase( - tvm.tir.Broadcast(x64, 4) + tvm.tir.Ramp(tvm.tir.IntImm(dtype="int64", value=0), 1, 4), - tvm.tir.Ramp(x64, 1, 4), + tvm.tirx.Broadcast(x64, 4) + + tvm.tirx.Ramp(tvm.tirx.IntImm(dtype="int64", value=0), 1, 4), + tvm.tirx.Ramp(x64, 1, 4), ), TestCase( - tvm.tir.Broadcast(0, tir.vscale() * 8) + y, tvm.tir.Broadcast(y, tir.vscale() * 8) + tvm.tirx.Broadcast(0, tirx.vscale() * 8) + y, tvm.tirx.Broadcast(y, tirx.vscale() * 8) ), TestCase( - tvm.tir.Ramp(x, 1, 4).astype("float32x4") + tvm.tir.Broadcast(0.0, 4), - tvm.tir.Ramp(x, 1, 4).astype("float32x4"), + tvm.tirx.Ramp(x, 1, 4).astype("float32x4") + tvm.tirx.Broadcast(0.0, 4), + tvm.tirx.Ramp(x, 1, 4).astype("float32x4"), ), # Sub rules - TestCase(tvm.tir.Ramp(x, 4, 4) - tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x - y, 2, 4)), - TestCase(tvm.tir.Ramp(x, 1, 2) - y, tvm.tir.Ramp(x - y, 1, 2)), - TestCase(y - tvm.tir.Ramp(x, 1, 2), tvm.tir.Ramp(y - x, -1, 2)), + TestCase(tvm.tirx.Ramp(x, 4, 4) - tvm.tirx.Ramp(y, 2, 4), tvm.tirx.Ramp(x - y, 2, 4)), + TestCase(tvm.tirx.Ramp(x, 1, 2) - y, tvm.tirx.Ramp(x - y, 1, 2)), + TestCase(y - tvm.tirx.Ramp(x, 1, 2), tvm.tirx.Ramp(y - x, -1, 2)), TestCase(y.astype("int32x2") - x.astype("int32x2"), (y - x).astype("int32x2")), # Mul rules TestCase(y.astype("int32x2") * x.astype("int32x2"), (y * x).astype("int32x2")), - TestCase(tvm.tir.Ramp(x, 4, 4) * 2, tvm.tir.Ramp(x * 2, 8, 4)), - TestCase(2 * tvm.tir.Ramp(x, 4, 4), tvm.tir.Ramp(x * 2, 8, 4)), - TestCase(tvm.tir.Broadcast(0, 4) * x, tvm.tir.Broadcast(0, 4)), - TestCase(tvm.tir.Broadcast(0.0, 4) * x, tvm.tir.Broadcast(0.0, 4)), + TestCase(tvm.tirx.Ramp(x, 4, 4) * 2, tvm.tirx.Ramp(x * 2, 8, 4)), + TestCase(2 * tvm.tirx.Ramp(x, 4, 4), tvm.tirx.Ramp(x * 2, 8, 4)), + TestCase(tvm.tirx.Broadcast(0, 4) * x, tvm.tirx.Broadcast(0, 4)), + TestCase(tvm.tirx.Broadcast(0.0, 4) * x, tvm.tirx.Broadcast(0.0, 4)), ## DivMod rules # trunc div TestCase(tdiv(y.astype("int32x2"), x.astype("int32x2")), tdiv(y, x).astype("int32x2")), - TestCase(tdiv(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Ramp(tdiv(x, 2), 2, 4)), + TestCase(tdiv(tvm.tirx.Ramp(x, 4, 4), 2), tvm.tirx.Ramp(tdiv(x, 2), 2, 4)), TestCase( - tdiv(tvm.tir.Ramp(x, 4, tir.vscale() * 5), 2), - tvm.tir.Ramp(tdiv(x, 2), 2, tir.vscale() * 5), + tdiv(tvm.tirx.Ramp(x, 4, tirx.vscale() * 5), 2), + tvm.tirx.Ramp(tdiv(x, 2), 2, tirx.vscale() * 5), + ), + TestCase(tdiv(tvm.tirx.Ramp(x * 8 + 1, 1, 4), 8), x.astype("int32x4"), x >= 0), + TestCase( + tdiv(tvm.tirx.Ramp(x * 8 + 15, 1, 4), 8), tdiv(tvm.tirx.Ramp(x * 8 + 15, 1, 4), 8) ), - TestCase(tdiv(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), x.astype("int32x4"), x >= 0), - TestCase(tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8)), # trunc mod TestCase(tmod(y.astype("int32x2"), x.astype("int32x2")), tmod(y, x).astype("int32x2")), - TestCase(tmod(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(tmod(x, 2), 4)), - TestCase(tmod(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), tvm.tir.Ramp(1, 1, 4), x >= 0), + TestCase(tmod(tvm.tirx.Ramp(x, 4, 4), 2), tvm.tirx.Broadcast(tmod(x, 2), 4)), + TestCase(tmod(tvm.tirx.Ramp(x * 8 + 1, 1, 4), 8), tvm.tirx.Ramp(1, 1, 4), x >= 0), TestCase( - tmod(tvm.tir.Ramp(x * 8 + 1, 1, tir.vscale() * 4), 8), - tmod(tvm.tir.Ramp(1, 1, tir.vscale() * 4), 8), + tmod(tvm.tirx.Ramp(x * 8 + 1, 1, tirx.vscale() * 4), 8), + tmod(tvm.tirx.Ramp(1, 1, tirx.vscale() * 4), 8), x >= 0, ), - TestCase(tmod(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), tmod(tvm.tir.Ramp(1, 15, 4), 8), x >= 0), + TestCase( + tmod(tvm.tirx.Ramp(x * 8 + 1, 15, 4), 8), tmod(tvm.tirx.Ramp(1, 15, 4), 8), x >= 0 + ), # floor div TestCase(fld(y.astype("int32x2"), x.astype("int32x2")), fld(y, x).astype("int32x2")), - TestCase(fld(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Ramp(fld(x, 2), 2, 4)), + TestCase(fld(tvm.tirx.Ramp(x, 4, 4), 2), tvm.tirx.Ramp(fld(x, 2), 2, 4)), TestCase( - fld(tvm.tir.Ramp(x, 4, tir.vscale() * 4), 2), - tvm.tir.Ramp(fld(x, 2), 2, tir.vscale() * 4), + fld(tvm.tirx.Ramp(x, 4, tirx.vscale() * 4), 2), + tvm.tirx.Ramp(fld(x, 2), 2, tirx.vscale() * 4), ), - TestCase(fld(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4")), - TestCase(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8)), + TestCase(fld(tvm.tirx.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4")), + TestCase(fld(tvm.tirx.Ramp(x * 8 + 15, 1, 4), 8), fld(tvm.tirx.Ramp(x * 8 + 15, 1, 4), 8)), TestCase( - fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)), tvm.tir.Ramp(fld(x, 4), 2, 5) + fld(tvm.tirx.Ramp(x, 8, 5), tvm.tirx.Broadcast(4, 5)), tvm.tirx.Ramp(fld(x, 4), 2, 5) ), TestCase( - fld(tvm.tir.Ramp(x, 8, tir.vscale() * 4), tvm.tir.Broadcast(4, tir.vscale() * 4)), - tvm.tir.Ramp(fld(x, 4), 2, tir.vscale() * 4), + fld(tvm.tirx.Ramp(x, 8, tirx.vscale() * 4), tvm.tirx.Broadcast(4, tirx.vscale() * 4)), + tvm.tirx.Ramp(fld(x, 4), 2, tirx.vscale() * 4), ), TestCase( - fld(tvm.tir.Ramp(flm(x * 4, 256), 1, 4), tvm.tir.Broadcast(8, 4)), - tvm.tir.Broadcast(fld(flm(x * 4, 256), 8), 4), + fld(tvm.tirx.Ramp(flm(x * 4, 256), 1, 4), tvm.tirx.Broadcast(8, 4)), + tvm.tirx.Broadcast(fld(flm(x * 4, 256), 8), 4), ), TestCase( - fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), - fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), + fld(tvm.tirx.Ramp(x, 7, 4), tvm.tirx.Broadcast(4, 4)), + fld(tvm.tirx.Ramp(x, 7, 4), tvm.tirx.Broadcast(4, 4)), ), TestCase( - fld(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)), tvm.tir.Broadcast(x * 2, 4) + fld(tvm.tirx.Ramp(x * 8, 1, 4), tvm.tirx.Broadcast(4, 4)), tvm.tirx.Broadcast(x * 2, 4) ), TestCase( - fld(tvm.tir.Ramp(x * 8, 1, tir.vscale() * 4), tvm.tir.Broadcast(4, tir.vscale() * 4)), - fld(tvm.tir.Ramp(x * 8, 1, tir.vscale() * 4), tvm.tir.Broadcast(4, tir.vscale() * 4)), + fld( + tvm.tirx.Ramp(x * 8, 1, tirx.vscale() * 4), tvm.tirx.Broadcast(4, tirx.vscale() * 4) + ), + fld( + tvm.tirx.Ramp(x * 8, 1, tirx.vscale() * 4), tvm.tirx.Broadcast(4, tirx.vscale() * 4) + ), ), TestCase( - fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)), - fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)), + fld(tvm.tirx.Ramp(x * 8, 3, 4), tvm.tirx.Broadcast(4, 4)), + fld(tvm.tirx.Ramp(x * 8, 3, 4), tvm.tirx.Broadcast(4, 4)), ), TestCase( - fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), tvm.tir.Broadcast(4, 4)), - fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), tvm.tir.Broadcast(4, 4)), + fld(tvm.tirx.Ramp(x * 8 + 15, 1, 4), tvm.tirx.Broadcast(4, 4)), + fld(tvm.tirx.Ramp(x * 8 + 15, 1, 4), tvm.tirx.Broadcast(4, 4)), ), TestCase( - fld(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)), - tvm.tir.Broadcast(fld(x, 16), 4), + fld(tvm.tirx.Ramp(x * 4, 1, 4), tvm.tirx.Broadcast(64, 4)), + tvm.tirx.Broadcast(fld(x, 16), 4), ), TestCase( - fld(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), - tvm.tir.Broadcast(fld(x, 8), 4), + fld(tvm.tirx.Ramp(x * 8, 2, 4), tvm.tirx.Broadcast(64, 4)), + tvm.tirx.Broadcast(fld(x, 8), 4), ), TestCase( - fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), - fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), + fld(tvm.tirx.Ramp(x * 4, 1, 5), tvm.tirx.Broadcast(64, 5)), + fld(tvm.tirx.Ramp(x * 4, 1, 5), tvm.tirx.Broadcast(64, 5)), ), # Example negative case: x = 15; [60, 61, 62, 63, 64] / 64 = [0, 0, 0, 0, 1] TestCase( - fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), - fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), + fld(tvm.tirx.Ramp(x * 4 + 3, 1, 4), tvm.tirx.Broadcast(64, 4)), + fld(tvm.tirx.Ramp(x * 4 + 3, 1, 4), tvm.tirx.Broadcast(64, 4)), ), # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [0, 1, 1, 1] TestCase( - fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), - fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), + fld(tvm.tirx.Ramp(x * 7, 1, 4), tvm.tirx.Broadcast(64, 4)), + fld(tvm.tirx.Ramp(x * 7, 1, 4), tvm.tirx.Broadcast(64, 4)), ), # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [0, 1, 1, 1] # floor mod TestCase(flm(y.astype("int32x2"), x.astype("int32x2")), flm(y, x).astype("int32x2")), - TestCase(flm(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(flm(x, 2), 4)), + TestCase(flm(tvm.tirx.Ramp(x, 4, 4), 2), tvm.tirx.Broadcast(flm(x, 2), 4)), TestCase( - flm(tvm.tir.Ramp(x, 4, tir.vscale() * 8), 2), - tvm.tir.Broadcast(flm(x, 2), tir.vscale() * 8), + flm(tvm.tirx.Ramp(x, 4, tirx.vscale() * 8), 2), + tvm.tirx.Broadcast(flm(x, 2), tirx.vscale() * 8), ), - TestCase(flm(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), tvm.tir.Ramp(1, 1, 4)), + TestCase(flm(tvm.tirx.Ramp(x * 8 + 1, 1, 4), 8), tvm.tirx.Ramp(1, 1, 4)), TestCase( - flm(tvm.tir.Ramp(x * 8 + 1, 1, tir.vscale() * 4), 8), - flm(tvm.tir.Ramp(1, 1, tir.vscale() * 4), 8), + flm(tvm.tirx.Ramp(x * 8 + 1, 1, tirx.vscale() * 4), 8), + flm(tvm.tirx.Ramp(1, 1, tirx.vscale() * 4), 8), ), - TestCase(flm(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), flm(tvm.tir.Ramp(1, 15, 4), 8)), + TestCase(flm(tvm.tirx.Ramp(x * 8 + 1, 15, 4), 8), flm(tvm.tirx.Ramp(1, 15, 4), 8)), TestCase( - flm(tvm.tir.Ramp(x, 8, 4), tvm.tir.Broadcast(4, 4)), tvm.tir.Broadcast(flm(x, 4), 4) + flm(tvm.tirx.Ramp(x, 8, 4), tvm.tirx.Broadcast(4, 4)), tvm.tirx.Broadcast(flm(x, 4), 4) ), TestCase( - flm(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), - flm(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), + flm(tvm.tirx.Ramp(x, 7, 4), tvm.tirx.Broadcast(4, 4)), + flm(tvm.tirx.Ramp(x, 7, 4), tvm.tirx.Broadcast(4, 4)), ), - TestCase(flm(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)), tvm.tir.Ramp(0, 1, 4)), + TestCase(flm(tvm.tirx.Ramp(x * 8, 1, 4), tvm.tirx.Broadcast(4, 4)), tvm.tirx.Ramp(0, 1, 4)), TestCase( - flm(tvm.tir.Ramp(x * 8, 1, 5), tvm.tir.Broadcast(4, 5)), - flm(tvm.tir.Ramp(0, 1, 5), tvm.tir.Broadcast(4, 5)), + flm(tvm.tirx.Ramp(x * 8, 1, 5), tvm.tirx.Broadcast(4, 5)), + flm(tvm.tirx.Ramp(0, 1, 5), tvm.tirx.Broadcast(4, 5)), ), TestCase( - flm(tvm.tir.Ramp(x * 8 + 7, 1, 4), tvm.tir.Broadcast(4, 4)), - flm(tvm.tir.Ramp(3, 1, 4), tvm.tir.Broadcast(4, 4)), + flm(tvm.tirx.Ramp(x * 8 + 7, 1, 4), tvm.tirx.Broadcast(4, 4)), + flm(tvm.tirx.Ramp(3, 1, 4), tvm.tirx.Broadcast(4, 4)), ), TestCase( - flm(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)), - tvm.tir.Ramp(flm(x * 4, 64), 1, 4), + flm(tvm.tirx.Ramp(x * 4, 1, 4), tvm.tirx.Broadcast(64, 4)), + tvm.tirx.Ramp(flm(x * 4, 64), 1, 4), ), TestCase( - flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), - tvm.tir.Ramp(flm(x * 8, 64), 2, 4), + flm(tvm.tirx.Ramp(x * 8, 2, 4), tvm.tirx.Broadcast(64, 4)), + tvm.tirx.Ramp(flm(x * 8, 64), 2, 4), ), TestCase( - flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), - flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), + flm(tvm.tirx.Ramp(x * 4, 1, 5), tvm.tirx.Broadcast(64, 5)), + flm(tvm.tirx.Ramp(x * 4, 1, 5), tvm.tirx.Broadcast(64, 5)), ), # Example negative case: x = 15; [60, 61, 62, 63, 64] % 64 = [60, 61, 62, 63, 0] TestCase( - flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), - flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)), + flm(tvm.tirx.Ramp(x * 4 + 3, 1, 4), tvm.tirx.Broadcast(64, 4)), + flm(tvm.tirx.Ramp(x * 4 + 3, 1, 4), tvm.tirx.Broadcast(64, 4)), ), # Example negative case: x = 15; [63, 64, 65, 66] % 64 = [63, 0, 1, 2] TestCase( - flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)), - flm(tvm.tir.Ramp(x * 2, 1, 8), tvm.tir.Broadcast(20, 8)), + flm(tvm.tirx.Ramp(x * 2, 1, 8), tvm.tirx.Broadcast(20, 8)), + flm(tvm.tirx.Ramp(x * 2, 1, 8), tvm.tirx.Broadcast(20, 8)), ), # Example negative case: x = 9; [18, 19, 20, ..., 25] % 20 = [18, 19, 0, 1, ..., 5] TestCase( - flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), - flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)), + flm(tvm.tirx.Ramp(x * 7, 1, 4), tvm.tirx.Broadcast(64, 4)), + flm(tvm.tirx.Ramp(x * 7, 1, 4), tvm.tirx.Broadcast(64, 4)), ), # Example negative case: x = 9; [63, 70, 77, 84] % 64 = [63, 6, 13, 20] # Min/Max rules TestCase( - tvm.tir.min(y.astype("int32x2"), x.astype("int32x2")), - tvm.tir.min(y, x).astype("int32x2"), + tvm.tirx.min(y.astype("int32x2"), x.astype("int32x2")), + tvm.tirx.min(y, x).astype("int32x2"), ), TestCase( - tvm.tir.min(tvm.tir.min(vx, y.astype("int32x2")), x.astype("int32x2")), - tvm.tir.min(vx, tvm.tir.min(y, x).astype("int32x2")), + tvm.tirx.min(tvm.tirx.min(vx, y.astype("int32x2")), x.astype("int32x2")), + tvm.tirx.min(vx, tvm.tirx.min(y, x).astype("int32x2")), ), TestCase( - tvm.tir.max(y.astype("int32x2"), x.astype("int32x2")), - tvm.tir.max(y, x).astype("int32x2"), + tvm.tirx.max(y.astype("int32x2"), x.astype("int32x2")), + tvm.tirx.max(y, x).astype("int32x2"), ), TestCase( - tvm.tir.max(tvm.tir.max(vx, y.astype("int32x2")), x.astype("int32x2")), - tvm.tir.max(vx, tvm.tir.max(y, x).astype("int32x2")), + tvm.tirx.max(tvm.tirx.max(vx, y.astype("int32x2")), x.astype("int32x2")), + tvm.tirx.max(vx, tvm.tirx.max(y, x).astype("int32x2")), ), ## Logical rules TestCase(y.astype("int32x2").equal(x.astype("int32x2")), (y.equal(x)).astype("boolx2")), TestCase( - tvm.tir.NE(y.astype("int32x2"), (x.astype("int32x2"))), - (tvm.tir.NE(y, x)).astype("boolx2"), + tvm.tirx.NE(y.astype("int32x2"), (x.astype("int32x2"))), + (tvm.tirx.NE(y, x)).astype("boolx2"), ), TestCase(y.astype("int32x2") > x.astype("int32x2"), (x < y).astype("boolx2")), TestCase(y.astype("int32x2") >= x.astype("int32x2"), (x <= y).astype("boolx2")), TestCase(y.astype("int32x2") < x.astype("int32x2"), (y < x).astype("boolx2")), TestCase(y.astype("int32x2") <= x.astype("int32x2"), (y <= x).astype("boolx2")), TestCase( - tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("boolx2")), - (tvm.tir.And(y <= x, vc)).astype("boolx2"), + tvm.tirx.And(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("boolx2")), + (tvm.tirx.And(y <= x, vc)).astype("boolx2"), ), TestCase( - tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("boolx2")), - (tvm.tir.Or(y <= x, vc)).astype("boolx2"), + tvm.tirx.Or(y.astype("int32x2") <= x.astype("int32x2"), vc.astype("boolx2")), + (tvm.tirx.Or(y <= x, vc)).astype("boolx2"), ), ) class TestSelect(BaseCompare): - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") test_case = tvm.testing.parameter( # Add rules TestCase( - tvm.tir.Select(x < 0, y, 0) + tvm.tir.Select(x < 0, 1, z), - tvm.tir.Select(x < 0, y + 1, z), + tvm.tirx.Select(x < 0, y, 0) + tvm.tirx.Select(x < 0, 1, z), + tvm.tirx.Select(x < 0, y + 1, z), ), TestCase( - tvm.tir.Select(x < 0, y, 1) - tvm.tir.Select(x < 0, 1, z), - tvm.tir.Select(x < 0, y + (-1), 1 - z), + tvm.tirx.Select(x < 0, y, 1) - tvm.tirx.Select(x < 0, 1, z), + tvm.tirx.Select(x < 0, y + (-1), 1 - z), ), - TestCase(tvm.tir.Select(x < 0, y, z) - y, tvm.tir.Select(x < 0, 0, z - y)), - TestCase(tvm.tir.Select(x < 0, y, z) - z, tvm.tir.Select(x < 0, y - z, 0)), + TestCase(tvm.tirx.Select(x < 0, y, z) - y, tvm.tirx.Select(x < 0, 0, z - y)), + TestCase(tvm.tirx.Select(x < 0, y, z) - z, tvm.tirx.Select(x < 0, y - z, 0)), TestCase( - tvm.tir.min(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1, z)), - tvm.tir.Select(x < 0, tvm.tir.min(y, 1), tvm.tir.min(0, z)), + tvm.tirx.min(tvm.tirx.Select(x < 0, y, 0), tvm.tirx.Select(x < 0, 1, z)), + tvm.tirx.Select(x < 0, tvm.tirx.min(y, 1), tvm.tirx.min(0, z)), ), TestCase( - tvm.tir.max(tvm.tir.Select(x < 0, y, 0), tvm.tir.Select(x < 0, 1, z)), - tvm.tir.Select(x < 0, tvm.tir.max(y, 1), tvm.tir.max(0, z)), + tvm.tirx.max(tvm.tirx.Select(x < 0, y, 0), tvm.tirx.Select(x < 0, 1, z)), + tvm.tirx.Select(x < 0, tvm.tirx.max(y, 1), tvm.tirx.max(0, z)), ), - TestCase(tvm.tir.Select(x * 3 + 1 != 0, y, z), y), - TestCase(tvm.tir.Select(x * 3 + 1 == 0, y, z), z), - TestCase(tvm.tir.Select(x > 0, y + 1, y + 1), y + 1), + TestCase(tvm.tirx.Select(x * 3 + 1 != 0, y, z), y), + TestCase(tvm.tirx.Select(x * 3 + 1 == 0, y, z), z), + TestCase(tvm.tirx.Select(x > 0, y + 1, y + 1), y + 1), ) class TestCancellation(BaseCompare): - var_int8 = tir.Var("var_int8", "int8") - var_int32 = tir.Var("var_int32", "int32") - var_int64 = tir.Var("var_int64", "int64") - var_uint8 = tir.Var("var_uint8", "uint8") - var_uint32 = tir.Var("var_uint32", "uint32") - var_uint64 = tir.Var("var_uint64", "uint64") + var_int8 = tirx.Var("var_int8", "int8") + var_int32 = tirx.Var("var_int32", "int32") + var_int64 = tirx.Var("var_int64", "int64") + var_uint8 = tirx.Var("var_uint8", "uint8") + var_uint32 = tirx.Var("var_uint32", "uint32") + var_uint64 = tirx.Var("var_uint64", "uint64") test_case = tvm.testing.parameter( - TestCase(tir.const(5, "int64") - tir.const(5, "int64"), tir.const(0, "int64")), - TestCase(tir.const(5, "uint8") - tir.const(5, "uint8"), tir.const(0, "uint8")), - TestCase(var_int8 - var_int8, tir.const(0, "int8")), - TestCase(var_int32 - var_int32, tir.const(0, "int32")), - TestCase(var_int64 - var_int64, tir.const(0, "int64")), - TestCase(var_uint8 - var_uint8, tir.const(0, "uint8")), - TestCase(var_uint32 - var_uint32, tir.const(0, "uint32")), - TestCase(var_uint64 - var_uint64, tir.const(0, "uint64")), - TestCase(tir.EQ(tir.const(5, "int64"), tir.const(5, "int64")), tir.const(True, "bool")), - TestCase(tir.EQ(tir.const(5, "uint8"), tir.const(5, "uint8")), tir.const(True, "bool")), - TestCase(tir.EQ(var_int8, var_int8), tir.const(True, "bool")), - TestCase(tir.EQ(var_int32, var_int32), tir.const(True, "bool")), - TestCase(tir.EQ(var_int64, var_int64), tir.const(True, "bool")), - TestCase(tir.EQ(var_uint8, var_uint8), tir.const(True, "bool")), - TestCase(tir.EQ(var_uint32, var_uint32), tir.const(True, "bool")), - TestCase(tir.EQ(var_uint64, var_uint64), tir.const(True, "bool")), - TestCase(tir.NE(tir.const(5, "int64"), tir.const(5, "int64")), tir.const(False, "bool")), - TestCase(tir.NE(tir.const(5, "uint8"), tir.const(5, "uint8")), tir.const(False, "bool")), - TestCase(tir.NE(var_int8, var_int8), tir.const(False, "bool")), - TestCase(tir.NE(var_int32, var_int32), tir.const(False, "bool")), - TestCase(tir.NE(var_int64, var_int64), tir.const(False, "bool")), - TestCase(tir.NE(var_uint8, var_uint8), tir.const(False, "bool")), - TestCase(tir.NE(var_uint32, var_uint32), tir.const(False, "bool")), - TestCase(tir.NE(var_uint64, var_uint64), tir.const(False, "bool")), + TestCase(tirx.const(5, "int64") - tirx.const(5, "int64"), tirx.const(0, "int64")), + TestCase(tirx.const(5, "uint8") - tirx.const(5, "uint8"), tirx.const(0, "uint8")), + TestCase(var_int8 - var_int8, tirx.const(0, "int8")), + TestCase(var_int32 - var_int32, tirx.const(0, "int32")), + TestCase(var_int64 - var_int64, tirx.const(0, "int64")), + TestCase(var_uint8 - var_uint8, tirx.const(0, "uint8")), + TestCase(var_uint32 - var_uint32, tirx.const(0, "uint32")), + TestCase(var_uint64 - var_uint64, tirx.const(0, "uint64")), + TestCase(tirx.EQ(tirx.const(5, "int64"), tirx.const(5, "int64")), tirx.const(True, "bool")), + TestCase(tirx.EQ(tirx.const(5, "uint8"), tirx.const(5, "uint8")), tirx.const(True, "bool")), + TestCase(tirx.EQ(var_int8, var_int8), tirx.const(True, "bool")), + TestCase(tirx.EQ(var_int32, var_int32), tirx.const(True, "bool")), + TestCase(tirx.EQ(var_int64, var_int64), tirx.const(True, "bool")), + TestCase(tirx.EQ(var_uint8, var_uint8), tirx.const(True, "bool")), + TestCase(tirx.EQ(var_uint32, var_uint32), tirx.const(True, "bool")), + TestCase(tirx.EQ(var_uint64, var_uint64), tirx.const(True, "bool")), + TestCase( + tirx.NE(tirx.const(5, "int64"), tirx.const(5, "int64")), tirx.const(False, "bool") + ), + TestCase( + tirx.NE(tirx.const(5, "uint8"), tirx.const(5, "uint8")), tirx.const(False, "bool") + ), + TestCase(tirx.NE(var_int8, var_int8), tirx.const(False, "bool")), + TestCase(tirx.NE(var_int32, var_int32), tirx.const(False, "bool")), + TestCase(tirx.NE(var_int64, var_int64), tirx.const(False, "bool")), + TestCase(tirx.NE(var_uint8, var_uint8), tirx.const(False, "bool")), + TestCase(tirx.NE(var_uint32, var_uint32), tirx.const(False, "bool")), + TestCase(tirx.NE(var_uint64, var_uint64), tirx.const(False, "bool")), ) class TestAddIndex(BaseCompare): - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") test_case = tvm.testing.parameter( TestCase(x + (y - x), y), TestCase(x - (y + 1) + (y + 1), x), TestCase((x - 10) + (10 - z), x - z), TestCase((x - y) + (z - x), z - y), - TestCase(tvm.tir.min(x, y - z) + z, tvm.tir.min(x + z, y)), - TestCase(tvm.tir.min(x - z, y) + z, tvm.tir.min(x, y + z)), - TestCase(tvm.tir.max(x, y - 10) + 10, tvm.tir.max(x + 10, y)), - TestCase(tvm.tir.max(x - 11, y) + 11, tvm.tir.max(x, y + 11)), - TestCase(tvm.tir.max(x, y * 2) + tvm.tir.min(x, y * 2), x + y * 2), - TestCase(tvm.tir.min(x, y * 2) + tvm.tir.max(x, y * 2), x + y * 2), - TestCase(tvm.tir.max(x, y + 2) + (-2), tvm.tir.max(x + (-2), y)), - TestCase(tvm.tir.min(x, y + 2) + (-2), tvm.tir.min(x + (-2), y)), - TestCase(tvm.tir.min(x + 2, y + 3) + (-2), tvm.tir.min(x, y + 1)), - TestCase(tvm.tir.max(0, 1 - x * 4) + x * 4, tvm.tir.max(x * 4, 1)), - TestCase(tvm.tir.max(2 - x * 4, 0) + x * 4, tvm.tir.max(x * 4, 2)), - TestCase(tvm.tir.min(0, 1 - x * 4) + x * 4, tvm.tir.min(x * 4, 1)), - TestCase(tvm.tir.min(2 - x * 4, 0) + x * 4, tvm.tir.min(x * 4, 2)), + TestCase(tvm.tirx.min(x, y - z) + z, tvm.tirx.min(x + z, y)), + TestCase(tvm.tirx.min(x - z, y) + z, tvm.tirx.min(x, y + z)), + TestCase(tvm.tirx.max(x, y - 10) + 10, tvm.tirx.max(x + 10, y)), + TestCase(tvm.tirx.max(x - 11, y) + 11, tvm.tirx.max(x, y + 11)), + TestCase(tvm.tirx.max(x, y * 2) + tvm.tirx.min(x, y * 2), x + y * 2), + TestCase(tvm.tirx.min(x, y * 2) + tvm.tirx.max(x, y * 2), x + y * 2), + TestCase(tvm.tirx.max(x, y + 2) + (-2), tvm.tirx.max(x + (-2), y)), + TestCase(tvm.tirx.min(x, y + 2) + (-2), tvm.tirx.min(x + (-2), y)), + TestCase(tvm.tirx.min(x + 2, y + 3) + (-2), tvm.tirx.min(x, y + 1)), + TestCase(tvm.tirx.max(0, 1 - x * 4) + x * 4, tvm.tirx.max(x * 4, 1)), + TestCase(tvm.tirx.max(2 - x * 4, 0) + x * 4, tvm.tirx.max(x * 4, 2)), + TestCase(tvm.tirx.min(0, 1 - x * 4) + x * 4, tvm.tirx.min(x * 4, 1)), + TestCase(tvm.tirx.min(2 - x * 4, 0) + x * 4, tvm.tirx.min(x * 4, 2)), TestCase(x * y + x * 10, (y + 10) * x), TestCase(y * x + x * 10, (y + 10) * x), TestCase(y * x + 10 * x, (y + 10) * x), TestCase(x * y + 10 * x, (y + 10) * x), - TestCase((2 * z) + tvm.tir.min(x, y - (2 * z)), tvm.tir.min(x + (z * 2), y)), + TestCase((2 * z) + tvm.tirx.min(x, y - (2 * z)), tvm.tirx.min(x + (z * 2), y)), TestCase(y * x + x, (y + 1) * x), TestCase(x * y + x, (y + 1) * x), TestCase((x + 10) + 13, x + 23), @@ -421,21 +434,21 @@ class TestAddIndex(BaseCompare): class TestSubIndex(BaseCompare): - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") test_case = tvm.testing.parameter( TestCase(x + y - y, x), TestCase(x + y - x, y), TestCase(x - (y + x), 0 - y), TestCase(x - (x + y), 0 - y), - TestCase(tvm.tir.min(x, y) - x, tvm.tir.min(0, y - x)), - TestCase(tvm.tir.min(x, y) - y, tvm.tir.min(x - y, 0)), - TestCase(tvm.tir.max(x, y) - x, tvm.tir.max(0, y - x)), - TestCase(tvm.tir.max(x, y) - y, tvm.tir.max(x - y, 0)), - TestCase(x - tvm.tir.min(x, y), tvm.tir.max(0, x - y)), - TestCase(y - tvm.tir.min(x, y), tvm.tir.max(y - x, 0)), - TestCase(x - tvm.tir.max(x, y), tvm.tir.min(0, x - y)), - TestCase(y - tvm.tir.max(x, y), tvm.tir.min(y - x, 0)), + TestCase(tvm.tirx.min(x, y) - x, tvm.tirx.min(0, y - x)), + TestCase(tvm.tirx.min(x, y) - y, tvm.tirx.min(x - y, 0)), + TestCase(tvm.tirx.max(x, y) - x, tvm.tirx.max(0, y - x)), + TestCase(tvm.tirx.max(x, y) - y, tvm.tirx.max(x - y, 0)), + TestCase(x - tvm.tirx.min(x, y), tvm.tirx.max(0, x - y)), + TestCase(y - tvm.tirx.min(x, y), tvm.tirx.max(y - x, 0)), + TestCase(x - tvm.tirx.max(x, y), tvm.tirx.min(0, x - y)), + TestCase(y - tvm.tirx.max(x, y), tvm.tirx.min(y - x, 0)), # mul co-efficient foldng TestCase(x - x, 0), TestCase(x * y - x, (y + (-1)) * x), @@ -448,26 +461,26 @@ class TestSubIndex(BaseCompare): TestCase((y + x) - (x + z), y - z), TestCase((x + y) - (z + x), y - z), TestCase((y + x) - (z + x), y - z), - TestCase(tvm.tir.min(x + y, z) - x, tvm.tir.min(y, z - x)), - TestCase(tvm.tir.min(y + x, z) - x, tvm.tir.min(y, z - x)), - TestCase(tvm.tir.min(z, x + y) - x, tvm.tir.min(z - x, y)), - TestCase(tvm.tir.min(z, y + x) - x, tvm.tir.min(z - x, y)), - TestCase(tvm.tir.max(x + y, z) - x, tvm.tir.max(y, z - x)), - TestCase(tvm.tir.max(y + x, z) - x, tvm.tir.max(y, z - x)), - TestCase(tvm.tir.max(z, x + y) - x, tvm.tir.max(z - x, y)), - TestCase(tvm.tir.max(z, y + x) - x, tvm.tir.max(z - x, y)), - TestCase(x - tvm.tir.min(x + y, z), tvm.tir.max(0 - y, x - z)), - TestCase(x - tvm.tir.min(y + x, z), tvm.tir.max(0 - y, x - z)), - TestCase(x - tvm.tir.min(z, x + y), tvm.tir.max(x - z, 0 - y)), - TestCase(x - tvm.tir.min(z, y + x), tvm.tir.max(x - z, 0 - y)), - TestCase(tvm.tir.min(x, y) - tvm.tir.min(y, x), 0), - TestCase(tvm.tir.max(x, y) - tvm.tir.max(y, x), 0), - TestCase(tvm.tir.min(x, y) - tvm.tir.min(x + 10, y + 10), -10), - TestCase(tvm.tir.min(x + 10, y + 1) - tvm.tir.min(x, y - 9), 10), - TestCase(x - tvm.tir.max(x + y, 0), tvm.tir.min(0 - y, x)), - TestCase(x - tvm.tir.max(0, x + y), tvm.tir.min(x, 0 - y)), - TestCase(x - tvm.tir.min(x + y, 0), tvm.tir.max(0 - y, x)), - TestCase(x - tvm.tir.min(0, x + y), tvm.tir.max(x, 0 - y)), + TestCase(tvm.tirx.min(x + y, z) - x, tvm.tirx.min(y, z - x)), + TestCase(tvm.tirx.min(y + x, z) - x, tvm.tirx.min(y, z - x)), + TestCase(tvm.tirx.min(z, x + y) - x, tvm.tirx.min(z - x, y)), + TestCase(tvm.tirx.min(z, y + x) - x, tvm.tirx.min(z - x, y)), + TestCase(tvm.tirx.max(x + y, z) - x, tvm.tirx.max(y, z - x)), + TestCase(tvm.tirx.max(y + x, z) - x, tvm.tirx.max(y, z - x)), + TestCase(tvm.tirx.max(z, x + y) - x, tvm.tirx.max(z - x, y)), + TestCase(tvm.tirx.max(z, y + x) - x, tvm.tirx.max(z - x, y)), + TestCase(x - tvm.tirx.min(x + y, z), tvm.tirx.max(0 - y, x - z)), + TestCase(x - tvm.tirx.min(y + x, z), tvm.tirx.max(0 - y, x - z)), + TestCase(x - tvm.tirx.min(z, x + y), tvm.tirx.max(x - z, 0 - y)), + TestCase(x - tvm.tirx.min(z, y + x), tvm.tirx.max(x - z, 0 - y)), + TestCase(tvm.tirx.min(x, y) - tvm.tirx.min(y, x), 0), + TestCase(tvm.tirx.max(x, y) - tvm.tirx.max(y, x), 0), + TestCase(tvm.tirx.min(x, y) - tvm.tirx.min(x + 10, y + 10), -10), + TestCase(tvm.tirx.min(x + 10, y + 1) - tvm.tirx.min(x, y - 9), 10), + TestCase(x - tvm.tirx.max(x + y, 0), tvm.tirx.min(0 - y, x)), + TestCase(x - tvm.tirx.max(0, x + y), tvm.tirx.min(x, 0 - y)), + TestCase(x - tvm.tirx.min(x + y, 0), tvm.tirx.max(0 - y, x)), + TestCase(x - tvm.tirx.min(0, x + y), tvm.tirx.max(x, 0 - y)), # DivMod patterns # truc div TestCase(x - tdiv(x, 3) * 3, tmod(x, 3)), @@ -516,18 +529,18 @@ class TestSubIndex(BaseCompare): class TestMulIndex(BaseCompare): - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") test_case = tvm.testing.parameter( TestCase((x + 2) * 3, x * 3 + 6), TestCase((x * 2) * 3, x * 6), - TestCase(tvm.tir.min(x, y) * tvm.tir.max(x, y), x * y), - TestCase(tvm.tir.max(x, y) * tvm.tir.min(x, y), x * y), + TestCase(tvm.tirx.min(x, y) * tvm.tirx.max(x, y), x * y), + TestCase(tvm.tirx.max(x, y) * tvm.tirx.min(x, y), x * y), TestCase((x - y) * (-2), (y - x) * 2), ) class TestDivIndex(BaseCompare): - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") non_negative = [x >= 0, y >= 0, z >= 0] test_case = tvm.testing.parameter( @@ -537,11 +550,11 @@ class TestDivIndex(BaseCompare): TestCase(tdiv(x * 2, 4), tdiv(x, 2)), TestCase(tdiv(x * 4, 2), x * 2), TestCase(tdiv(x * 4 + y, 2), x * 2 + tdiv(y, 2), non_negative), - TestCase(tdiv(tvm.tir.min(x * 6, y), 2), tvm.tir.min(x * 3, tdiv(y, 2)), non_negative), - TestCase(tdiv(tvm.tir.max(x * 6, y), 2), tvm.tir.max(x * 3, tdiv(y, 2)), non_negative), + TestCase(tdiv(tvm.tirx.min(x * 6, y), 2), tvm.tirx.min(x * 3, tdiv(y, 2)), non_negative), + TestCase(tdiv(tvm.tirx.max(x * 6, y), 2), tvm.tirx.max(x * 3, tdiv(y, 2)), non_negative), TestCase(tdiv(y + x * 4, 2), tdiv(y, 2) + x * 2, non_negative), - TestCase(tdiv(tvm.tir.min(y, x * 6), 2), tvm.tir.min(tdiv(y, 2), x * 3), non_negative), - TestCase(tdiv(tvm.tir.max(y, x * 6), 2), tvm.tir.max(tdiv(y, 2), x * 3), non_negative), + TestCase(tdiv(tvm.tirx.min(y, x * 6), 2), tvm.tirx.min(tdiv(y, 2), x * 3), non_negative), + TestCase(tdiv(tvm.tirx.max(y, x * 6), 2), tvm.tirx.max(tdiv(y, 2), x * 3), non_negative), # 3-operands TestCase(tdiv(x * 6 + y + z, 2), x * 3 + tdiv(y + z, 2), non_negative), TestCase(tdiv(x * 6 - y + (y + 3), 2), x * 3 + 1, non_negative), @@ -564,7 +577,7 @@ class TestDivIndex(BaseCompare): class TestFloordivIndex(BaseCompare): - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") test_case = tvm.testing.parameter( TestCase(fld(fld(x, 2), 3), fld(x, 6)), @@ -583,11 +596,11 @@ class TestFloordivIndex(BaseCompare): TestCase(fld(x * 360 + y, 25), x * 14, [x >= 0, x < 2, y >= 0, y < 7]), TestCase(fld(x * 360 - 8, 25), fld(x * 360 + -8, 25)), TestCase(fld(x * 4 + y, 2), x * 2 + fld(y, 2)), - TestCase(fld(tvm.tir.min(x * 6, y), 2), tvm.tir.min(x * 3, fld(y, 2))), - TestCase(fld(tvm.tir.max(x * 6, y), 2), tvm.tir.max(x * 3, fld(y, 2))), + TestCase(fld(tvm.tirx.min(x * 6, y), 2), tvm.tirx.min(x * 3, fld(y, 2))), + TestCase(fld(tvm.tirx.max(x * 6, y), 2), tvm.tirx.max(x * 3, fld(y, 2))), TestCase(fld(y + x * 4, 2), x * 2 + fld(y, 2)), - TestCase(fld(tvm.tir.min(y, x * 6), 2), tvm.tir.min(fld(y, 2), x * 3)), - TestCase(fld(tvm.tir.max(y, x * 6), 2), tvm.tir.max(fld(y, 2), x * 3)), + TestCase(fld(tvm.tirx.min(y, x * 6), 2), tvm.tirx.min(fld(y, 2), x * 3)), + TestCase(fld(tvm.tirx.max(y, x * 6), 2), tvm.tirx.max(fld(y, 2), x * 3)), # 3-operands # # TODO(Lunderberg): Remove the necessity for the preconditions @@ -618,11 +631,11 @@ class TestFloordivIndex(BaseCompare): class TestModIndex(BaseCompare): x, y, nx, ny, z = ( - tvm.tir.Var("x", "int32"), - tvm.tir.Var("y", "int32"), - tvm.tir.Var("nx", "int32"), - tvm.tir.Var("ny", "int32"), - tvm.tir.Var("z", "int32"), + tvm.tirx.Var("x", "int32"), + tvm.tirx.Var("y", "int32"), + tvm.tirx.Var("nx", "int32"), + tvm.tirx.Var("ny", "int32"), + tvm.tirx.Var("z", "int32"), ) test_case = tvm.testing.parameter( @@ -655,7 +668,7 @@ class TestModIndex(BaseCompare): class TestFloormodIndex(BaseCompare): - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") test_case = tvm.testing.parameter( TestCase(flm(x * 10, 2), 0), @@ -693,7 +706,7 @@ class TestFloorModTwo(BaseCompare): however during simplification """ - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") test_case = tvm.testing.parameter( # Removing offsets from floormod TestCase(flm(x, 2) + flm(x + 1, 2), 1), @@ -722,7 +735,7 @@ class TestFloorModPadded(BaseCompare): such that (x - x % k) must be divisible by k """ - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") test_case = tvm.testing.parameter( TestCase(flm(x - flm(x, 9), 9), 0), TestCase(flm(x - flm(x, -9), 9), 0), @@ -735,189 +748,199 @@ class TestFloorModPadded(BaseCompare): class TestMinIndex(BaseCompare): - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") test_case = tvm.testing.parameter( # const int bound - TestCase(tvm.tir.min(tmod(x, 2), tmod(y, 2) + 10), tmod(x, 2)), - TestCase(tvm.tir.min(flm(x, 2), flm(y, 2) + 10), flm(x, 2)), - TestCase(tvm.tir.min(x + 1, x + 10), x + 1), - TestCase(tvm.tir.min(x + 111, x + 10), x + 10), - TestCase(tvm.tir.min(x + 1, x), x), - TestCase(tvm.tir.min(x, x + 2), x), - TestCase(tvm.tir.min(1 - x, 2 - x), 1 - x), - TestCase(tvm.tir.min(3 - x, 2 - x), 2 - x), - TestCase(tvm.tir.min(tvm.tir.max(x, y), tvm.tir.min(x, y)), tvm.tir.min(x, y)), - TestCase(tvm.tir.min(tvm.tir.max(x, y), tvm.tir.min(y, x)), tvm.tir.min(x, y)), - TestCase(tvm.tir.min(tvm.tir.max(x, y), x), x), - TestCase(tvm.tir.min(tvm.tir.max(y, x), x), x), - TestCase(tvm.tir.min(tvm.tir.min(x, y), x), tvm.tir.min(x, y)), - TestCase(tvm.tir.min(tvm.tir.min(x, y), y), tvm.tir.min(x, y)), - TestCase(tvm.tir.min(x, tvm.tir.max(x, y)), x), - TestCase(tvm.tir.min(x, tvm.tir.max(y, x)), x), - TestCase(tvm.tir.min(x, tvm.tir.min(x, y)), tvm.tir.min(x, y)), - TestCase(tvm.tir.min(y, tvm.tir.min(x, y)), tvm.tir.min(x, y)), - TestCase( - tvm.tir.min(tvm.tir.min(tvm.tir.min(x, y), z), y), tvm.tir.min(tvm.tir.min(x, y), z) - ), - TestCase( - tvm.tir.min(tvm.tir.min(tvm.tir.min(tvm.tir.min(x, y), z), x * 2), y), - tvm.tir.min(tvm.tir.min(tvm.tir.min(x, y), z), x * 2), - ), - TestCase( - tvm.tir.min( - tvm.tir.min(tvm.tir.min(tvm.tir.min(tvm.tir.min(x, y), z), x * 2), z * 2), y + TestCase(tvm.tirx.min(tmod(x, 2), tmod(y, 2) + 10), tmod(x, 2)), + TestCase(tvm.tirx.min(flm(x, 2), flm(y, 2) + 10), flm(x, 2)), + TestCase(tvm.tirx.min(x + 1, x + 10), x + 1), + TestCase(tvm.tirx.min(x + 111, x + 10), x + 10), + TestCase(tvm.tirx.min(x + 1, x), x), + TestCase(tvm.tirx.min(x, x + 2), x), + TestCase(tvm.tirx.min(1 - x, 2 - x), 1 - x), + TestCase(tvm.tirx.min(3 - x, 2 - x), 2 - x), + TestCase(tvm.tirx.min(tvm.tirx.max(x, y), tvm.tirx.min(x, y)), tvm.tirx.min(x, y)), + TestCase(tvm.tirx.min(tvm.tirx.max(x, y), tvm.tirx.min(y, x)), tvm.tirx.min(x, y)), + TestCase(tvm.tirx.min(tvm.tirx.max(x, y), x), x), + TestCase(tvm.tirx.min(tvm.tirx.max(y, x), x), x), + TestCase(tvm.tirx.min(tvm.tirx.min(x, y), x), tvm.tirx.min(x, y)), + TestCase(tvm.tirx.min(tvm.tirx.min(x, y), y), tvm.tirx.min(x, y)), + TestCase(tvm.tirx.min(x, tvm.tirx.max(x, y)), x), + TestCase(tvm.tirx.min(x, tvm.tirx.max(y, x)), x), + TestCase(tvm.tirx.min(x, tvm.tirx.min(x, y)), tvm.tirx.min(x, y)), + TestCase(tvm.tirx.min(y, tvm.tirx.min(x, y)), tvm.tirx.min(x, y)), + TestCase( + tvm.tirx.min(tvm.tirx.min(tvm.tirx.min(x, y), z), y), + tvm.tirx.min(tvm.tirx.min(x, y), z), + ), + TestCase( + tvm.tirx.min(tvm.tirx.min(tvm.tirx.min(tvm.tirx.min(x, y), z), x * 2), y), + tvm.tirx.min(tvm.tirx.min(tvm.tirx.min(x, y), z), x * 2), + ), + TestCase( + tvm.tirx.min( + tvm.tirx.min(tvm.tirx.min(tvm.tirx.min(tvm.tirx.min(x, y), z), x * 2), z * 2), y ), - tvm.tir.min(tvm.tir.min(tvm.tir.min(tvm.tir.min(x, y), z), x * 2), z * 2), + tvm.tirx.min(tvm.tirx.min(tvm.tirx.min(tvm.tirx.min(x, y), z), x * 2), z * 2), ), TestCase( - tvm.tir.min(tvm.tir.max(x, y), tvm.tir.max(x, z)), tvm.tir.max(tvm.tir.min(y, z), x) + tvm.tirx.min(tvm.tirx.max(x, y), tvm.tirx.max(x, z)), + tvm.tirx.max(tvm.tirx.min(y, z), x), ), TestCase( - tvm.tir.min(tvm.tir.max(x, y), tvm.tir.max(z, x)), tvm.tir.max(tvm.tir.min(y, z), x) + tvm.tirx.min(tvm.tirx.max(x, y), tvm.tirx.max(z, x)), + tvm.tirx.max(tvm.tirx.min(y, z), x), ), TestCase( - tvm.tir.min(tvm.tir.max(y, x), tvm.tir.max(x, z)), tvm.tir.max(tvm.tir.min(y, z), x) + tvm.tirx.min(tvm.tirx.max(y, x), tvm.tirx.max(x, z)), + tvm.tirx.max(tvm.tirx.min(y, z), x), ), TestCase( - tvm.tir.min(tvm.tir.max(y, x), tvm.tir.max(z, x)), tvm.tir.max(tvm.tir.min(y, z), x) + tvm.tirx.min(tvm.tirx.max(y, x), tvm.tirx.max(z, x)), + tvm.tirx.max(tvm.tirx.min(y, z), x), ), - TestCase(tvm.tir.min(y + x, z + x), tvm.tir.min(y, z) + x), - TestCase(tvm.tir.min(y + x, x + z), tvm.tir.min(y, z) + x), - TestCase(tvm.tir.min(x + y, z + x), tvm.tir.min(y, z) + x), - TestCase(tvm.tir.min(x + y, x + z), tvm.tir.min(y, z) + x), - TestCase(tvm.tir.min(x - y, x - z), x - tvm.tir.max(y, z)), - TestCase(tvm.tir.min(y - x, z - x), tvm.tir.min(y, z) - x), - TestCase(tvm.tir.min(tvm.tir.min(x, 1), 10), tvm.tir.min(x, 1)), - TestCase(tvm.tir.min(tvm.tir.min(x, 11), 10), tvm.tir.min(x, 10)), - TestCase(tvm.tir.min(x * 3, 9), tvm.tir.min(x, 3) * 3), - TestCase(tvm.tir.min(x * 2, 0), tvm.tir.min(x, 0) * 2), - TestCase(tvm.tir.min(0 - x * 2, 0), tvm.tir.max(x, 0) * -2), - TestCase(tvm.tir.min(3 - x, 2), 3 - tvm.tir.max(x, 1)), - TestCase(tvm.tir.min(x * (-2), -4), tvm.tir.max(x, 2) * -2), - TestCase(tvm.tir.min(x * (-2), 4), tvm.tir.max(x, -2) * -2), - TestCase(tvm.tir.min(x * (0), 4), 0), - TestCase(tvm.tir.min(x * (0), -4), -4), + TestCase(tvm.tirx.min(y + x, z + x), tvm.tirx.min(y, z) + x), + TestCase(tvm.tirx.min(y + x, x + z), tvm.tirx.min(y, z) + x), + TestCase(tvm.tirx.min(x + y, z + x), tvm.tirx.min(y, z) + x), + TestCase(tvm.tirx.min(x + y, x + z), tvm.tirx.min(y, z) + x), + TestCase(tvm.tirx.min(x - y, x - z), x - tvm.tirx.max(y, z)), + TestCase(tvm.tirx.min(y - x, z - x), tvm.tirx.min(y, z) - x), + TestCase(tvm.tirx.min(tvm.tirx.min(x, 1), 10), tvm.tirx.min(x, 1)), + TestCase(tvm.tirx.min(tvm.tirx.min(x, 11), 10), tvm.tirx.min(x, 10)), + TestCase(tvm.tirx.min(x * 3, 9), tvm.tirx.min(x, 3) * 3), + TestCase(tvm.tirx.min(x * 2, 0), tvm.tirx.min(x, 0) * 2), + TestCase(tvm.tirx.min(0 - x * 2, 0), tvm.tirx.max(x, 0) * -2), + TestCase(tvm.tirx.min(3 - x, 2), 3 - tvm.tirx.max(x, 1)), + TestCase(tvm.tirx.min(x * (-2), -4), tvm.tirx.max(x, 2) * -2), + TestCase(tvm.tirx.min(x * (-2), 4), tvm.tirx.max(x, -2) * -2), + TestCase(tvm.tirx.min(x * (0), 4), 0), + TestCase(tvm.tirx.min(x * (0), -4), -4), # DivMod rules # truc div - TestCase(tvm.tir.min(tdiv(x + 3, 4) * 4, x), x), - TestCase(tvm.tir.min(x, tdiv(x + 3, 4) * 4), x), - TestCase(tvm.tir.min(tdiv(x + 3, 4) * 4, tvm.tir.max(x, 4)), tvm.tir.max(x, 4), x > 0), - TestCase(tvm.tir.min(tvm.tir.max(x, 4), tdiv(x + 3, 4) * 4), tvm.tir.max(x, 4), x > 0), - TestCase(tvm.tir.min(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.tir.min(x, y), 10)), - TestCase(tvm.tir.min(tdiv(x, (-10)), tdiv(y, (-10))), tdiv(tvm.tir.max(x, y), (-10))), + TestCase(tvm.tirx.min(tdiv(x + 3, 4) * 4, x), x), + TestCase(tvm.tirx.min(x, tdiv(x + 3, 4) * 4), x), + TestCase(tvm.tirx.min(tdiv(x + 3, 4) * 4, tvm.tirx.max(x, 4)), tvm.tirx.max(x, 4), x > 0), + TestCase(tvm.tirx.min(tvm.tirx.max(x, 4), tdiv(x + 3, 4) * 4), tvm.tirx.max(x, 4), x > 0), + TestCase(tvm.tirx.min(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.tirx.min(x, y), 10)), + TestCase(tvm.tirx.min(tdiv(x, (-10)), tdiv(y, (-10))), tdiv(tvm.tirx.max(x, y), (-10))), # floor div - TestCase(tvm.tir.min(fld(x + 3, 4) * 4, x), x), - TestCase(tvm.tir.min(x, fld(x + 3, 4) * 4), x), - TestCase(tvm.tir.min(x, fld(x, 4) * 4), fld(x, 4) * 4), - TestCase(tvm.tir.min(fld(x + 3, 4) * 4, tvm.tir.max(x, 4)), tvm.tir.max(x, 4), x > 0), - TestCase(tvm.tir.min(tvm.tir.max(x, 4), fld(x + 3, 4) * 4), tvm.tir.max(x, 4), x > 0), - TestCase(tvm.tir.min(fld(x, 10), fld(y, 10)), fld(tvm.tir.min(x, y), 10)), - TestCase(tvm.tir.min(fld(x, (-10)), fld(y, (-10))), fld(tvm.tir.max(x, y), (-10))), + TestCase(tvm.tirx.min(fld(x + 3, 4) * 4, x), x), + TestCase(tvm.tirx.min(x, fld(x + 3, 4) * 4), x), + TestCase(tvm.tirx.min(x, fld(x, 4) * 4), fld(x, 4) * 4), + TestCase(tvm.tirx.min(fld(x + 3, 4) * 4, tvm.tirx.max(x, 4)), tvm.tirx.max(x, 4), x > 0), + TestCase(tvm.tirx.min(tvm.tirx.max(x, 4), fld(x + 3, 4) * 4), tvm.tirx.max(x, 4), x > 0), + TestCase(tvm.tirx.min(fld(x, 10), fld(y, 10)), fld(tvm.tirx.min(x, y), 10)), + TestCase(tvm.tirx.min(fld(x, (-10)), fld(y, (-10))), fld(tvm.tirx.max(x, y), (-10))), ) class TestMaxIndex(BaseCompare): - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") test_case = tvm.testing.parameter( # const int bound - TestCase(tvm.tir.max(tmod(x, 2), tmod(y, 2) + 10), tmod(y, 2) + 10), - TestCase(tvm.tir.max(flm(x, 2), flm(y, 2) + 10), flm(y, 2) + 10), - TestCase(tvm.tir.max(x + 1, x + 10), x + 10), - TestCase(tvm.tir.max(x + 111, x + 10), x + 111), - TestCase(tvm.tir.max(x + 1, x), x + 1), - TestCase(tvm.tir.max(x, x + 2), x + 2), - TestCase(tvm.tir.max(1 - x, 2 - x), 2 - x), - TestCase(tvm.tir.max(3 - x, 2 - x), 3 - x), - TestCase(tvm.tir.max(tvm.tir.min(x, y), tvm.tir.max(x, y)), tvm.tir.max(x, y)), - TestCase(tvm.tir.max(tvm.tir.min(x, y), tvm.tir.max(y, x)), tvm.tir.max(x, y)), - TestCase(tvm.tir.max(tvm.tir.min(x, y), x), x), - TestCase(tvm.tir.max(tvm.tir.min(y, x), x), x), - TestCase(tvm.tir.max(tvm.tir.max(x, y), x), tvm.tir.max(x, y)), - TestCase(tvm.tir.max(tvm.tir.max(x, y), y), tvm.tir.max(x, y)), - TestCase(tvm.tir.max(x, tvm.tir.min(x, y)), x), - TestCase(tvm.tir.max(x, tvm.tir.min(y, x)), x), - TestCase(tvm.tir.max(x, tvm.tir.max(x, y)), tvm.tir.max(x, y)), - TestCase(tvm.tir.max(y, tvm.tir.max(x, y)), tvm.tir.max(x, y)), - TestCase( - tvm.tir.max(tvm.tir.max(tvm.tir.max(x, y), z), y), tvm.tir.max(tvm.tir.max(x, y), z) - ), - TestCase( - tvm.tir.max(tvm.tir.max(tvm.tir.max(tvm.tir.max(x, y), z), x * 2), y), - tvm.tir.max(tvm.tir.max(tvm.tir.max(x, y), z), x * 2), - ), - TestCase( - tvm.tir.max( - tvm.tir.max(tvm.tir.max(tvm.tir.max(tvm.tir.max(x, y), z), x * 2), z * 2), y + TestCase(tvm.tirx.max(tmod(x, 2), tmod(y, 2) + 10), tmod(y, 2) + 10), + TestCase(tvm.tirx.max(flm(x, 2), flm(y, 2) + 10), flm(y, 2) + 10), + TestCase(tvm.tirx.max(x + 1, x + 10), x + 10), + TestCase(tvm.tirx.max(x + 111, x + 10), x + 111), + TestCase(tvm.tirx.max(x + 1, x), x + 1), + TestCase(tvm.tirx.max(x, x + 2), x + 2), + TestCase(tvm.tirx.max(1 - x, 2 - x), 2 - x), + TestCase(tvm.tirx.max(3 - x, 2 - x), 3 - x), + TestCase(tvm.tirx.max(tvm.tirx.min(x, y), tvm.tirx.max(x, y)), tvm.tirx.max(x, y)), + TestCase(tvm.tirx.max(tvm.tirx.min(x, y), tvm.tirx.max(y, x)), tvm.tirx.max(x, y)), + TestCase(tvm.tirx.max(tvm.tirx.min(x, y), x), x), + TestCase(tvm.tirx.max(tvm.tirx.min(y, x), x), x), + TestCase(tvm.tirx.max(tvm.tirx.max(x, y), x), tvm.tirx.max(x, y)), + TestCase(tvm.tirx.max(tvm.tirx.max(x, y), y), tvm.tirx.max(x, y)), + TestCase(tvm.tirx.max(x, tvm.tirx.min(x, y)), x), + TestCase(tvm.tirx.max(x, tvm.tirx.min(y, x)), x), + TestCase(tvm.tirx.max(x, tvm.tirx.max(x, y)), tvm.tirx.max(x, y)), + TestCase(tvm.tirx.max(y, tvm.tirx.max(x, y)), tvm.tirx.max(x, y)), + TestCase( + tvm.tirx.max(tvm.tirx.max(tvm.tirx.max(x, y), z), y), + tvm.tirx.max(tvm.tirx.max(x, y), z), + ), + TestCase( + tvm.tirx.max(tvm.tirx.max(tvm.tirx.max(tvm.tirx.max(x, y), z), x * 2), y), + tvm.tirx.max(tvm.tirx.max(tvm.tirx.max(x, y), z), x * 2), + ), + TestCase( + tvm.tirx.max( + tvm.tirx.max(tvm.tirx.max(tvm.tirx.max(tvm.tirx.max(x, y), z), x * 2), z * 2), y ), - tvm.tir.max(tvm.tir.max(tvm.tir.max(tvm.tir.max(x, y), z), x * 2), z * 2), + tvm.tirx.max(tvm.tirx.max(tvm.tirx.max(tvm.tirx.max(x, y), z), x * 2), z * 2), ), TestCase( - tvm.tir.max(tvm.tir.min(x, y), tvm.tir.min(x, z)), tvm.tir.min(tvm.tir.max(y, z), x) + tvm.tirx.max(tvm.tirx.min(x, y), tvm.tirx.min(x, z)), + tvm.tirx.min(tvm.tirx.max(y, z), x), ), TestCase( - tvm.tir.max(tvm.tir.min(x, y), tvm.tir.min(z, x)), tvm.tir.min(tvm.tir.max(y, z), x) + tvm.tirx.max(tvm.tirx.min(x, y), tvm.tirx.min(z, x)), + tvm.tirx.min(tvm.tirx.max(y, z), x), ), TestCase( - tvm.tir.max(tvm.tir.min(y, x), tvm.tir.min(x, z)), tvm.tir.min(tvm.tir.max(y, z), x) + tvm.tirx.max(tvm.tirx.min(y, x), tvm.tirx.min(x, z)), + tvm.tirx.min(tvm.tirx.max(y, z), x), ), TestCase( - tvm.tir.max(tvm.tir.min(y, x), tvm.tir.min(z, x)), tvm.tir.min(tvm.tir.max(y, z), x) + tvm.tirx.max(tvm.tirx.min(y, x), tvm.tirx.min(z, x)), + tvm.tirx.min(tvm.tirx.max(y, z), x), ), - TestCase(tvm.tir.max(y + x, z + x), tvm.tir.max(y, z) + x), - TestCase(tvm.tir.max(y + x, x + z), tvm.tir.max(y, z) + x), - TestCase(tvm.tir.max(x + y, z + x), tvm.tir.max(y, z) + x), - TestCase(tvm.tir.max(x + y, x + z), tvm.tir.max(y, z) + x), - TestCase(tvm.tir.max(x - y, x - z), x - tvm.tir.min(y, z)), - TestCase(tvm.tir.max(y - x, z - x), tvm.tir.max(y, z) - x), - TestCase(tvm.tir.max(tvm.tir.max(x, 1), 10), tvm.tir.max(x, 10)), - TestCase(tvm.tir.max(tvm.tir.max(x, 11), 10), tvm.tir.max(x, 11)), - TestCase(tvm.tir.max(x * 3, 9), tvm.tir.max(x, 3) * 3), - TestCase(tvm.tir.max(3 - x, 1), 3 - tvm.tir.min(x, 2)), - TestCase(tvm.tir.max(x * 2, 0), tvm.tir.max(x, 0) * 2), - TestCase(tvm.tir.max(0 - x * 2, 0), tvm.tir.min(x, 0) * -2), - TestCase(tvm.tir.max(x * (-2), -4), tvm.tir.min(x, 2) * -2), - TestCase(tvm.tir.max(x * (-2), 4), tvm.tir.min(x, -2) * -2), - TestCase(tvm.tir.max(x * (0), 4), 4), - TestCase(tvm.tir.max(x * (0), -4), 0), + TestCase(tvm.tirx.max(y + x, z + x), tvm.tirx.max(y, z) + x), + TestCase(tvm.tirx.max(y + x, x + z), tvm.tirx.max(y, z) + x), + TestCase(tvm.tirx.max(x + y, z + x), tvm.tirx.max(y, z) + x), + TestCase(tvm.tirx.max(x + y, x + z), tvm.tirx.max(y, z) + x), + TestCase(tvm.tirx.max(x - y, x - z), x - tvm.tirx.min(y, z)), + TestCase(tvm.tirx.max(y - x, z - x), tvm.tirx.max(y, z) - x), + TestCase(tvm.tirx.max(tvm.tirx.max(x, 1), 10), tvm.tirx.max(x, 10)), + TestCase(tvm.tirx.max(tvm.tirx.max(x, 11), 10), tvm.tirx.max(x, 11)), + TestCase(tvm.tirx.max(x * 3, 9), tvm.tirx.max(x, 3) * 3), + TestCase(tvm.tirx.max(3 - x, 1), 3 - tvm.tirx.min(x, 2)), + TestCase(tvm.tirx.max(x * 2, 0), tvm.tirx.max(x, 0) * 2), + TestCase(tvm.tirx.max(0 - x * 2, 0), tvm.tirx.min(x, 0) * -2), + TestCase(tvm.tirx.max(x * (-2), -4), tvm.tirx.min(x, 2) * -2), + TestCase(tvm.tirx.max(x * (-2), 4), tvm.tirx.min(x, -2) * -2), + TestCase(tvm.tirx.max(x * (0), 4), 4), + TestCase(tvm.tirx.max(x * (0), -4), 0), # DivMod rules # truc div - TestCase(tvm.tir.max(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.tir.max(x, y), 10)), - TestCase(tvm.tir.max(tdiv(x, (-10)), tdiv(y, (-10))), tdiv(tvm.tir.min(x, y), (-10))), - TestCase(tvm.tir.max(tdiv(x + 3, 4) * 4, x), tdiv(x + 3, 4) * 4), + TestCase(tvm.tirx.max(tdiv(x, 10), tdiv(y, 10)), tdiv(tvm.tirx.max(x, y), 10)), + TestCase(tvm.tirx.max(tdiv(x, (-10)), tdiv(y, (-10))), tdiv(tvm.tirx.min(x, y), (-10))), + TestCase(tvm.tirx.max(tdiv(x + 3, 4) * 4, x), tdiv(x + 3, 4) * 4), # floordiv - TestCase(tvm.tir.max(fld(x, 10), fld(y, 10)), fld(tvm.tir.max(x, y), 10)), - TestCase(tvm.tir.max(fld(x, (-10)), fld(y, (-10))), fld(tvm.tir.min(x, y), (-10))), - TestCase(tvm.tir.max(fld(x + 3, 4) * 4, x), fld(x + 3, 4) * 4), - TestCase(tvm.tir.max(fld(x, 4) * 4, x), x), - TestCase(tvm.tir.max(x, fld(x, 4) * 4), x), + TestCase(tvm.tirx.max(fld(x, 10), fld(y, 10)), fld(tvm.tirx.max(x, y), 10)), + TestCase(tvm.tirx.max(fld(x, (-10)), fld(y, (-10))), fld(tvm.tirx.min(x, y), (-10))), + TestCase(tvm.tirx.max(fld(x + 3, 4) * 4, x), fld(x + 3, 4) * 4), + TestCase(tvm.tirx.max(fld(x, 4) * 4, x), x), + TestCase(tvm.tirx.max(x, fld(x, 4) * 4), x), ) class TestScalableIndex(BaseCompare): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") test_case = tvm.testing.parameter( # MinNode - TestCase(tvm.tir.min(x + tir.vscale() * 4, x), x), - TestCase(tvm.tir.min(x - tir.vscale() * 4, x), x + tir.vscale() * -4), - TestCase(tvm.tir.min(x + tir.vscale() * 4, x + tir.vscale() * 8), tir.vscale() * 4 + x), - TestCase(tvm.tir.min(x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), x), x), - TestCase(tvm.tir.min(tir.vscale() * x, tir.vscale() * y), tir.vscale() * x, x < y), + TestCase(tvm.tirx.min(x + tirx.vscale() * 4, x), x), + TestCase(tvm.tirx.min(x - tirx.vscale() * 4, x), x + tirx.vscale() * -4), + TestCase(tvm.tirx.min(x + tirx.vscale() * 4, x + tirx.vscale() * 8), tirx.vscale() * 4 + x), + TestCase(tvm.tirx.min(x + tirx.vscale() * 4 - flm(4, tirx.vscale() * 4), x), x), + TestCase(tvm.tirx.min(tirx.vscale() * x, tirx.vscale() * y), tirx.vscale() * x, x < y), # MaxNode - TestCase(tvm.tir.max(x + tir.vscale() * 4, x), x + tir.vscale() * 4), - TestCase(tvm.tir.max(x - tir.vscale() * 4, x), x), - TestCase(tvm.tir.max(x + tir.vscale() * 4, x + tir.vscale() * 4), x + tir.vscale() * 4), + TestCase(tvm.tirx.max(x + tirx.vscale() * 4, x), x + tirx.vscale() * 4), + TestCase(tvm.tirx.max(x - tirx.vscale() * 4, x), x), + TestCase(tvm.tirx.max(x + tirx.vscale() * 4, x + tirx.vscale() * 4), x + tirx.vscale() * 4), TestCase( - tvm.tir.max(x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), x), - x + tir.vscale() * 4 - flm(4, tir.vscale() * 4), + tvm.tirx.max(x + tirx.vscale() * 4 - flm(4, tirx.vscale() * 4), x), + x + tirx.vscale() * 4 - flm(4, tirx.vscale() * 4), ), - TestCase(tvm.tir.max(tir.vscale() * x, tir.vscale() * y), tir.vscale() * x, x > y), + TestCase(tvm.tirx.max(tirx.vscale() * x, tirx.vscale() * y), tirx.vscale() * x, x > y), # FloorDiv - TestCase(fld(x * tir.vscale() * 4 + y, tir.vscale() * 4), x + fld(y, tir.vscale() * 4)), - TestCase(fld(x, tir.vscale() * 4), 0, [x >= 0, x < tir.vscale() * 4]), + TestCase(fld(x * tirx.vscale() * 4 + y, tirx.vscale() * 4), x + fld(y, tirx.vscale() * 4)), + TestCase(fld(x, tirx.vscale() * 4), 0, [x >= 0, x < tirx.vscale() * 4]), # FloorMod - TestCase(flm(x * tir.vscale() * 4 + y, tir.vscale() * 4), flm(y, tir.vscale() * 4)), - TestCase(flm(x, tir.vscale() * 4), x, [x >= 0, x < tir.vscale() * 4]), + TestCase(flm(x * tirx.vscale() * 4 + y, tirx.vscale() * 4), flm(y, tirx.vscale() * 4)), + TestCase(flm(x, tirx.vscale() * 4), x, [x >= 0, x < tirx.vscale() * 4]), ) def test_simplify(self, test_case): @@ -926,22 +949,22 @@ def test_simplify(self, test_case): class TestComparisons(BaseCompare): - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") test_case = tvm.testing.parameter( # const int bound - TestCase((tmod(x, 2) + 10).equal(0), tvm.tir.const(0, "bool")), - TestCase(tvm.tir.NE(tmod(x, 2) + 10, 0), tvm.tir.const(1, "bool")), - TestCase(tmod(x, 2) + 10 > 1, tvm.tir.const(1, "bool")), - TestCase(tmod(x, 2) + 10 <= 1, tvm.tir.const(0, "bool")), - TestCase(flm(x, 2) + 2 > 1, tvm.tir.const(1, "bool")), - TestCase(flm(x, 2) + 10 <= 1, tvm.tir.const(0, "bool")), - TestCase(x * 3 + 10 == 0, tvm.tir.const(0, "bool")), - TestCase(x * 3 + 10 != 0, tvm.tir.const(1, "bool")), + TestCase((tmod(x, 2) + 10).equal(0), tvm.tirx.const(0, "bool")), + TestCase(tvm.tirx.NE(tmod(x, 2) + 10, 0), tvm.tirx.const(1, "bool")), + TestCase(tmod(x, 2) + 10 > 1, tvm.tirx.const(1, "bool")), + TestCase(tmod(x, 2) + 10 <= 1, tvm.tirx.const(0, "bool")), + TestCase(flm(x, 2) + 2 > 1, tvm.tirx.const(1, "bool")), + TestCase(flm(x, 2) + 10 <= 1, tvm.tirx.const(0, "bool")), + TestCase(x * 3 + 10 == 0, tvm.tirx.const(0, "bool")), + TestCase(x * 3 + 10 != 0, tvm.tirx.const(1, "bool")), # canonicalization TestCase((x - 10).equal(0), x.equal(10)), TestCase((10 - x).equal(0), x.equal(10)), - TestCase((x * y).equal(0), tvm.tir.Or(x.equal(0), y.equal(0))), + TestCase((x * y).equal(0), tvm.tirx.Or(x.equal(0), y.equal(0))), # Write LT as LE for integer arguments, if possible TestCase(x - 1 < y, x <= y), TestCase(x + (-1) < y, x <= y), @@ -966,153 +989,153 @@ class TestComparisons(BaseCompare): TestCase(y + x < z + x, y < z), TestCase(y - x < z - x, y < z), TestCase(x - y < x - z, z < y), - TestCase(x < z + x, tvm.tir.LT(0, z)), - TestCase(x < x + z, tvm.tir.LT(0, z)), - TestCase(100 < x + 1, tvm.tir.LT(99, x)), - TestCase(1 < 100 - x, tvm.tir.LT(x, 99)), + TestCase(x < z + x, tvm.tirx.LT(0, z)), + TestCase(x < x + z, tvm.tirx.LT(0, z)), + TestCase(100 < x + 1, tvm.tirx.LT(99, x)), + TestCase(1 < 100 - x, tvm.tirx.LT(x, 99)), TestCase(x * 3 < y * 3, x < y), TestCase(x * (-3) < y * (-3), y < x), TestCase(x * 3 >= y * 3, y <= x), - TestCase(x * 4 >= 2, tvm.tir.LE(1, x)), - TestCase(x * 2 >= 50, tvm.tir.LE(25, x)), + TestCase(x * 4 >= 2, tvm.tirx.LE(1, x)), + TestCase(x * 2 >= 50, tvm.tirx.LE(25, x)), TestCase(x * 4 <= 2, x <= 0), - TestCase((0 - x * 3) <= 0, tvm.tir.LE(0, x)), - TestCase((0 - x * 3) >= 0, tvm.tir.LE(x, 0)), + TestCase((0 - x * 3) <= 0, tvm.tirx.LE(0, x)), + TestCase((0 - x * 3) >= 0, tvm.tirx.LE(x, 0)), TestCase(2 * x <= 0, x <= 0), - TestCase(x * 2 >= 3, tvm.tir.LE(2, x)), - TestCase(x * 2 >= 2, tvm.tir.LE(1, x)), - TestCase(x * 2 >= 1, tvm.tir.LE(1, x)), - TestCase(x * 2 >= 0, tvm.tir.LE(0, x)), - TestCase(x * 2 >= -1, tvm.tir.LE(0, x)), - TestCase(x * 2 >= -2, tvm.tir.LE(-1, x)), - TestCase(x * 2 >= -3, tvm.tir.LE(-1, x)), - TestCase(x * 2 <= 3, tvm.tir.LE(x, 1)), - TestCase(x * 2 <= 2, tvm.tir.LE(x, 1)), - TestCase(x * 2 <= 1, tvm.tir.LE(x, 0)), - TestCase(x * 2 <= 0, tvm.tir.LE(x, 0)), - TestCase(x * 2 <= -1, tvm.tir.LE(x, -1)), - TestCase(x * 2 <= -2, tvm.tir.LE(x, -1)), - TestCase(x * 2 <= -3, tvm.tir.LE(x, -2)), - TestCase(x * (-2) >= 3, tvm.tir.LE(x, -2)), - TestCase(x * (-2) >= 2, tvm.tir.LE(x, -1)), - TestCase(x * (-2) >= 1, tvm.tir.LE(x, -1)), - TestCase(x * (-2) >= 0, tvm.tir.LE(x, 0)), - TestCase(x * (-2) >= -1, tvm.tir.LE(x, 0)), - TestCase(x * (-2) >= -2, tvm.tir.LE(x, 1)), - TestCase(x * (-2) >= -3, tvm.tir.LE(x, 1)), - TestCase(x * (-2) <= 3, tvm.tir.LE(-1, x)), - TestCase(x * (-2) <= 2, tvm.tir.LE(-1, x)), - TestCase(x * (-2) <= 1, tvm.tir.LE(0, x)), - TestCase(x * (-2) <= 0, tvm.tir.LE(0, x)), - TestCase(x * (-2) <= -1, tvm.tir.LE(1, x)), - TestCase(x * (-2) <= -2, tvm.tir.LE(1, x)), - TestCase(x * (-2) <= -3, tvm.tir.LE(2, x)), + TestCase(x * 2 >= 3, tvm.tirx.LE(2, x)), + TestCase(x * 2 >= 2, tvm.tirx.LE(1, x)), + TestCase(x * 2 >= 1, tvm.tirx.LE(1, x)), + TestCase(x * 2 >= 0, tvm.tirx.LE(0, x)), + TestCase(x * 2 >= -1, tvm.tirx.LE(0, x)), + TestCase(x * 2 >= -2, tvm.tirx.LE(-1, x)), + TestCase(x * 2 >= -3, tvm.tirx.LE(-1, x)), + TestCase(x * 2 <= 3, tvm.tirx.LE(x, 1)), + TestCase(x * 2 <= 2, tvm.tirx.LE(x, 1)), + TestCase(x * 2 <= 1, tvm.tirx.LE(x, 0)), + TestCase(x * 2 <= 0, tvm.tirx.LE(x, 0)), + TestCase(x * 2 <= -1, tvm.tirx.LE(x, -1)), + TestCase(x * 2 <= -2, tvm.tirx.LE(x, -1)), + TestCase(x * 2 <= -3, tvm.tirx.LE(x, -2)), + TestCase(x * (-2) >= 3, tvm.tirx.LE(x, -2)), + TestCase(x * (-2) >= 2, tvm.tirx.LE(x, -1)), + TestCase(x * (-2) >= 1, tvm.tirx.LE(x, -1)), + TestCase(x * (-2) >= 0, tvm.tirx.LE(x, 0)), + TestCase(x * (-2) >= -1, tvm.tirx.LE(x, 0)), + TestCase(x * (-2) >= -2, tvm.tirx.LE(x, 1)), + TestCase(x * (-2) >= -3, tvm.tirx.LE(x, 1)), + TestCase(x * (-2) <= 3, tvm.tirx.LE(-1, x)), + TestCase(x * (-2) <= 2, tvm.tirx.LE(-1, x)), + TestCase(x * (-2) <= 1, tvm.tirx.LE(0, x)), + TestCase(x * (-2) <= 0, tvm.tirx.LE(0, x)), + TestCase(x * (-2) <= -1, tvm.tirx.LE(1, x)), + TestCase(x * (-2) <= -2, tvm.tirx.LE(1, x)), + TestCase(x * (-2) <= -3, tvm.tirx.LE(2, x)), # DivMod rules # truc div TestCase(tdiv(x, 2) < 3, x < 6), - TestCase(3 < tdiv(x, 2), tvm.tir.LT(7, x)), - TestCase(tdiv(x, 3) >= 0, tvm.tir.LE(-2, x)), - TestCase(tdiv(x, 2) >= 1, tvm.tir.LE(2, x)), - TestCase(tdiv(x, 2) >= 0, tvm.tir.LE(-1, x)), - TestCase(tdiv(x, 2) >= -1, tvm.tir.LE(-3, x)), - TestCase(tdiv(x, 2) <= 1, tvm.tir.LE(x, 3)), - TestCase(tdiv(x, 2) <= 0, tvm.tir.LE(x, 1)), - TestCase(tdiv(x, 2) <= -1, tvm.tir.LE(x, -2)), - TestCase(tdiv(x, 4) * 4 < x, tvm.tir.LT(0, tmod(x, 4))), - TestCase(tdiv(x, 4) * 4 >= x, tvm.tir.LE(tmod(x, 4), 0)), - TestCase(tdiv(x, 4) * 4 < x + y, tvm.tir.LT(0, tmod(x, 4) + y)), - TestCase(tdiv(x, 4) * 4 < x - y, tvm.tir.LT(y, tmod(x, 4))), - TestCase(tdiv(x + 2, 4) * 4 >= x, tvm.tir.LE(tmod(x + 2, 4), 2)), - TestCase(tdiv(x + 2, 4) * 4 >= x + y, tvm.tir.LE(tmod(x + 2, 4) + y, 2)), - TestCase(tdiv(x + 2, 4) * 4 >= x - y, tvm.tir.LE(tmod(x + 2, 4), y + 2)), + TestCase(3 < tdiv(x, 2), tvm.tirx.LT(7, x)), + TestCase(tdiv(x, 3) >= 0, tvm.tirx.LE(-2, x)), + TestCase(tdiv(x, 2) >= 1, tvm.tirx.LE(2, x)), + TestCase(tdiv(x, 2) >= 0, tvm.tirx.LE(-1, x)), + TestCase(tdiv(x, 2) >= -1, tvm.tirx.LE(-3, x)), + TestCase(tdiv(x, 2) <= 1, tvm.tirx.LE(x, 3)), + TestCase(tdiv(x, 2) <= 0, tvm.tirx.LE(x, 1)), + TestCase(tdiv(x, 2) <= -1, tvm.tirx.LE(x, -2)), + TestCase(tdiv(x, 4) * 4 < x, tvm.tirx.LT(0, tmod(x, 4))), + TestCase(tdiv(x, 4) * 4 >= x, tvm.tirx.LE(tmod(x, 4), 0)), + TestCase(tdiv(x, 4) * 4 < x + y, tvm.tirx.LT(0, tmod(x, 4) + y)), + TestCase(tdiv(x, 4) * 4 < x - y, tvm.tirx.LT(y, tmod(x, 4))), + TestCase(tdiv(x + 2, 4) * 4 >= x, tvm.tirx.LE(tmod(x + 2, 4), 2)), + TestCase(tdiv(x + 2, 4) * 4 >= x + y, tvm.tirx.LE(tmod(x + 2, 4) + y, 2)), + TestCase(tdiv(x + 2, 4) * 4 >= x - y, tvm.tirx.LE(tmod(x + 2, 4), y + 2)), # floor div TestCase(fld(x, 2) < 3, x < 6), - TestCase(3 < fld(x, 2), tvm.tir.LT(7, x)), - TestCase(-3 < fld(x, 2), tvm.tir.LT(-5, x)), - TestCase(fld(x, 3) >= 0, tvm.tir.LE(0, x)), - TestCase(fld(x, 2) >= 1, tvm.tir.LE(2, x)), - TestCase(fld(x, 2) >= 0, tvm.tir.LE(0, x)), - TestCase(fld(x, 2) >= -1, tvm.tir.LE(-2, x)), - TestCase(fld(x, 2) <= 1, tvm.tir.LE(x, 3)), - TestCase(fld(x, 2) <= 0, tvm.tir.LE(x, 1)), - TestCase(fld(x, 2) <= -1, tvm.tir.LE(x, -1)), - TestCase(fld(x, 4) * 4 < x, tvm.tir.LT(0, flm(x, 4))), - TestCase(fld(x, 4) * 4 >= x, tvm.tir.EQ(flm(x, 4), 0)), - TestCase(fld(x, 4) * 4 < x + y, tvm.tir.LT(0, flm(x, 4) + y)), - TestCase(fld(x, 4) * 4 < x - y, tvm.tir.LT(y, flm(x, 4))), - TestCase(fld(x + 2, 4) * 4 >= x, tvm.tir.LE(flm(x + 2, 4), 2)), - TestCase(fld(x + 2, 4) * 4 >= x + y, tvm.tir.LE(flm(x + 2, 4) + y, 2)), - TestCase(fld(x + 2, 4) * 4 >= x - y, tvm.tir.LE(flm(x + 2, 4), y + 2)), + TestCase(3 < fld(x, 2), tvm.tirx.LT(7, x)), + TestCase(-3 < fld(x, 2), tvm.tirx.LT(-5, x)), + TestCase(fld(x, 3) >= 0, tvm.tirx.LE(0, x)), + TestCase(fld(x, 2) >= 1, tvm.tirx.LE(2, x)), + TestCase(fld(x, 2) >= 0, tvm.tirx.LE(0, x)), + TestCase(fld(x, 2) >= -1, tvm.tirx.LE(-2, x)), + TestCase(fld(x, 2) <= 1, tvm.tirx.LE(x, 3)), + TestCase(fld(x, 2) <= 0, tvm.tirx.LE(x, 1)), + TestCase(fld(x, 2) <= -1, tvm.tirx.LE(x, -1)), + TestCase(fld(x, 4) * 4 < x, tvm.tirx.LT(0, flm(x, 4))), + TestCase(fld(x, 4) * 4 >= x, tvm.tirx.EQ(flm(x, 4), 0)), + TestCase(fld(x, 4) * 4 < x + y, tvm.tirx.LT(0, flm(x, 4) + y)), + TestCase(fld(x, 4) * 4 < x - y, tvm.tirx.LT(y, flm(x, 4))), + TestCase(fld(x + 2, 4) * 4 >= x, tvm.tirx.LE(flm(x + 2, 4), 2)), + TestCase(fld(x + 2, 4) * 4 >= x + y, tvm.tirx.LE(flm(x + 2, 4) + y, 2)), + TestCase(fld(x + 2, 4) * 4 >= x - y, tvm.tirx.LE(flm(x + 2, 4), y + 2)), # End DivMod Rules # merging flm/fld into known value - TestCase(tir.all(fld(x, 8) == 3, flm(x, 8) == 4), x == 28), - TestCase(tir.all(flm(x, 8) == 4, fld(x, 8) == 3), x == 28), - TestCase(tir.all(fld(x, 8) == -3, flm(x, 8) == 4), x == -20), - TestCase(tir.all(flm(x, 8) == 4, fld(x, 8) == -3), x == -20), + TestCase(tirx.all(fld(x, 8) == 3, flm(x, 8) == 4), x == 28), + TestCase(tirx.all(flm(x, 8) == 4, fld(x, 8) == 3), x == 28), + TestCase(tirx.all(fld(x, 8) == -3, flm(x, 8) == 4), x == -20), + TestCase(tirx.all(flm(x, 8) == 4, fld(x, 8) == -3), x == -20), # Rewrite based on definition of integer division - TestCase(tir.all(T.int32(0) <= x - y * 5, x - y * 5 < 5), y == fld(x, 5)), - TestCase(tir.all(x - y * 5 < 5, T.int32(0) <= x - y * 5), y == fld(x, 5)), + TestCase(tirx.all(T.int32(0) <= x - y * 5, x - y * 5 < 5), y == fld(x, 5)), + TestCase(tirx.all(x - y * 5 < 5, T.int32(0) <= x - y * 5), y == fld(x, 5)), # Narrow upper bound using floormod - TestCase(tir.all(x < 20, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2)), - TestCase(tir.all(x < 18, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2)), - TestCase(tir.all(x <= 19, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2)), - TestCase(tir.all(x <= 18, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2)), - TestCase(tir.all(x < -20, flm(x, 5) < 2), tir.all(x < -23, flm(x, 5) < 2)), - TestCase(tir.all(x < 18 - 40, flm(x, 5) < 2), tir.all(x < 17 - 40, flm(x, 5) < 2)), - TestCase(tir.all(x <= -21, flm(x, 5) < 2), tir.all(x < -23, flm(x, 5) < 2)), - TestCase(tir.all(x <= -22, flm(x, 5) < 2), tir.all(x < -23, flm(x, 5) < 2)), + TestCase(tirx.all(x < 20, flm(x, 5) < 2), tirx.all(x < 17, flm(x, 5) < 2)), + TestCase(tirx.all(x < 18, flm(x, 5) < 2), tirx.all(x < 17, flm(x, 5) < 2)), + TestCase(tirx.all(x <= 19, flm(x, 5) < 2), tirx.all(x < 17, flm(x, 5) < 2)), + TestCase(tirx.all(x <= 18, flm(x, 5) < 2), tirx.all(x < 17, flm(x, 5) < 2)), + TestCase(tirx.all(x < -20, flm(x, 5) < 2), tirx.all(x < -23, flm(x, 5) < 2)), + TestCase(tirx.all(x < 18 - 40, flm(x, 5) < 2), tirx.all(x < 17 - 40, flm(x, 5) < 2)), + TestCase(tirx.all(x <= -21, flm(x, 5) < 2), tirx.all(x < -23, flm(x, 5) < 2)), + TestCase(tirx.all(x <= -22, flm(x, 5) < 2), tirx.all(x < -23, flm(x, 5) < 2)), # No change if the floormod cannot help narrow the upper bound - TestCase(tir.all(x < 16, flm(x, 5) < 2), tir.all(x < 16, flm(x, 5) < 2)), - TestCase(tir.all(x <= 15, flm(x, 5) < 2), tir.all(x <= 15, flm(x, 5) < 2)), + TestCase(tirx.all(x < 16, flm(x, 5) < 2), tirx.all(x < 16, flm(x, 5) < 2)), + TestCase(tirx.all(x <= 15, flm(x, 5) < 2), tirx.all(x <= 15, flm(x, 5) < 2)), # Merge a known floordiv and an upper bound of floormod into a value range TestCase( - tir.all(fld(x, 10) == 5, flm(x, 10) < 7), - tir.all(T.int32(50) <= x, x < 57), + tirx.all(fld(x, 10) == 5, flm(x, 10) < 7), + tirx.all(T.int32(50) <= x, x < 57), ), TestCase( - tir.all(fld(x, 10) == 5, flm(x, 10) <= 7), - tir.all(T.int32(50) <= x, x <= 57), + tirx.all(fld(x, 10) == 5, flm(x, 10) <= 7), + tirx.all(T.int32(50) <= x, x <= 57), ), TestCase( - tir.all(fld(x, 10) == -5, flm(x, 10) < 7), - tir.all(T.int32(-50) <= x, x < -43), + tirx.all(fld(x, 10) == -5, flm(x, 10) < 7), + tirx.all(T.int32(-50) <= x, x < -43), ), TestCase( - tir.all(fld(x, 10) == -5, flm(x, 10) <= 7), - tir.all(T.int32(-50) <= x, x <= -43), + tirx.all(fld(x, 10) == -5, flm(x, 10) <= 7), + tirx.all(T.int32(-50) <= x, x <= -43), ), # Merge a known floordiv and an lower bound of floormod into a value range TestCase( - tir.all(fld(x, 10) == 5, T.int32(7) < flm(x, 10)), - tir.all(T.int32(57) < x, x < 60), + tirx.all(fld(x, 10) == 5, T.int32(7) < flm(x, 10)), + tirx.all(T.int32(57) < x, x < 60), ), TestCase( - tir.all(fld(x, 10) == 5, T.int32(7) <= flm(x, 10)), - tir.all(T.int32(57) <= x, x < 60), + tirx.all(fld(x, 10) == 5, T.int32(7) <= flm(x, 10)), + tirx.all(T.int32(57) <= x, x < 60), ), TestCase( - tir.all(fld(x, 10) == -5, T.int32(7) < flm(x, 10)), - tir.all(T.int32(-43) < x, x < -40), + tirx.all(fld(x, 10) == -5, T.int32(7) < flm(x, 10)), + tirx.all(T.int32(-43) < x, x < -40), ), TestCase( - tir.all(fld(x, 10) == -5, T.int32(7) <= flm(x, 10)), - tir.all(T.int32(-43) <= x, x < -40), + tirx.all(fld(x, 10) == -5, T.int32(7) <= flm(x, 10)), + tirx.all(T.int32(-43) <= x, x < -40), ), - TestCase(tvm.tir.min(x, 11) < 10, x < 10), - TestCase(tvm.tir.min(x, 8) < 10, tvm.tir.const(1, "bool")), - TestCase(tvm.tir.max(8, x) > 10, tvm.tir.LT(10, x)), - TestCase(x + 1 < tvm.tir.max(8, x), x < 7), - TestCase(x < 11, tvm.tir.const(1, "bool"), x <= 10), - TestCase(x <= 10, tvm.tir.const(1, "bool"), x <= 10), - TestCase(z <= 5, tvm.tir.const(1, "bool"), z <= 5), - TestCase(x + y <= 10, tvm.tir.const(1, "bool"), [x <= 10, y <= 0]), - TestCase(x + y >= -10, tvm.tir.const(1, "bool"), [x >= 0, y >= -10]), - TestCase(z - 5 <= y + 10, tvm.tir.const(1, "bool"), [z <= 5, y >= -10]), - TestCase(tvm.tir.all(x > -1, z <= x + 5), tvm.tir.const(1, "bool"), [x >= 0, z <= 5]), - TestCase(x * y <= 0, tvm.tir.const(1, "bool"), [x >= 0, y <= 0]), - TestCase((x + 1) * (y - 1) < 0, tvm.tir.const(1, "bool"), [x >= 0, y <= 0]), - TestCase(y * y >= 0, tvm.tir.const(1, "bool"), y <= 0), - TestCase(x * 6 <= -3, tvm.tir.const(0, "bool"), x >= 0), + TestCase(tvm.tirx.min(x, 11) < 10, x < 10), + TestCase(tvm.tirx.min(x, 8) < 10, tvm.tirx.const(1, "bool")), + TestCase(tvm.tirx.max(8, x) > 10, tvm.tirx.LT(10, x)), + TestCase(x + 1 < tvm.tirx.max(8, x), x < 7), + TestCase(x < 11, tvm.tirx.const(1, "bool"), x <= 10), + TestCase(x <= 10, tvm.tirx.const(1, "bool"), x <= 10), + TestCase(z <= 5, tvm.tirx.const(1, "bool"), z <= 5), + TestCase(x + y <= 10, tvm.tirx.const(1, "bool"), [x <= 10, y <= 0]), + TestCase(x + y >= -10, tvm.tirx.const(1, "bool"), [x >= 0, y >= -10]), + TestCase(z - 5 <= y + 10, tvm.tirx.const(1, "bool"), [z <= 5, y >= -10]), + TestCase(tvm.tirx.all(x > -1, z <= x + 5), tvm.tirx.const(1, "bool"), [x >= 0, z <= 5]), + TestCase(x * y <= 0, tvm.tirx.const(1, "bool"), [x >= 0, y <= 0]), + TestCase((x + 1) * (y - 1) < 0, tvm.tirx.const(1, "bool"), [x >= 0, y <= 0]), + TestCase(y * y >= 0, tvm.tirx.const(1, "bool"), y <= 0), + TestCase(x * 6 <= -3, tvm.tirx.const(0, "bool"), x >= 0), TestCase(tmod(y - 1, 3) == 0, tmod(y + (-1), 3) == 0), ) @@ -1120,84 +1143,84 @@ class TestComparisons(BaseCompare): class TestComparisonOfProductAndSum(BaseCompare): extensions = tvm.arith.Extension.ComparisonOfProductAndSum - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") test_case = tvm.testing.parameter( # Special inequality cases TestCase( x * y < (x + y) * 2048, - tvm.tir.const(1, "bool"), + tvm.tirx.const(1, "bool"), [x > 0, y > 0, x < 2048], ), TestCase( x * y < (x + y) * 2048, - tvm.tir.const(1, "bool"), + tvm.tirx.const(1, "bool"), [x > 0, y > 0, x < 4096, y < 4096], ), TestCase( # Both sides are divisible by 8192 x * y * 8192 < (y + x) * 16777216, - tvm.tir.const(1, "bool"), + tvm.tirx.const(1, "bool"), [x > 0, y > 0, x < 4096, y < 4096], ), TestCase( # The two sides have co-prime factors, but the bounds are # still sufficient to prove the inequality. x * y * 59 < (y + x) * 176128, - tvm.tir.const(1, "bool"), + tvm.tirx.const(1, "bool"), [x > 0, y > 0, x < 4096, y < 4096], ), ) class TestLogical(BaseCompare): - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") test_case = tvm.testing.parameter( - TestCase(tvm.tir.And(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)), tvm.tir.const(False, "bool")), - TestCase(tvm.tir.And(tvm.tir.NE(x, y), tvm.tir.EQ(x, y)), tvm.tir.const(False, "bool")), - TestCase(tvm.tir.And(x > 1, tvm.tir.Not(x > 1)), tvm.tir.const(False, "bool")), - TestCase(tvm.tir.And(x <= y, y < x), tvm.tir.const(False, "bool")), - TestCase(tvm.tir.And(y < x, x <= y), tvm.tir.const(False, "bool")), - TestCase(tvm.tir.And(x < 1, 0 < x), tvm.tir.const(False, "bool")), - TestCase(tvm.tir.And(x < 0, 1 < x), tvm.tir.const(False, "bool")), - TestCase(tvm.tir.And(x < 1, 1 <= x), tvm.tir.const(False, "bool")), - TestCase(tvm.tir.And(x <= 1, 1 < x), tvm.tir.const(False, "bool")), - TestCase(tvm.tir.And(1 <= x, x < 1), tvm.tir.const(False, "bool")), - TestCase(tvm.tir.And(1 < x, x <= 1), tvm.tir.const(False, "bool")), - TestCase(tvm.tir.And(x <= 1, 2 <= x), tvm.tir.const(False, "bool")), - TestCase(tvm.tir.And(2 <= x, x <= 1), tvm.tir.const(False, "bool")), - TestCase(tvm.tir.And(x == 1, x != 2), x == 1), - TestCase(tvm.tir.And(x == 1, x == 2), tvm.tir.const(False, "bool")), - TestCase(tvm.tir.Or(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)), tvm.tir.const(True, "bool")), - TestCase(tvm.tir.Or(tvm.tir.NE(x, y), tvm.tir.EQ(x, y)), tvm.tir.const(True, "bool")), - TestCase(tvm.tir.Or(x > y, tvm.tir.Not(x > y)), tvm.tir.const(True, "bool")), - TestCase(tvm.tir.Or(x <= y, y < x), tvm.tir.const(True, "bool")), - TestCase(tvm.tir.Or(y < x, y >= x), tvm.tir.const(True, "bool")), - TestCase(tvm.tir.Or(x < 1, 0 < x), tvm.tir.const(True, "bool")), - TestCase(tvm.tir.Or(0 < x, x < 1), tvm.tir.const(True, "bool")), - TestCase(tvm.tir.Or(x < 1, 1 <= x), tvm.tir.const(True, "bool")), - TestCase(tvm.tir.Or(x <= 1, 1 < x), tvm.tir.const(True, "bool")), - TestCase(tvm.tir.Or(1 <= x, x < 1), tvm.tir.const(True, "bool")), - TestCase(tvm.tir.Or(1 < x, x <= 1), tvm.tir.const(True, "bool")), - TestCase(tvm.tir.Or(x <= 1, 2 <= x), tvm.tir.const(True, "bool")), - TestCase(tvm.tir.Or(2 <= x, x <= 1), tvm.tir.const(True, "bool")), - TestCase(tvm.tir.Or(x != 1, x == 2), x != 1), - TestCase(tvm.tir.Or(x != 1, x != 2), tvm.tir.const(True, "bool")), - TestCase( - tvm.tir.Or(x == 1, tvm.tir.Or(y == 1, z == 1)), - tvm.tir.Or(tvm.tir.Or(x == 1, y == 1), z == 1), - ), - TestCase( - tvm.tir.And(x == 1, tvm.tir.And(y == 1, z == 1)), - tvm.tir.And(tvm.tir.And(x == 1, y == 1), z == 1), + TestCase(tvm.tirx.And(tvm.tirx.EQ(x, y), tvm.tirx.NE(x, y)), tvm.tirx.const(False, "bool")), + TestCase(tvm.tirx.And(tvm.tirx.NE(x, y), tvm.tirx.EQ(x, y)), tvm.tirx.const(False, "bool")), + TestCase(tvm.tirx.And(x > 1, tvm.tirx.Not(x > 1)), tvm.tirx.const(False, "bool")), + TestCase(tvm.tirx.And(x <= y, y < x), tvm.tirx.const(False, "bool")), + TestCase(tvm.tirx.And(y < x, x <= y), tvm.tirx.const(False, "bool")), + TestCase(tvm.tirx.And(x < 1, 0 < x), tvm.tirx.const(False, "bool")), + TestCase(tvm.tirx.And(x < 0, 1 < x), tvm.tirx.const(False, "bool")), + TestCase(tvm.tirx.And(x < 1, 1 <= x), tvm.tirx.const(False, "bool")), + TestCase(tvm.tirx.And(x <= 1, 1 < x), tvm.tirx.const(False, "bool")), + TestCase(tvm.tirx.And(1 <= x, x < 1), tvm.tirx.const(False, "bool")), + TestCase(tvm.tirx.And(1 < x, x <= 1), tvm.tirx.const(False, "bool")), + TestCase(tvm.tirx.And(x <= 1, 2 <= x), tvm.tirx.const(False, "bool")), + TestCase(tvm.tirx.And(2 <= x, x <= 1), tvm.tirx.const(False, "bool")), + TestCase(tvm.tirx.And(x == 1, x != 2), x == 1), + TestCase(tvm.tirx.And(x == 1, x == 2), tvm.tirx.const(False, "bool")), + TestCase(tvm.tirx.Or(tvm.tirx.EQ(x, y), tvm.tirx.NE(x, y)), tvm.tirx.const(True, "bool")), + TestCase(tvm.tirx.Or(tvm.tirx.NE(x, y), tvm.tirx.EQ(x, y)), tvm.tirx.const(True, "bool")), + TestCase(tvm.tirx.Or(x > y, tvm.tirx.Not(x > y)), tvm.tirx.const(True, "bool")), + TestCase(tvm.tirx.Or(x <= y, y < x), tvm.tirx.const(True, "bool")), + TestCase(tvm.tirx.Or(y < x, y >= x), tvm.tirx.const(True, "bool")), + TestCase(tvm.tirx.Or(x < 1, 0 < x), tvm.tirx.const(True, "bool")), + TestCase(tvm.tirx.Or(0 < x, x < 1), tvm.tirx.const(True, "bool")), + TestCase(tvm.tirx.Or(x < 1, 1 <= x), tvm.tirx.const(True, "bool")), + TestCase(tvm.tirx.Or(x <= 1, 1 < x), tvm.tirx.const(True, "bool")), + TestCase(tvm.tirx.Or(1 <= x, x < 1), tvm.tirx.const(True, "bool")), + TestCase(tvm.tirx.Or(1 < x, x <= 1), tvm.tirx.const(True, "bool")), + TestCase(tvm.tirx.Or(x <= 1, 2 <= x), tvm.tirx.const(True, "bool")), + TestCase(tvm.tirx.Or(2 <= x, x <= 1), tvm.tirx.const(True, "bool")), + TestCase(tvm.tirx.Or(x != 1, x == 2), x != 1), + TestCase(tvm.tirx.Or(x != 1, x != 2), tvm.tirx.const(True, "bool")), + TestCase( + tvm.tirx.Or(x == 1, tvm.tirx.Or(y == 1, z == 1)), + tvm.tirx.Or(tvm.tirx.Or(x == 1, y == 1), z == 1), + ), + TestCase( + tvm.tirx.And(x == 1, tvm.tirx.And(y == 1, z == 1)), + tvm.tirx.And(tvm.tirx.And(x == 1, y == 1), z == 1), ), ) class TestLet(BaseCompare): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") - z = tvm.tir.Let(x, 1, x + 1) + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") + z = tvm.tirx.Let(x, 1, x + 1) test_case = tvm.testing.parameter( TestCase(z + z, 4), @@ -1206,43 +1229,44 @@ class TestLet(BaseCompare): class TestCast(BaseCompare): def _generate_tests(): - x = tvm.tir.Var("x", "int32") + x = tvm.tirx.Var("x", "int32") dtypes = ["float32", "float16", "int32", "int8", "bool"] for dtype1 in dtypes: - yield TestCase(tvm.tir.Cast(dtype1, x - x), tvm.tir.const(0, dtype1)) - yield TestCase(tvm.tir.Cast(dtype1, x == x), tvm.tir.const(1, dtype1)) + yield TestCase(tvm.tirx.Cast(dtype1, x - x), tvm.tirx.const(0, dtype1)) + yield TestCase(tvm.tirx.Cast(dtype1, x == x), tvm.tirx.const(1, dtype1)) for dtype2 in dtypes: for i in [0, 1, 2, 3]: if i <= 1 or (dtype1 != "bool" and dtype2 != "bool"): yield TestCase( - tvm.tir.Cast(dtype1, tvm.tir.const(i, dtype2)), tvm.tir.const(i, dtype1) + tvm.tirx.Cast(dtype1, tvm.tirx.const(i, dtype2)), + tvm.tirx.const(i, dtype1), ) test_case = tvm.testing.parameter(*_generate_tests()) class TestShiftLeft(BaseCompare): - z = tvm.tir.op.call_intrin("int32", "tir.shift_left", 1, 10) + z = tvm.tirx.op.call_intrin("int32", "tirx.shift_left", 1, 10) test_case = tvm.testing.parameter( - TestCase(z, tvm.tir.const(1 << 10, "int32")), + TestCase(z, tvm.tirx.const(1 << 10, "int32")), ) class TestDivZero(BaseCompare): - ramp = tvm.tir.Ramp(1, 1, 2) - broadcast = tvm.tir.Broadcast(0, 2) + ramp = tvm.tirx.Ramp(1, 1, 2) + broadcast = tvm.tirx.Broadcast(0, 2) test_case = tvm.testing.parameter( - TestCase(tvm.tir.Div(ramp, broadcast), tvm.error.TVMError), - TestCase(tvm.tir.Mod(ramp, broadcast), tvm.error.TVMError), - TestCase(tvm.tir.FloorDiv(ramp, broadcast), tvm.error.TVMError), - TestCase(tvm.tir.FloorMod(ramp, broadcast), tvm.error.TVMError), + TestCase(tvm.tirx.Div(ramp, broadcast), tvm.error.TVMError), + TestCase(tvm.tirx.Mod(ramp, broadcast), tvm.error.TVMError), + TestCase(tvm.tirx.FloorDiv(ramp, broadcast), tvm.error.TVMError), + TestCase(tvm.tirx.FloorMod(ramp, broadcast), tvm.error.TVMError), ) class TestSubBufferload(BaseCompare): - buf = tvm.tir.decl_buffer([1], dtype="float32") - load = tvm.tir.BufferLoad(buf, [0]) + buf = tvm.tirx.decl_buffer([1], dtype="float32") + load = tvm.tirx.BufferLoad(buf, [0]) test_case = tvm.testing.parameter( TestCase(load - load, 0.0), @@ -1250,31 +1274,37 @@ class TestSubBufferload(BaseCompare): class TestIfThenElse(BaseCompare): - x = tvm.tir.Var("x", "int32") + x = tvm.tirx.Var("x", "int32") test_case = tvm.testing.parameter( TestCase( - tvm.tir.if_then_else(x < 5, tvm.tir.if_then_else(x > 1, 1, 0), 0), - tvm.tir.if_then_else(tvm.tir.And(tvm.tir.LT(x, 5), tvm.tir.LT(1, x)), 1, 0), + tvm.tirx.if_then_else(x < 5, tvm.tirx.if_then_else(x > 1, 1, 0), 0), + tvm.tirx.if_then_else(tvm.tirx.And(tvm.tirx.LT(x, 5), tvm.tirx.LT(1, x)), 1, 0), ), TestCase( - tvm.tir.if_then_else(x > 2, tvm.tir.if_then_else(x > 1, 1, 0), 0), - tvm.tir.if_then_else(tvm.tir.LT(2, x), 1, 0), + tvm.tirx.if_then_else(x > 2, tvm.tirx.if_then_else(x > 1, 1, 0), 0), + tvm.tirx.if_then_else(tvm.tirx.LT(2, x), 1, 0), ), ) class TestCLZ(BaseCompare): test_case = tvm.testing.parameter( - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 0), T.int32(32)), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 1), T.int32(31)), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 2), T.int32(30)), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 128), T.int32(24)), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 0)), T.int32(64)), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 1)), T.int32(63)), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 2)), T.int32(62)), - TestCase( - tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 128)), T.int32(56) + TestCase(tvm.tirx.call_intrin("int32", "tirx.clz", 0), T.int32(32)), + TestCase(tvm.tirx.call_intrin("int32", "tirx.clz", 1), T.int32(31)), + TestCase(tvm.tirx.call_intrin("int32", "tirx.clz", 2), T.int32(30)), + TestCase(tvm.tirx.call_intrin("int32", "tirx.clz", 128), T.int32(24)), + TestCase( + tvm.tirx.call_intrin("int32", "tirx.clz", tvm.tirx.IntImm("int64", 0)), T.int32(64) + ), + TestCase( + tvm.tirx.call_intrin("int32", "tirx.clz", tvm.tirx.IntImm("int64", 1)), T.int32(63) + ), + TestCase( + tvm.tirx.call_intrin("int32", "tirx.clz", tvm.tirx.IntImm("int64", 2)), T.int32(62) + ), + TestCase( + tvm.tirx.call_intrin("int32", "tirx.clz", tvm.tirx.IntImm("int64", 128)), T.int32(56) ), ) diff --git a/tests/python/arith/test_arith_simplify.py b/tests/python/arith/test_arith_simplify.py index f316508d49de..b367735c1f36 100644 --- a/tests/python/arith/test_arith_simplify.py +++ b/tests/python/arith/test_arith_simplify.py @@ -20,15 +20,15 @@ import tvm import tvm.ir import tvm.testing -from tvm import tir -from tvm.script import tir as T +from tvm import tirx +from tvm.script import tirx as T def test_simplify_reshape_flattened_index(): ana = tvm.arith.Analyzer() - i0 = tir.Var("i0", "int64") - i1 = tir.Var("i1", "int64") + i0 = tirx.Var("i0", "int64") + i1 = tirx.Var("i1", "int64") ana.bind(i0, tvm.ir.Range(0, 8)) ana.bind(i1, tvm.ir.Range(0, 3)) @@ -57,23 +57,23 @@ def test_simplify_reshape_flattened_index(): def test_can_prove_self_identity(dtype): ana = tvm.arith.Analyzer() - n = tir.Var("n", dtype) + n = tirx.Var("n", dtype) assert ana.can_prove(n == n) def test_can_prove_self_equal_to_self(dtype): ana = tvm.arith.Analyzer() - n = tir.Var("n", dtype) + n = tirx.Var("n", dtype) assert ana.can_prove_equal(n, n) def test_simplify_symbolic_comparison(): ana = tvm.arith.Analyzer() - i0 = tir.Var("i0", "int64") - i1 = tir.Var("i1", "int64") - n, m = tvm.tir.SizeVar("n", "int64"), tvm.tir.SizeVar("m", "int64") + i0 = tirx.Var("i0", "int64") + i1 = tirx.Var("i1", "int64") + n, m = tvm.tirx.SizeVar("n", "int64"), tvm.tirx.SizeVar("m", "int64") outer = (n + 31) // 32 ana.bind(i0, tvm.ir.Range(0, outer)) ana.bind(i1, tvm.ir.Range(0, 32)) @@ -105,7 +105,7 @@ def test_simplify_vscale_comparison_with_sve_target(expression): def test_simplify_vscale_comparison_without_sve_target(capfd): ana = tvm.arith.Analyzer() - vs = tvm.tir.vscale() + vs = tvm.tirx.vscale() with pytest.raises(AssertionError): with tvm.target.Target({"kind": "llvm", "mtriple": "aarch64-linux-gnu"}): @@ -124,9 +124,9 @@ def test_simplify_vscale_comparison_without_sve_target(capfd): def test_regression_simplify_inf_recursion(): ana = tvm.arith.Analyzer() - cond = tir.Var("cond", "int32") + cond = tirx.Var("cond", "int32") - res = (tvm.tir.NE(cond, 0).astype("int8") - tvm.tir.NE(cond, 0).astype("int8")).astype( + res = (tvm.tirx.NE(cond, 0).astype("int8") - tvm.tirx.NE(cond, 0).astype("int8")).astype( "int32" ) == 0 # regression in a previous case @@ -139,19 +139,19 @@ def test_simplify_floor_mod_with_linear_offset(): Test that the floor_mod is simplified correctly when the offset is linear. """ ana = tvm.arith.Analyzer() - past_decoder_sequence_length = tir.Var("past_decoder_sequence_length", "int64") + past_decoder_sequence_length = tirx.Var("past_decoder_sequence_length", "int64") expr1 = (past_decoder_sequence_length + 1) * 64 divisor1 = (past_decoder_sequence_length + 1) * 32 - assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor1), 0) + assert ana.can_prove_equal(tvm.tirx.floormod(expr1, divisor1), 0) divisor2 = 32 * (past_decoder_sequence_length + 1) - assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor2), 0) + assert ana.can_prove_equal(tvm.tirx.floormod(expr1, divisor2), 0) def test_simplify_float_division(): # Test for the discussion: # https://discuss.tvm.apache.org/t/discuss-is-constant-division-to-multiplication-rewrite-in-tvm-necessary/18615 ana = tvm.arith.Analyzer() - x = tir.Var("x", "float32") + x = tirx.Var("x", "float32") ry = x / 27 # in old version, the division will be rewritten into x * T.float32(1 / 27) sy = ana.rewrite_simplify(ry) diff --git a/tests/python/arith/test_arith_solve_linear_equations.py b/tests/python/arith/test_arith_solve_linear_equations.py index 067738f2807c..d1218fc3518c 100644 --- a/tests/python/arith/test_arith_solve_linear_equations.py +++ b/tests/python/arith/test_arith_solve_linear_equations.py @@ -21,8 +21,8 @@ import pytest import tvm -from tvm import arith, ir, testing, tir -from tvm.script import tir as T +from tvm import arith, ir, testing, tirx +from tvm.script import tirx as T def test_solution_consistency(): @@ -34,7 +34,7 @@ def test_solution_consistency(): random.seed(seed) def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)): - variables = [tvm.tir.Var("x" + str(i), "int32") for i in range(num_vars)] + variables = [tvm.tirx.Var("x" + str(i), "int32") for i in range(num_vars)] relations = [] for i in range(num_formulas): @@ -43,10 +43,10 @@ def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)): s2 = sum([v * random.randint(coef[0], coef[1]) for v in variables]) s2 += random.randint(coef[0], coef[1]) if random.random() < 0.7: - op = tvm.tir.EQ + op = tvm.tirx.EQ else: # we also make sure it can correctly handle inequalities - op = random.choice([tvm.tir.LE, tvm.tir.LT, tvm.tir.GE, tvm.tir.GT]) + op = random.choice([tvm.tirx.LE, tvm.tirx.LT, tvm.tirx.GE, tvm.tirx.GT]) relations.append(op(s1, s2)) vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in variables} @@ -88,10 +88,10 @@ def _check(num_vars, num_formulas, coef=(-5, 5), bounds=(-20, 20)): def test_empty_var_to_solve(): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") equations = [ - tvm.tir.EQ(x + y, 20), - tvm.tir.EQ(x - y, 10), + tvm.tirx.EQ(x + y, 20), + tvm.tirx.EQ(x - y, 10), ] solution = arith.solve_linear_equations(equations) assert len(solution.src_to_dst) == 0 @@ -103,12 +103,12 @@ def test_empty_var_to_solve(): def test_unique_solution(): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") solution = arith.solve_linear_equations( [ - tvm.tir.EQ(x + y, 20), - tvm.tir.EQ(x - y, 10), + tvm.tirx.EQ(x + y, 20), + tvm.tirx.EQ(x - y, 10), ], [x, y], ) @@ -118,13 +118,13 @@ def test_unique_solution(): def test_low_rank(): - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") ranges = {} solution = arith.solve_linear_equations( [ - tvm.tir.EQ(x + y + z, 15), - tvm.tir.EQ(x + y, 10), + tvm.tirx.EQ(x + y + z, 15), + tvm.tirx.EQ(x + y, 10), ], [x, y, z], ranges, @@ -136,7 +136,7 @@ def test_low_rank(): def test_infer_range(): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") ranges = { x: tvm.ir.Range.from_min_extent(-5, 10), y: tvm.ir.Range.from_min_extent(0, 10), @@ -144,7 +144,7 @@ def test_infer_range(): solution = arith.solve_linear_equations( [ - tvm.tir.EQ(x + y, 0), + tvm.tirx.EQ(x + y, 0), ], [x, y], ranges, @@ -157,26 +157,26 @@ def test_infer_range(): assert ir.structural_equal(solution.dst.ranges[n0].extent, T.int32(10)) # additional inequality is added into the system for x [ineq] = solution.dst.relations - assert isinstance(ineq, tvm.tir.LE) + assert isinstance(ineq, tvm.tirx.LE) assert ir.structural_equal(ineq.a, T.int32(-5)) assert ir.structural_equal(ineq.b, n0) def test_ill_formed(): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") solution = arith.solve_linear_equations( [ - tvm.tir.EQ(x + y, 0), - tvm.tir.EQ(x - y, 0), - tvm.tir.EQ(x, 5), + tvm.tirx.EQ(x + y, 0), + tvm.tirx.EQ(x - y, 0), + tvm.tirx.EQ(x, 5), ], [x, y], {}, ) assert list(solution.dst.variables) == [] [rel] = solution.dst.relations - ir.assert_structural_equal(rel, tir.const(False)) + ir.assert_structural_equal(rel, tirx.const(False)) assert len(solution.src_to_dst) == 0 assert len(solution.dst_to_src) == 0 diff --git a/tests/python/arith/test_arith_solve_linear_inequality.py b/tests/python/arith/test_arith_solve_linear_inequality.py index b8c31702eae2..8050f73b4469 100644 --- a/tests/python/arith/test_arith_solve_linear_inequality.py +++ b/tests/python/arith/test_arith_solve_linear_inequality.py @@ -20,8 +20,8 @@ import pytest import tvm -from tvm import arith, ir, testing, tir -from tvm.script import tir as T +from tvm import arith, ir, testing, tirx +from tvm.script import tirx as T @pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/11458") @@ -34,7 +34,7 @@ def test_solution_consistency(): random.seed(seed) def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): - vs = [tvm.tir.Var("x" + str(i), "int32") for i in range(variables)] + vs = [tvm.tirx.Var("x" + str(i), "int32") for i in range(variables)] fs = [] for i in range(formulas): @@ -42,13 +42,15 @@ def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): s1 += random.randint(coef[0], coef[1]) s2 = sum([v * random.randint(coef[0], coef[1]) for v in vs]) s2 += random.randint(coef[0], coef[1]) - op = random.choice([tir.expr.EQ, tir.expr.LE, tir.expr.LT, tir.expr.GE, tir.expr.GT]) + op = random.choice( + [tirx.expr.EQ, tirx.expr.LE, tirx.expr.LT, tirx.expr.GE, tirx.expr.GT] + ) fs.append(op(s1, s2)) vranges = {v: tvm.ir.expr.Range(bounds[0], bounds[1] + 1) for v in vs} - before = tvm.tir.all(tir.const(1, "bool"), *fs) + before = tvm.tirx.all(tirx.const(1, "bool"), *fs) after = arith._ffi_api.SolveInequalitiesAsCondition(vs, vranges, fs) - after = tvm.tir.all(tir.const(1, "bool"), *after) + after = tvm.tirx.all(tirx.const(1, "bool"), *after) testing.check_bool_expr_is_true(before == after, vranges) solution = arith.solve_linear_inequalities(fs, vs, vranges, deskew_range=True) @@ -83,7 +85,7 @@ def _check(variables, formulas, coef=(-5, 5), bounds=(-20, 20)): def test_dual_variable(): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") variables = [x, y] ranges = { @@ -91,8 +93,8 @@ def test_dual_variable(): y: tvm.ir.Range(0, 10), } problem = [ - tvm.tir.LE(x + y, 20), - tvm.tir.GE(x - y, 10), + tvm.tirx.LE(x + y, 20), + tvm.tirx.GE(x - y, 10), ] # solution as conditions @@ -127,11 +129,11 @@ def test_dual_variable(): def test_equal(): - x, y = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32") + x, y = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32") problem = [ - tvm.tir.GE(x + y, 10), - tvm.tir.GE(x - y, 2), - tvm.tir.LE(x, 6), + tvm.tirx.GE(x + y, 10), + tvm.tirx.GE(x - y, 2), + tvm.tirx.LE(x, 6), ] solution = arith.solve_linear_inequalities(problem, [x, y]) @@ -149,12 +151,12 @@ def test_equal(): def test_multi_equal(): - x, y, z = tvm.tir.Var("x", "int32"), tvm.tir.Var("y", "int32"), tvm.tir.Var("z", "int32") + x, y, z = tvm.tirx.Var("x", "int32"), tvm.tirx.Var("y", "int32"), tvm.tirx.Var("z", "int32") problem = [ - tvm.tir.LE(x, 6), - tvm.tir.GE(x, 6), - tvm.tir.GE(x - z * y, 0), - tvm.tir.LE(x - z * y, 0), + tvm.tirx.LE(x, 6), + tvm.tirx.GE(x, 6), + tvm.tirx.GE(x - z * y, 0), + tvm.tirx.LE(x - z * y, 0), ] solution = arith.solve_linear_inequalities(problem, [x, y, z]) @@ -163,9 +165,9 @@ def test_multi_equal(): assert len(solution.relations) == 3 assert ir.structural_equal(solution.relations[0], x == z * y) - assert isinstance(solution.relations[1], tvm.tir.LE) + assert isinstance(solution.relations[1], tvm.tirx.LE) assert solution.relations[1].b == 0 - assert isinstance(solution.relations[2], tvm.tir.LE) + assert isinstance(solution.relations[2], tvm.tirx.LE) assert solution.relations[2].b == 0 # (z*y - 6) <= 0 && (6 - z*y) <= 0 ana = tvm.arith.Analyzer() @@ -181,14 +183,14 @@ def test_multi_equal(): def test_no_solution(): - x = tvm.tir.Var("x0", "int32") + x = tvm.tirx.Var("x0", "int32") vranges = {x: tvm.ir.Range.from_min_extent(-20, 41)} problem = [-x - 4 <= -5 * x + 2, x * 4 + 5 <= x * 5] solution = arith.solve_linear_inequalities(problem, [x], vranges, deskew_range=True) assert list(solution.dst.variables) == [] [rel] = solution.dst.relations - ir.assert_structural_equal(rel, tir.const(False)) + ir.assert_structural_equal(rel, tirx.const(False)) assert len(solution.src_to_dst) == 0 assert len(solution.dst_to_src) == 0 @@ -200,9 +202,11 @@ def test_no_solution(): def test_unbound_var_range(): - x = tvm.tir.Var("x0", "int32") - free_var = tvm.tir.Var("fv", "int32") - vranges = {x: tvm.ir.Range.from_min_extent(0, tvm.tir.Cast("int32", 1 + tvm.tir.log(free_var)))} + x = tvm.tirx.Var("x0", "int32") + free_var = tvm.tirx.Var("fv", "int32") + vranges = { + x: tvm.ir.Range.from_min_extent(0, tvm.tirx.Cast("int32", 1 + tvm.tirx.log(free_var))) + } problem = [x > 3] solution = arith.solve_linear_inequalities( problem, diff --git a/tests/python/codegen/test_codegen_assert.py b/tests/python/codegen/test_codegen_assert.py index 21c3ac5cde2b..0c50d4bb222f 100644 --- a/tests/python/codegen/test_codegen_assert.py +++ b/tests/python/codegen/test_codegen_assert.py @@ -20,7 +20,7 @@ import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T codegen_target = tvm.testing.parameter("llvm", "c") diff --git a/tests/python/codegen/test_codegen_error_handling.py b/tests/python/codegen/test_codegen_error_handling.py index b882c3a795c4..88c53410e350 100644 --- a/tests/python/codegen/test_codegen_error_handling.py +++ b/tests/python/codegen/test_codegen_error_handling.py @@ -28,7 +28,7 @@ import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T # Parameterize over both LLVM and C backends codegen_target = tvm.testing.parameter("llvm", "c") diff --git a/tests/python/codegen/test_gpu_codegen_allreduce.py b/tests/python/codegen/test_gpu_codegen_allreduce.py index 4ffa7b765dc6..c958b01373d4 100644 --- a/tests/python/codegen/test_gpu_codegen_allreduce.py +++ b/tests/python/codegen/test_gpu_codegen_allreduce.py @@ -22,7 +22,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def _reduce_sum_module(d1, d2, d3): diff --git a/tests/python/codegen/test_inject_ptx_ldg32.py b/tests/python/codegen/test_inject_ptx_ldg32.py index 38313988503d..10a29b3582f8 100644 --- a/tests/python/codegen/test_inject_ptx_ldg32.py +++ b/tests/python/codegen/test_inject_ptx_ldg32.py @@ -18,12 +18,12 @@ import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func def vector_add(A: T.Buffer((16), "float32"), B: T.Buffer((32), "float32")) -> None: - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) @@ -46,7 +46,7 @@ def test_inject_ptx_intrin(): if major < 8: # Require at least SM80 return - with tvm.transform.PassContext(config={"tir.ptx_ldg32": True}): + with tvm.transform.PassContext(config={"tirx.ptx_ldg32": True}): mod = tvm.compile(f, target="cuda") A_np = np.random.rand(16).astype("float32") B_np = np.zeros(32).astype("float32") diff --git a/tests/python/codegen/test_target_codegen.py b/tests/python/codegen/test_target_codegen.py index cb7aa065f3d0..ec41a4d6a28a 100644 --- a/tests/python/codegen/test_target_codegen.py +++ b/tests/python/codegen/test_target_codegen.py @@ -20,7 +20,7 @@ import pytest import tvm -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.testing.parametrize_targets("c") diff --git a/tests/python/codegen/test_target_codegen_aarch64.py b/tests/python/codegen/test_target_codegen_aarch64.py index f238766cf89a..b258d826307c 100644 --- a/tests/python/codegen/test_target_codegen_aarch64.py +++ b/tests/python/codegen/test_target_codegen_aarch64.py @@ -25,7 +25,7 @@ import tvm from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target.codegen import llvm_version_major @@ -43,7 +43,7 @@ def test_mul(dtype): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (m,), dtype=dtype) B = T.match_buffer(var_B, (m,), dtype=dtype) @@ -56,7 +56,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): C[v_i] = A[v_i] * B[v_i] with tvm.target.Target(target): - f = tvm.tir.build(Module) + f = tvm.tirx.build(Module) assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) @@ -82,7 +82,7 @@ def test_add(dtype): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (m,), dtype=dtype) B = T.match_buffer(var_B, (m,), dtype=dtype) @@ -95,7 +95,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): C[v_i] = A[v_i] + B[v_i] with tvm.target.Target(target): - f = tvm.tir.build(Module) + f = tvm.tirx.build(Module) assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) @@ -121,7 +121,7 @@ def test_sub(dtype): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (m,), dtype=dtype) B = T.match_buffer(var_B, (m,), dtype=dtype) @@ -134,7 +134,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): C[v_i] = A[v_i] - B[v_i] with tvm.target.Target(target): - f = tvm.tir.build(Module) + f = tvm.tirx.build(Module) assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) @@ -160,7 +160,7 @@ def test_muladd(dtype): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, var_C: T.handle, var_D: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (m,), dtype=dtype) B = T.match_buffer(var_B, (m,), dtype=dtype) @@ -174,7 +174,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle, var_D: T.handle): D[v_i] = A[v_i] * B[v_i] + C[v_i] with tvm.target.Target(target): - f = tvm.tir.build(Module) + f = tvm.tirx.build(Module) assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) @@ -200,7 +200,7 @@ def test_max(dtype): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (m,), dtype=dtype) B = T.match_buffer(var_B, (m,), dtype=dtype) @@ -213,7 +213,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): C[v_i] = T.max(A[v_i], B[v_i]) with tvm.target.Target(target): - f = tvm.tir.build(Module) + f = tvm.tirx.build(Module) assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) @@ -243,7 +243,7 @@ def test_min(dtype): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (m,), dtype=dtype) B = T.match_buffer(var_B, (m,), dtype=dtype) @@ -256,7 +256,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): C[v_i] = T.min(A[v_i], B[v_i]) with tvm.target.Target(target): - f = tvm.tir.build(Module) + f = tvm.tirx.build(Module) assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) @@ -286,7 +286,7 @@ def test_div(dtype): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (m,), dtype=dtype) B = T.match_buffer(var_B, (m,), dtype=dtype) @@ -296,10 +296,10 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): v_i = T.axis.spatial(m, i) T.reads(A[v_i], B[v_i]) T.writes(C[v_i]) - C[v_i] = tvm.tir.div(A[v_i], B[v_i]) + C[v_i] = tvm.tirx.div(A[v_i], B[v_i]) with tvm.target.Target(target): - f = tvm.tir.build(Module) + f = tvm.tirx.build(Module) assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) @@ -324,7 +324,7 @@ def test_mod(dtype): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (m,), dtype=dtype) B = T.match_buffer(var_B, (m,), dtype=dtype) @@ -337,7 +337,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): C[v_i] = T.floormod(A[v_i], B[v_i]) with tvm.target.Target(target): - f = tvm.tir.build(Module) + f = tvm.tirx.build(Module) assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) @@ -363,7 +363,7 @@ def test_eq(dtype): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (m,), dtype=dtype) B = T.match_buffer(var_B, (m,), dtype=dtype) @@ -376,7 +376,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): C[v_i] = A[v_i] == B[v_i] with tvm.target.Target(target): - f = tvm.tir.build(Module) + f = tvm.tirx.build(Module) assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) @@ -402,7 +402,7 @@ def test_neq(dtype): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (m,), dtype=dtype) B = T.match_buffer(var_B, (m,), dtype=dtype) @@ -415,7 +415,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): C[v_i] = A[v_i] != B[v_i] with tvm.target.Target(target): - f = tvm.tir.build(Module) + f = tvm.tirx.build(Module) assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) @@ -440,7 +440,7 @@ def test_or(dtype): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (m,), dtype=dtype) B = T.match_buffer(var_B, (m,), dtype=dtype) @@ -453,7 +453,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): C[v_i] = A[v_i] | B[v_i] with tvm.target.Target(target): - f = tvm.tir.build(Module) + f = tvm.tirx.build(Module) assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) @@ -478,7 +478,7 @@ def test_and(dtype): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (m,), dtype=dtype) B = T.match_buffer(var_B, (m,), dtype=dtype) @@ -491,7 +491,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): C[v_i] = A[v_i] & B[v_i] with tvm.target.Target(target): - f = tvm.tir.build(Module) + f = tvm.tirx.build(Module) assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) @@ -516,7 +516,7 @@ def test_not(dtype): class Module: @T.prim_func def main(var_A: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (m,), dtype=dtype) C = T.match_buffer(var_C, (m,), dtype=dtype) @@ -528,7 +528,7 @@ def main(var_A: T.handle, var_C: T.handle): C[v_i] = ~A[v_i] with tvm.target.Target(target): - f = tvm.tir.build(Module) + f = tvm.tirx.build(Module) assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) @@ -557,7 +557,7 @@ def test_memcpy(dtype): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (m,), dtype=dtype) B = T.match_buffer(var_B, (m,), "int32") @@ -570,7 +570,7 @@ def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): C[v_i] = A[B[v_i]] with tvm.target.Target(target): - f = tvm.tir.build(Module) + f = tvm.tirx.build(Module) assembly = f.inspect_source("asm") loads = re.findall("ld1[whdb] { z", assembly) @@ -598,7 +598,7 @@ def test_vscale_range_function_attribute(mattr, expect_attr): class Module: @T.prim_func def main(var_A: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (m,)) C = T.match_buffer(var_C, (m,)) @@ -610,7 +610,7 @@ def main(var_A: T.handle, var_C: T.handle): C[v_i] = A[v_i] + T.float32(1) with tvm.target.Target(target): - f = tvm.tir.build(Module) + f = tvm.tirx.build(Module) # Check if the vscale_range() attribute exists ll = f.inspect_source("ll") diff --git a/tests/python/codegen/test_target_codegen_arm.py b/tests/python/codegen/test_target_codegen_arm.py index d67cb3233b4c..7cd1140a1507 100644 --- a/tests/python/codegen/test_target_codegen_arm.py +++ b/tests/python/codegen/test_target_codegen_arm.py @@ -18,7 +18,7 @@ import tvm from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_popcount(): @@ -34,7 +34,7 @@ def check_correct_assembly(type, elements, counts): class Module: @T.prim_func def main(A: T.Buffer((elements,), type), B: T.Buffer((elements,), type)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in T.vectorized(elements): with T.sblock("B"): v_i = T.axis.spatial(elements, i) @@ -42,7 +42,7 @@ def main(A: T.Buffer((elements,), type), B: T.Buffer((elements,), type)): T.writes(B[v_i]) B[v_i] = T.popcount(A[v_i]) - f = tvm.tir.build(Module, target=target) + f = tvm.tirx.build(Module, target=target) # Verify we see the correct number of vpaddl and vcnt instructions in the assembly assembly = f.inspect_source("asm") matches = re.findall("vpaddl", assembly) @@ -70,7 +70,7 @@ def check_correct_assembly(N): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, C: T.Buffer((N,), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) K = T.int32(is_size_var=True) A = T.match_buffer(var_A, (K, N), "int8") B = T.match_buffer(var_B, (K, N), "int8") @@ -86,7 +86,7 @@ def main(var_A: T.handle, var_B: T.handle, C: T.Buffer((N,), "int32")): "int32", B[v_rv, v_n] ) - f = tvm.tir.build(Module, target=target) + f = tvm.tirx.build(Module, target=target) # Verify we see the correct number of vmlal.s16 instructions assembly = f.inspect_source("asm") @@ -103,7 +103,7 @@ def check_broadcast_correct_assembly(N): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, C: T.Buffer((N,), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) K = T.int32(is_size_var=True) A = T.match_buffer(var_A, (K, N), "int8") B = T.match_buffer(var_B, (K,), "int8") @@ -119,7 +119,7 @@ def main(var_A: T.handle, var_B: T.handle, C: T.Buffer((N,), "int32")): "int32", B[v_rv] ) - f = tvm.tir.build(Module, target=target) + f = tvm.tirx.build(Module, target=target) # Verify we see the correct number of vmlal.s16 instructions assembly = f.inspect_source("asm") diff --git a/tests/python/codegen/test_target_codegen_blob.py b/tests/python/codegen/test_target_codegen_blob.py index 4a8e2dfdf949..41339a4cd36b 100644 --- a/tests/python/codegen/test_target_codegen_blob.py +++ b/tests/python/codegen/test_target_codegen_blob.py @@ -24,7 +24,7 @@ import tvm.testing from tvm.contrib import cc, popen_pool, tar, utils from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.testing.uses_gpu diff --git a/tests/python/codegen/test_target_codegen_bool.py b/tests/python/codegen/test_target_codegen_bool.py index b69fa15e7cf9..0d0a5f79d96b 100644 --- a/tests/python/codegen/test_target_codegen_bool.py +++ b/tests/python/codegen/test_target_codegen_bool.py @@ -21,7 +21,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.testing.uses_gpu @@ -34,7 +34,7 @@ def main( B: T.Buffer((32,), "float32"), D: T.Buffer((32,), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) C = T.sblock_alloc_buffer((32,), "bool") for i0_0 in T.thread_binding(8, thread="blockIdx.x"): for i0_1 in T.thread_binding(4, thread="blockIdx.x"): @@ -59,7 +59,7 @@ def main( B: T.Buffer((32,), "float32"), D: T.Buffer((32,), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) C = T.sblock_alloc_buffer((32,), "bool") for i0 in range(32): with T.sblock("C"): diff --git a/tests/python/codegen/test_target_codegen_c_host.py b/tests/python/codegen/test_target_codegen_c_host.py index 42a6034a0529..d021cd46e75b 100644 --- a/tests/python/codegen/test_target_codegen_c_host.py +++ b/tests/python/codegen/test_target_codegen_c_host.py @@ -21,7 +21,7 @@ import tvm.testing from tvm.contrib import utils from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_add(): @@ -35,7 +35,7 @@ def test_fadd( B: T.Buffer((1024,), "float32"), C: T.Buffer((1024,), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0 in range(1024): with T.sblock("C"): v_i0 = T.axis.spatial(1024, i0) @@ -71,7 +71,7 @@ def test_reinterpret( A: T.Buffer((1024,), "int32"), B: T.Buffer((1024,), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0 in range(1024): with T.sblock("B"): v_i0 = T.axis.spatial(1024, i0) @@ -106,7 +106,7 @@ def test_ceil( A: T.Buffer((1024,), "float32"), B: T.Buffer((1024,), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0 in range(1024): with T.sblock("B"): v_i0 = T.axis.spatial(1024, i0) @@ -141,7 +141,7 @@ def test_floor( A: T.Buffer((1024,), "float32"), B: T.Buffer((1024,), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0 in range(1024): with T.sblock("B"): v_i0 = T.axis.spatial(1024, i0) @@ -176,7 +176,7 @@ def test_round( A: T.Buffer((1024,), "float32"), B: T.Buffer((1024,), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0 in range(1024): with T.sblock("B"): v_i0 = T.axis.spatial(1024, i0) @@ -213,7 +213,7 @@ def subroutine(A_data: T.handle("float32")): A = T.decl_buffer(1, dtype="float32", data=A_data) A[0] = 42.0 - built = tvm.tir.build(Module, target="c") + built = tvm.tirx.build(Module, target="c") source = built.inspect_source() assert source.count("__tvm_ffi_main(void*") == 2, ( diff --git a/tests/python/codegen/test_target_codegen_cross_llvm.py b/tests/python/codegen/test_target_codegen_cross_llvm.py index 8d1d45cd04f8..b782391fb9c4 100644 --- a/tests/python/codegen/test_target_codegen_cross_llvm.py +++ b/tests/python/codegen/test_target_codegen_cross_llvm.py @@ -27,7 +27,7 @@ from tvm import rpc from tvm.contrib import cc, utils from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T @I.ir_module @@ -38,7 +38,7 @@ def main( B: T.Buffer((1024,), "float32"), C: T.Buffer((1024,), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0_0 in T.parallel(256): for i0_1 in T.vectorized(4): with T.sblock("C"): @@ -66,7 +66,7 @@ def build_arm(): print(f"Skip because {target} is not enabled..") return temp = utils.tempdir() - f = tvm.tir.build(AddModule, target=target) + f = tvm.tirx.build(AddModule, target=target) path = temp.relpath("myadd.o") f.write_to_file(path) verify_elf(path, 0x28) diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index 60b387b98eb2..31fa12d2cdc3 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -25,7 +25,7 @@ import tvm.testing from tvm.contrib.nvcc import have_bf16, have_fp16, have_int8 from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T @pytest.fixture(autouse=True, params=["nvcc", "nvrtc"]) @@ -66,14 +66,14 @@ def check_cuda(dtype, n, lanes): print("skip because gpu does not support int8") return vec_dtype = f"{dtype}x{lanes}" - one = tvm.tir.const(1, vec_dtype) + one = tvm.tirx.const(1, vec_dtype) num_blocks = (n + num_thread - 1) // num_thread @I.ir_module class Module: @T.prim_func def main(A: T.Buffer((n,), vec_dtype), B: T.Buffer((n,), vec_dtype)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(num_blocks, thread="blockIdx.x"): for i_1 in T.thread_binding(num_thread, thread="threadIdx.x"): with T.sblock("B"): @@ -130,13 +130,13 @@ def np_bf162np_float(arr): def check_cuda(n, lanes): vec_dtype = f"bfloat16x{lanes}" num_blocks = n // num_thread - one = tvm.tir.Broadcast(tvm.tir.const(1, "bfloat16"), lanes) + one = tvm.tirx.Broadcast(tvm.tirx.const(1, "bfloat16"), lanes) @I.ir_module class Module: @T.prim_func def main(A: T.Buffer((n,), vec_dtype), B: T.Buffer((n,), vec_dtype)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(num_blocks, thread="blockIdx.x"): for i_1 in T.thread_binding(num_thread, thread="threadIdx.x"): with T.sblock("B"): @@ -146,7 +146,7 @@ def main(A: T.Buffer((n,), vec_dtype), B: T.Buffer((n,), vec_dtype)): B[v_i] = A[v_i] + one with tvm.transform.PassContext( - disabled_pass=["tir.BF16Promote", "tir.BF16CastElimination", "tir.BF16TypeLowering"] + disabled_pass=["tirx.BF16Promote", "tirx.BF16CastElimination", "tirx.BF16TypeLowering"] ): fun = tvm.compile(Module, target="cuda") dev = tvm.cuda(0) @@ -185,7 +185,7 @@ def main( C: T.Buffer((n,), "int32"), D: T.Buffer((n,), "int32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(num_blocks, thread="blockIdx.x"): for i_1 in T.thread_binding(num_thread, thread="threadIdx.x"): with T.sblock("D"): @@ -225,7 +225,7 @@ def check_cuda(dtype, n, lanes): class Module: @T.prim_func def main(A: T.Buffer((n,), vec_dtype), B: T.Buffer((n,), vec_dtype)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(num_blocks, thread="blockIdx.x"): for i_1 in T.thread_binding(num_thread, thread="threadIdx.x"): with T.sblock("B"): @@ -255,13 +255,13 @@ def test_cuda_make_int8(): def check_cuda(n, value, lanes): dtype = "int8" dev = tvm.cuda(0) - const_value = tvm.tir.const(value, dtype=dtype) + const_value = tvm.tirx.const(value, dtype=dtype) @I.ir_module class Module: @T.prim_func def main(A: T.Buffer((n, lanes), dtype)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in T.thread_binding(n, thread="blockIdx.x"): for j in T.vectorized(lanes): with T.sblock("A"): @@ -294,13 +294,13 @@ def test_cuda_inf_nan(): target = "cuda" def check_inf_nan(dev, n, value, dtype): - inf_value = tvm.tir.const(value, dtype=dtype) + inf_value = tvm.tirx.const(value, dtype=dtype) @I.ir_module class Module: @T.prim_func def main(A: T.Buffer((n,), dtype), C: T.Buffer((n,), dtype)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(1, thread="blockIdx.x"): for i_1 in T.thread_binding(8, thread="threadIdx.x"): with T.sblock("C"): @@ -334,7 +334,7 @@ def sched(nthd): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n, m = T.int32(), T.int32() A = T.match_buffer(var_A, (n, m)) B = T.match_buffer(var_B, (n,)) @@ -378,7 +378,7 @@ def sched(nthdx, nthdy): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n, k0, k1 = T.int32(), T.int32(), T.int32() A = T.match_buffer(var_A, (n, k0, k1)) B = T.match_buffer(var_B, (n,)) @@ -434,7 +434,7 @@ def test_cuda_reduction_binding(): class Module: @T.prim_func def main(A: T.Buffer((96, 32), "float32"), B: T.Buffer((96,), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for k in range(32): for m_0 in T.thread_binding(3, thread="blockIdx.x"): for m_1 in range(32): @@ -456,13 +456,13 @@ def test_cuda_const_float_to_half(): # This import is required to use nvcc to perform code gen; # otherwise it is found that the code gen is done by nvrtc. - half_const = tvm.tir.const(0.5, dtype="float16") + half_const = tvm.tirx.const(0.5, dtype="float16") @I.ir_module class Module: @T.prim_func def main(a: T.Buffer((2, 3, 4), "float16"), C: T.Buffer((2, 3, 4), "bool")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_j_k_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i_j_k_fused_1 in T.thread_binding(64, thread="threadIdx.x"): with T.sblock("C"): @@ -498,7 +498,7 @@ def test_cuda_floordiv_with_vectorization(): class Module: @T.prim_func def main(A: T.Buffer((256,), "float32"), B: T.Buffer((256,), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(1, thread="blockIdx.x"): for i_1_0 in T.thread_binding(64, thread="threadIdx.x"): for i_1_1 in T.vectorized(4): @@ -531,7 +531,7 @@ def test_cuda_floormod_with_vectorization(): class Module: @T.prim_func def main(A: T.Buffer((256,), "float32"), B: T.Buffer((256,), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(1, thread="blockIdx.x"): for i_1_0 in T.thread_binding(64, thread="threadIdx.x"): for i_1_1 in T.vectorized(4): @@ -567,7 +567,7 @@ def check(t0, t1, factor): class Module: @T.prim_func def main(A: T.Buffer((n,), t0), B: T.Buffer((n,), t1), C: T.Buffer((n,), t0)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(num_thread, thread="threadIdx.x"): for i_1 in T.vectorized(factor): with T.sblock("C"): @@ -633,7 +633,7 @@ def sched(compute_fn, dtype, n=128): class Module: @T.prim_func def main(A: T.Buffer((n,), dtype), B: T.Buffer((n,), dtype)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0_0 in T.thread_binding(1, thread="blockIdx.x"): for i0_1_0 in T.thread_binding(32, thread="threadIdx.x"): for i0_1_1_0 in range(1): @@ -651,25 +651,25 @@ def main(A: T.Buffer((n,), dtype), B: T.Buffer((n,), dtype)): @tvm.testing.requires_cuda def test_vectorized_intrin1(): test_funcs = [ - (tvm.tir.floor, lambda x: np.floor(x)), - (tvm.tir.ceil, lambda x: np.ceil(x)), - (tvm.tir.trunc, lambda x: np.trunc(x)), - (tvm.tir.abs, lambda x: np.fabs(x)), - (tvm.tir.round, lambda x: np.round(x)), - (tvm.tir.exp, lambda x: np.exp(x)), - (tvm.tir.exp2, lambda x: np.exp2(x)), - (tvm.tir.exp10, lambda x: np.power(10, x)), - (tvm.tir.log, lambda x: np.log(x)), - (tvm.tir.log2, lambda x: np.log2(x)), - (tvm.tir.log10, lambda x: np.log10(x)), - (tvm.tir.tan, lambda x: np.tan(x)), - (tvm.tir.cos, lambda x: np.cos(x)), - (tvm.tir.cosh, lambda x: np.cosh(x)), - (tvm.tir.sin, lambda x: np.sin(x)), - (tvm.tir.sinh, lambda x: np.sinh(x)), - (tvm.tir.atan, lambda x: np.arctan(x)), - (tvm.tir.tanh, lambda x: np.tanh(x)), - (tvm.tir.sqrt, lambda x: np.sqrt(x)), + (tvm.tirx.floor, lambda x: np.floor(x)), + (tvm.tirx.ceil, lambda x: np.ceil(x)), + (tvm.tirx.trunc, lambda x: np.trunc(x)), + (tvm.tirx.abs, lambda x: np.fabs(x)), + (tvm.tirx.round, lambda x: np.round(x)), + (tvm.tirx.exp, lambda x: np.exp(x)), + (tvm.tirx.exp2, lambda x: np.exp2(x)), + (tvm.tirx.exp10, lambda x: np.power(10, x)), + (tvm.tirx.log, lambda x: np.log(x)), + (tvm.tirx.log2, lambda x: np.log2(x)), + (tvm.tirx.log10, lambda x: np.log10(x)), + (tvm.tirx.tan, lambda x: np.tan(x)), + (tvm.tirx.cos, lambda x: np.cos(x)), + (tvm.tirx.cosh, lambda x: np.cosh(x)), + (tvm.tirx.sin, lambda x: np.sin(x)), + (tvm.tirx.sinh, lambda x: np.sinh(x)), + (tvm.tirx.atan, lambda x: np.arctan(x)), + (tvm.tirx.tanh, lambda x: np.tanh(x)), + (tvm.tirx.sqrt, lambda x: np.sqrt(x)), ] def run_test(tvm_intrin, np_func, dtype): @@ -678,13 +678,13 @@ def run_test(tvm_intrin, np_func, dtype): return # set of intrinsics does not support fp16 yet. skip_set = { - tvm.tir.abs, - tvm.tir.round, - tvm.tir.tan, - tvm.tir.atan, - tvm.tir.tanh, - tvm.tir.cosh, - tvm.tir.sinh, + tvm.tirx.abs, + tvm.tirx.round, + tvm.tirx.tan, + tvm.tirx.atan, + tvm.tirx.tanh, + tvm.tirx.cosh, + tvm.tirx.sinh, } if dtype == "float16" and tvm_intrin in skip_set: print(f"Skip because '{tvm_intrin.__name__}' does not support fp16 yet") @@ -706,10 +706,10 @@ def run_test(tvm_intrin, np_func, dtype): @tvm.testing.requires_gpu @tvm.testing.requires_cuda def test_vectorized_intrin2(dtype="float32"): - c2 = tvm.tir.const(2, dtype=dtype) + c2 = tvm.tirx.const(2, dtype=dtype) test_funcs = [ - (tvm.tir.power, lambda x: np.power(x, 2.0)), - (tvm.tir.fmod, lambda x: np.fmod(x, 2.0)), + (tvm.tirx.power, lambda x: np.power(x, 2.0)), + (tvm.tirx.fmod, lambda x: np.fmod(x, 2.0)), ] def run_test(tvm_intrin, np_func): @@ -737,7 +737,7 @@ def ref_popcount(x): def run_test(dtype): n = 128 - f = sched(lambda x: tvm.tir.popcount(x), dtype, n) + f = sched(lambda x: tvm.tirx.popcount(x), dtype, n) dev = tvm.cuda(0) a = tvm.runtime.tensor(np.random.randint(0, 100000, size=n).astype(dtype), dev) b = tvm.runtime.tensor(np.zeros(shape=(n,)).astype(dtype), dev) @@ -758,7 +758,7 @@ def check_cuda(dtype, n, l, padding, lanes): return dev = tvm.cuda(0) - zero = tvm.tir.const(0, dtype) + zero = tvm.tirx.const(0, dtype) dim0 = n // lanes dim1 = l + 2 * padding @@ -766,7 +766,7 @@ def check_cuda(dtype, n, l, padding, lanes): class Module: @T.prim_func def main(A: T.Buffer((n, l), dtype), B: T.Buffer((dim0, dim1, lanes), dtype)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in T.thread_binding(dim0, thread="blockIdx.x"): for j in T.thread_binding(dim1, thread="threadIdx.x"): for k in T.vectorized(lanes): @@ -809,7 +809,7 @@ def build(N, C_N, offset): class Module: @T.prim_func def main(A: T.Buffer((N,), "float16"), C: T.Buffer((C_N,), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(C_N // 2, thread="threadIdx.x"): for i_1 in T.vectorized(2): with T.sblock("C"): @@ -818,7 +818,7 @@ def main(A: T.Buffer((N,), "float16"), C: T.Buffer((C_N,), "float16")): T.writes(C[v_i]) C[v_i] = A[v_i + offset] - f = tvm.tir.build(Module, target="cuda") + f = tvm.tirx.build(Module, target="cuda") kernel_source = f.imports[0].inspect_source() dev = tvm.cuda() @@ -898,7 +898,7 @@ def test_invalid_reinterpret(): @T.prim_func def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,), "uint8")) -> None: for tx in T.thread_binding(4, "threadIdx.x"): - B[tx] = T.call_intrin("uint8", "tir.reinterpret", A[tx]) + B[tx] = T.call_intrin("uint8", "tirx.reinterpret", A[tx]) with pytest.raises(tvm.error.TVMError): tvm.compile(func, target="cuda") @@ -997,7 +997,7 @@ def main( C[bx, tx] = Module.add(A[bx, tx], B[bx, tx]) # Call from device # 1. If we set host to llvm, it will raise an error of - # "the tir.ret should be transformed to return zero before the llvm code generation." + # "the tirx.ret should be transformed to return zero before the llvm code generation." # Need to revisit this. # 2. We set a dummy mcpu value for testing purpose, # in order to avoid checking a function is host or device based on the "cpu" substring. diff --git a/tests/python/codegen/test_target_codegen_cuda_fp4.py b/tests/python/codegen/test_target_codegen_cuda_fp4.py index 2730dab20157..3088a67873d4 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp4.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp4.py @@ -23,7 +23,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T try: from ml_dtypes import float4_e2m1fn @@ -47,7 +47,7 @@ def main( B: T.Buffer((vector_length,), native_dtype), C: T.Buffer((vector_length,), native_dtype), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(vector_length // 32, thread="blockIdx.x"): for i_1 in T.thread_binding(32, thread="threadIdx.x"): with T.sblock("C"): @@ -117,7 +117,7 @@ def main( A: T.Buffer((n // num_elem_per_storage,), "uint32"), B: T.Buffer((n,), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(num_blocks, thread="blockIdx.x"): for i_1 in T.thread_binding(32, thread="threadIdx.x"): for i_2 in T.vectorized(vector_length): @@ -156,7 +156,7 @@ def main( A: T.Buffer((n // num_elem_per_storage,), "uint32"), B: T.Buffer((n,), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(num_blocks, thread="blockIdx.x"): for i_1 in T.thread_binding(32, thread="threadIdx.x"): for i_2 in T.vectorized(vector_length): diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index b35e5d793a3b..730349973313 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -23,11 +23,11 @@ import tvm import tvm.testing -from tvm import DataType, DataTypeCode, IRModule, relax, te, tir, topi +from tvm import DataType, DataTypeCode, IRModule, relax, te, tirx, topi from tvm.s_tir import dlight as dl from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T try: import ml_dtypes @@ -55,7 +55,7 @@ def main( B: T.Buffer((64,), dtype), C: T.Buffer((64,), dtype), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(2, thread="blockIdx.x"): for i_1 in T.thread_binding(32, thread="threadIdx.x"): with T.sblock("C"): @@ -70,7 +70,7 @@ def main( mod = _create_mod(dtype) target = "cuda" - fadd = tvm.tir.build(mod, target=target) + fadd = tvm.tirx.build(mod, target=target) cuda_src = fadd.imports[0].inspect_source() assert nv_dtype in cuda_src, f"{nv_dtype} datatype not found in generated CUDA" @@ -106,7 +106,7 @@ def main( R: T.Buffer((length,), packed_dtype), B: T.Buffer((length,), native_dtype), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(2, thread="blockIdx.x"): for i_1 in T.thread_binding(32, thread="threadIdx.x"): with T.sblock("R"): @@ -169,7 +169,7 @@ def main( B: T.Buffer((64,), native_dtype), C: T.Buffer((64,), native_dtype), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(2, thread="blockIdx.x"): for i_1 in T.thread_binding(32, thread="threadIdx.x"): with T.sblock("C"): @@ -185,7 +185,7 @@ def main( mod = _create_mod(native_dtype, promoted_dtype) target = "cuda" - fadd = tvm.tir.build(mod, target=target) + fadd = tvm.tirx.build(mod, target=target) cuda_src = fadd.imports[0].inspect_source() dev = tvm.device(target, 0) @@ -302,7 +302,7 @@ def main( B: T.Buffer((64,), "float16x4"), C: T.Buffer((64,), "float16x4"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(2, thread="blockIdx.x"): for i_1 in T.thread_binding(32, thread="threadIdx.x"): with T.sblock("C"): @@ -417,7 +417,7 @@ def create_dequantize_func( @classmethod def quantize_fp8x4_e4m3( # pylint: disable=too-many-locals cls, - weight_shape: list[tir.PrimExpr], + weight_shape: list[tirx.PrimExpr], model_dtype, quantize_dtype, storage_dtype, @@ -428,24 +428,24 @@ def quantize_fp8x4_e4m3( # pylint: disable=too-many-locals output_transpose: bool = False, ) -> tuple[te.Tensor, te.Tensor]: """Group quantization for weight tensor, defined in tensor expression.""" - max_int = tir.const(max_int_value, model_dtype) + max_int = tirx.const(max_int_value, model_dtype) shape = weight_shape # pylint: disable=invalid-name axis = axis if axis >= 0 else len(shape) + axis k = shape[axis] quantize_dtype = DataType(quantize_dtype) # compute scale per group r = te.reduce_axis((0, group_size), name="r") # pylint: disable=invalid-name - num_group = tir.ceildiv(k, group_size) + num_group = tirx.ceildiv(k, group_size) # (4096, 4096) -> quantize axis = 0, group size = 32 -> (128, 4096) # for channel quant group_size = 4096 -> (1, 4096) scale_shape = (*shape[:axis], num_group, *shape[axis + 1 :]) def compute_scale(weight: te.Tensor): - min_scaling_factor = tir.const(1.0 / (max_int_value * 512.0), model_dtype) + min_scaling_factor = tirx.const(1.0 / (max_int_value * 512.0), model_dtype) max_abs = te.compute( shape=scale_shape, fcompute=lambda *idx: te.max( - tir.if_then_else( + tirx.if_then_else( idx[axis] * group_size + r < k, te.abs(weight(*idx[:axis], idx[axis] * group_size + r, *idx[axis + 1 :])), te.min_value(model_dtype), @@ -504,7 +504,7 @@ def compute_transpose(quantized_weight: te.Tensor, scale: te.Tensor): @classmethod def dequantize_fp8x4_e4m3( # pylint: disable=too-many-locals cls, - packed_weight_shape: list[tir.PrimExpr], + packed_weight_shape: list[tirx.PrimExpr], scale_shape, dequant_shape, model_dtype, @@ -613,7 +613,7 @@ def dequant( scale: T.Buffer(scale_shape, model_dtype), dequantize: T.Buffer(out_shape, model_dtype), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(packed_weight_shape[0]), T.int64(packed_weight_shape[1])): with T.sblock("dequantize"): @@ -700,7 +700,7 @@ def compile_quant_and_dequant_by_scale( def print_cuda(target, mod, name=None): if name: mod = mod[name] - f = tvm.tir.build(mod, target=target) + f = tvm.tirx.build(mod, target=target) cuda_src = f.imports[0].inspect_source() print(cuda_src) @@ -871,7 +871,7 @@ def moe_dequantize_gemv( indptr: T.Buffer((1, 2), "int32"), o: T.Buffer((2, spatial_size), "float16"), ): - T.func_attr({"op_pattern": 4, "tir.noalias": True}) + T.func_attr({"op_pattern": 4, "tirx.noalias": True}) num_seq = T.int64() x = T.match_buffer(x_handle, (num_seq, reduce_size), "float16") for expert_id in T.thread_binding(2, thread="blockIdx.y"): @@ -989,7 +989,7 @@ def main( mod = _create_mod(vec_length, dtype) device = tvm.cuda() target = tvm.target.Target.from_device(device) - f = tvm.tir.build(mod, target=target) + f = tvm.tirx.build(mod, target=target) a_np = np.random.rand(128).astype("float8_e4m3fn") b_np = np.random.rand(128).astype(dtype) diff --git a/tests/python/codegen/test_target_codegen_device.py b/tests/python/codegen/test_target_codegen_device.py index 87b09a6ed2c8..36586eb37c0e 100644 --- a/tests/python/codegen/test_target_codegen_device.py +++ b/tests/python/codegen/test_target_codegen_device.py @@ -19,19 +19,19 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.testing.requires_gpu def test_large_uint_imm(): value = (1 << 63) + 123 - value_const = tvm.tir.const(value, "uint64") + value_const = tvm.tirx.const(value, "uint64") @I.ir_module class Module: @T.prim_func def main(A: T.Buffer((12,), "uint64")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0_0 in T.thread_binding(6, thread="blockIdx.x"): for i0_1 in T.thread_binding(2, thread="threadIdx.x"): with T.sblock("A"): @@ -61,7 +61,7 @@ def test_add_pipeline(): class Module: @T.prim_func def main(var_A: T.handle, B: T.Buffer((), "float32"), var_D: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int32(is_size_var=True) A = T.match_buffer(var_A, (n,)) D = T.match_buffer(var_D, (n,)) @@ -88,7 +88,7 @@ def check_target(device, host): return dev = tvm.device(device, 0) target = tvm.target.Target(device, host) - mhost = tvm.tir.build(Module, target=target) + mhost = tvm.tirx.build(Module, target=target) f = mhost.main # launch the kernel. n = 1027 diff --git a/tests/python/codegen/test_target_codegen_extern.py b/tests/python/codegen/test_target_codegen_extern.py index efe043452d75..50b6996ec301 100644 --- a/tests/python/codegen/test_target_codegen_extern.py +++ b/tests/python/codegen/test_target_codegen_extern.py @@ -20,7 +20,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.testing.uses_gpu diff --git a/tests/python/codegen/test_target_codegen_gpu_common.py b/tests/python/codegen/test_target_codegen_gpu_common.py index 94a1c2ea588e..baf069fc3a77 100644 --- a/tests/python/codegen/test_target_codegen_gpu_common.py +++ b/tests/python/codegen/test_target_codegen_gpu_common.py @@ -22,7 +22,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.testing.requires_gpu @@ -45,7 +45,7 @@ def main( A: T.Buffer((n,), dtype), B: T.Buffer((n,), dtype), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0 in T.thread_binding(n, thread="threadIdx.x"): with T.sblock("B"): v_i0 = T.axis.spatial(n, i0) diff --git a/tests/python/codegen/test_target_codegen_hexagon.py b/tests/python/codegen/test_target_codegen_hexagon.py index b6561b740ded..a7e5b7003ef8 100644 --- a/tests/python/codegen/test_target_codegen_hexagon.py +++ b/tests/python/codegen/test_target_codegen_hexagon.py @@ -23,7 +23,7 @@ import tvm.contrib.hexagon as hexagon import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T @pytest.fixture(autouse=True) @@ -48,7 +48,7 @@ def main( A: T.Buffer((128,), "uint8"), A_1: T.Buffer((128,), "uint8"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in range(128): with T.sblock("C"): v_i = T.axis.spatial(128, i) @@ -70,7 +70,7 @@ def test_llvm_target_features(): class Module: @T.prim_func def add_one(C: T.Buffer((128,), "int32"), A: T.Buffer((128,), "uint8")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in range(128): with T.sblock("C"): v_i = T.axis.spatial(128, i) @@ -103,7 +103,7 @@ def test_llvm_options(): class Module: @T.prim_func def main(compute: T.Buffer((10,), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for _ in range(10): with T.sblock("compute"): v__ = T.axis.spatial(10, _) diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index 38b094d97f83..4612f34557b8 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -25,7 +25,7 @@ import tvm.testing from tvm.contrib import clang, utils from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target.codegen import llvm_get_intrinsic_name, llvm_lookup_intrinsic_id @@ -36,7 +36,7 @@ class Module: @T.prim_func def main(A: T.handle("float32")): A_buf = T.decl_buffer((4,), "float32", data=A) - T.evaluate(T.Call("void", "tir.prefetch", [T.address_of(A_buf[0]), 0, 3, 1])) + T.evaluate(T.Call("void", "tirx.prefetch", [T.address_of(A_buf[0]), 0, 3, 1])) fcode = tvm.compile(Module) @@ -69,7 +69,7 @@ def test_llvm_overloaded_intrin(): return # int1 is the type for the is_zero_undef parameter - int1_zero = tvm.tir.const(0, "int1") + int1_zero = tvm.tirx.const(0, "int1") @I.ir_module class Module: @@ -98,13 +98,13 @@ def main(A: T.handle("uint8x8")): @tvm.testing.requires_llvm def test_llvm_large_uintimm(): value = (1 << 63) + 123 - large_val = tvm.tir.const(value, "uint64") + large_val = tvm.tirx.const(value, "uint64") @I.ir_module class Module: @T.prim_func def main(A: T.Buffer((), "uint64")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) with T.sblock("A"): vi = T.axis.spatial(1, 0) T.reads() @@ -124,7 +124,7 @@ def test_llvm_multi_parallel(): class Module: @T.prim_func def main(A: T.Buffer((128,), "float32"), C: T.Buffer((128,), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) B = T.sblock_alloc_buffer((128,)) for i0_0_0 in T.parallel(1): for ax0 in range(128): @@ -157,7 +157,7 @@ def check_llvm(nn, base): class Module: @T.prim_func def main(A: T.Buffer((nn + base,), "float32"), C: T.Buffer((nn,), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.parallel((nn + 3) // 4): for i_1 in T.vectorized(4): with T.sblock("C"): @@ -186,7 +186,7 @@ def test_llvm_vadd_pipeline(): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int32(is_size_var=True) A = T.match_buffer(var_A, (n,)) B = T.match_buffer(var_B, (n,)) @@ -220,7 +220,7 @@ def main( A: T.Buffer((nn + base, stride), "float32"), C: T.Buffer((nn, stride), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.parallel((nn + 3) // 4): for i_1 in T.vectorized(4): for j in range(stride): @@ -242,7 +242,7 @@ def main( check_llvm(64, 0, 2) check_llvm(4, 0, 1) - with tvm.transform.PassContext(config={"tir.noalias": False}): + with tvm.transform.PassContext(config={"tirx.noalias": False}): check_llvm(4, 0, 3) @@ -252,7 +252,7 @@ def test_llvm_temp_space(): class Module: @T.prim_func def main(A: T.Buffer((1024,), "float32"), C: T.Buffer((1024,), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) B = T.sblock_alloc_buffer((1024,)) for i in range(1024): with T.sblock("B"): @@ -282,7 +282,7 @@ def test_multiple_func(): class Module: @T.prim_func def fadd1(var_A: T.handle, var_B: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int32(is_size_var=True) A = T.match_buffer(var_A, (n,)) B = T.match_buffer(var_B, (n,)) @@ -296,7 +296,7 @@ def fadd1(var_A: T.handle, var_B: T.handle, var_C: T.handle): @T.prim_func def fadd2(var_A: T.handle, var_B: T.handle, var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int32(is_size_var=True) A = T.match_buffer(var_A, (n,)) B = T.match_buffer(var_B, (n,)) @@ -327,7 +327,7 @@ def test_llvm_condition(): class Module: @T.prim_func def main(A: T.Buffer((64,), "float32"), C: T.Buffer((64,), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in range(64): with T.sblock("C"): v_i = T.axis.spatial(64, i) @@ -353,7 +353,7 @@ def test_llvm_bool(): class Module: @T.prim_func def main(A: T.Buffer((64,), "int32"), C: T.Buffer((64,), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in range(64): with T.sblock("C"): v_i = T.axis.spatial(64, i) @@ -377,7 +377,7 @@ def test_llvm_cast_float_to_bool(): class Module: @T.prim_func def main(A: T.Buffer((4,), "float32"), C: T.Buffer((4,), "bool")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in range(4): with T.sblock("C"): v_i = T.axis.spatial(4, i) @@ -405,7 +405,7 @@ def main( scale: T.Buffer((), "float32"), compute: T.Buffer((), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) C = T.sblock_alloc_buffer(()) for k in range(64): with T.sblock("C"): @@ -442,7 +442,7 @@ def main( scale: T.Buffer((), "float32"), compute: T.Buffer((), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) C = T.sblock_alloc_buffer(()) for k in range(64): with T.sblock("C"): @@ -459,7 +459,7 @@ def main( compute[()] = C[()] + T.float32(1.0) n = 64 - with tvm.transform.PassContext(config={"tir.instrument_bound_checkers": True}): + with tvm.transform.PassContext(config={"tirx.instrument_bound_checkers": True}): f = tvm.compile(Module, target="llvm") dev = tvm.cpu(0) a = tvm.runtime.tensor(np.random.randint(0, 2, size=(n,)).astype("float32"), dev) @@ -476,7 +476,7 @@ def test_alignment(): class Module: @T.prim_func def test_alignment(A: T.Buffer((1024,), "float32"), B: T.Buffer((1024,), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in range(128): for i_1 in T.vectorized(8): with T.sblock("B"): @@ -485,7 +485,7 @@ def test_alignment(A: T.Buffer((1024,), "float32"), B: T.Buffer((1024,), "float3 T.writes(B[v_i]) B[v_i] = A[v_i] * T.float32(3.0) - f = tvm.tir.build(Module, target="llvm") + f = tvm.tirx.build(Module, target="llvm") lines = f.inspect_source().split("\n") @@ -526,14 +526,14 @@ def check(start, end, dstart, dend, dtype, floor_div=False): a_size = end - start + 1 b_size = dend - dstart + 1 - div_fn = tvm.tir.floordiv if floor_div else tvm.tir.truncdiv - mod_fn = tvm.tir.floormod if floor_div else tvm.tir.truncmod + div_fn = tvm.tirx.floordiv if floor_div else tvm.tirx.truncdiv + mod_fn = tvm.tirx.floormod if floor_div else tvm.tirx.truncmod # Build clipping helpers — capture TIR const values from env - _start = tvm.tir.const(start, dtype) - _end = tvm.tir.const(end, dtype) - _dstart = tvm.tir.const(dstart, dtype) - _dend = tvm.tir.const(dend, dtype) + _start = tvm.tirx.const(start, dtype) + _end = tvm.tirx.const(end, dtype) + _dstart = tvm.tirx.const(dstart, dtype) + _dend = tvm.tirx.const(dend, dtype) if start == end: clipa = lambda x: _start @@ -554,7 +554,7 @@ def main( D: T.Buffer((a_size, b_size), dtype), M: T.Buffer((a_size, b_size), dtype), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j in T.grid(a_size, b_size): with T.sblock("D"): v_i, v_j = T.axis.remap("SS", [i, j]) @@ -664,7 +664,7 @@ def test_llvm_fp_math(): class RecipModule: @T.prim_func def main(var_A: T.handle, var_B: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int32(is_size_var=True) A = T.match_buffer(var_A, (n,)) B = T.match_buffer(var_B, (n,)) @@ -689,7 +689,7 @@ def main(var_A: T.handle, var_B: T.handle): class SigmoidModule: @T.prim_func def main(var_A: T.handle, var_B: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int32(is_size_var=True) A = T.match_buffer(var_A, (n,)) B = T.match_buffer(var_B, (n,)) @@ -719,7 +719,7 @@ def main( B: T.Buffer((1024,), "float32"), C: T.Buffer((1024,), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0_0 in T.parallel(256): for i0_1 in T.vectorized(4): with T.sblock("C"): @@ -778,7 +778,7 @@ def check_llvm_ir(): "fadd2": Module["main"].with_attr("global_symbol", "fadd2"), } ) - m = tvm.tir.build(mod, target={"kind": "llvm", "mtriple": "aarch64-linux-gnu"}) + m = tvm.tirx.build(mod, target={"kind": "llvm", "mtriple": "aarch64-linux-gnu"}) ll = m.inspect_source("ll") # On non-Darwin OS, don't explicitly specify DWARF version. @@ -788,7 +788,7 @@ def check_llvm_ir(): assert re.search(r"""llvm.dbg.value""", ll) # Try Darwin, require DWARF-2 - m = tvm.tir.build(mod, target={"kind": "llvm", "mtriple": "x86_64-apple-darwin-macho"}) + m = tvm.tirx.build(mod, target={"kind": "llvm", "mtriple": "x86_64-apple-darwin-macho"}) ll = m.inspect_source("ll") assert re.search(r"""i32 4, !"Dwarf Version", i32 2""", ll) assert re.search(r"""llvm.dbg.value""", ll) @@ -810,7 +810,7 @@ def main( B: T.Buffer((32,), "bfloat16"), D: T.Buffer((32,), "bfloat16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for x in loop_kind(32): with T.sblock("D"): v_x = T.axis.spatial(32, x) @@ -845,7 +845,7 @@ def main( B: T.Buffer((32,), "bfloat16"), C: T.Buffer((32,), "bfloat16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for x in range(32): with T.sblock("compute"): v_x = T.axis.spatial(32, x) @@ -853,7 +853,7 @@ def main( T.writes(C[v_x]) C[v_x] = A[v_x] + B[v_x] - module = tvm.tir.build( + module = tvm.tirx.build( Module.with_attr("system_lib_prefix", ""), target=tvm.target.Target("llvm"), ) @@ -882,7 +882,7 @@ def Sammy(v: T.float32) -> T.float32: def Kirby(v: T.float32) -> T.float32: T.ret(T.call_extern("float32", "Fred", v)) - ir_text = tvm.tir.build(Module, target="llvm").inspect_source("ll") + ir_text = tvm.tirx.build(Module, target="llvm").inspect_source("ll") # Skip functions whose names start with _. matches = re.findall(r"^define[^@]*@([a-zA-Z][a-zA-Z0-9_]*)", ir_text, re.MULTILINE) assert matches == sorted(matches) @@ -912,7 +912,7 @@ def check_llvm(use_file): class Module: @T.prim_func def main(A: T.Buffer((10,), "float32"), B: T.Buffer((10,), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in T.serial(10, annotations={"pragma_import_llvm": import_val}): with T.sblock("B"): v_i = T.axis.spatial(10, i) @@ -941,7 +941,7 @@ def main(x: T.int32, y: T.int32, buffer: T.Buffer((1,), "int32x2")): # This will crash in LLVM codegen if CodeGenLLVM::CreateVecConcat doesn't convert # scalars to single-lane LLVM vectors. - with tvm.transform.PassContext(config={"tir.disable_assert": True}): + with tvm.transform.PassContext(config={"tirx.disable_assert": True}): m = tvm.compile(Module, target="llvm") @@ -951,7 +951,7 @@ def test_raise_exception_during_codegen(): class Module: @T.prim_func def main(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")) -> None: - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in T.parallel(4): for j in T.parallel(4): B[i, j] = A[i, j] * 2.0 @@ -972,7 +972,7 @@ def test_llvm_target_attributes(): class Module: @T.prim_func def test_func(var_A: T.handle, var_B: T.handle, var_C: T.handle, tindex: T.int32): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) A = T.match_buffer(var_A, (tindex,)) B = T.match_buffer(var_B, (tindex,)) C = T.match_buffer(var_C, (tindex,)) @@ -998,7 +998,7 @@ def test_func(var_A: T.handle, var_B: T.handle, var_C: T.handle, tindex: T.int32 "mattr": ["+avx512f"], } target = tvm.target.Target(target_llvm, host=target_llvm) - module = tvm.tir.build(Module, target=target) + module = tvm.tirx.build(Module, target=target) llvm_ir = module.inspect_source() llvm_ir_lines = llvm_ir.split("\n") @@ -1031,7 +1031,7 @@ def test_func(var_A: T.handle, var_B: T.handle, var_C: T.handle, tindex: T.int32 @tvm.testing.requires_llvm def test_llvm_assume(): """ - Check that LLVM does not error out when generating code with tir.assume. + Check that LLVM does not error out when generating code with tirx.assume. Verifying for llvm.assume being generated is not easy as the intrinsic and its related instructions get removed during optimizations """ @@ -1040,7 +1040,7 @@ def test_llvm_assume(): class Module: @T.prim_func def main(A: T.Buffer((4, 4), "int32"), B: T.Buffer((14,), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) A_1 = T.decl_buffer((16,), "int32", data=A.data) for axis0, axis1 in T.grid(4, 4): T.assume(axis0 < 3 or axis1 < 2 or A_1[axis0 * 4 + axis1] == 0) @@ -1121,7 +1121,7 @@ class Module: def main(): T.Call( "void", - tvm.ir.Op.get("tir.tvm_call_packed"), + tvm.ir.Op.get("tirx.tvm_call_packed"), ["dummy_function_name"], ) @@ -1144,7 +1144,7 @@ def test_call_packed_without_string_arg(): class Module: @T.prim_func def main(A: T.Buffer(1, "float32")): - T.Call("int32", tvm.ir.Op.get("tir.tvm_call_packed"), [A.data]) + T.Call("int32", tvm.ir.Op.get("tirx.tvm_call_packed"), [A.data]) with pytest.raises(tvm.TVMError): built = tvm.compile(Module, target="llvm") @@ -1158,7 +1158,7 @@ def test_call_extern_returning_void(): class Module: @T.prim_func def main(): - T.Call("void", tvm.ir.Op.get("tir.call_extern"), ["dummy_function_name"]) + T.Call("void", tvm.ir.Op.get("tirx.call_extern"), ["dummy_function_name"]) built = tvm.compile(Module, target="llvm") @@ -1169,7 +1169,7 @@ class Module: @T.prim_func def main(b: T.handle): B = T.match_buffer(b, [4]) - A = T.alloc_buffer((4,), annotations={"tir.volatile": True}) + A = T.alloc_buffer((4,), annotations={"tirx.volatile": True}) B[0:4] = A.vload([T.Ramp(0, 1, 4)], predicate=T.Broadcast(T.bool(True), 4)) err_msg = "The masked load intrinsic does not support declaring load as volatile." @@ -1183,7 +1183,7 @@ def test_invalid_volatile_masked_buffer_store(): class Module: @T.prim_func def main(): - A = T.alloc_buffer((4,), annotations={"tir.volatile": True}) + A = T.alloc_buffer((4,), annotations={"tirx.volatile": True}) A.vstore( [T.Ramp(0, 1, 4)], T.Broadcast(0.0, 4), diff --git a/tests/python/codegen/test_target_codegen_llvm_vla.py b/tests/python/codegen/test_target_codegen_llvm_vla.py index ae30376f6e62..6b1ea4bddef8 100644 --- a/tests/python/codegen/test_target_codegen_llvm_vla.py +++ b/tests/python/codegen/test_target_codegen_llvm_vla.py @@ -24,7 +24,7 @@ import pytest import tvm -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target.codegen import llvm_version_major @@ -42,7 +42,7 @@ }, ) def test_codegen_vscale(target): - vscale = tvm.tir.vscale() + vscale = tvm.tirx.vscale() @T.prim_func def main(A: T.Buffer((5,), "int32")): @@ -50,7 +50,7 @@ def main(A: T.Buffer((5,), "int32")): A[i] = 2 * vscale with tvm.target.Target(target): - build_mod = tvm.tir.build(main) + build_mod = tvm.tirx.build(main) llvm = build_mod.inspect_source() assert re.findall(r"llvm.vscale.i32", llvm), "No vscale in generated LLVM." @@ -74,11 +74,11 @@ def test_scalable_buffer_load_store(target): def my_func(a: T.handle, b: T.handle): A = T.match_buffer(a, (128,), "float32") B = T.match_buffer(b, (128,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + T.func_attr({"global_symbol": "my_module", "tirx.noalias": True}) B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())] with tvm.target.Target(target): - mod = tvm.tir.build(my_func) + mod = tvm.tirx.build(my_func) llvm = mod.inspect_source("ll") assert re.findall(r"load ", llvm), "No scalable load in generated LLVM." @@ -102,11 +102,11 @@ def test_scalable_broadcast(target): @T.prim_func def my_func(a: T.handle): A = T.match_buffer(a, (128,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + T.func_attr({"global_symbol": "my_module", "tirx.noalias": True}) A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale()) with tvm.target.Target(target): - mod = tvm.tir.build(my_func) + mod = tvm.tirx.build(my_func) llvm = mod.inspect_source("ll") assert re.findall( @@ -136,7 +136,7 @@ def before(a: T.handle): A[i : i + T.vscale() * 4] = T.get_active_lane_mask("uint1xvscalex4", i, 30) with tvm.target.Target(target): - out = tvm.tir.build(before) + out = tvm.tirx.build(before) ll = out.inspect_source("ll") assert "get.active.lane.mask" in ll @@ -160,14 +160,14 @@ def test_predicated_scalable_buffer(target): def before(a: T.handle, b: T.handle): A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i_0 in T.serial(T.ceildiv(16, 4 * T.vscale())): for i_1 in T.vectorized(4 * T.vscale()): if i_0 * 4 * T.vscale() + i_1 < 14: B[i_0 * 4 * T.vscale() + i_1] = A[i_0 * 4 * T.vscale() + i_1] + 1.0 with tvm.target.Target(target): - out = tvm.tir.build(before) + out = tvm.tirx.build(before) ll = out.inspect_source("ll") assert "get.active.lane.mask" in ll diff --git a/tests/python/codegen/test_target_codegen_metal.py b/tests/python/codegen/test_target_codegen_metal.py index fc51daa3d682..781fa15d01c4 100644 --- a/tests/python/codegen/test_target_codegen_metal.py +++ b/tests/python/codegen/test_target_codegen_metal.py @@ -19,7 +19,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.testing.requires_gpu @@ -35,7 +35,7 @@ def main( A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in T.thread_binding(1, thread="threadIdx.x"): with T.sblock("C"): v_i = T.axis.spatial(1, i) @@ -97,7 +97,7 @@ def main( A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0 in T.thread_binding(1, thread="threadIdx.x"): with T.sblock("C"): v_i0 = T.axis.spatial(1, i0) @@ -202,7 +202,7 @@ def compile_metal(src, target): mod = tvm.IRModule({"main": func}) - f = tvm.tir.build(mod, target="metal") + f = tvm.tirx.build(mod, target="metal") src: str = f.imports[0].inspect_source() occurrences = src.count("struct func_kernel_args_t") assert occurrences == 1, occurrences diff --git a/tests/python/codegen/test_target_codegen_opencl.py b/tests/python/codegen/test_target_codegen_opencl.py index 92a7992ec157..e3c76d475d21 100644 --- a/tests/python/codegen/test_target_codegen_opencl.py +++ b/tests/python/codegen/test_target_codegen_opencl.py @@ -20,7 +20,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T target = "opencl" @@ -33,7 +33,7 @@ def check_if_then_else(dev, n, dtype): class Module: @T.prim_func def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in T.thread_binding(1, thread="threadIdx.x"): with T.sblock("C"): v_i = T.axis.spatial(1, i) @@ -48,7 +48,7 @@ def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): ), ) - fun = tvm.tir.build(Module, target=target) + fun = tvm.tirx.build(Module, target=target) a = tvm.runtime.empty((n,), dtype, dev) c = tvm.runtime.empty((n,), dtype, dev) # Only need to test compiling here @@ -59,7 +59,7 @@ def check_select(dev, n, dtype): class Module: @T.prim_func def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in T.thread_binding(1, thread="threadIdx.x"): with T.sblock("C"): v_i = T.axis.spatial(1, i) @@ -74,7 +74,7 @@ def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): ), ) - fun = tvm.tir.build(Module, target=target) + fun = tvm.tirx.build(Module, target=target) a = tvm.runtime.empty((n,), dtype, dev) c = tvm.runtime.empty((n,), dtype, dev) # Only need to test compiling here @@ -100,7 +100,7 @@ def check_inf_nan(dev, n, value, dtype): class Module: @T.prim_func def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in T.thread_binding(1, thread="threadIdx.x"): with T.sblock("C"): v_i = T.axis.spatial(1, i) @@ -108,7 +108,7 @@ def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): T.writes(C[v_i]) C[v_i] = T.Cast(dtype, value) - fun = tvm.tir.build(Module, target=target) + fun = tvm.tirx.build(Module, target=target) a = tvm.runtime.empty((n,), dtype, dev) c = tvm.runtime.empty((n,), dtype, dev) # Only need to test compiling here @@ -132,7 +132,7 @@ def check_max(dev, n, dtype): class Module: @T.prim_func def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in T.thread_binding(1, thread="threadIdx.x"): with T.sblock("C"): v_i = T.axis.spatial(1, i) @@ -140,7 +140,7 @@ def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): T.writes(C[v_i]) C[v_i] = T.max(A[0] + T.Cast(dtype, 1), T.Cast(dtype, 0)) - fun = tvm.tir.build(Module, target=target) + fun = tvm.tirx.build(Module, target=target) a = tvm.runtime.empty((n,), dtype, dev) c = tvm.runtime.empty((n,), dtype, dev) # Only need to test compiling here @@ -162,7 +162,7 @@ def check_erf(dev, n, dtype): class Module: @T.prim_func def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0 in T.thread_binding(1, thread="threadIdx.x"): with T.sblock("C"): v_i0 = T.axis.spatial(1, i0) @@ -170,7 +170,7 @@ def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): T.writes(C[v_i0]) C[v_i0] = T.erf(A[v_i0]) - fun = tvm.tir.build(Module, target=target) + fun = tvm.tirx.build(Module, target=target) source_str = fun.imports[0].inspect_source() matches = re.findall("erf", source_str) @@ -190,7 +190,7 @@ def test_opencl_type_casting(): class Module: @T.prim_func def main(C: T.Buffer((32,), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(8, thread="threadIdx.x"): for i_1 in T.vectorized(4): with T.sblock("C"): @@ -202,7 +202,7 @@ def main(C: T.Buffer((32,), "float32")): ) def check_type_casting(ctx, n, dtype): - fun = tvm.tir.build(Module, target=target) + fun = tvm.tirx.build(Module, target=target) c = tvm.runtime.empty((n,), dtype, ctx) assembly = fun.imports[0].inspect_source() lcond = "convert_int4(((convert_uint4(((uint4)(((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3), ((convert_int(get_local_id(0))) == 3)))))" @@ -231,7 +231,7 @@ def _check(target, n, dtype): class Module: @T.prim_func def main(C: T.Buffer((n,), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in T.thread_binding(n, thread="threadIdx.x"): with T.sblock("C"): v_i = T.axis.spatial(n, i) @@ -239,7 +239,7 @@ def main(C: T.Buffer((n,), "int32")): T.writes(C[v_i]) C[v_i] = T.Cast("int32", T.ceil(T.log2(T.Cast(inter_dtype, v_i)))) - fun = tvm.tir.build(Module, target=target) + fun = tvm.tirx.build(Module, target=target) assembly = fun.imports[0].inspect_source() if is_adreno: pattern = "convert_float" @@ -254,7 +254,7 @@ def _get_maximum_kernel_args(source): def get_kernel_args(source): import re - p = re.tir.build(r"__kernel void .+\((.*)\)") + p = re.tirx.build(r"__kernel void .+\((.*)\)") args = p.findall(source) return args diff --git a/tests/python/codegen/test_target_codegen_riscv.py b/tests/python/codegen/test_target_codegen_riscv.py index e4fe54303f97..08edc487251a 100644 --- a/tests/python/codegen/test_target_codegen_riscv.py +++ b/tests/python/codegen/test_target_codegen_riscv.py @@ -19,7 +19,7 @@ import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target.codegen import target_has_features @@ -61,7 +61,7 @@ def load_vec(A: T.Buffer((N,), "int8")): for j in T.vectorized(0, extent): A[j] = 1 - f = tvm.tir.build(load_vec, target) + f = tvm.tirx.build(load_vec, target) # Check RVV `vsetvli` prensence assembly = f.inspect_source("asm") if target_has_features("v"): @@ -108,7 +108,7 @@ def rvv_with_vscale(A_handle: T.handle, B_handle: T.handle, C_handle: T.handle): # tvm.error.InternalError: Can't fetch the lanes of a scalable vector at a compile time. with tvm.target.Target(target): - f = tvm.tir.build(rvv_with_vscale, target) + f = tvm.tirx.build(rvv_with_vscale, target) if __name__ == "__main__": diff --git a/tests/python/codegen/test_target_codegen_rocm.py b/tests/python/codegen/test_target_codegen_rocm.py index 5e9e1eb6ad33..2b62f0c08b7a 100644 --- a/tests/python/codegen/test_target_codegen_rocm.py +++ b/tests/python/codegen/test_target_codegen_rocm.py @@ -20,7 +20,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.testing.requires_rocm @@ -30,7 +30,7 @@ def check_inf_nan(dev, n, value, dtype): class Module: @T.prim_func def main(A: T.Buffer((1,), dtype), C: T.Buffer((1,), dtype)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(1, thread="blockIdx.x"): for i_1 in T.thread_binding(128, thread="threadIdx.x"): with T.sblock("C"): @@ -83,7 +83,7 @@ def check_rocm(dtype, n, lanes): class Module: @T.prim_func def main(A: T.Buffer((n,), vec_dtype), B: T.Buffer((n,), vec_dtype)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i_0 in T.thread_binding(num_blocks, thread="blockIdx.x"): for i_1 in T.thread_binding(4, thread="threadIdx.x"): with T.sblock("B"): diff --git a/tests/python/codegen/test_target_codegen_static_init.py b/tests/python/codegen/test_target_codegen_static_init.py index 5dbe5d1315a8..008c601cf240 100644 --- a/tests/python/codegen/test_target_codegen_static_init.py +++ b/tests/python/codegen/test_target_codegen_static_init.py @@ -20,7 +20,7 @@ import tvm from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_static_init(): @@ -38,7 +38,7 @@ def ramp(A: T.handle): Ab = T.match_buffer(A, (n,), "int64") T.call_packed( "test_static_callback", - T.call_intrin("handle", "tir.tvm_static_handle"), + T.call_intrin("handle", "tirx.tvm_static_handle"), Ab.data, ) diff --git a/tests/python/codegen/test_target_codegen_vulkan.py b/tests/python/codegen/test_target_codegen_vulkan.py index 38830ae96f30..c975073922d7 100644 --- a/tests/python/codegen/test_target_codegen_vulkan.py +++ b/tests/python/codegen/test_target_codegen_vulkan.py @@ -24,10 +24,10 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import ir as I_builder -from tvm.script.ir_builder import tir as T_builder +from tvm.script.ir_builder import tirx as T_builder dtype = tvm.testing.parameter("float32", "int32", "float16", "int8") fuzz_seed = tvm.testing.parameter(range(25)) @@ -47,8 +47,8 @@ ) def test_vector_comparison(target, dev, dtype): target = tvm.target.Target(target) - zero = tvm.tir.const(0, dtype) - one = tvm.tir.const(1, dtype) + zero = tvm.tirx.const(0, dtype) + one = tvm.tirx.const(1, dtype) @I.ir_module class Module: @@ -62,7 +62,7 @@ def main(A: T.Buffer((1024,), dtype), B: T.Buffer((1024,), dtype)): B[v_i] = T.Select(A[v_i] >= zero, A[v_i] + one, zero) # Build - f = tvm.tir.build(Module, target=target) + f = tvm.tirx.build(Module, target=target) # Verify we generate the boolx4 type declaration and the OpSelect # v4{float,half,int} instruction @@ -95,7 +95,7 @@ def test_array_vectorize_add(target, dev, dtype): pytest.xfail("Opencl target does not support float16") vec_dtype = f"{dtype}x{lanes}" - one = tvm.tir.const(1, vec_dtype) + one = tvm.tirx.const(1, vec_dtype) @I.ir_module class Module: @@ -178,11 +178,11 @@ def test_vulkan_constant_passing(target, dev, vulkan_parameter_impl, vulkan_para T_builder.func_name("main") scalar_vars = [] for i in range(num_int_params): - v = T_builder.arg(f"scale{i}", tvm.tir.Var("", dtype)) + v = T_builder.arg(f"scale{i}", tvm.tirx.Var("", dtype)) scalar_vars.append(v) var_A = T_builder.arg("var_A", T_builder.handle()) var_B = T_builder.arg("var_B", T_builder.handle()) - T_builder.func_attr({"tir.noalias": True}) + T_builder.func_attr({"tirx.noalias": True}) n_var = T_builder.int32(is_size_var=True) A = T_builder.match_buffer(var_A, (n_var,), dtype) B = T_builder.match_buffer(var_B, (n_var,), dtype) @@ -190,7 +190,7 @@ def test_vulkan_constant_passing(target, dev, vulkan_parameter_impl, vulkan_para for s in scalar_vars[1:]: scalar_sum = scalar_sum + s with T_builder.thread_binding( - tvm.tir.ceildiv(n_var, 64), thread="blockIdx.x" + tvm.tirx.ceildiv(n_var, 64), thread="blockIdx.x" ) as i_0: with T_builder.thread_binding(64, thread="threadIdx.x") as i_1: with T_builder.sblock("B"): @@ -288,13 +288,13 @@ def local_threadidx_func(A: T.Buffer((32,), "int32"), B: T.Buffer((32,), "int32" def test_vectorized_index_ramp(target, dev): """Test vectorized copy with ramp indices (load N values, write to N locations)""" n = 4 - ramp_index = tvm.tir.Ramp(0, 1, 4) + ramp_index = tvm.tirx.Ramp(0, 1, 4) @I.ir_module class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) A = T.match_buffer(var_A, (n,), "int32", offset_factor=1) B = T.match_buffer(var_B, (n,), "int32", offset_factor=1) with T.sblock("compute"): @@ -318,14 +318,14 @@ def main(var_A: T.handle, var_B: T.handle): def test_vectorized_index_broadcast(target, dev): """Test broadcast index (load 1 value, write to N locations)""" n = 4 - broadcast_index = tvm.tir.Broadcast(0, 4) - ramp_index = tvm.tir.Ramp(0, 1, 4) + broadcast_index = tvm.tirx.Broadcast(0, 4) + ramp_index = tvm.tirx.Ramp(0, 1, 4) @I.ir_module class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) A = T.match_buffer(var_A, (n,), "int32", offset_factor=1) B = T.match_buffer(var_B, (n,), "int32", offset_factor=1) with T.sblock("compute"): @@ -404,7 +404,7 @@ def test_cooperative_matrix(out_dtype): class Module: @T.prim_func def main(X: T.Buffer((16, 32), "float16"), W: T.Buffer((32, 16), "float16"), compute: T.Buffer((16, 16), out_dtype)): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) X_shared = T.sblock_alloc_buffer((16, 32), "float16", scope="shared") W_shared = T.sblock_alloc_buffer((32, 16), "float16", scope="shared") X_shared_wmma_matrix_a = T.sblock_alloc_buffer((16, 32), "float16", scope="wmma.matrix_a") @@ -506,7 +506,7 @@ def test_codegen_decl_buffer(): class Module: @T.prim_func def kernel(): - T.func_attr({"calling_conv": 2, "global_symbol": "kernel", "tir.noalias": True}) + T.func_attr({"calling_conv": 2, "global_symbol": "kernel", "tirx.noalias": True}) A = T.alloc_buffer((256,), dtype="float32", scope="local") A_buf = T.decl_buffer([256], dtype="float32", scope="local", data=A.data) @@ -519,18 +519,18 @@ def kernel(): @tvm.testing.requires_vulkan def test_unary(): test_funcs = [ - (tvm.tir.sin, lambda x: np.sin(x)), - (tvm.tir.cos, lambda x: np.cos(x)), - (tvm.tir.tan, lambda x: np.tan(x)), - (tvm.tir.sinh, lambda x: np.sinh(x)), - (tvm.tir.cosh, lambda x: np.cosh(x)), - (tvm.tir.tanh, lambda x: np.tanh(x)), - (tvm.tir.asin, lambda x: np.arcsin(x)), - (tvm.tir.acos, lambda x: np.arccos(x)), - (tvm.tir.atan, lambda x: np.arctan(x)), - (tvm.tir.asinh, lambda x: np.arcsinh(x)), - (tvm.tir.acosh, lambda x: np.arccosh(x)), - (tvm.tir.atanh, lambda x: np.arctanh(x)), + (tvm.tirx.sin, lambda x: np.sin(x)), + (tvm.tirx.cos, lambda x: np.cos(x)), + (tvm.tirx.tan, lambda x: np.tan(x)), + (tvm.tirx.sinh, lambda x: np.sinh(x)), + (tvm.tirx.cosh, lambda x: np.cosh(x)), + (tvm.tirx.tanh, lambda x: np.tanh(x)), + (tvm.tirx.asin, lambda x: np.arcsin(x)), + (tvm.tirx.acos, lambda x: np.arccos(x)), + (tvm.tirx.atan, lambda x: np.arctan(x)), + (tvm.tirx.asinh, lambda x: np.arcsinh(x)), + (tvm.tirx.acosh, lambda x: np.arccosh(x)), + (tvm.tirx.atanh, lambda x: np.arctanh(x)), ] def run_test(tvm_intrin, np_func): @@ -556,11 +556,11 @@ def main(var_A: T.handle, var_B: T.handle): dev = tvm.device(target.kind.name, 0) func = tvm.compile(Module, target=target) - if tvm_intrin in [tvm.tir.asin, tvm.tir.acos]: + if tvm_intrin in [tvm.tirx.asin, tvm.tirx.acos]: data = np.random.uniform(-1.0, 1.0, size=n) - elif tvm_intrin == tvm.tir.atanh: + elif tvm_intrin == tvm.tirx.atanh: data = np.random.uniform(-0.999, 0.999, size=n) - elif tvm_intrin == tvm.tir.acosh: + elif tvm_intrin == tvm.tirx.acosh: data = np.random.uniform(1.0, 5.0, size=n) else: data = np.random.uniform(0.1, 0.9, size=n) diff --git a/tests/python/codegen/test_target_codegen_x86.py b/tests/python/codegen/test_target_codegen_x86.py index 483fafe8ff03..9421ac14e03b 100644 --- a/tests/python/codegen/test_target_codegen_x86.py +++ b/tests/python/codegen/test_target_codegen_x86.py @@ -22,7 +22,7 @@ import tvm from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T llvm_version = tvm.target.codegen.llvm_version_major() machine = platform.machine() @@ -44,7 +44,7 @@ def main( A: T.Buffer((elements, width), "float16"), B: T.Buffer((elements, width), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0 in range(elements): for i1 in T.vectorized(width): with T.sblock("B"): @@ -53,7 +53,7 @@ def main( T.writes(B[v_i0, v_i1]) B[v_i0, v_i1] = T.Cast("float32", A[v_i0, v_i1]) - f = tvm.tir.build(Module, target=target) + f = tvm.tirx.build(Module, target=target) assembly = f.inspect_source("asm").splitlines() if match: diff --git a/tests/python/contrib/test_android/test_meta_schedule.py b/tests/python/contrib/test_android/test_meta_schedule.py index f3a2d9eb11d5..56097580de47 100644 --- a/tests/python/contrib/test_android/test_meta_schedule.py +++ b/tests/python/contrib/test_android/test_meta_schedule.py @@ -27,7 +27,7 @@ import tvm.topi.testing from tvm.s_tir import meta_schedule as ms from tvm.s_tir.meta_schedule.builder import LocalBuilder -from tvm.script import tir as T +from tvm.script import tirx as T from .infrastructure import get_android_gpu_target, get_rpc_runner diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index 006805f96baf..57332eb9f152 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -27,7 +27,7 @@ def ceildiv(o, d): assert o >= 0 assert d >= 0 - return tvm.tir.floordiv(o + d - 1, d) + return tvm.tirx.floordiv(o + d - 1, d) # defines inner block shape: 8h8w32c diff --git a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py index 13b3b07354fc..c5e4db1c3a8d 100644 --- a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py +++ b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py @@ -32,7 +32,7 @@ from tvm import te from tvm.contrib.hexagon import allocate_hexagon_array from tvm.contrib.hexagon.pytest_plugin import requires_hexagon_toolchain -from tvm.tir.stmt_functor import post_order_visit +from tvm.tirx.stmt_functor import post_order_visit from .infrastructure import get_hexagon_target @@ -171,7 +171,7 @@ def extract_buffers(stmt): buffers = [] def visitor(node): - if isinstance(node, tvm.tir.BufferLoad | tvm.tir.BufferStore): + if isinstance(node, tvm.tirx.BufferLoad | tvm.tirx.BufferStore): buffers.append(node.buffer) post_order_visit(stmt, visitor) diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index 0d9083718210..7aa923787c19 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -22,7 +22,7 @@ import pytest import tvm -from tvm.script import tir as T +from tvm.script import tirx as T VRMPY_SIZE_B = 128 VRMPY_SIZE_INT32 = 32 @@ -43,7 +43,7 @@ def conv2d_async_non_contig( """Non contiguous memory access is used in this conv2d taken from MS.""" # pylint: disable=no-self-argument # function attr dict - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) # body # with T.sblock("root") p0_global_vtcm = T.sblock_alloc_buffer( @@ -223,7 +223,7 @@ def conv_approximation(size_a, size_w): @T.prim_func def operator(a_input: T.handle, b_input: T.handle, c_output: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) a_buffer = T.match_buffer(a_input, a_shape, dtype="uint8") w_buffer = T.match_buffer(b_input, w_shape, dtype="uint8") c_buffer = T.match_buffer(c_output, out_shape, dtype="int32") @@ -274,8 +274,8 @@ def evaluate( target_hexagon = tvm.target.Target("qcom/hexagon-v68") with tvm.transform.PassContext( config={ - "tir.use_async_copy": use_async_copy, - "tir.experimental_dma_bypass_cache": 1, + "tirx.use_async_copy": use_async_copy, + "tirx.experimental_dma_bypass_cache": 1, } ): func_tir = tvm.compile( @@ -528,7 +528,7 @@ def test_loading_vtcm_for_vrmpy( ) -# from tvm.script import tir as T +# from tvm.script import tirx as T @tvm.script.ir_module class ModulePipelined: """Pipelined module class.""" @@ -542,7 +542,7 @@ def main( ) -> None: # pylint: disable=missing-function-docstring # function attr dict - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) # body # with T.sblock("root") conv2d_nchwc_int8 = T.sblock_alloc_buffer( @@ -685,7 +685,7 @@ def main( ] -# from tvm.script import tir as T +# from tvm.script import tirx as T @tvm.script.ir_module class ModuleBase: """Base module test class.""" @@ -699,7 +699,7 @@ def main( ) -> None: # pylint: disable=missing-function-docstring # function attr dict - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) # buffer definition # body # with T.sblock("root") diff --git a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py index 75ea71e3ebda..52e1f8a2386f 100644 --- a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py +++ b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py @@ -27,7 +27,7 @@ import tvm.script import tvm.testing from tvm.contrib.hexagon.session import Session -from tvm.script import tir as T +from tvm.script import tirx as T from . import benchmark_util as bu from .infrastructure import get_hexagon_target @@ -144,7 +144,7 @@ class BenchmarkModule: @T.prim_func def main(a: T.handle, b: T.handle, c: T.handle): # We exchange data between function by handles, which are similar to pointer. - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, shape, dtype=dtype) B = T.match_buffer(b, shape, dtype=dtype) diff --git a/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py b/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py index 374f47300827..475b624bd5ed 100644 --- a/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py +++ b/tests/python/contrib/test_hexagon/test_benchmark_maxpool2d.py @@ -50,7 +50,7 @@ import pytest import tvm.testing -from tvm import te, tir, topi +from tvm import te, tirx, topi from tvm.contrib.hexagon import allocate_hexagon_array from tvm.contrib.hexagon.session import Session from tvm.topi import testing diff --git a/tests/python/contrib/test_hexagon/test_dma_builtin.py b/tests/python/contrib/test_hexagon/test_dma_builtin.py index 4acdcf1c528b..5f3b2d65020b 100644 --- a/tests/python/contrib/test_hexagon/test_dma_builtin.py +++ b/tests/python/contrib/test_hexagon/test_dma_builtin.py @@ -28,7 +28,7 @@ from tvm import relax from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T # pylint: disable=invalid-name, missing-class-docstring, missing-function-docstring, no-self-argument diff --git a/tests/python/contrib/test_hexagon/test_memory_alloc.py b/tests/python/contrib/test_hexagon/test_memory_alloc.py index 5ff80aea1977..da380199ad12 100644 --- a/tests/python/contrib/test_hexagon/test_memory_alloc.py +++ b/tests/python/contrib/test_hexagon/test_memory_alloc.py @@ -20,7 +20,7 @@ import tvm from tvm.contrib.hexagon import allocate_hexagon_array -from tvm.script import tir as T +from tvm.script import tirx as T from .infrastructure import get_hexagon_target diff --git a/tests/python/contrib/test_hexagon/test_meta_schedule.py b/tests/python/contrib/test_hexagon/test_meta_schedule.py index c498dacfbdcf..0b4a8335360a 100644 --- a/tests/python/contrib/test_hexagon/test_meta_schedule.py +++ b/tests/python/contrib/test_hexagon/test_meta_schedule.py @@ -35,8 +35,8 @@ from tvm.s_tir.meta_schedule.builder import BuilderInput from tvm.s_tir.meta_schedule.runner import RunnerInput from tvm.s_tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN -from tvm.script import tir as T -from tvm.tir import FloatImm +from tvm.script import tirx as T +from tvm.tirx import FloatImm from .infrastructure import get_hexagon_target @@ -52,7 +52,7 @@ class MatmulModule: @T.prim_func def main(a: T.handle, b: T.handle, c: T.handle) -> None: # type: ignore # pylint: disable=missing-function-docstring - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) a_buffer = T.match_buffer(a, (16, 16), "float32") b_buffer = T.match_buffer(b, (16, 16), "float32") c_buffer = T.match_buffer(c, (16, 16), "float32") @@ -118,7 +118,7 @@ def dense_compute(m, n, k): lambda i, j: te.sum( X[i, axis_k].astype("int32") * packed_width[ - tvm.tir.indexdiv(j, 32), tvm.tir.indexdiv(axis_k, 4), j % 32, axis_k % 4 + tvm.tirx.indexdiv(j, 32), tvm.tirx.indexdiv(axis_k, 4), j % 32, axis_k % 4 ].astype("int32"), axis=axis_k, ), @@ -248,7 +248,7 @@ def main( # type: ignore compute: T.Buffer((128, 768), "int32"), # type: ignore ) -> None: # pylint: disable=missing-function-docstring - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i0_0_i1_0_0_fused in T.parallel( 512, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1} ): diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx.py b/tests/python/contrib/test_hexagon/test_parallel_hvx.py index 1ca01a1f3a74..bd9abf6b50f9 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx.py @@ -22,7 +22,7 @@ import numpy as np import tvm -from tvm.script import tir as T +from tvm.script import tirx as T from .infrastructure import get_hexagon_target @@ -77,7 +77,7 @@ def get_vmpy_operator(operations): @T.prim_func def operator(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8") b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8") c_buffer = T.match_buffer(c, [operations, 128], dtype="int16") @@ -99,7 +99,7 @@ def get_vadd_operator(operations): @T.prim_func def operator(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8") b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8") c_buffer = T.match_buffer(c, [operations, 128], dtype="int16") @@ -121,7 +121,7 @@ def get_vrmpy_operator(operations): @T.prim_func def operator(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8") b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8") c_buffer = T.match_buffer(c, [operations, 32], dtype="int32") diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py index e13ad8d460d4..580e027c4644 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py @@ -20,7 +20,7 @@ import numpy as np import tvm -from tvm.script import tir as T +from tvm.script import tirx as T from .infrastructure import get_hexagon_target @@ -79,7 +79,7 @@ def vrmpy(operations): @T.prim_func def operator(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8", align=128) b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", align=128) c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128) @@ -101,7 +101,7 @@ def preloaded_vrmpy(operations): @T.prim_func def operator(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) a_buffer = T.match_buffer( a, [T.cast(operations, "int32") * 128], @@ -145,7 +145,7 @@ def preallocated_vrmpy(operations): def operator( a: T.handle, b: T.handle, c: T.handle, a_v: T.handle, b_v: T.handle, c_v: T.handle ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8", align=128, scope="global") b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", align=128, scope="global") c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, scope="global") @@ -199,7 +199,7 @@ def operator( b_v: T.handle, c_v: T.handle, ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) a_buffer = T.match_buffer(a, [operations, 128], dtype="uint8", align=128, scope="global") b_buffer = T.match_buffer(b, [operations, 128], dtype="uint8", align=128, scope="global") c_buffer = T.match_buffer(c, [operations, 32], dtype="int32", align=128, scope="global") diff --git a/tests/python/contrib/test_hexagon/test_parallel_scalar.py b/tests/python/contrib/test_hexagon/test_parallel_scalar.py index 8146f81e21ec..31ab24d9454e 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_scalar.py +++ b/tests/python/contrib/test_hexagon/test_parallel_scalar.py @@ -20,7 +20,7 @@ import numpy as np import tvm -from tvm.script import tir as T +from tvm.script import tirx as T from .infrastructure import get_hexagon_target @@ -36,7 +36,7 @@ def get_add_operator(operations): @T.prim_func def operator(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) a_buffer = T.match_buffer(a, [operations], dtype="float64") b_buffer = T.match_buffer(b, [operations], dtype="float64") c_buffer = T.match_buffer(c, [operations], dtype="float64") @@ -53,7 +53,7 @@ def get_multiply_operator(operations): @T.prim_func def operator(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) a_buffer = T.match_buffer(a, [operations], dtype="float64") b_buffer = T.match_buffer(b, [operations], dtype="float64") c_buffer = T.match_buffer(c, [operations], dtype="float64") @@ -70,7 +70,7 @@ def get_sub_operator(operations): @T.prim_func def operator(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) a_buffer = T.match_buffer(a, [operations], dtype="float64") b_buffer = T.match_buffer(b, [operations], dtype="float64") c_buffer = T.match_buffer(c, [operations], dtype="float64") diff --git a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py index a950302c64e0..98a109d966bd 100644 --- a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py +++ b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py @@ -25,7 +25,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=missing-docstring,no-self-argument,invalid-name diff --git a/tests/python/contrib/test_hexagon/test_sigmoid.py b/tests/python/contrib/test_hexagon/test_sigmoid.py index ce5fba81dc4b..805f0b3477da 100644 --- a/tests/python/contrib/test_hexagon/test_sigmoid.py +++ b/tests/python/contrib/test_hexagon/test_sigmoid.py @@ -21,7 +21,7 @@ import tvm import tvm.testing -from tvm import te, tir, topi +from tvm import te, tirx, topi from tvm.contrib.hexagon import allocate_hexagon_array from .infrastructure import get_hexagon_target diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py index 270fc65b37ca..176793efd94d 100644 --- a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py +++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py @@ -20,8 +20,8 @@ import numpy as np import tvm -from tvm import tir -from tvm.script import tir as T +from tvm import tirx +from tvm.script import tirx as T from .infrastructure import get_hexagon_target @@ -180,8 +180,8 @@ def test_async_software_pipeline( with tvm.transform.PassContext( config={ - "tir.use_async_copy": 1, - "tir.experimental_dma_bypass_cache": 1, + "tirx.use_async_copy": 1, + "tirx.experimental_dma_bypass_cache": 1, } ): func = tvm.compile(schedule.mod["main"], target=get_hexagon_target("v68")) diff --git a/tests/python/contrib/test_hexagon/test_take.py b/tests/python/contrib/test_hexagon/test_take.py index 6fc5150f715d..63d8036baaea 100644 --- a/tests/python/contrib/test_hexagon/test_take.py +++ b/tests/python/contrib/test_hexagon/test_take.py @@ -23,7 +23,7 @@ from tvm import relax from tvm.contrib.hexagon import generate_take_op, hexagon_unary_ops from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T from .infrastructure import quantize_np @@ -58,7 +58,7 @@ def tanh( rxplaceholder_4: T.Buffer((), "int32"), compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), ): - T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.tanh"}}) + T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.tanh"}}) @tvm.script.ir_module @@ -89,7 +89,7 @@ def sqrt( rxplaceholder_4: T.Buffer((), "int32"), compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), ): - T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.sqrt"}}) + T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.sqrt"}}) @tvm.script.ir_module @@ -120,7 +120,7 @@ def rsqrt( rxplaceholder_4: T.Buffer((), "int32"), compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), ): - T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.rsqrt"}}) + T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.rsqrt"}}) @tvm.script.ir_module @@ -151,7 +151,7 @@ def exp( rxplaceholder_4: T.Buffer((), "int32"), compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), ): - T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.exp"}}) + T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.exp"}}) @tvm.script.ir_module @@ -182,7 +182,7 @@ def erf( rxplaceholder_4: T.Buffer((), "int32"), compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), ): - T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.erf"}}) + T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.erf"}}) @tvm.script.ir_module @@ -213,7 +213,7 @@ def sigmoid( rxplaceholder_4: T.Buffer((), "int32"), compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), ): - T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.sigmoid"}}) + T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.sigmoid"}}) @tvm.script.ir_module @@ -244,7 +244,7 @@ def hardswish( rxplaceholder_4: T.Buffer((), "int32"), compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), ): - T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.hardswish"}}) + T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.hardswish"}}) @tvm.script.ir_module @@ -275,7 +275,7 @@ def log( rxplaceholder_4: T.Buffer((), "int32"), compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), ): - T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.log"}}) + T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.log"}}) @tvm.script.ir_module @@ -306,7 +306,7 @@ def abs( rxplaceholder_4: T.Buffer((), "int32"), compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "uint8"), ): - T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.abs"}}) + T.func_attr({"tirx.noalias": True, "op_attrs": {"op_name": "qnn.abs"}}) # data = np.random.random([1, 2, 2, 2]).astype("float32") : Need to hadcode the data diff --git a/tests/python/contrib/test_hexagon/test_thread_pool.py b/tests/python/contrib/test_hexagon/test_thread_pool.py index b0d5951f1041..245b856c3c28 100644 --- a/tests/python/contrib/test_hexagon/test_thread_pool.py +++ b/tests/python/contrib/test_hexagon/test_thread_pool.py @@ -24,7 +24,7 @@ import tvm.script import tvm.testing from tvm.contrib.hexagon.session import Session -from tvm.script import tir as T +from tvm.script import tirx as T from .infrastructure import get_hexagon_target @@ -36,7 +36,7 @@ class ElemwiseSumIRModule: # pylint: disable=no-self-argument,invalid-name,missing-function-docstring @T.prim_func def elemwise_sum_serial(a: T.handle, b: T.handle, c: T.handle, n: T.int32): - T.func_attr({"global_symbol": "elemwise_sum_serial", "tir.noalias": True}) + T.func_attr({"global_symbol": "elemwise_sum_serial", "tirx.noalias": True}) A = T.match_buffer(a, (n,), dtype="float32") B = T.match_buffer(b, (n,), dtype="float32") C = T.match_buffer(c, (n,), dtype="float32") @@ -47,7 +47,7 @@ def elemwise_sum_serial(a: T.handle, b: T.handle, c: T.handle, n: T.int32): @T.prim_func def elemwise_sum_parallel(a: T.handle, b: T.handle, c: T.handle, n: T.int32): - T.func_attr({"global_symbol": "elemwise_sum_parallel", "tir.noalias": True}) + T.func_attr({"global_symbol": "elemwise_sum_parallel", "tirx.noalias": True}) A = T.match_buffer(a, (n,), dtype="float32") B = T.match_buffer(b, (n,), dtype="float32") C = T.match_buffer(c, (n,), dtype="float32") diff --git a/tests/python/contrib/test_hexagon/test_vtcm.py b/tests/python/contrib/test_hexagon/test_vtcm.py index 56d9555fff9d..8844fa029d56 100644 --- a/tests/python/contrib/test_hexagon/test_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_vtcm.py @@ -20,8 +20,8 @@ import pytest import tvm.testing -from tvm import tir -from tvm.script import tir as T +from tvm import tirx +from tvm.script import tirx as T from .infrastructure import get_hexagon_target @@ -80,7 +80,7 @@ def _raises_exception(f): "Case 2 - with.VTCM memory allocation limiter does not work correctly " ) - with tvm.transform.PassContext(config={"tir.vtcm_capacity": vtcm_capacity}): + with tvm.transform.PassContext(config={"tirx.vtcm_capacity": vtcm_capacity}): assert ( _raises_exception( lambda: tvm.compile(sch.mod, target=get_hexagon_target("v68", vtcm_capacity=0)) diff --git a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py index b4736bf62943..301b38507c6e 100644 --- a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py +++ b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py @@ -22,7 +22,7 @@ import tvm from tvm.s_tir.tensor_intrin.hexagon import DMA_READ_128_i8 -from tvm.script import tir as T +from tvm.script import tirx as T from .infrastructure import get_hexagon_target diff --git a/tests/python/contrib/test_sort.py b/tests/python/contrib/test_sort.py index 019af981f652..33d37ccdd372 100644 --- a/tests/python/contrib/test_sort.py +++ b/tests/python/contrib/test_sort.py @@ -37,7 +37,7 @@ def test_sort(): out = te.extern( data.shape, [data, sort_num], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.sort.argsort_nms", ins[0], ins[1], outs[0], axis, is_ascend ), dtype="int32", @@ -74,7 +74,7 @@ def test_sort_np(): out = te.extern( data.shape, [data, sort_num], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.sort.argsort_nms", ins[0], ins[1], outs[0], axis, is_ascend ), dtype="int32", diff --git a/tests/python/contrib/test_tir_triton_integration.py b/tests/python/contrib/test_tir_triton_integration.py index 4a59d52bed0c..556b3a558729 100644 --- a/tests/python/contrib/test_tir_triton_integration.py +++ b/tests/python/contrib/test_tir_triton_integration.py @@ -26,7 +26,7 @@ from tvm.relax.frontend import nn from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T try: import triton diff --git a/tests/python/disco/test_callback.py b/tests/python/disco/test_callback.py index ed2b8226e32e..9fdf9aaae62e 100644 --- a/tests/python/disco/test_callback.py +++ b/tests/python/disco/test_callback.py @@ -25,7 +25,7 @@ import tvm import tvm.testing from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.testing.requires_nccl diff --git a/tests/python/disco/test_nvshmem.py b/tests/python/disco/test_nvshmem.py index df285c9c65c3..a5a1b819760a 100644 --- a/tests/python/disco/test_nvshmem.py +++ b/tests/python/disco/test_nvshmem.py @@ -38,7 +38,7 @@ from tvm.runtime import disco as di from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T _SOCKET_SESSION_TESTER = None diff --git a/tests/python/disco/test_session.py b/tests/python/disco/test_session.py index b490cd0b040d..afb25921dafd 100644 --- a/tests/python/disco/test_session.py +++ b/tests/python/disco/test_session.py @@ -34,7 +34,7 @@ from tvm.runtime import disco as di from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device): diff --git a/tests/python/driver/test_compile.py b/tests/python/driver/test_compile.py index 25c71b16dd6f..014cb7173410 100644 --- a/tests/python/driver/test_compile.py +++ b/tests/python/driver/test_compile.py @@ -23,7 +23,7 @@ from tvm.runtime import Executable from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_compile_tir(): @@ -109,7 +109,7 @@ def main(x: R.Tensor((4,), "float32")): dev = tvm.cpu(0) x = tvm.runtime.tensor(np.array([1, 2, 3, 4], dtype=np.float32), dev) y = tvm.runtime.tensor(np.zeros(4, dtype=np.float32), dev) - # For tir function, we can directly call the function + # For tirx function, we can directly call the function ex["add_one"](x, y) tvm.testing.assert_allclose(y.numpy(), x.numpy() + 1) # For relax function, we need to use the vm to call the function diff --git a/tests/python/ir/analysis/test_collect_call_map.py b/tests/python/ir/analysis/test_collect_call_map.py index 8990aea7f78d..f1c2f3f52040 100644 --- a/tests/python/ir/analysis/test_collect_call_map.py +++ b/tests/python/ir/analysis/test_collect_call_map.py @@ -22,7 +22,7 @@ from tvm.ir.analysis import collect_call_map from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def _build_str_map(call_map: dict[GlobalVar, list[GlobalVar]]) -> dict[str, list[str]]: diff --git a/tests/python/ir/test_datatype_nv_fp4.py b/tests/python/ir/test_datatype_nv_fp4.py index da3098c4ae15..653ce2205056 100644 --- a/tests/python/ir/test_datatype_nv_fp4.py +++ b/tests/python/ir/test_datatype_nv_fp4.py @@ -19,9 +19,9 @@ import tvm import tvm.testing -import tvm.tir as tir +import tvm.tirx as tirx from tvm import te -from tvm.script import tir as T +from tvm.script import tirx as T try: from ml_dtypes import float4_e2m1fn @@ -45,7 +45,7 @@ def test_create_nv_fp4_nd_array(np_dtype, dtype_str): def test_nv_fp4_buffer(np_dtype, dtype_str): m = te.size_var("m") n = te.size_var("n") - A = tvm.tir.decl_buffer((m, n), dtype_str) + A = tvm.tirx.decl_buffer((m, n), dtype_str) assert A.dtype == dtype_str diff --git a/tests/python/ir/test_datatype_nv_fp8.py b/tests/python/ir/test_datatype_nv_fp8.py index 7e6b30d02025..949abe27b913 100644 --- a/tests/python/ir/test_datatype_nv_fp8.py +++ b/tests/python/ir/test_datatype_nv_fp8.py @@ -19,9 +19,9 @@ import tvm import tvm.testing -import tvm.tir as tir +import tvm.tirx as tirx from tvm import te -from tvm.script import tir as T +from tvm.script import tirx as T try: from ml_dtypes import ( @@ -122,7 +122,7 @@ def test_fp8_unary_op(np_dtype, dtype_str): def test_nv_fp8_buffer(np_dtype, dtype_str): m = te.size_var("m") n = te.size_var("n") - A = tvm.tir.decl_buffer((m, n), dtype_str) + A = tvm.tirx.decl_buffer((m, n), dtype_str) assert A.dtype == dtype_str diff --git a/tests/python/ir/test_ir_container.py b/tests/python/ir/test_ir_container.py index 6ece81e65a15..67a6e857f5f3 100644 --- a/tests/python/ir/test_ir_container.py +++ b/tests/python/ir/test_ir_container.py @@ -46,8 +46,8 @@ def test_dir_array(): def test_map(): - a = tvm.tir.Var("a", "int32") - b = tvm.tir.Var("b", "int32") + a = tvm.tirx.Var("a", "int32") + b = tvm.tirx.Var("b", "int32") amap = tvm.runtime.convert({a: 2, b: 3}) assert a in amap assert len(amap) == 2 @@ -71,8 +71,8 @@ def test_str_map(): def test_map_save_load_json(): - a = tvm.tir.Var("a", "int32") - b = tvm.tir.Var("b", "int32") + a = tvm.tirx.Var("a", "int32") + b = tvm.tirx.Var("b", "int32") amap = tvm.runtime.convert({a: 2, b: 3}) json_str = tvm.ir.save_json(amap) amap = tvm.ir.load_json(json_str) @@ -82,15 +82,15 @@ def test_map_save_load_json(): def test_dir_map(): - a = tvm.tir.Var("a", "int32") - b = tvm.tir.Var("b", "int32") + a = tvm.tirx.Var("a", "int32") + b = tvm.tirx.Var("b", "int32") amap = tvm.runtime.convert({a: 2, b: 3}) assert dir(amap) def test_getattr_map(): - a = tvm.tir.Var("a", "int32") - b = tvm.tir.Var("b", "int32") + a = tvm.tirx.Var("a", "int32") + b = tvm.tirx.Var("b", "int32") amap = tvm.runtime.convert({a: 2, b: 3}) assert isinstance(amap, tvm_ffi.Map) @@ -112,7 +112,7 @@ def test_tensor_container(): def test_return_variant_type(): func = tvm.get_global_func("testing.ReturnsVariant") res_even = func(42) - assert isinstance(res_even, tvm.tir.IntImm) + assert isinstance(res_even, tvm.tirx.IntImm) assert res_even == 21 res_odd = func(17) @@ -128,7 +128,7 @@ def test_pass_variant_type(): def test_pass_incorrect_variant_type(): func = tvm.get_global_func("testing.AcceptsVariant") - float_arg = tvm.tir.FloatImm("float32", 0.5) + float_arg = tvm.tirx.FloatImm("float32", 0.5) with pytest.raises(Exception): func(float_arg) diff --git a/tests/python/ir/test_ir_type.py b/tests/python/ir/test_ir_type.py index 9170f33f6d2e..8fee003b7618 100644 --- a/tests/python/ir/test_ir_type.py +++ b/tests/python/ir/test_ir_type.py @@ -18,7 +18,7 @@ """Test type nodes in the IR""" import tvm -from tvm.script import tir as T +from tvm.script import tirx as T def check_json_roundtrip(node): diff --git a/tests/python/ir/test_node_reflection.py b/tests/python/ir/test_node_reflection.py index 8223f293488e..4cc0769d58cb 100644 --- a/tests/python/ir/test_node_reflection.py +++ b/tests/python/ir/test_node_reflection.py @@ -27,8 +27,8 @@ def test_const_saveload_json(): # save load json - x = tvm.tir.const(1, "int32") - y = tvm.tir.const(10, "int32") + x = tvm.tirx.const(1, "int32") + y = tvm.tirx.const(10, "int32") z = x + y z = z + z json_str = tvm.ir.save_json(z) @@ -37,7 +37,7 @@ def test_const_saveload_json(): def _test_infinity_value(value, dtype): - x = tvm.tir.const(value, dtype) + x = tvm.tirx.const(value, dtype) json_str = tvm.ir.save_json(x) tvm.ir.assert_structural_equal(x, tvm.ir.load_json(json_str)) @@ -55,15 +55,15 @@ def _test_minmax_value(value): def test_minmax_value(): - _test_minmax_value(tvm.tir.min_value("float32")) - _test_minmax_value(tvm.tir.max_value("float32")) + _test_minmax_value(tvm.tirx.min_value("float32")) + _test_minmax_value(tvm.tirx.max_value("float32")) def test_make_smap(): # save load json - x = tvm.tir.const(1, "int32") - y = tvm.tir.const(10, "int32") - z = tvm.tir.Add(x, y) + x = tvm.tirx.const(1, "int32") + y = tvm.tirx.const(10, "int32") + z = tvm.tirx.Add(x, y) smap = tvm.runtime.convert({"z": z, "x": x}) json_str = tvm.ir.save_json(tvm.runtime.convert([smap])) arr = tvm.ir.load_json(json_str) @@ -74,7 +74,7 @@ def test_make_smap(): def test_make_node(): x = tvm.ir.make_node("ir.IntImm", dtype="int32", value=10, span=None) - assert isinstance(x, tvm.tir.IntImm) + assert isinstance(x, tvm.tirx.IntImm) assert x.value == 10 A = te.placeholder((10,), name="A") AA = tvm.ir.make_node( @@ -136,20 +136,20 @@ def test_pass_config(): cfg = tvm.transform.PassContext( opt_level=1, config={ - "tir.UnrollLoop": { + "tirx.UnrollLoop": { "auto_max_step": 10, } }, ) cfg.opt_level == 1 - assert cfg.config["tir.UnrollLoop"].auto_max_step == 10 + assert cfg.config["tirx.UnrollLoop"].auto_max_step == 10 # default option - assert cfg.config["tir.UnrollLoop"].explicit_unroll == True + assert cfg.config["tirx.UnrollLoop"].explicit_unroll == True # schema checking for specific config key with pytest.raises(TypeError): - cfg = tvm.transform.PassContext(config={"tir.UnrollLoop": {"invalid": 1}}) + cfg = tvm.transform.PassContext(config={"tirx.UnrollLoop": {"invalid": 1}}) # schema check for un-registered config with pytest.raises(AttributeError): @@ -157,11 +157,11 @@ def test_pass_config(): # schema check for wrong type with pytest.raises(AttributeError): - cfg = tvm.transform.PassContext(config={"tir.UnrollLoop": 1}) + cfg = tvm.transform.PassContext(config={"tirx.UnrollLoop": 1}) def test_dict(): - x = tvm.tir.const(1) # a class that has Python-defined methods + x = tvm.tirx.const(1) # a class that has Python-defined methods # instances should see the full class dict assert set(dir(x.__class__)) <= set(dir(x)) @@ -185,9 +185,9 @@ def test_tensor_dict(): def test_free_var_equal(): - x = tvm.tir.Var("x", dtype="int32") - y = tvm.tir.Var("y", dtype="int32") - z = tvm.tir.Var("z", dtype="int32") + x = tvm.tirx.Var("x", dtype="int32") + y = tvm.tirx.Var("y", dtype="int32") + z = tvm.tirx.Var("z", dtype="int32") v1 = x + y v1 = y + z tvm.ir.assert_structural_equal(x, z, map_free_vars=True) diff --git a/tests/python/ir/test_pass_instrument.py b/tests/python/ir/test_pass_instrument.py index 96ff239c4135..5814cd59b942 100644 --- a/tests/python/ir/test_pass_instrument.py +++ b/tests/python/ir/test_pass_instrument.py @@ -22,7 +22,7 @@ from tvm.ir.instrument import PrintAfterAll, PrintBeforeAll from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=invalid-name,missing-function-docstring,no-value-for-parameter @@ -42,7 +42,7 @@ def func(a: T.handle, b: T.handle) -> None: all_passes_output = capsys.readouterr().out assert "Before Running Pass:" in all_passes_output assert "After Running Pass:" in all_passes_output - assert "pass name: tir." in all_passes_output + assert "pass name: tirx." in all_passes_output def test_relax_print_all_passes(capsys): diff --git a/tests/python/ir/test_transform_replace_global_var.py b/tests/python/ir/test_transform_replace_global_var.py index 5c67a4ecfc1c..ad83099515db 100644 --- a/tests/python/ir/test_transform_replace_global_var.py +++ b/tests/python/ir/test_transform_replace_global_var.py @@ -18,7 +18,7 @@ import tvm.testing from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def _get_before_module(): diff --git a/tests/python/nightly/test_nnapi/test_ops.py b/tests/python/nightly/test_nnapi/test_ops.py index 21453c360ee4..a4a866e4677e 100644 --- a/tests/python/nightly/test_nnapi/test_ops.py +++ b/tests/python/nightly/test_nnapi/test_ops.py @@ -23,7 +23,7 @@ import tvm import tvm.script import tvm.script.relax as R -import tvm.script.tir as T +import tvm.script.tirx as T from test_nnapi.conftest import remote from test_nnapi.infrastructure import build_and_run diff --git a/tests/python/relax/backend/adreno/mod_utils.py b/tests/python/relax/backend/adreno/mod_utils.py index d84f16115f47..c6521d44168c 100644 --- a/tests/python/relax/backend/adreno/mod_utils.py +++ b/tests/python/relax/backend/adreno/mod_utils.py @@ -28,7 +28,7 @@ from tvm.relax.backend.adreno import clml from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import relax as relax_builder @@ -750,7 +750,7 @@ def main( @T.prim_func def dequantize(weight: T.handle, scale: T.handle, var_dequantize: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) + T.func_attr({"tirx.noalias": T.bool(True)}) lm_head_q_weight1 = T.match_buffer(weight, (T.int64(K // 8), T.int64(N)), "uint32") lm_head_q_scale1 = T.match_buffer(scale, (T.int64(K // 32), T.int64(N)), "float16") dequantize = T.match_buffer(var_dequantize, (T.int64(K), T.int64(N)), "float16") @@ -808,7 +808,7 @@ def main( @T.prim_func def dequantize(weight: T.handle, scale: T.handle, var_dequantize: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) + T.func_attr({"tirx.noalias": T.bool(True)}) vocab_size = T.int64() lm_head_q_weight1 = T.match_buffer(weight, (T.int64(K // 8), vocab_size), "uint32") lm_head_q_scale1 = T.match_buffer(scale, (T.int64(K // 32), vocab_size), "float16") diff --git a/tests/python/relax/backend/adreno/test_clml_ops.py b/tests/python/relax/backend/adreno/test_clml_ops.py index 934f5fb6b41f..9c0e1808e0c8 100644 --- a/tests/python/relax/backend/adreno/test_clml_ops.py +++ b/tests/python/relax/backend/adreno/test_clml_ops.py @@ -53,7 +53,7 @@ from tvm.relax.backend.adreno.clml import OpenCLMLOffLoad, OpenCLMLOffLoadForLLM from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import relax as relax_builder diff --git a/tests/python/relax/backend/adreno/test_texture_network.py b/tests/python/relax/backend/adreno/test_texture_network.py index 11188d7e2f3c..7690629b5710 100644 --- a/tests/python/relax/backend/adreno/test_texture_network.py +++ b/tests/python/relax/backend/adreno/test_texture_network.py @@ -31,7 +31,7 @@ from tvm.relax.transform.legalize_ops import adreno as legalize_adreno from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import relax as relax_builder diff --git a/tests/python/relax/backend/adreno/test_transform_annotate_custom_scope.py b/tests/python/relax/backend/adreno/test_transform_annotate_custom_scope.py index 2352c5bd67c4..8d364628fbf7 100644 --- a/tests/python/relax/backend/adreno/test_transform_annotate_custom_scope.py +++ b/tests/python/relax/backend/adreno/test_transform_annotate_custom_scope.py @@ -24,7 +24,7 @@ from tvm.relax.transform.legalize_ops import adreno as legalize_adreno from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T @visitor @@ -83,7 +83,7 @@ def verify(mod, expected): # "relax.nn.layer_norm", ] with tgt: - mod = tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False))(mod) + mod = tvm.tirx.transform.BindTarget(tvm.target.Target.current(allow_none=False))(mod) mod = tvm.relax.transform.DecomposeOpsForInference()(mod) mod = tvm.relax.transform.FoldConstant()(mod) desired_layouts = {"relax.nn.conv2d": ["NCHW4c", "OIHW4o", "NCHW4c"]} diff --git a/tests/python/relax/backend/adreno/test_transform_fold_vdevice_scope_change.py b/tests/python/relax/backend/adreno/test_transform_fold_vdevice_scope_change.py index e85299d2f347..7af632288654 100644 --- a/tests/python/relax/backend/adreno/test_transform_fold_vdevice_scope_change.py +++ b/tests/python/relax/backend/adreno/test_transform_fold_vdevice_scope_change.py @@ -22,7 +22,7 @@ from tvm.ir.module import IRModule from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def verify(input, expected): diff --git a/tests/python/relax/backend/adreno/utils.py b/tests/python/relax/backend/adreno/utils.py index c4e7d53ef61a..da52ffd1d21b 100644 --- a/tests/python/relax/backend/adreno/utils.py +++ b/tests/python/relax/backend/adreno/utils.py @@ -197,7 +197,7 @@ def build_and_run(mod, inputs, tgt): tgt = tvm.target.Target(tgt, host={"kind": "llvm"}) relax_pipeline = relax.pipeline.get_default_pipeline(tgt) - tir_pipeline = tvm.tir.get_default_tir_pipeline(tgt) + tir_pipeline = tvm.tirx.get_default_tir_pipeline(tgt) mod = relax_pipeline(mod) ex = tvm.compile(mod, tgt, tir_pipeline=tir_pipeline) diff --git a/tests/python/relax/distributed/test_distributed_dtensor_sinfo.py b/tests/python/relax/distributed/test_distributed_dtensor_sinfo.py index 795b365c01ec..c9b38aeaa687 100644 --- a/tests/python/relax/distributed/test_distributed_dtensor_sinfo.py +++ b/tests/python/relax/distributed/test_distributed_dtensor_sinfo.py @@ -20,7 +20,7 @@ import tvm import tvm.testing -from tvm import TVMError, tir +from tvm import TVMError, tirx from tvm import relax as rx from tvm.ir import Range, structural_equal @@ -42,7 +42,7 @@ def _check_json_roundtrip(x): def test_dtensor_struct_info(): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") tensor_s0 = rx.TensorStructInfo([1, n + 1, m], "float32") tensor_s1 = rx.TensorStructInfo([1, n + 1, m], "float32") diff --git a/tests/python/relax/distributed/test_distributed_transform_lower_distir.py b/tests/python/relax/distributed/test_distributed_transform_lower_distir.py index 170c2893257c..4fd3e25f5353 100644 --- a/tests/python/relax/distributed/test_distributed_transform_lower_distir.py +++ b/tests/python/relax/distributed/test_distributed_transform_lower_distir.py @@ -23,7 +23,7 @@ from tvm.ir import assert_structural_equal from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def test_mlp(): @@ -39,7 +39,7 @@ def gelu1( A: T.Buffer((T.int64(128), T.int64(64)), "float32"), T_multiply: T.Buffer((T.int64(128), T.int64(64)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): T_multiply_1 = T.sblock_alloc_buffer((T.int64(128), T.int64(64))) compute = T.sblock_alloc_buffer((T.int64(128), T.int64(64))) @@ -82,7 +82,7 @@ def matmul1( B: T.Buffer((T.int64(128), T.int64(64)), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(64)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(128), T.int64(64), T.int64(128)): with T.sblock("matmul"): @@ -99,7 +99,7 @@ def matmul2( B: T.Buffer((T.int64(64), T.int64(128)), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(64)): with T.sblock("matmul"): @@ -198,7 +198,7 @@ def gelu1( A: T.Buffer((T.int64(128), T.int64(64)), "float32"), T_multiply: T.Buffer((T.int64(128), T.int64(64)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): T_multiply_1 = T.sblock_alloc_buffer((T.int64(128), T.int64(64))) compute = T.sblock_alloc_buffer((T.int64(128), T.int64(64))) @@ -241,7 +241,7 @@ def matmul11( B: T.Buffer((T.int64(64), T.int64(128)), "float32"), matmul: T.Buffer((T.int64(64), T.int64(128)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(64), T.int64(128), T.int64(64)): with T.sblock("matmul"): @@ -258,7 +258,7 @@ def matmul2( B: T.Buffer((T.int64(128), T.int64(64)), "float32"), matmul: T.Buffer((T.int64(128), T.int64(64)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(128), T.int64(64), T.int64(128)): with T.sblock("matmul"): @@ -275,7 +275,7 @@ def split11( T_split: T.Buffer((64, 64), "float32"), T_split_1: T.Buffer((64, 64), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax1, ax2 in T.grid(64, 64): with T.sblock("T_split"): diff --git a/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py b/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py index e03957208bd0..bdf1375bc459 100644 --- a/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py +++ b/tests/python/relax/distributed/test_distributed_transform_lower_global_to_local_view.py @@ -23,7 +23,7 @@ from tvm.ir import assert_structural_equal from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def test_mlp(): @@ -39,7 +39,7 @@ def gelu( A: T.Buffer((T.int64(128), T.int64(128)), "float32"), T_multiply: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): T_multiply_1 = T.sblock_alloc_buffer((T.int64(128), T.int64(128))) compute = T.sblock_alloc_buffer((T.int64(128), T.int64(128))) @@ -82,7 +82,7 @@ def matmul( B: T.Buffer((T.int64(128), T.int64(128)), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(128)): with T.sblock("matmul"): @@ -128,7 +128,7 @@ def gelu1( A: T.Buffer((T.int64(128), T.int64(64)), "float32"), T_multiply: T.Buffer((T.int64(128), T.int64(64)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): T_multiply_1 = T.sblock_alloc_buffer((T.int64(128), T.int64(64))) compute = T.sblock_alloc_buffer((T.int64(128), T.int64(64))) @@ -171,7 +171,7 @@ def matmul1( B: T.Buffer((T.int64(128), T.int64(64)), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(64)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(128), T.int64(64), T.int64(128)): with T.sblock("matmul"): @@ -188,7 +188,7 @@ def matmul2( B: T.Buffer((T.int64(64), T.int64(128)), "float32"), matmul_1: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(64)): with T.sblock("matmul"): @@ -245,7 +245,7 @@ def add( B: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), T_add: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(256), T.int64(4096)): with T.sblock("T_add"): @@ -260,7 +260,7 @@ def divide( B: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), T_divide: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): with T.sblock("T_divide"): @@ -277,7 +277,7 @@ def matmul( B: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), matmul: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, k in T.grid(T.int64(1), T.int64(256), T.int64(4096), T.int64(4096)): with T.sblock("matmul"): @@ -296,7 +296,7 @@ def matmul1( B: T.Buffer((T.int64(1), T.int64(32), T.int64(128), T.int64(256)), "float16"), matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid( T.int64(1), T.int64(32), T.int64(256), T.int64(256), T.int64(128) @@ -318,7 +318,7 @@ def matmul2( B: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"), matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid( T.int64(1), T.int64(32), T.int64(256), T.int64(128), T.int64(256) @@ -340,7 +340,7 @@ def maximum( B: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), T_maximum: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): with T.sblock("T_maximum"): @@ -357,7 +357,7 @@ def minimum( B: T.Buffer((T.int64(1), T.int64(1), T.int64(256), T.int64(256)), "float16"), T_minimum: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): with T.sblock("T_minimum"): @@ -373,7 +373,7 @@ def reshape( A: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(32), T.int64(128)): with T.sblock("T_reshape"): @@ -398,7 +398,7 @@ def reshape1( A: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), T_reshape: T.Buffer((T.int64(256), T.int64(32), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(256), T.int64(32), T.int64(128)): with T.sblock("T_reshape"): @@ -424,7 +424,7 @@ def reshape2( A: T.Buffer((T.int64(256), T.int64(32), T.int64(128)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(32), T.int64(128)): with T.sblock("T_reshape"): @@ -448,7 +448,7 @@ def reshape3( A: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(256), T.int64(4096)): with T.sblock("T_reshape"): @@ -475,7 +475,7 @@ def rms_norm( B: T.Buffer((T.int64(4096),), "float16"), rms_norm_1: T.Buffer((T.int64(1), 256, T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): Ared_temp = T.sblock_alloc_buffer((T.int64(1), 256)) for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)): @@ -512,7 +512,7 @@ def rotary_embedding( C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), rotary: T.Buffer((T.int64(1), 256, T.int64(32), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, i3 in T.grid(T.int64(1), 256, T.int64(32), T.int64(128)): with T.sblock("rotary"): @@ -538,7 +538,7 @@ def softmax( (T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16" ), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): T_softmax_maxelem = T.sblock_alloc_buffer( (T.int64(1), T.int64(32), T.int64(256)), "float16" @@ -594,7 +594,7 @@ def transpose( A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): with T.sblock("T_transpose"): @@ -608,7 +608,7 @@ def transpose1( A: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(128)): with T.sblock("T_transpose"): @@ -622,7 +622,7 @@ def transpose2( A: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(128), T.int64(256)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(128), T.int64(256)): with T.sblock("T_transpose"): @@ -636,7 +636,7 @@ def transpose3( A: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(32), T.int64(128)): with T.sblock("T_transpose"): @@ -860,7 +860,7 @@ def add( B: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), T_add: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(256), T.int64(4096)): with T.sblock("T_add"): @@ -875,7 +875,7 @@ def divide1( B: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(256)), "float16"), T_divide: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(256)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(16), T.int64(256), T.int64(256)): with T.sblock("T_divide"): @@ -892,7 +892,7 @@ def matmul11( B: T.Buffer((T.int64(1), T.int64(16), T.int64(128), T.int64(256)), "float16"), matmul: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(256)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid( T.int64(1), T.int64(16), T.int64(256), T.int64(256), T.int64(128) @@ -914,7 +914,7 @@ def matmul21( B: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(128)), "float16"), matmul: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid( T.int64(1), T.int64(16), T.int64(256), T.int64(128), T.int64(256) @@ -936,7 +936,7 @@ def matmul3( B: T.Buffer((T.int64(4096), T.int64(2048)), "float16"), matmul: T.Buffer((T.int64(1), T.int64(256), T.int64(2048)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, k in T.grid(T.int64(1), T.int64(256), T.int64(2048), T.int64(4096)): with T.sblock("matmul"): @@ -955,7 +955,7 @@ def matmul4( B: T.Buffer((T.int64(2048), T.int64(4096)), "float16"), matmul: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, k in T.grid(T.int64(1), T.int64(256), T.int64(4096), T.int64(2048)): with T.sblock("matmul"): @@ -974,7 +974,7 @@ def maximum1( B: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(256)), "float16"), T_maximum: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(256)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(16), T.int64(256), T.int64(256)): with T.sblock("T_maximum"): @@ -991,7 +991,7 @@ def minimum1( B: T.Buffer((T.int64(1), T.int64(1), T.int64(256), T.int64(256)), "float16"), T_minimum: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(256)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(16), T.int64(256), T.int64(256)): with T.sblock("T_minimum"): @@ -1007,7 +1007,7 @@ def reshape11( A: T.Buffer((T.int64(1), T.int64(256), T.int64(16), T.int64(128)), "float16"), T_reshape: T.Buffer((T.int64(256), T.int64(16), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(256), T.int64(16), T.int64(128)): with T.sblock("T_reshape"): @@ -1033,7 +1033,7 @@ def reshape21( A: T.Buffer((T.int64(256), T.int64(16), T.int64(128)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(16), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(16), T.int64(128)): with T.sblock("T_reshape"): @@ -1057,7 +1057,7 @@ def reshape31( A: T.Buffer((T.int64(1), T.int64(256), T.int64(16), T.int64(128)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(2048)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(256), T.int64(2048)): with T.sblock("T_reshape"): @@ -1083,7 +1083,7 @@ def reshape4( A: T.Buffer((T.int64(1), T.int64(256), T.int64(2048)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(16), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(16), T.int64(128)): with T.sblock("T_reshape"): @@ -1109,7 +1109,7 @@ def rms_norm( B: T.Buffer((T.int64(4096),), "float16"), rms_norm_1: T.Buffer((T.int64(1), 256, T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): Ared_temp = T.sblock_alloc_buffer((T.int64(1), 256)) for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)): @@ -1146,7 +1146,7 @@ def rotary_embedding( C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), rotary: T.Buffer((T.int64(1), 256, T.int64(32), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, i3 in T.grid(T.int64(1), 256, T.int64(32), T.int64(128)): with T.sblock("rotary"): @@ -1172,7 +1172,7 @@ def rotary_embedding1( C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), rotary: T.Buffer((T.int64(1), 256, T.int64(16), T.int64(128)), "float16"), ): - T.func_attr({"global_symbol": "rotary_embedding", "tir.noalias": True}) + T.func_attr({"global_symbol": "rotary_embedding", "tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, i3 in T.grid(T.int64(1), 256, T.int64(16), T.int64(128)): with T.sblock("rotary"): @@ -1198,7 +1198,7 @@ def softmax1( (T.int64(1), T.int64(16), T.int64(256), T.int64(256)), "float16" ), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): T_softmax_maxelem = T.sblock_alloc_buffer( (T.int64(1), T.int64(16), T.int64(256)), "float16" @@ -1254,7 +1254,7 @@ def transpose11( A: T.Buffer((T.int64(1), T.int64(256), T.int64(16), T.int64(128)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(16), T.int64(256), T.int64(128)): with T.sblock("T_transpose"): @@ -1268,7 +1268,7 @@ def transpose21( A: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(128)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(16), T.int64(128), T.int64(256)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(16), T.int64(128), T.int64(256)): with T.sblock("T_transpose"): @@ -1282,7 +1282,7 @@ def transpose31( A: T.Buffer((T.int64(1), T.int64(16), T.int64(256), T.int64(128)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(256), T.int64(16), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(16), T.int64(128)): with T.sblock("T_transpose"): @@ -1296,7 +1296,7 @@ def transpose4( A: T.Buffer((T.int64(2048), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(2048)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(4096), T.int64(2048)): with T.sblock("T_transpose"): @@ -1310,7 +1310,7 @@ def transpose5( A: T.Buffer((T.int64(4096), T.int64(2048)), "float16"), T_transpose: T.Buffer((T.int64(2048), T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(2048), T.int64(4096)): with T.sblock("T_transpose"): diff --git a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py index 326dfea5fb81..5fc7aba39f46 100644 --- a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py +++ b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py @@ -24,7 +24,7 @@ from tvm.ir import assert_structural_equal from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def test_mlp(): @@ -93,7 +93,7 @@ class MLPWithTuple: @T.prim_func(private=True) def split1(var_A: T.handle, var_T_split: T.handle, var_T_split_1: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) A = T.match_buffer(var_A, (128, 128), "float32") T_split = T.match_buffer(var_T_split, (64, 128), "float32") T_split_1 = T.match_buffer(var_T_split_1, (64, 128), "float32") @@ -142,7 +142,7 @@ def split1( T_split: T.Buffer((64, 128), "float32"), T_split_1: T.Buffer((64, 128), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax1, ax2 in T.grid(64, 128): with T.sblock("T_split"): @@ -390,7 +390,7 @@ class LlamaAttentionLayer: def rms_norm( var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) A = T.match_buffer(var_A, (T.int64(1), 256, T.int64(4096)), "float16") rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), 256, T.int64(4096)), "float16") @@ -430,7 +430,7 @@ def rotary_embedding( C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), var_rotary: T.handle, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) A = T.match_buffer(var_A, (T.int64(1), 256, T.int64(32), T.int64(128)), "float16") rotary = T.match_buffer( @@ -603,7 +603,7 @@ def rms_norm( B: T.Buffer((T.int64(4096),), "float16"), rms_norm_1: T.Buffer((T.int64(1), 256, T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): Ared_temp = T.sblock_alloc_buffer((T.int64(1), 256)) for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)): @@ -640,7 +640,7 @@ def rotary_embedding( C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), rotary: T.Buffer((T.int64(1), 256, T.int64(32), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, i3 in T.grid(T.int64(1), 256, T.int64(32), T.int64(128)): with T.sblock("rotary"): @@ -819,7 +819,7 @@ def add( B: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), T_add: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(256), T.int64(4096)): with T.sblock("T_add"): @@ -837,7 +837,7 @@ def divide( B: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), T_divide: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): with T.sblock("T_divide"): @@ -855,7 +855,7 @@ def matmul( B: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), matmul: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, k in T.grid(T.int64(1), T.int64(256), T.int64(4096), T.int64(4096)): with T.sblock("matmul"): @@ -875,7 +875,7 @@ def matmul1( B: T.Buffer((T.int64(1), T.int64(32), T.int64(128), T.int64(256)), "float16"), matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid( T.int64(1), T.int64(32), T.int64(256), T.int64(256), T.int64(128) @@ -898,7 +898,7 @@ def matmul2( B: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"), matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid( T.int64(1), T.int64(32), T.int64(256), T.int64(128), T.int64(256) @@ -921,7 +921,7 @@ def maximum( B: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), T_maximum: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): with T.sblock("T_maximum"): @@ -939,7 +939,7 @@ def minimum( B: T.Buffer((T.int64(1), T.int64(1), T.int64(256), T.int64(256)), "float16"), T_minimum: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(256)): with T.sblock("T_minimum"): @@ -958,7 +958,7 @@ def reshape( A: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(32), T.int64(128)): with T.sblock("T_reshape"): @@ -975,7 +975,7 @@ def reshape1( A: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), T_reshape: T.Buffer((T.int64(256), T.int64(32), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(256), T.int64(32), T.int64(128)): with T.sblock("T_reshape"): @@ -989,7 +989,7 @@ def reshape2( A: T.Buffer((T.int64(256), T.int64(32), T.int64(128)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(32), T.int64(128)): with T.sblock("T_reshape"): @@ -1004,7 +1004,7 @@ def reshape3( A: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), T_reshape: T.Buffer((T.int64(1), T.int64(256), T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(256), T.int64(4096)): with T.sblock("T_reshape"): @@ -1022,7 +1022,7 @@ def rms_norm( B: T.Buffer((T.int64(4096),), "float16"), rms_norm_1: T.Buffer((T.int64(1), 256, T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): Ared_temp = T.sblock_alloc_buffer((T.int64(1), 256)) for bsz, i, k in T.grid(T.int64(1), 256, T.int64(4096)): @@ -1061,7 +1061,7 @@ def rotary_embedding( C: T.Buffer((T.int64(2048), T.int64(128)), "float16"), rotary: T.Buffer((T.int64(1), 256, T.int64(32), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, i3 in T.grid(T.int64(1), 256, T.int64(32), T.int64(128)): with T.sblock("rotary"): @@ -1093,7 +1093,7 @@ def softmax( (T.int64(1), T.int64(32), T.int64(256), T.int64(256)), "float16" ), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): T_softmax_maxelem = T.sblock_alloc_buffer( (T.int64(1), T.int64(32), T.int64(256)), "float16" @@ -1158,7 +1158,7 @@ def transpose( A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(4096), T.int64(4096)): with T.sblock("T_transpose"): @@ -1172,7 +1172,7 @@ def transpose1( A: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(256), T.int64(128)): with T.sblock("T_transpose"): @@ -1189,7 +1189,7 @@ def transpose2( A: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(32), T.int64(128), T.int64(256)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(128), T.int64(256)): with T.sblock("T_transpose"): @@ -1206,7 +1206,7 @@ def transpose3( A: T.Buffer((T.int64(1), T.int64(32), T.int64(256), T.int64(128)), "float16"), T_transpose: T.Buffer((T.int64(1), T.int64(256), T.int64(32), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(256), T.int64(32), T.int64(128)): with T.sblock("T_transpose"): @@ -1603,7 +1603,7 @@ class LlamaAttentionLayerDynamicShape: def rms_norm( var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16") @@ -1644,7 +1644,7 @@ def rotary_embedding( var_rotary: T.handle, m: T.int64, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") rotary = T.match_buffer( @@ -1811,7 +1811,7 @@ class ShardedLlamaAttentionLayerDynamicShape: def rms_norm( var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16") @@ -1852,7 +1852,7 @@ def rotary_embedding( var_rotary: T.handle, m: T.int64, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") rotary = T.match_buffer( diff --git a/tests/python/relax/distributed/test_distributed_tvmscript_parser.py b/tests/python/relax/distributed/test_distributed_tvmscript_parser.py index c9c2c50bd85b..5e81dfed5af0 100644 --- a/tests/python/relax/distributed/test_distributed_tvmscript_parser.py +++ b/tests/python/relax/distributed/test_distributed_tvmscript_parser.py @@ -23,13 +23,13 @@ import tvm import tvm.script import tvm.testing -from tvm import IRModule, relax, tir, topi +from tvm import IRModule, relax, tirx, topi from tvm.ir import Range from tvm.relax import Call, SeqExpr, VarBinding from tvm.relax.distributed import DeviceMesh from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def _check( @@ -61,7 +61,7 @@ def tir_func( x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) @@ -124,7 +124,7 @@ def tir_func( x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) @@ -164,7 +164,7 @@ def tir_func( x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) diff --git a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py index 47daadb314bd..486e4c5d39c7 100644 --- a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py +++ b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py @@ -22,7 +22,7 @@ from tvm.relax.distributed import DeviceMesh, DTensorStructInfo, Placement from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def _assert_print(obj, expected): @@ -90,7 +90,7 @@ def tir_func( x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) @@ -129,7 +129,7 @@ def test_module(): TestModule, """ # from tvm.script import ir as I -# from tvm.script import tir as T +# from tvm.script import tirx as T # from tvm.script import relax as R @I.ir_module @@ -138,7 +138,7 @@ class Module: I.module_global_infos({"mesh": [R.device_mesh((2, 2), I.Range(0, 4)), R.device_mesh((1,), I.Range(4, 5))]}) @T.prim_func def tir_func(x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer((T.int64(128), T.int64(128)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i, j in T.grid(T.int64(128), T.int64(128)): with T.sblock(""): diff --git a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py index b9d9e76d062b..fc3d31ae95b8 100644 --- a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py +++ b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py @@ -173,7 +173,7 @@ def set_global_func(head_dim, dtype): mod = tvm.IRModule({"main": tir_func}) with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) - f = tvm.tir.build(mod["main"], target=target) + f = tvm.tirx.build(mod["main"], target=target) builts.append(f.main) ( @@ -219,13 +219,13 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): tvm.runtime.empty((), dtype, device=device), ftranspose_append, None, # f_transpose_append_mla - ["tir", fattn_prefill_ragged], - ["tir", fattn_prefill], - ["tir", fattn_decode], - ["tir", fattn_prefill_sliding_window], - ["tir", fattn_decode_sliding_window], - ["tir", fattn_prefill_with_tree_mask_paged_kv_cache], - ["tir", fattn_prefill_with_tree_mask], + ["tirx", fattn_prefill_ragged], + ["tirx", fattn_prefill], + ["tirx", fattn_decode], + ["tirx", fattn_prefill_sliding_window], + ["tirx", fattn_decode_sliding_window], + ["tirx", fattn_prefill_with_tree_mask_paged_kv_cache], + ["tirx", fattn_prefill_with_tree_mask], [], # f_mla_prefill [fmerge_state], fsplit_rotary, diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index 69bfd1f22210..bfbd16ba514a 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -22,7 +22,7 @@ import tvm import tvm.testing from tvm import relax as rx -from tvm import tir +from tvm import tirx from tvm.relax.analysis import ( all_global_vars, all_vars, @@ -36,7 +36,7 @@ ) from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def var_name_set(vars: list[rx.Var | rx.GlobalVar]) -> set[str]: @@ -44,8 +44,8 @@ def var_name_set(vars: list[rx.Var | rx.GlobalVar]) -> set[str]: def test_use_def(): - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x = rx.Var("x", R.Tensor([m, n], "float16")) y = rx.Var("y", R.Tensor([n], "float16")) ib = rx.BlockBuilder() @@ -74,8 +74,8 @@ def test_use_def(): ids=["binary_op", "self_reference", "tuple"], ) def test_used_vars(expr_fn, expected_var_names): - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x = rx.Var("x", R.Tensor([m, n], "float16")) y = rx.Var("y", R.Tensor([n], "float16")) z = rx.Var("z", R.Tensor([m], "float16")) @@ -597,7 +597,7 @@ def expand_dims( rxplaceholder: T.Buffer((2, 3, 4), "float32"), expand_dims: T.Buffer((2, 1, 1, 1, 3, 1, 4, 1), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(2, 1, 1, 1, 3, 1, 4, 1): with T.sblock("expand_dims"): i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1 = T.axis.remap( @@ -659,7 +659,7 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): def test_reshape_pattern_dyn_3(): @T.prim_func def reshape(var_A: T.handle, var_T_reshape: T.handle): - T.func_attr({"op_pattern": 8, "tir.noalias": True}) + T.func_attr({"op_pattern": 8, "tirx.noalias": True}) n = T.int64() A = T.match_buffer(var_A, (n, T.int64(4096)), "float16") T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16") @@ -678,7 +678,7 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): def test_reshape_pattern_dyn_4(): @T.prim_func def reshape(var_A: T.handle, var_T_reshape: T.handle): - T.func_attr({"op_pattern": 8, "tir.noalias": True}) + T.func_attr({"op_pattern": 8, "tirx.noalias": True}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") T_reshape = T.match_buffer( @@ -707,7 +707,7 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): def test_reshape_pattern_dyn_5(): @T.prim_func def reshape(var_A: T.handle, var_T_reshape: T.handle): - T.func_attr({"op_pattern": 8, "tir.noalias": True}) + T.func_attr({"op_pattern": 8, "tirx.noalias": True}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(32), T.int64(128)), "float16") T_reshape = T.match_buffer(var_T_reshape, (T.int64(1), n, T.int64(4096)), "float16") diff --git a/tests/python/relax/test_analysis_computable_at_compile_time.py b/tests/python/relax/test_analysis_computable_at_compile_time.py index fd592a989450..734986f2cff3 100644 --- a/tests/python/relax/test_analysis_computable_at_compile_time.py +++ b/tests/python/relax/test_analysis_computable_at_compile_time.py @@ -20,7 +20,7 @@ import tvm import tvm.testing from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def _analyze_func(func: tvm.relax.Function) -> list[str]: diff --git a/tests/python/relax/test_analysis_detect_recursion.py b/tests/python/relax/test_analysis_detect_recursion.py index 5e44128a7069..994f12546d84 100644 --- a/tests/python/relax/test_analysis_detect_recursion.py +++ b/tests/python/relax/test_analysis_detect_recursion.py @@ -20,7 +20,7 @@ from tvm import relax as rx from tvm.relax.analysis import detect_recursion from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def assert_groups(groups: list[list[rx.GlobalVar]], expected: list[list[str]]) -> None: diff --git a/tests/python/relax/test_analysis_estimate_memory_usage.py b/tests/python/relax/test_analysis_estimate_memory_usage.py index 5f65b0670ff3..683b9940fa6c 100644 --- a/tests/python/relax/test_analysis_estimate_memory_usage.py +++ b/tests/python/relax/test_analysis_estimate_memory_usage.py @@ -20,7 +20,7 @@ import tvm.testing from tvm.relax.analysis import estimate_memory_usage from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_basic(): diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py b/tests/python/relax/test_analysis_struct_info_analysis.py index b539a4c32a23..dbcc94db83f2 100644 --- a/tests/python/relax/test_analysis_struct_info_analysis.py +++ b/tests/python/relax/test_analysis_struct_info_analysis.py @@ -22,10 +22,10 @@ import tvm import tvm.testing -from tvm import TVMError, ir, tir +from tvm import TVMError, ir, tirx from tvm import relax as rx from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_get_static_type_basic(): @@ -40,7 +40,7 @@ def test_get_static_type_basic(): def test_get_static_type_shape(): # shape - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") s2 = rx.ShapeStructInfo([1, n + 1, m]) s3 = rx.ShapeStructInfo(ndim=2) @@ -51,7 +51,7 @@ def test_get_static_type_shape(): def test_get_static_type_tensor(): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") s4 = rx.TensorStructInfo([1, n + 1, m], "int64") tvm.ir.assert_structural_equal( @@ -61,7 +61,7 @@ def test_get_static_type_tensor(): def test_get_static_type_tuple(): # tuple - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") s0 = rx.ObjectStructInfo() s2 = rx.ShapeStructInfo([1, n + 1, m]) s4 = rx.TensorStructInfo([1, n + 1, m], "int64") @@ -82,7 +82,7 @@ def test_get_static_type_tuple(): def test_get_static_type_func(): # tuple def fn_info(c): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = rx.TensorStructInfo([c, n, m], "float32") y = rx.TensorStructInfo([c, n, 1], "float32") z = rx.TensorStructInfo([c, n], "float32") @@ -109,7 +109,7 @@ def test_erase_to_well_defined_basic(): def test_erase_to_well_defined_shape(): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") s2 = rx.ShapeStructInfo([1, n + 1, m]) s3 = rx.ShapeStructInfo(ndim=2) @@ -132,7 +132,7 @@ def test_erase_to_well_defined_shape(): def test_erase_to_well_defined_tensor(): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") rshape = rx.Var("shape", rx.ShapeStructInfo(ndim=2)) s0 = rx.TensorStructInfo(rshape, dtype="int32") @@ -172,7 +172,7 @@ def test_erase_to_well_defined_tensor(): def test_erase_to_well_defined_tuple(): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") s0 = rx.ObjectStructInfo() s2 = rx.ShapeStructInfo([1, m]) s4 = rx.TensorStructInfo([1, n + 1, m], "int64") @@ -194,7 +194,7 @@ def test_erase_to_well_defined_tuple(): def test_erase_to_well_defined_func(): def fn_info(c): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = rx.TensorStructInfo([c, n, m], "float32") y = rx.TensorStructInfo([c, n, 1], "float32") z = rx.TensorStructInfo([c, n], "float32") @@ -209,7 +209,7 @@ def test_base_check(): BR = rx.analysis.BaseCheckResult bcheck = rx.analysis.struct_info_base_check - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") obj0 = rx.ObjectStructInfo() prim0 = rx.PrimStructInfo("int32") prim1 = rx.PrimStructInfo("float32") @@ -331,7 +331,7 @@ def test_base_check(): assert bcheck(rx.TupleStructInfo([t0, t1]), rx.TupleStructInfo([t1, t0])) == BR.FAIL_L1 def fn_info_shape(c): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = rx.TensorStructInfo([c, n, m], "float32") y = rx.TensorStructInfo([c, n, 1], "float32") z = rx.TensorStructInfo([c, n], "float32") @@ -367,13 +367,13 @@ def test_derive_call_ret_struct_info(): obj0 = rx.ObjectStructInfo() prim0 = rx.PrimStructInfo("float32") - n, m = tir.Var("n0", "int64"), tir.Var("m0", "int64") + n, m = tirx.Var("n0", "int64"), tirx.Var("m0", "int64") bb = rx.BlockBuilder() # derivation cases with bb.testing_scope(def_vars=[n, m]): def func0(c): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = rx.TensorStructInfo([n, m], "float32") z = rx.TensorStructInfo([m + c, n], "float32") return rx.FuncStructInfo([x], z) @@ -420,7 +420,7 @@ def func0(c): vdev = ir.VDevice("llvm") def func1(c): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = rx.TensorStructInfo([n, m], "float32", vdev) z = rx.TensorStructInfo([m + c, n], "float32", vdev) return rx.FuncStructInfo([x], z) @@ -440,7 +440,7 @@ def func1(c): # recursive tuple derivation def func_tuple0(c): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x0 = rx.TensorStructInfo([n, c], "float32") x1 = rx.TensorStructInfo([n + c, m], "float32") z = rx.TupleStructInfo([rx.TensorStructInfo([m, n], "float32")]) @@ -461,7 +461,7 @@ def func_tuple0(c): ) def func_tuple1(c): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x0 = rx.TensorStructInfo([n, m], "float32") x1 = rx.TensorStructInfo([n + c, c], "float32") z = rx.TupleStructInfo([rx.TensorStructInfo([m, n], "float32")]) @@ -493,7 +493,7 @@ def func_tuple1(c): # mixed shape types def func_shape_mixed(c): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x0 = rx.ShapeStructInfo([n, m]) f0 = func_tuple0(c) z = rx.ShapeStructInfo([m + n, c]) @@ -518,7 +518,7 @@ def _check_lca(lhs, rhs, target): def test_struct_info_lca(): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") obj0 = rx.ObjectStructInfo() prim0 = rx.PrimStructInfo("int32") prim1 = rx.PrimStructInfo("float32") @@ -595,7 +595,7 @@ def test_struct_info_lca(): _check_lca(t7, rx.TupleStructInfo([]), t7) def fn_info_shape(c): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = rx.TensorStructInfo([c, n, m], "float32") y = rx.TensorStructInfo([c, n, 1], "float32") z = rx.TensorStructInfo([c, n], "float32") @@ -644,29 +644,29 @@ def _generate_prim_test_cases(): # The LCA of two values, each statically known to be the same # value, is known to have that value. yield ( - R.Prim(value=tir.const(0, dtype)), - R.Prim(value=tir.const(0, dtype)), - R.Prim(value=tir.const(0, dtype)), + R.Prim(value=tirx.const(0, dtype)), + R.Prim(value=tirx.const(0, dtype)), + R.Prim(value=tirx.const(0, dtype)), ) # The LCA of two values, each of which is statically known to # have a different value, no longer knows the contained value. yield ( - R.Prim(value=tir.const(0, dtype)), - R.Prim(value=tir.const(1, dtype)), + R.Prim(value=tirx.const(0, dtype)), + R.Prim(value=tirx.const(1, dtype)), R.Prim(dtype=dtype), ) # LCA of a known variable with itself yields itself - var_N = tir.Var("N", dtype) + var_N = tirx.Var("N", dtype) yield (R.Prim(value=var_N), R.Prim(value=var_N), R.Prim(value=var_N)) # LCA of a known variable with a known static value is no # longer known to have a specific value. - yield (R.Prim(value=var_N), R.Prim(value=tir.const(0, dtype)), R.Prim(dtype=dtype)) - yield (R.Prim(value=tir.const(0, dtype)), R.Prim(value=var_N), R.Prim(dtype=dtype)) + yield (R.Prim(value=var_N), R.Prim(value=tirx.const(0, dtype)), R.Prim(dtype=dtype)) + yield (R.Prim(value=tirx.const(0, dtype)), R.Prim(value=var_N), R.Prim(dtype=dtype)) - var_M = tir.Var("M", dtype) + var_M = tirx.Var("M", dtype) yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Prim(dtype=dtype)) for dtype_a in dtypes: @@ -682,14 +682,14 @@ def _generate_prim_test_cases(): # the same value in different representations (e.g. # `T.float32(0)` vs `T.float16(0)`) fall back to `R.Object`. yield ( - R.Prim(value=tir.const(0, dtype_a)), - R.Prim(value=tir.const(0, dtype_b)), + R.Prim(value=tirx.const(0, dtype_a)), + R.Prim(value=tirx.const(0, dtype_b)), R.Object, ) # And the same is true for known variable values - var_N = tir.Var("N", dtype_a) - var_M = tir.Var("M", dtype_b) + var_N = tirx.Var("N", dtype_a) + var_M = tirx.Var("M", dtype_b) yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Object) @@ -714,7 +714,7 @@ def _normalize_sinfo(sinfo): def _generate_tir_var_test_cases(): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") shape0 = rx.ShapeStructInfo([1, n, 3]) shape1 = rx.ShapeStructInfo([1, 2 * n, n, m]) shape2 = rx.ShapeStructInfo([1, 2 * n, m]) @@ -751,11 +751,11 @@ def test_definable_tir_vars_in_struct_info(tir_var_test_case): def test_collect_symbolic_var_from_tensor_shape(): n, m, k, q, p = ( - tir.Var("n", "int64"), - tir.Var("m", "int64"), - tir.Var("k", "int64"), - tir.Var("q", "int64"), - tir.Var("p", "int64"), + tirx.Var("n", "int64"), + tirx.Var("m", "int64"), + tirx.Var("k", "int64"), + tirx.Var("q", "int64"), + tirx.Var("p", "int64"), ) bb = rx.BlockBuilder() x = rx.Var("x", rx.TensorStructInfo([m, m + n], "float32")) @@ -776,8 +776,8 @@ def test_collect_symbolic_var_from_tensor_shape(): def test_collect_symbolic_var_from_non_tensor_params(param_type, param_order): - tir_n = tir.Var("n", "int64") - tir_m = tir.Var("m", "int64") + tir_n = tirx.Var("n", "int64") + tir_m = tirx.Var("m", "int64") bb = rx.BlockBuilder() arg = rx.Var("arg", rx.TensorStructInfo([tir_n * tir_m])) diff --git a/tests/python/relax/test_analysis_suggest_layout_transforms.py b/tests/python/relax/test_analysis_suggest_layout_transforms.py index e09cc9c5806e..336cd867051d 100644 --- a/tests/python/relax/test_analysis_suggest_layout_transforms.py +++ b/tests/python/relax/test_analysis_suggest_layout_transforms.py @@ -19,8 +19,8 @@ import pytest import tvm.testing -from tvm import relax, tir -from tvm.script import tir as T +from tvm import relax, tirx +from tvm.script import tirx as T def apply_transformations(func, suggested_transfoms, print_transformation=False): @@ -28,13 +28,13 @@ def apply_transformations(func, suggested_transfoms, print_transformation=False) for block, per_block_transformations in suggested_transfoms.items(): blockrv = sch.get_sblock(block.name_hint) for obj, index_map in per_block_transformations.items(): - if isinstance(obj, tir.SBlock): + if isinstance(obj, tirx.SBlock): block_name = obj.name_hint if print_transformation: print("Block transformation: ", block_name, " :: ", index_map) sch.transform_block_layout(block_name, index_map) else: - assert isinstance(obj, tir.Buffer) + assert isinstance(obj, tirx.Buffer) buffer = obj if print_transformation: print("Buffer transformation: ", buffer, " :: ", index_map) @@ -622,7 +622,7 @@ def before( arg1: T.Buffer((64, 224, 224), "float32"), T_add: T.Buffer((32, 64, 224, 224), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(32, 64, 224, 224): with T.sblock("T_add"): @@ -642,7 +642,7 @@ def expected( arg1: T.Buffer((224, 224, 16, 4), "float32"), T_add: T.Buffer((32, 224, 224, 16, 4), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3, ax4 in T.grid(32, 224, 224, 16, 4): with T.sblock("T_add"): diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 00ee1f250a6a..9acb5ad752ca 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -21,13 +21,13 @@ import tvm import tvm.testing from tvm import relax as rx -from tvm import tir +from tvm import tirx from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T -m = tir.Var("m", "int64") -n = tir.Var("n", "int64") +m = tirx.Var("m", "int64") +n = tirx.Var("n", "int64") x = rx.Var("x", R.Tensor([m, n], "float32")) cond = rx.Var("cond", R.Tensor([], "bool")) @@ -136,7 +136,7 @@ def test_global_var(): def test_symbolic_var(): # Error: Symbolic Var new_s is not defined - new_s = tir.Var("new_s", "int64") + new_s = tirx.Var("new_s", "int64") gv0 = rx.Var("gv0", R.Tensor([m, new_s], "int64")) call_node = rx.op.add(x, x) bindings = [rx.VarBinding(gv0, call_node)] @@ -148,7 +148,7 @@ def test_symbolic_var(): def test_symbolic_var_across_functions(): # Error: Symbolic Var s presents across different functions - s = tir.Var("s", "int64") + s = tirx.Var("s", "int64") v0 = rx.Var("v0", R.Tensor([5, s], "float32")) v1 = rx.Var("v1", R.Tensor([s, 7], "float32")) bb = rx.BlockBuilder() @@ -164,7 +164,7 @@ def test_symbolic_var_invalid_type(): with pytest.raises( tvm.TVMError, match="the value in ShapeStructInfo can only have dtype of int64" ): - dim = tir.Var("dim", "float32") + dim = tirx.Var("dim", "float32") y = rx.Var("y", R.Tensor([dim], "float32")) gv0 = rx.Var("gv0", R.Tensor([dim], "float32")) call_node = rx.op.add(y, y) @@ -415,7 +415,7 @@ def test_inline_prim_func(): [ rx.VarBinding( var=x, - value=tir.PrimFunc([], tir.Evaluate(0)), + value=tirx.PrimFunc([], tirx.Evaluate(0)), ), rx.VarBinding( var=y, @@ -423,7 +423,7 @@ def test_inline_prim_func(): op=tvm.ir.Op.get("relax.call_tir"), args=[ rx.GlobalVar("GlobalVar0"), - rx.Tuple([x, tir.PrimFunc([], tir.Evaluate(0))]), + rx.Tuple([x, tirx.PrimFunc([], tirx.Evaluate(0))]), rx.ShapeExpr([]), ], ), @@ -498,8 +498,8 @@ def test_nested_dataflow(): def test_sinfo_args_tir_var_used_before_define_call_packed(): # Error: Symbolic Var m1, n1 are not defined - m1 = tir.Var("m1", "int64") - n1 = tir.Var("n1", "int64") + m1 = tirx.Var("m1", "int64") + n1 = tirx.Var("n1", "int64") call = R.call_packed("my_func", x, sinfo_args=R.Tensor((m1, n1), "float32")) func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]) mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) @@ -508,8 +508,8 @@ def test_sinfo_args_tir_var_used_before_define_call_packed(): def test_sinfo_args_tir_var_used_before_define_call_tir(): # Error: Symbolic Var m1, n1 are not defined - m1 = tir.Var("m1", "int64") - n1 = tir.Var("n1", "int64") + m1 = tirx.Var("m1", "int64") + n1 = tirx.Var("n1", "int64") call = R.call_dps_packed("my_func", x, out_sinfo=R.Tensor((m1, n1), "float32")) func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]) mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) @@ -526,8 +526,8 @@ def foo(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m1", "n1"), dtyp gv = R.call_dps_packed("my_func", (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) return gv """ - m1 = tir.Var("m1", "int64") - n1 = tir.Var("n1", "int64") + m1 = tirx.Var("m1", "int64") + n1 = tirx.Var("n1", "int64") call = R.call_dps_packed("my_func", x, out_sinfo=R.Tensor((m, n), "float32")) blocks = [rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])] seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var) @@ -663,7 +663,7 @@ def test_pass_dltensor_arg_to_tir(): """Relax may pass R.Tensor as DLTensor In TIR, a `DLTensor*` argument with unknown shape and dtype is - represented as a `tir.Var` with + represented as a `tirx.Var` with `tvm::PrimType(DataType::Handle())`, and with no entry in the `PrimFuncNode::buffer_map`. In Relax, this is represented as `R.Tensor`. Calls from Relax to TIR that pass a tensor of unknown @@ -682,9 +682,9 @@ def main(A: R.Tensor) -> R.Prim("bool"): @T.prim_func(private=True) def is_bfloat16_dtype(tensor: T.handle) -> T.bool: - T.func_attr({"tir.is_scheduled": True, "tir.is_host_func": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.is_host_func": True}) - # From #include + # From #include kDLTensorTypeCode = T.meta_var(5) kDLTensorTypeBits = T.meta_var(6) kDLTensorTypeLanes = T.meta_var(7) diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index bba5c6aa402b..512f5ce465fc 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -23,11 +23,11 @@ import tvm import tvm.testing from tvm import relax as rx -from tvm import tir +from tvm import tirx from tvm.relax.testing import dump_ast from tvm.relax.testing.ast_printer import ASTPrinter from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T # Overload dump_ast to test both struct info and type annotations dump_ast = partial(dump_ast, include_struct_info_annotations=True) @@ -104,8 +104,8 @@ def test_dataflow_var() -> None: def test_match_cast() -> None: # match_cast([16, 8], [m, n]) - m = tir.Var("m", dtype="int64") - n = tir.Var("n", dtype="int64") + m = tirx.Var("m", dtype="int64") + n = tirx.Var("n", dtype="int64") shape = rx.const([16, 8], "int32") var = rx.Var("v0", R.Shape()) b0 = rx.MatchCast(var, shape, R.Tensor([m, n], "int32")) @@ -141,8 +141,8 @@ def test_var_binding() -> None: def test_binding_block() -> None: - m = tir.Var("m", dtype="int64") - n = tir.Var("n", dtype="int64") + m = tirx.Var("m", dtype="int64") + n = tirx.Var("n", dtype="int64") shape = rx.const([16, 8], "int32") b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32")) @@ -160,8 +160,8 @@ def test_binding_block() -> None: def test_dataflow_block() -> None: - m = tir.Var("m", dtype="int64") - n = tir.Var("n", dtype="int64") + m = tirx.Var("m", dtype="int64") + n = tirx.Var("n", dtype="int64") shape = rx.const([16, 8], "int32") b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32")) @@ -195,8 +195,8 @@ def test_seq_expr() -> None: def test_shape_expr() -> None: - m = tir.Var("m", dtype="int32") - n = tir.Var("n", dtype="int32") + m = tirx.Var("m", dtype="int32") + n = tirx.Var("n", dtype="int32") s = rx.ShapeExpr([m, n]) s_str = dump_ast(s) assert s_str.startswith("ShapeExpr(") @@ -292,7 +292,7 @@ def test_struct_info(): assert printer.visit_struct_info_(empty_ssi) == "ShapeStructInfo(ndim=-1)" # include some dimensions - shape_info = rx.ShapeStructInfo([tir.IntImm("int64", 1), tir.IntImm("int64", 2)]) + shape_info = rx.ShapeStructInfo([tirx.IntImm("int64", 1), tirx.IntImm("int64", 2)]) assert strip_whitespace(printer.visit_struct_info_(shape_info)) == strip_whitespace( """ ShapeStructInfo( @@ -650,7 +650,7 @@ def f(x: R.Tuple(R.Tensor((), dtype="int32"))) -> R.Tensor((), dtype="int32"): def test_prim_value(): - prim_value = rx.PrimValue(tir.IntImm("int64", 1)) + prim_value = rx.PrimValue(tirx.IntImm("int64", 1)) prim_str = strip_whitespace(dump_ast(prim_value)) assert prim_str == strip_whitespace( """ diff --git a/tests/python/relax/test_backend_dispatch_sampling.py b/tests/python/relax/test_backend_dispatch_sampling.py index 013c1dddeba9..7134f66fe9c1 100644 --- a/tests/python/relax/test_backend_dispatch_sampling.py +++ b/tests/python/relax/test_backend_dispatch_sampling.py @@ -24,7 +24,7 @@ from tvm.relax.backend import DispatchSampling from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @I.ir_module @@ -86,7 +86,7 @@ def test_dispatch_multinomial_from_uniform_gpu(): class Expected: @T.prim_func def parallel_sampling_from_prob(var_prob: T.handle, var_uniform_samples: T.handle, var_row_indices: T.handle, var_sampled_token_ids: T.handle): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) n, vocab_size = T.int64(), T.int64() prob = T.match_buffer(var_prob, (n, vocab_size)) batch_size = T.int64() diff --git a/tests/python/relax/test_backend_dispatch_sort_scan.py b/tests/python/relax/test_backend_dispatch_sort_scan.py index dd5a65b8ecf8..2b24ab53a097 100644 --- a/tests/python/relax/test_backend_dispatch_sort_scan.py +++ b/tests/python/relax/test_backend_dispatch_sort_scan.py @@ -22,14 +22,14 @@ import tvm import tvm.script import tvm.testing -from tvm import relax, tir, topi +from tvm import relax, tirx, topi from tvm.contrib.thrust import can_use_thrust from tvm.ir.base import assert_structural_equal from tvm.relax.backend import DispatchSortScan from tvm.s_tir import dlight from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_dispatch_scanop(): @@ -88,7 +88,7 @@ def main(x: R.Tensor(("m", 3), "float32", "cuda")): target = tvm.target.Target("cuda", host="llvm") vdevices = [I.vdevice("cuda", 0)] - m = tir.Var("m", "int64") + m = tirx.Var("m", "int64") x = relax.Var("x", R.Tensor((m, 3), "float32", vdevices[0])) bb = relax.BlockBuilder() with target: @@ -132,7 +132,7 @@ def foo(x: R.Tensor(("m", 3), "float32", "llvm")): return gv vdevices = [I.vdevice("llvm", 0)] - m = tir.Var("m", "int64") + m = tirx.Var("m", "int64") x = relax.Var("x", R.Tensor((m, 3), "float32", vdevices[0])) bb = relax.BlockBuilder() @@ -229,7 +229,7 @@ def foo(x: R.Tensor(("m", 3), "float32", "llvm")): return gv vdevices = [I.vdevice("llvm", 0)] - m = tir.Var("m", "int64") + m = tirx.Var("m", "int64") x = relax.Var("x", R.Tensor((m, 3), "float32", vdevices[0])) bb = relax.BlockBuilder() @@ -322,7 +322,7 @@ def foo(x: R.Tensor(("m", 3), "float32", "llvm")): return gv vdevices = [I.vdevice("llvm", 0)] - m = tir.Var("m", "int64") + m = tirx.Var("m", "int64") x = relax.Var("x", R.Tensor((m, 3), "float32", vdevices[0])) bb = relax.BlockBuilder() diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py index 2f5cef99d880..8acd01aa7d4f 100644 --- a/tests/python/relax/test_backend_transform_shape_lower.py +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -23,7 +23,7 @@ from tvm.relax.testing.runtime_builtin import MakeShapeCode, MatchShapeCode from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T # note: we expected RemovePurityChecking to be run first, so we force purity in most test cases @@ -194,7 +194,7 @@ class Expected: @T.prim_func(private=True) def shape_func(H: T.Buffer(T.int64(4), "int64")): # generated compute function - T.func_attr({"tir.is_host_func": True}) + T.func_attr({"tirx.is_host_func": True}) H[T.int64(sindex["k+1"])] = H[T.int64(sindex["k"])] + T.int64(1) @R.function @@ -528,7 +528,7 @@ def main(x: R.Tensor(["n", "n"], "float32")) -> R.Tensor(["n * n"], "float32"): @T.prim_func(private=True) def shape_func(H: T.Buffer(T.int64(2), "int64")): # generated compute function - T.func_attr({"tir.is_host_func": True}) + T.func_attr({"tirx.is_host_func": True}) H[T.int64(sindex["n * n"])] = H[T.int64(sindex["n"])] * H[T.int64(sindex["n"])] before = Before diff --git a/tests/python/relax/test_base_py_module.py b/tests/python/relax/test_base_py_module.py index a13c6be9a23c..b60c6d7aa151 100644 --- a/tests/python/relax/test_base_py_module.py +++ b/tests/python/relax/test_base_py_module.py @@ -30,10 +30,10 @@ import torch import tvm -from tvm import relax, tir +from tvm import relax, tirx from tvm.relax import BasePyModule from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T class TestBasePyModule: diff --git a/tests/python/relax/test_base_py_module_printer.py b/tests/python/relax/test_base_py_module_printer.py index 951b56ab35bc..ceac17000793 100644 --- a/tests/python/relax/test_base_py_module_printer.py +++ b/tests/python/relax/test_base_py_module_printer.py @@ -23,7 +23,7 @@ from tvm.relax.base_py_module import BasePyModule from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @I.ir_module @@ -129,7 +129,7 @@ def data_preprocessing(self, raw_data): @T.prim_func def extract_features(data: T.handle, features: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) Data = T.match_buffer(data, (10,), "float32") Features = T.match_buffer(features, (10,), "float32") @@ -138,7 +138,7 @@ def extract_features(data: T.handle, features: T.handle): @T.prim_func def ml_inference(features: T.handle, params: T.handle, output: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) Features = T.match_buffer(features, (10,), "float32") Params = T.match_buffer(params, (10,), "float32") Output = T.match_buffer(output, (5,), "float32") @@ -148,7 +148,7 @@ def ml_inference(features: T.handle, params: T.handle, output: T.handle): @T.prim_func def post_process(predictions: T.handle, final: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) Predictions = T.match_buffer(predictions, (5,), "float32") Final = T.match_buffer(final, (5,), "float32") @@ -157,7 +157,7 @@ def post_process(predictions: T.handle, final: T.handle): @T.prim_func def normalize_data(data: T.handle, normalized: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) Data = T.match_buffer(data, (10,), "float32") Normalized = T.match_buffer(normalized, (10,), "float32") @@ -213,7 +213,7 @@ def loop_with_break(self, data, max_iter): @T.prim_func def dummy_tir(data: T.handle, output: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) Data = T.match_buffer(data, (1,), "float32") Output = T.match_buffer(output, (1,), "float32") Output[0] = Data[0] @@ -273,7 +273,7 @@ def memory_efficient_transform(self, large_tensor): @T.prim_func def vectorized_add(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) A = T.match_buffer(a, (10,), "float32") B = T.match_buffer(b, (10,), "float32") C = T.match_buffer(c, (10,), "float32") @@ -345,7 +345,7 @@ def multi_stage_pipeline(self, raw_input): @T.prim_func def final_transform(data: T.handle, output: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) Data = T.match_buffer(data, (10, 10), "float32") Output = T.match_buffer(output, (10, 10), "float32") @@ -410,7 +410,7 @@ def graceful_degradation(self, primary_input, fallback_input): @T.prim_func def safe_transform(data: T.handle, output: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) Data = T.match_buffer(data, (5,), "float32") Output = T.match_buffer(output, (5,), "float32") diff --git a/tests/python/relax/test_base_py_module_symbolic_shape.py b/tests/python/relax/test_base_py_module_symbolic_shape.py index fd6331b5cbc2..385a81045517 100644 --- a/tests/python/relax/test_base_py_module_symbolic_shape.py +++ b/tests/python/relax/test_base_py_module_symbolic_shape.py @@ -20,12 +20,12 @@ import pytest import tvm -from tvm import relax, tir +from tvm import relax, tirx from tvm.ir import IRModule from tvm.relax.base_py_module import BasePyModule from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def _make_module(): @@ -36,8 +36,8 @@ def test_infer_concrete_shape_from_numpy_input(): mod = _make_module() bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") - n = tir.Var("n", "int64") - m = tir.Var("m", "int64") + n = tirx.Var("n", "int64") + m = tirx.Var("m", "int64") sym_shape = [n, m] x = np.zeros((3, 4), dtype="float32") @@ -49,7 +49,7 @@ def test_infer_concrete_shape_all_concrete_dims(): mod = _make_module() bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") - shape = [tir.IntImm("int32", 5), 6] + shape = [tirx.IntImm("int32", 5), 6] inferred = bpm._infer_concrete_shape_from_args(shape, in_args=[]) assert inferred == [5, 6] @@ -58,7 +58,7 @@ def test_infer_concrete_shape_error_when_uninferrable(): mod = _make_module() bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") - k = tir.Var("k", "int64") + k = tirx.Var("k", "int64") with pytest.raises(ValueError): bpm._infer_concrete_shape_from_args([k, 8], in_args=[]) @@ -106,7 +106,7 @@ def test_base_py_module_tir_symbolic_end_to_end(): a = np.random.randn(5).astype("float32") b = np.random.randn(5).astype("float32") - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") out_sinfo = relax.TensorStructInfo((n,), "float32") out = bpm.call_tir("add_tir", [a, b], out_sinfo) @@ -119,9 +119,9 @@ def test_infer_concrete_shape_multiple_symbolic_dims(): mod = _make_module() bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") - n = tir.Var("n", "int64") - m = tir.Var("m", "int64") - k = tir.Var("k", "int64") + n = tirx.Var("n", "int64") + m = tirx.Var("m", "int64") + k = tirx.Var("k", "int64") sym_shape = [n, m, k] x = np.zeros((2, 3, 4), dtype="float32") @@ -134,7 +134,7 @@ def test_infer_concrete_shape_mixed_concrete_symbolic(): mod = _make_module() bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") sym_shape = [n, 5, 10] # First dim is symbolic, others are concrete x = np.zeros((3, 5, 10), dtype="float32") @@ -152,8 +152,8 @@ def test_infer_concrete_shape_from_tvm_tensors(): mod = _make_module() bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") - n = tir.Var("n", "int64") - m = tir.Var("m", "int64") + n = tirx.Var("n", "int64") + m = tirx.Var("m", "int64") sym_shape = [n, m] inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x_tvm]) @@ -168,8 +168,8 @@ def test_infer_concrete_shape_multiple_inputs(): mod = _make_module() bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") - n = tir.Var("n", "int64") - m = tir.Var("m", "int64") + n = tirx.Var("n", "int64") + m = tirx.Var("m", "int64") sym_shape = [n, m] # Multiple inputs with different shapes - should use first matching one @@ -184,8 +184,8 @@ def test_infer_concrete_shape_wrong_ndim(): mod = _make_module() bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") - n = tir.Var("n", "int64") - m = tir.Var("m", "int64") + n = tirx.Var("n", "int64") + m = tirx.Var("m", "int64") sym_shape = [n, m] # 2D x = np.zeros((3,), dtype="float32") # 1D - wrong ndim @@ -256,7 +256,7 @@ def test_add_packed(a, b, out): a = np.random.randn(5).astype("float32") b = np.random.randn(5).astype("float32") - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") out_sinfo = relax.TensorStructInfo((n,), "float32") out = bpm.call_dps_packed("test_add_packed", [a, b], out_sinfo) @@ -317,7 +317,7 @@ def test_add_scalar_packed(x, scalar, out): x = np.random.randn(4).astype("float32") scalar = 2.5 - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") out_sinfo = relax.TensorStructInfo((n,), "float32") out = bpm.call_dps_packed("test_add_scalar_packed", [x, scalar], out_sinfo) @@ -339,8 +339,8 @@ def test_infer_concrete_shape_from_pytorch_tensors(): mod = _make_module() bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") - n = tir.Var("n", "int64") - m = tir.Var("m", "int64") + n = tirx.Var("n", "int64") + m = tirx.Var("m", "int64") sym_shape = [n, m] x_torch = torch.zeros((3, 4), dtype=torch.float32) diff --git a/tests/python/relax/test_bind_params.py b/tests/python/relax/test_bind_params.py index 679dea963e45..dc13bc099259 100644 --- a/tests/python/relax/test_bind_params.py +++ b/tests/python/relax/test_bind_params.py @@ -22,7 +22,7 @@ import tvm import tvm.script import tvm.testing -from tvm import relax, tir +from tvm import relax, tirx from tvm.script import relax as R param_specification = tvm.testing.parameter("by_string", "by_var") @@ -35,7 +35,7 @@ def test_bind_tensor_param(param_specification, param_shape, tensor_param_dtype) shape = [16] ndim = -1 elif param_shape == "dynamic_shape": - shape = [tir.Var("N", "int64")] + shape = [tirx.Var("N", "int64")] ndim = -1 elif param_shape == "ndim": shape = None @@ -80,7 +80,7 @@ def test_bind_shape_param(param_shape): shape = [16] ndim = -1 elif param_shape == "dynamic_shape": - shape = [tir.Var("N", "int64")] + shape = [tirx.Var("N", "int64")] ndim = -1 elif param_shape == "ndim": shape = None @@ -115,8 +115,8 @@ def test_bind_prim_value(prim_value_dtype): if prim_value_dtype != "int64": pytest.xfail(reason="Currently, only support int64 as known symbolic value") - N = tir.Var("N", prim_value_dtype) - value = tir.const(16, prim_value_dtype) + N = tirx.Var("N", prim_value_dtype) + value = tirx.const(16, prim_value_dtype) @R.function def before(A: R.Prim(value=N)) -> R.Prim(value=N): diff --git a/tests/python/relax/test_bind_symbolic_vars.py b/tests/python/relax/test_bind_symbolic_vars.py index 1113410eaf4b..b822b589eec9 100644 --- a/tests/python/relax/test_bind_symbolic_vars.py +++ b/tests/python/relax/test_bind_symbolic_vars.py @@ -21,10 +21,10 @@ import tvm import tvm.testing from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T replace_by_tir_var = tvm.testing.parameter( - by_dict={"replace-by-string": False, "replace-by-tir-var": True} + by_dict={"replace-by-string": False, "replace-by-tirx-var": True} ) @@ -60,8 +60,8 @@ def test_error_with_duplicate_var_names(): variables share the same name, the replacement map may not refer to that variable by string. """ - N1 = tvm.tir.Var("N", "int64") - N2 = tvm.tir.Var("N", "int64") + N1 = tvm.tirx.Var("N", "int64") + N2 = tvm.tirx.Var("N", "int64") @R.function(private=True) def func(A: R.Tensor((N1, N1)), B: R.Tensor((N1, N2))) -> R.Tensor((N1, N2)): @@ -79,9 +79,9 @@ def test_string_var_when_other_var_has_duplicate_var_names(): replacing variables by name only applies to those duplicate names. Other variables may still be replaced by name. """ - N1 = tvm.tir.Var("N", "int64") - N2 = tvm.tir.Var("N", "int64") - BatchSize = tvm.tir.Var("BatchSize", "int64") + N1 = tvm.tirx.Var("N", "int64") + N2 = tvm.tirx.Var("N", "int64") + BatchSize = tvm.tirx.Var("BatchSize", "int64") @R.function(private=True) def before(A: R.Tensor((BatchSize, N1, N1)), B: R.Tensor((N1, N2))) -> R.Tensor( @@ -118,7 +118,7 @@ def func(A: R.Tensor(["M", "N"])): return A with pytest.raises(tvm.TVMError): - func.bind_symbolic_vars({tvm.tir.Var("M", "int64"): 64}) + func.bind_symbolic_vars({tvm.tirx.Var("M", "int64"): 64}) def test_error_with_multiple_definitions(): @@ -142,7 +142,7 @@ def test_error_if_output_has_undefined(): def func(A: R.Tensor(["M", "N"])): return A - outside_var = tvm.tir.Var("outside_var", "int64") + outside_var = tvm.tirx.Var("outside_var", "int64") with pytest.raises(tvm.TVMError): func.bind_symbolic_vars({"M": outside_var * 2}) @@ -159,7 +159,7 @@ def before(A: R.Tensor(["M", "N"])): def expected(A: R.Tensor(["outside_var * 2", "outside_var"])): return A - outside_var = tvm.tir.Var("outside_var", "int64") + outside_var = tvm.tirx.Var("outside_var", "int64") after = before.bind_symbolic_vars({"M": outside_var * 2, "N": outside_var}) tvm.ir.assert_structural_equal(expected, after) diff --git a/tests/python/relax/test_blockbuilder_core.py b/tests/python/relax/test_blockbuilder_core.py index 6bfdf8ad278f..670822ad1877 100644 --- a/tests/python/relax/test_blockbuilder_core.py +++ b/tests/python/relax/test_blockbuilder_core.py @@ -24,13 +24,13 @@ import tvm.contrib.cblas import tvm.testing from tvm import relax as rx -from tvm import te, tir, topi +from tvm import te, tirx, topi from tvm.ir.base import assert_structural_equal from tvm.relax import ExternFunc from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T -from tvm.tir.function import PrimFunc +from tvm.script import tirx as T +from tvm.tirx.function import PrimFunc @pytest.fixture(scope="module") @@ -41,8 +41,8 @@ def nop(): def test_block_builder(): - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) y = rx.Var("y", rx.TensorStructInfo([n], "float16")) bb = rx.BlockBuilder() @@ -66,8 +66,8 @@ def test_block_builder(): def test_emit_with_name(): - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) y = rx.Var("y", rx.TensorStructInfo([n], "float16")) bb = rx.BlockBuilder() @@ -82,8 +82,8 @@ def test_emit_with_name(): def test_function_single_block(): - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) y = rx.Var("y", rx.TensorStructInfo([n], "float16")) bb = rx.BlockBuilder() @@ -108,8 +108,8 @@ def test_function_single_block(): def test_function_multi_blocks(): - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) y = rx.Var("y", rx.TensorStructInfo([n], "float16")) bb = rx.BlockBuilder() @@ -143,8 +143,8 @@ def test_function_multi_blocks(): def test_multi_functions(): bb = rx.BlockBuilder() - m_1 = tir.Var("m", "int64") - n_1 = tir.Var("n", "int64") + m_1 = tirx.Var("m", "int64") + n_1 = tirx.Var("n", "int64") x_1 = rx.Var("x", rx.TensorStructInfo([m_1, n_1], "float16")) y_1 = rx.Var("y", rx.TensorStructInfo([n_1], "float16")) @@ -155,8 +155,8 @@ def test_multi_functions(): gv0 = bb.emit_output(lv0) bb.emit_func_output(gv0) - m_2 = tir.Var("m", "int64") - n_2 = tir.Var("n", "int64") + m_2 = tirx.Var("m", "int64") + n_2 = tirx.Var("n", "int64") x_2 = rx.Var("x", rx.TensorStructInfo([m_2, n_2], "float16")) y_2 = rx.Var("y", rx.TensorStructInfo([n_2], "float16")) @@ -180,9 +180,9 @@ def test_multi_functions(): def test_binary_shape_type_deduction(): - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") - k = tir.Var("k", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") + k = tirx.Var("k", "int64") x = rx.Var("x", rx.TensorStructInfo([m, 1], "float16")) y = rx.Var("y", rx.TensorStructInfo([n], "float16")) z = rx.Var("z", rx.TensorStructInfo([5], "float16")) @@ -216,8 +216,8 @@ def test_binary_shape_type_deduction(): def test_emit_match_cast(): - m = tir.Var("m", dtype="int64") - n = tir.Var("n", dtype="int64") + m = tirx.Var("m", dtype="int64") + n = tirx.Var("n", dtype="int64") x = rx.Var("tensor_value", rx.TensorStructInfo(dtype="float32", ndim=-1)) y = rx.Var("shape_value", rx.ShapeStructInfo([16, 8])) bb = rx.BlockBuilder() @@ -256,7 +256,7 @@ def test_emit_match_cast_binding_in_dataflow_block(): bb = rx.BlockBuilder() x = rx.Var("x", rx.TensorStructInfo(dtype="float32", ndim=-1)) - m = tir.Var("m", dtype="int64") + m = tirx.Var("m", dtype="int64") gv = rx.Var("gv", rx.TensorStructInfo(dtype="float32", ndim=-1)) match_cast = rx.MatchCast(gv, x, rx.TensorStructInfo((m,), "float32")) @@ -278,8 +278,8 @@ def test_emit_match_cast_binding_in_dataflow_block(): def test_normalize(): - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) y = rx.Var("y", rx.TensorStructInfo([n], "float16")) @@ -314,8 +314,8 @@ def test_normalize(): def test_tuple_indexing(): - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") shape_x = rx.TensorStructInfo([m, n], "float16") shape_y = rx.TensorStructInfo([n], "float16") @@ -346,7 +346,7 @@ def test_tuple_indexing(): def test_call_te(): bb = rx.BlockBuilder() - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) z = rx.Var("z", rx.TensorStructInfo([n, m], "float32")) @@ -405,7 +405,7 @@ def test_call_te_with_unsupported_shape_arg(): def test_emit_te(): bb = rx.BlockBuilder() - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) z = rx.Var("z", rx.TensorStructInfo([n, m], "float32")) @@ -453,7 +453,7 @@ def get_tir_func(): def test_emit_te_multiple(): bb = rx.BlockBuilder() - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) y = rx.Var("y", rx.TensorStructInfo([n, m], "float32")) z = rx.Var("z", rx.TensorStructInfo([128, m], "float32")) @@ -485,7 +485,7 @@ def te_func(A): def test_emit_te_multiple_output(): bb = rx.BlockBuilder() - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) def te_func(A): @@ -499,7 +499,7 @@ def te_func(A): rx_func = bb.finalize()["rx_func"] - # check call tir output shape is a Tuple of ShapeExpr + # check call tirx output shape is a Tuple of ShapeExpr assert rx_func.params[0] == x call_node = rx_func.body.blocks[0].bindings[0].value assert call_node.args[0].name_hint == "te_func" @@ -511,7 +511,7 @@ def te_func(A): def test_emit_te_extern(): bb = rx.BlockBuilder() - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) y = rx.Var("y", rx.TensorStructInfo([m, n], "float32")) @@ -538,7 +538,7 @@ def test_emit_te_extern(): def test_emit_te_prim_value(): bb = rx.BlockBuilder() - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = rx.Var("x", R.Tensor([n, m], "float32")) a_min = rx.PrimValue(0) a_max = rx.PrimValue(6) @@ -559,8 +559,8 @@ def test_emit_te_prim_value(): def test_nested_function_fail(): - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) y = rx.Var("y", rx.TensorStructInfo([n], "float16")) bb = rx.BlockBuilder() @@ -574,8 +574,8 @@ def test_nested_function_fail(): def test_emit_func_output_twice_fail(): - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) y = rx.Var("y", rx.TensorStructInfo([n], "float16")) bb = rx.BlockBuilder() @@ -588,8 +588,8 @@ def test_emit_func_output_twice_fail(): def test_func_params_twice_fail(): - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) y = rx.Var("y", rx.TensorStructInfo([n], "float16")) bb = rx.BlockBuilder() @@ -601,8 +601,8 @@ def test_func_params_twice_fail(): def test_no_func_params_fail(): - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x = rx.Var("x", rx.TensorStructInfo([m, n], "float16")) y = rx.Var("y", rx.TensorStructInfo([n], "float16")) bb = rx.BlockBuilder() @@ -616,7 +616,7 @@ def test_no_func_params_fail(): def test_block_builder_scope_recovery(): bb = rx.BlockBuilder() - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = rx.Var("x", rx.TensorStructInfo([n, m], "float32")) y = rx.Var("y", rx.TensorStructInfo([m, n], "float32")) @@ -641,8 +641,8 @@ def test_emit_nested_tuple(emit_nested_tuple): def make_function(emit_nested_tuple: bool): bb = rx.BlockBuilder() - n_sym = tir.Var("n", "int64") - m_sym = tir.Var("m", "int64") + n_sym = tirx.Var("n", "int64") + m_sym = tirx.Var("m", "int64") n = rx.Var("n", rx.PrimStructInfo(value=n_sym)) m = rx.Var("m", rx.PrimStructInfo(value=m_sym)) x = rx.Var("x", rx.TensorStructInfo([n_sym, m_sym], "float32")) @@ -691,14 +691,14 @@ def func( @pytest.mark.skip_well_formed_check_before_transform def test_finalize_public_private_name_conflict(): - # tir call + # tirx call bb = rx.BlockBuilder() def te_zero(): - return topi.full((), "int64", tir.IntImm("int64", 0)) + return topi.full((), "int64", tirx.IntImm("int64", 0)) def te_one(): - return topi.full((), "int64", tir.IntImm("int64", 1)) + return topi.full((), "int64", tirx.IntImm("int64", 1)) with bb.function("func", []): gv0 = bb.emit_te(te_zero, primfunc_name_hint="func") diff --git a/tests/python/relax/test_blockbuilder_emit_te.py b/tests/python/relax/test_blockbuilder_emit_te.py index a871f77c5349..62eb08e4b722 100644 --- a/tests/python/relax/test_blockbuilder_emit_te.py +++ b/tests/python/relax/test_blockbuilder_emit_te.py @@ -19,16 +19,16 @@ # The tests here depend on tvmscript import tvm from tvm import relax as rx -from tvm import te, tir +from tvm import te, tirx from tvm.ir.base import assert_structural_equal from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def test_emit_te_with_symbolic_arg(): bb = rx.BlockBuilder() - m = tir.Var("m", "int64") + m = tirx.Var("m", "int64") x = rx.Var("x", R.Tensor([10], "float32")) y = rx.Var("y", R.Shape([m])) @@ -49,7 +49,7 @@ def te_func( B: T.Buffer((T.int64(10),), "float32"), m: T.int64, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in range(T.int64(10)): with T.sblock("B"): v_i = T.axis.spatial(T.int64(10), i) @@ -82,7 +82,7 @@ def te_slice(tensor, i): def from_builder(): bb = rx.BlockBuilder() A = rx.Var("A", R.Tensor([16, 16], "float32")) - tir_i = tvm.tir.Var("tir_i", "int64") + tir_i = tvm.tirx.Var("tir_i", "int64") relax_i = rx.Var("relax_i", R.Prim(value=tir_i)) with bb.function("main", params=[A, relax_i]): @@ -99,7 +99,7 @@ def te_slice( Output: T.Buffer(T.int64(16), "float32"), row_index: T.int64, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in range(A.shape[1]): with T.sblock("slice"): diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index b736f790fd33..4cd9d4a029be 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -73,7 +73,7 @@ def get_result_with_relax_cublas_offload(mod, np_inputs, cuda_graph=False, bind_ def _to_concrete_shape(symbolic_shape, var_table): result = [] for dim in symbolic_shape: - if not isinstance(dim, tvm.tir.expr.Var): + if not isinstance(dim, tvm.tirx.expr.Var): result.append(dim) continue @@ -85,8 +85,8 @@ def _to_concrete_shape(symbolic_shape, var_table): _vars = { - "a": tvm.tir.expr.Var("a", "int64"), - "b": tvm.tir.expr.Var("b", "int64"), + "a": tvm.tirx.expr.Var("a", "int64"), + "b": tvm.tirx.expr.Var("b", "int64"), } diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index cab1b7dfc5b7..3009c62905f2 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -32,7 +32,7 @@ ) from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import relax as relax_builder @@ -193,7 +193,7 @@ def _to_concrete_shape(symbolic_shape, var_table=None): result.append(_to_concrete_shape(dim, var_table)) continue - if not isinstance(dim, tvm.tir.expr.Var): + if not isinstance(dim, tvm.tirx.expr.Var): result.append(dim) continue @@ -205,8 +205,8 @@ def _to_concrete_shape(symbolic_shape, var_table=None): _vars = { - "a": tvm.tir.expr.Var("a", "int64"), - "b": tvm.tir.expr.Var("b", "int64"), + "a": tvm.tirx.expr.Var("a", "int64"), + "b": tvm.tirx.expr.Var("b", "int64"), } @@ -884,7 +884,7 @@ def get_relax_attention_rewrite_module( ): from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import relax as relax_builder - from tvm.script.ir_builder import tir as T + from tvm.script.ir_builder import tirx as T with IRBuilder() as builder: with relax_builder.function(): @@ -1245,7 +1245,7 @@ def split_transform_deploy_mod(mod): if "transform_params" in gv.name_hint: transform_func_name = gv.name_hint mod_transform[gv] = func - elif isinstance(func, tvm.tir.PrimFunc): + elif isinstance(func, tvm.tirx.PrimFunc): mod_transform[gv] = func else: mod_deploy[gv] = func @@ -1263,7 +1263,7 @@ def decode( B: T.Buffer((T.int64(128),), "float16"), decode_1: T.Buffer((T.int64(64), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i, j in T.grid(T.int64(64), T.int64(128)): with T.sblock("decode"): @@ -1296,7 +1296,7 @@ def encode( w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"), compute: T.Buffer((T.int64(128),), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): max_abs_value = T.sblock_alloc_buffer((T.int64(128),), "float16") scale = T.sblock_alloc_buffer((T.int64(128),)) @@ -1520,7 +1520,7 @@ def decode( B: T.Buffer((T.int64(64),), "float16"), decode_1: T.Buffer((T.int64(64), T.int64(64)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i, j in T.grid(T.int64(64), T.int64(64)): with T.sblock("decode"): @@ -1535,7 +1535,7 @@ def encode( w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"), compute: T.Buffer((T.int64(64),), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): max_abs_value = T.sblock_alloc_buffer((T.int64(64),), "float16") scale = T.sblock_alloc_buffer((T.int64(64),)) @@ -1666,7 +1666,7 @@ def rms_norm( B: T.Buffer((T.int64(4096),), "float16"), rms_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): Ared_temp = T.sblock_alloc_buffer((T.int64(1), T.int64(1))) for bsz, i, k in T.grid(T.int64(1), T.int64(1), T.int64(4096)): @@ -1799,7 +1799,7 @@ def decode( B: T.Buffer((T.int64(64),), "float16"), decode_1: T.Buffer((T.int64(64), T.int64(64)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i, j in T.grid(T.int64(64), T.int64(64)): with T.sblock("decode"): @@ -1814,7 +1814,7 @@ def encode( w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"), compute: T.Buffer((T.int64(64),), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): max_abs_value = T.sblock_alloc_buffer((T.int64(64),), "float16") scale = T.sblock_alloc_buffer((T.int64(64),)) @@ -1932,7 +1932,7 @@ def decode( B: T.Buffer((T.int64(2), T.int64(128)), "float16"), decode_1: T.Buffer((T.int64(128), T.int64(128)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): with T.sblock("decode"): v_i, v_j = T.axis.remap("SS", [i, j]) @@ -1952,7 +1952,7 @@ def encode( "float16", ), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) max_abs_value = T.sblock_alloc_buffer( ( T.int64(2), diff --git a/tests/python/relax/test_codegen_hipblas.py b/tests/python/relax/test_codegen_hipblas.py index 286acc44f1f1..c6dd7222bef9 100644 --- a/tests/python/relax/test_codegen_hipblas.py +++ b/tests/python/relax/test_codegen_hipblas.py @@ -59,7 +59,7 @@ def get_result_with_relax_cublas_offload(mod, np_inputs): def _to_concrete_shape(symbolic_shape, var_table): result = [] for dim in symbolic_shape: - if not isinstance(dim, tvm.tir.expr.Var): + if not isinstance(dim, tvm.tirx.expr.Var): result.append(dim) continue @@ -71,8 +71,8 @@ def _to_concrete_shape(symbolic_shape, var_table): _vars = { - "a": tvm.tir.expr.Var("a", "int64"), - "b": tvm.tir.expr.Var("b", "int64"), + "a": tvm.tirx.expr.Var("a", "int64"), + "b": tvm.tirx.expr.Var("b", "int64"), } diff --git a/tests/python/relax/test_contrib_vllm.py b/tests/python/relax/test_contrib_vllm.py index d5b500248221..6ec59072e8a1 100644 --- a/tests/python/relax/test_contrib_vllm.py +++ b/tests/python/relax/test_contrib_vllm.py @@ -23,7 +23,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T has_vllm = tvm.get_global_func("tvm.contrib.vllm.single_query_cached_kv_attention", True) diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index 816d48597c98..7bbdcac75f8b 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -30,7 +30,7 @@ from tvm.relax.transform import DataflowUseInplaceCalls from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def test_liveness_analysis(): @@ -380,7 +380,7 @@ def expected_add( A: T.Buffer((T.int64(2), T.int64(3)), "float32"), B: T.Buffer((T.int64(2), T.int64(3)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -397,7 +397,7 @@ def expected_add( @T.prim_func(private=True) def expected_silu(A: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) compute = T.sblock_alloc_buffer((T.int64(2), T.int64(3))) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("compute"): @@ -448,7 +448,7 @@ def add_inplace( A: T.Buffer((T.int64(2), T.int64(3)), "float32"), B: T.Buffer((T.int64(1), T.int64(3)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -461,7 +461,7 @@ def multiply_inplace( A: T.Buffer((T.int64(2), T.int64(3)), "float32"), B: T.Buffer((T.int64(1), T.int64(3)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -474,7 +474,7 @@ def subtract_inplace( A: T.Buffer((T.int64(1), T.int64(3)), "float32"), B: T.Buffer((T.int64(1), T.int64(3)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(1), T.int64(3)): with T.sblock("T_subtract"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -563,7 +563,7 @@ def main( class Expected: @T.prim_func(private=True) def add_inplace(var_A: T.handle, var_B: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a, b = T.int64(), T.int64() A = T.match_buffer(var_A, (a, b)) B = T.match_buffer(var_B, (a, b)) @@ -576,7 +576,7 @@ def add_inplace(var_A: T.handle, var_B: T.handle): @T.prim_func(private=True) def subtract_inplace(var_A: T.handle, var_B: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a, b = T.int64(), T.int64() A = T.match_buffer(var_A, (a, b)) B = T.match_buffer(var_B, (a, b)) diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 11d3094ef412..6d797969af8d 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -23,11 +23,11 @@ import tvm.testing from tvm import relax as rx -from tvm import tir +from tvm import tirx from tvm.relax.analysis import get_var2val from tvm.relax.dpl import * from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.script.ir_module @@ -242,7 +242,7 @@ def test_shape_pattern(): tvm.ir.structural_equal(pattern.shape, shape) assert pattern.match(bindings[0].var) assert wildcard().has_shape([32, 32]).match(bindings[0].var) - n, m = tir.Var("n", dtype="int64"), tir.Var("m", dtype="int64") + n, m = tirx.Var("n", dtype="int64"), tirx.Var("m", dtype="int64") symsh_var = rx.Var("x", R.Tensor([n, m, n + m], "float32")) assert wildcard().has_shape([n, m, n + m]).match(symsh_var) assert wildcard().has_shape([n, m, m + n]).match(symsh_var) # + is commutative. @@ -261,7 +261,7 @@ def test_prim_arr_pattern(): assert pattern[1] == 32 assert isinstance(pattern, PrimArrPattern) assert pattern.match(rx.get_shape_of(bindings[0].var)) - n, m = tir.Var("n", dtype="int64"), tir.Var("m", dtype="int64") + n, m = tirx.Var("n", dtype="int64"), tirx.Var("m", dtype="int64") symbolic_shape = rx.ShapeExpr([n, m, n + m]) assert is_shape([n, m, n + m]).match(symbolic_shape) assert not is_shape([n, m, n * m]).match(symbolic_shape) @@ -1629,9 +1629,9 @@ def rewriter(expr, matches): size = arg.struct_info.shape[0] if ( - isinstance(size, tir.IntImm) - and isinstance(begin, tir.IntImm) - and isinstance(end, tir.IntImm) + isinstance(size, tirx.IntImm) + and isinstance(begin, tirx.IntImm) + and isinstance(end, tirx.IntImm) ): size = size.value begin = begin.value @@ -1891,8 +1891,8 @@ def test_wildcard_struct_info_with_symbolic_vars(): broadcasted `R.add`. """ - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") pat_lhs = wildcard().has_struct_info(R.Tensor([m, n])) pat_rhs = wildcard().has_struct_info(R.Tensor([m, n])) diff --git a/tests/python/relax/test_dataflow_rewriter.py b/tests/python/relax/test_dataflow_rewriter.py index e591ee4dc422..9e1578d70b0f 100644 --- a/tests/python/relax/test_dataflow_rewriter.py +++ b/tests/python/relax/test_dataflow_rewriter.py @@ -22,7 +22,7 @@ import tvm.testing from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_rewrite_defined_by_ir_module(): diff --git a/tests/python/relax/test_dlpack_integration.py b/tests/python/relax/test_dlpack_integration.py index 4c038678b5cb..50181f72c26c 100644 --- a/tests/python/relax/test_dlpack_integration.py +++ b/tests/python/relax/test_dlpack_integration.py @@ -31,10 +31,10 @@ import torch import tvm -from tvm import relax, tir +from tvm import relax, tirx from tvm.relax import BasePyModule from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T class TestDLPackIntegration: diff --git a/tests/python/relax/test_e2e_op_dynamic.py b/tests/python/relax/test_e2e_op_dynamic.py index a3637ad7ac00..e728ce355c0c 100644 --- a/tests/python/relax/test_e2e_op_dynamic.py +++ b/tests/python/relax/test_e2e_op_dynamic.py @@ -24,9 +24,9 @@ from tvm import relax from tvm.relax.transform import LegalizeOps from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T -# TODO(tvm-team): `tir.transform.DefaultGPUSchedule` does not work. +# TODO(tvm-team): `tirx.transform.DefaultGPUSchedule` does not work. target, dev = "llvm", tvm.cpu() diff --git a/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py index ecc96a6cbf4f..904d8704b185 100644 --- a/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py +++ b/tests/python/relax/test_eliminate_pad_branch_using_buffer_assumption.py @@ -24,7 +24,7 @@ import tvm.script import tvm.testing from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.script.ir_module @@ -49,7 +49,7 @@ def add( "op_attrs": {"lhs_axis": 0, "op_name": "qnn.add", "rhs_axis": 0}, "op_pattern": 0, "operator_name": "add", - "tir.noalias": True, + "tirx.noalias": True, } ) # with T.sblock("root"): @@ -146,7 +146,7 @@ def add( "op_attrs": {"lhs_axis": 0, "op_name": "qnn.add", "rhs_axis": 0}, "op_pattern": 0, "operator_name": "add", - "tir.noalias": True, + "tirx.noalias": True, } ) # with T.sblock("root"): @@ -248,7 +248,7 @@ def sub( "op_attrs": {"lhs_axis": 0, "op_name": "qnn.subtract", "rhs_axis": 0}, "op_pattern": 0, "operator_name": "sub", - "tir.noalias": True, + "tirx.noalias": True, } ) # with T.sblock("root"): @@ -345,7 +345,7 @@ def sub( "op_attrs": {"lhs_axis": 0, "op_name": "qnn.subtract", "rhs_axis": 0}, "op_pattern": 0, "operator_name": "sub", - "tir.noalias": True, + "tirx.noalias": True, } ) # with T.sblock("root"): @@ -447,7 +447,7 @@ def mul( "op_attrs": {"lhs_axis": 0, "op_name": "qnn.mul", "rhs_axis": 0}, "op_pattern": 0, "operator_name": "mul", - "tir.noalias": True, + "tirx.noalias": True, } ) # with T.sblock("root"): @@ -544,7 +544,7 @@ def mul( "op_attrs": {"lhs_axis": 0, "op_name": "qnn.mul", "rhs_axis": 0}, "op_pattern": 0, "operator_name": "mul", - "tir.noalias": True, + "tirx.noalias": True, } ) # with T.sblock("root"): diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py index 775e2a77daa4..b9f12c2f4a2b 100644 --- a/tests/python/relax/test_expr.py +++ b/tests/python/relax/test_expr.py @@ -20,7 +20,7 @@ import tvm from tvm import relax as rx -from tvm import tir +from tvm import tirx from tvm.relax.expr import make_shape from tvm.script import relax as R @@ -106,8 +106,8 @@ def test_tuple_sinfo_requires_fields_with_known_sinfo(): def test_match_cast() -> None: # match_cast([16, 8], [m, n]) - m = tir.Var("m", dtype="int64") - n = tir.Var("n", dtype="int64") + m = tirx.Var("m", dtype="int64") + n = tirx.Var("n", dtype="int64") shape = rx.const([16, 8], "int32") var = rx.Var("v0", R.Shape()) b0 = rx.MatchCast(var, shape, R.Tensor([m, n], "int32")) @@ -129,8 +129,8 @@ def test_match_cast() -> None: def test_match_cast() -> None: - m = tir.Var("m", dtype="int64") - n = tir.Var("n", dtype="int64") + m = tirx.Var("m", dtype="int64") + n = tirx.Var("n", dtype="int64") ivalue = rx.Var("input_value") sinfo = rx.TensorStructInfo([n, m], "float32") b0 = rx.MatchCast(rx.Var("v"), ivalue, sinfo) @@ -148,8 +148,8 @@ def test_var_binding() -> None: def test_binding_block() -> None: - m = tir.Var("m", dtype="int64") - n = tir.Var("n", dtype="int64") + m = tirx.Var("m", dtype="int64") + n = tirx.Var("n", dtype="int64") shape = rx.const([16, 8], "int32") b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32")) @@ -163,8 +163,8 @@ def test_binding_block() -> None: def test_dataflow_block() -> None: - m = tir.Var("m", dtype="int64") - n = tir.Var("n", dtype="int64") + m = tirx.Var("m", dtype="int64") + n = tirx.Var("n", dtype="int64") shape = rx.const([16, 8], "int32") b0 = rx.MatchCast(rx.Var("v0"), shape, R.Tensor([m, n], "int32")) @@ -211,8 +211,8 @@ def test_shape_of(): def test_shape_expr(): - m = tir.Var("m", dtype="int64") - n = tir.Var("n", dtype="int64") + m = tirx.Var("m", dtype="int64") + n = tirx.Var("n", dtype="int64") s = rx.ShapeExpr([m, n]) assert s.values[0] == m assert s.values[1] == n @@ -238,7 +238,7 @@ def test_shape_expr(): assert x.struct_info.shape[1] == 20 tvm.ir.assert_structural_equal(x.struct_info.shape.struct_info, R.Shape((10, 20))) - m = tir.Var("m", "int32") + m = tirx.Var("m", "int32") with pytest.raises( tvm.TVMError, match="the value in ShapeStructInfo can only have dtype of int64" ): @@ -246,14 +246,14 @@ def test_shape_expr(): def test_prim_value(): - pv = rx.PrimValue(tir.IntImm("int64", 1)) + pv = rx.PrimValue(tirx.IntImm("int64", 1)) assert pv.value.value == 1 - _check_equal(pv, rx.PrimValue(tir.IntImm("int64", 1))) + _check_equal(pv, rx.PrimValue(tirx.IntImm("int64", 1))) _check_json_roundtrip(pv) def test_prim_value_with_var(): - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") pv = rx.PrimValue(n) assert pv.value.same_as(n) tvm.ir.assert_structural_equal(pv.struct_info, rx.PrimStructInfo(value=n)) @@ -262,7 +262,7 @@ def test_prim_value_with_var(): def test_prim_value_with_expr(): - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") pv = rx.PrimValue(n + 1) tvm.ir.assert_structural_equal(pv.struct_info, rx.PrimStructInfo(value=n + 1)) _check_equal(pv, rx.PrimValue(n + 1)) diff --git a/tests/python/relax/test_expr_functor.py b/tests/python/relax/test_expr_functor.py index cf6da59e70c1..61b8696ee0bd 100644 --- a/tests/python/relax/test_expr_functor.py +++ b/tests/python/relax/test_expr_functor.py @@ -19,7 +19,7 @@ import tvm import tvm.testing -from tvm import relax, tir +from tvm import relax, tirx from tvm.ir import Op from tvm.ir.base import assert_structural_equal from tvm.relax import PyExprMutator, PyExprVisitor @@ -47,7 +47,7 @@ ) from tvm.script import relax as R -m, n = tir.Var("m", "int64"), tir.Var("n", "int64") +m, n = tirx.Var("m", "int64"), tirx.Var("n", "int64") x = relax.Var("x", R.Tensor([n], "float32")) y = relax.Var("y", R.Tensor([m, n], "float32")) bb = relax.BlockBuilder() diff --git a/tests/python/relax/test_frontend_common.py b/tests/python/relax/test_frontend_common.py index 2813d4019bfa..b3ea93a7aae5 100644 --- a/tests/python/relax/test_frontend_common.py +++ b/tests/python/relax/test_frontend_common.py @@ -20,7 +20,7 @@ from tvm.relax.frontend import detach_params from tvm.relax.frontend.common import autopad from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.script.parser import relax as R @@ -73,7 +73,7 @@ def pad( x: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(4)), "float32"), PadInput: T.Buffer((T.int64(1), T.int64(1), T.int64(5), T.int64(5)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(5), T.int64(5)): with T.sblock("PadInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -113,7 +113,7 @@ def replicate_pad( (T.int64(1), T.int64(1), T.int64(5), T.int64(5)), "float32" ), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(5), T.int64(5)): with T.sblock("ReplicatePadInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -174,7 +174,7 @@ def mirror_pad( (T.int64(1), T.int64(1), T.int64(5), T.int64(5)), "float32" ), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(5), T.int64(5)): with T.sblock("MirrorPadInput"): v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py index b8ba7c5883a2..936d4d1a5fd1 100644 --- a/tests/python/relax/test_frontend_dynamo.py +++ b/tests/python/relax/test_frontend_dynamo.py @@ -26,12 +26,12 @@ import tvm import tvm.testing -from tvm import relax, tir +from tvm import relax, tirx from tvm.relax.frontend.torch import relax_dynamo from tvm.s_tir import meta_schedule as ms from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T torch_version = torch.__version__ @@ -58,7 +58,7 @@ def main( compute: T.Buffer((T.int64(10), T.int64(10)), "float32"), ): # function attr dict - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) # body # with T.sblock("root") matmul = T.sblock_alloc_buffer([T.int64(10), T.int64(10)], dtype="float32") diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 63b771cef0d7..a9cea19fdcf0 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -30,7 +30,7 @@ from tvm.relax.frontend.torch import from_exported_program from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def verify_model( diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 91ec5223f9cd..102b8a0f420f 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -33,7 +33,7 @@ from tvm.relax.frontend.torch import from_fx from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def verify_model(torch_model, input_info, binding, expected): diff --git a/tests/python/relax/test_frontend_nn_debug.py b/tests/python/relax/test_frontend_nn_debug.py index f3ead2e9c011..c6681772827c 100644 --- a/tests/python/relax/test_frontend_nn_debug.py +++ b/tests/python/relax/test_frontend_nn_debug.py @@ -19,7 +19,7 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.relax.frontend import nn from tvm.relax.frontend.nn import op, spec from tvm.runtime import Tensor @@ -60,7 +60,7 @@ def _debug( # pylint: disable=too-many-arguments assert var_int == 8 class Layer(nn.Module): - def forward(self, x: nn.Tensor, v: tir.Var): # pylint: disable=invalid-name + def forward(self, x: nn.Tensor, v: tirx.Var): # pylint: disable=invalid-name op.debug_func("testing.relax.frontend.nn.test_debug_func", x, 1, 2.0, "test", v) return x diff --git a/tests/python/relax/test_frontend_nn_exporter.py b/tests/python/relax/test_frontend_nn_exporter.py index df2a4b1601ce..2b28d950a304 100644 --- a/tests/python/relax/test_frontend_nn_exporter.py +++ b/tests/python/relax/test_frontend_nn_exporter.py @@ -19,12 +19,12 @@ import tvm import tvm.testing -from tvm import relax, tir +from tvm import relax, tirx from tvm.ir import assert_structural_equal from tvm.relax.frontend import nn from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_simple(): @@ -121,7 +121,7 @@ def test_dynamic_shape(): slm_mod = nn.modules.ReLU() exported_mod, _ = slm_mod.export_tvm( - spec={"forward": {"x": nn.spec.Tensor([tir.Var("batch_size", "int64"), 8], "float32")}}, + spec={"forward": {"x": nn.spec.Tensor([tirx.Var("batch_size", "int64"), 8], "float32")}}, debug=False, ) @@ -152,8 +152,8 @@ def forward_silu(self, x: nn.Tensor): slm_mod = Before() exported_mod, _ = slm_mod.export_tvm( spec={ - "forward_relu": {"x": nn.spec.Tensor((tir.Var("batch_size", "int64"), 8), "float32")}, - "forward_silu": {"x": nn.spec.Tensor((tir.Var("batch_size", "int64"), 8), "float32")}, + "forward_relu": {"x": nn.spec.Tensor((tirx.Var("batch_size", "int64"), 8), "float32")}, + "forward_silu": {"x": nn.spec.Tensor((tirx.Var("batch_size", "int64"), 8), "float32")}, }, debug=False, ) @@ -221,7 +221,7 @@ def forward(self, x: nn.Tensor): exported_mod, _ = slm_mod.export_tvm( spec={ "forward": { - "x": nn.spec.Tensor((tir.Var("batch_size", "int64"), hidden_size), "float16") + "x": nn.spec.Tensor((tirx.Var("batch_size", "int64"), hidden_size), "float16") }, }, debug=False, @@ -345,7 +345,7 @@ def forward(self, x: nn.Tensor): exported_mod, _ = slm_mod.export_tvm( spec={ "forward": { - "x": nn.spec.Tensor((tir.Var("batch_size", "int64"), hidden_size), "float16") + "x": nn.spec.Tensor((tirx.Var("batch_size", "int64"), hidden_size), "float16") }, }, debug=False, @@ -533,22 +533,22 @@ def forward(self, state: nn.Tensor): args = ["hidden_size", "intermediate_size"] expected_num_symbolic_vars = 2 elif dynamic_type == "same_tir_var": - # Symbolic variables can be specified as tir.Var instances. + # Symbolic variables can be specified as tirx.Var instances. # Providing the same variable for the two different shape # parameters uses the symbolic variable in both locations. - dim = tir.Var("hidden_size", "int64") + dim = tirx.Var("hidden_size", "int64") args = [dim, dim] expected_num_symbolic_vars = 1 elif dynamic_type == "distinct_tir_vars_with_distinct_names": # Providing distinct TIR variables for the two different shape # parameters uses each TIR variable in the specified location. - args = [tir.Var("hidden_size", "int64"), tir.Var("intermediate_size", "int64")] + args = [tirx.Var("hidden_size", "int64"), tirx.Var("intermediate_size", "int64")] expected_num_symbolic_vars = 2 elif dynamic_type == "distinct_tir_vars_with_same_name": # TIR variable have reference equality. Even if two different # TIR variables have the same name, providing two distinct TIR # variables still results in two distinct symbolic variables. - args = [tir.Var("hidden_size", "int64"), tir.Var("hidden_size", "int64")] + args = [tirx.Var("hidden_size", "int64"), tirx.Var("hidden_size", "int64")] expected_num_symbolic_vars = 2 else: raise ValueError(f"Unexpected dynamic_type: {dynamic_type}") diff --git a/tests/python/relax/test_frontend_nn_extern_module.py b/tests/python/relax/test_frontend_nn_extern_module.py index d8c59adb3158..7f884f0dfb13 100644 --- a/tests/python/relax/test_frontend_nn_extern_module.py +++ b/tests/python/relax/test_frontend_nn_extern_module.py @@ -80,7 +80,7 @@ def _check_ir_equality(mod): # pylint: disable=import-outside-toplevel from tvm.script import ir as I from tvm.script import relax as R - from tvm.script import tir as T + from tvm.script import tirx as T # pylint: enable=import-outside-toplevel diff --git a/tests/python/relax/test_frontend_nn_jit.py b/tests/python/relax/test_frontend_nn_jit.py index 31328104c7a1..a70808b4ca46 100644 --- a/tests/python/relax/test_frontend_nn_jit.py +++ b/tests/python/relax/test_frontend_nn_jit.py @@ -20,7 +20,7 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.relax.frontend import nn from tvm.relax.frontend.nn import spec @@ -52,7 +52,7 @@ class Layer(nn.Module): def __init__(self): pass - def forward(self, x: nn.Tensor, i: tir.Var): + def forward(self, x: nn.Tensor, i: tirx.Var): y = nn.add(x, x) y = nn.reshape(y, (i, 5, 5)) return y @@ -74,7 +74,7 @@ class Layer(nn.Module): def __init__(self): self.cache = nn.KVCache(10, [10, 5]) - def forward(self, x: nn.Tensor, total_seq_len: tir.Var): + def forward(self, x: nn.Tensor, total_seq_len: tirx.Var): self.cache.append(x) y = self.cache.view(total_seq_len) return y diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index 3fb81ff08c0c..601c4891ead2 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -27,7 +27,7 @@ from tvm.relax.frontend.nn import core, modules, spec from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_relu(): @@ -307,7 +307,7 @@ def forward( R.output(gv1) return gv1 - mod = modules.Conv2D(tvm.tir.Var("in_channels", "int64"), 32, 3, bias=True) + mod = modules.Conv2D(tvm.tirx.Var("in_channels", "int64"), 32, 3, bias=True) tvm_mod, _ = mod.export_tvm( spec={ "forward": { diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 7675f46fc4e0..7d47ed7d4484 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -20,11 +20,11 @@ import tvm import tvm.testing -from tvm import relax, s_tir, tir +from tvm import relax, s_tir, tirx from tvm.relax.frontend.nn import Module, Tensor, op, spec from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T # mypy: disable-error-code="attr-defined,valid-type,name-defined" @@ -460,7 +460,7 @@ def test(self, x: Tensor): triu_out = op.triu(x) full_with_scalar_out = op.full([10, 10], fill_value=10) # type: ignore full_with_FloatImm_out = op.full( - [10, 10], fill_value=tir.FloatImm(dtype="float32", value=10) + [10, 10], fill_value=tirx.FloatImm(dtype="float32", value=10) ) full_with_Tensor_out = op.full( [10, 10], fill_value=Tensor.from_scalar(10, dtype="float32") @@ -593,7 +593,7 @@ def test(self, x: Tensor): class Expected: @T.prim_func(private=True) def add_one(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), T_add: T.Buffer((T.int64(10), T.int64(10)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(10), T.int64(10)): with T.sblock("T_add"): @@ -658,7 +658,7 @@ def fused_rope( # pylint: disable=too-many-locals T.evaluate(offset) class Model(Module): - def test(self, qkv: Tensor, offset: tir.Var): + def test(self, qkv: Tensor, offset: tirx.Var): tensor_expr_op_out = op.tensor_ir_op( fused_rope, "llama_fused_rope", @@ -725,7 +725,7 @@ def test_tensor_ir_inplace_op(): def inplace_take( var_weight: T.handle, var_pos: T.handle, var_embeddings: T.handle, offset: T.int64 ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) vocab_size = T.int64() weight = T.match_buffer(var_weight, (vocab_size, hidden_size), dtype) seq_len = T.int64() @@ -758,7 +758,7 @@ class Expected: def inplace_take( var_weight: T.handle, var_pos: T.handle, var_embeddings: T.handle, offset: T.int64 ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) vocab_size = T.int64() weight = T.match_buffer(var_weight, (vocab_size, hidden_size), dtype) seq_len = T.int64() @@ -1152,7 +1152,7 @@ def foo( class Expected: @T.prim_func(private=True) def filter_with_top_p_top_k(A: T.Buffer((T.int64(2), T.int64(3)), "float32"), B: T.Buffer((T.int64(2), T.int64(1)), "float32"), filter_with_top_p_top_k: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i, j in T.grid(T.int64(2), T.int64(3)): with T.sblock("filter_with_top_p_top_k"): diff --git a/tests/python/relax/test_frontend_nn_subroutines.py b/tests/python/relax/test_frontend_nn_subroutines.py index 6bbf57aeadde..9ea44781b8bc 100644 --- a/tests/python/relax/test_frontend_nn_subroutines.py +++ b/tests/python/relax/test_frontend_nn_subroutines.py @@ -90,7 +90,7 @@ def activation( return dataflow_output mod = Layer(64, 32) - batch_size = tvm.tir.Var("batch_size", "int64") + batch_size = tvm.tirx.Var("batch_size", "int64") tvm_mod, _ = mod.export_tvm( spec={"forward": {"input": nn.spec.Tensor((batch_size, 64), "float32")}}, debug=True ) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index ecbc6c9e8a5e..7ea80c1bbe84 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -36,7 +36,7 @@ from tvm.relax.frontend.onnx import from_onnx from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T bg = np.random.MT19937(0) rg = np.random.Generator(bg) diff --git a/tests/python/relax/test_frontend_stablehlo.py b/tests/python/relax/test_frontend_stablehlo.py index 2261b32a9c5e..5632421f90b5 100644 --- a/tests/python/relax/test_frontend_stablehlo.py +++ b/tests/python/relax/test_frontend_stablehlo.py @@ -29,7 +29,7 @@ from tvm.relax.frontend.stablehlo import from_stablehlo from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def generate_np_inputs( diff --git a/tests/python/relax/test_inline_functions.py b/tests/python/relax/test_inline_functions.py index 73d402e124d9..e4efd077d972 100644 --- a/tests/python/relax/test_inline_functions.py +++ b/tests/python/relax/test_inline_functions.py @@ -22,7 +22,7 @@ import tvm.testing from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @pytest.mark.parametrize("key_type", [tvm.ir.GlobalVar, str]) @@ -171,8 +171,8 @@ def test_subroutine_with_symbolic_vars(): """Inlined subroutines should use the caller's symbolic variables Before inlining, the subroutine and the caller have distinct - `tir::Var` for each symbolic variables. After inlining, only the - caller's `tir::Var` symbolic variables should remain. + `tirx::Var` for each symbolic variables. After inlining, only the + caller's `tirx::Var` symbolic variables should remain. """ @I.ir_module diff --git a/tests/python/relax/test_meta_schedule_relax_integration.py b/tests/python/relax/test_meta_schedule_relax_integration.py index 76af1473d948..72699a2c6fc5 100644 --- a/tests/python/relax/test_meta_schedule_relax_integration.py +++ b/tests/python/relax/test_meta_schedule_relax_integration.py @@ -23,7 +23,7 @@ from tvm.s_tir import meta_schedule as ms from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T # fmt: off @@ -59,7 +59,7 @@ def test_extracting_tasks(): relax_mod = relax.transform.FuseTIR()(relax_mod) relax_expectation = { - "structural": 2, # The relax constants do not reach the tir at the lowering. + "structural": 2, # The relax constants do not reach the tirx at the lowering. "ignore-tensor": 2, "anchor-block": 1, } diff --git a/tests/python/relax/test_op_binary.py b/tests/python/relax/test_op_binary.py index 90f71b7f388a..7049e6aaef87 100644 --- a/tests/python/relax/test_op_binary.py +++ b/tests/python/relax/test_op_binary.py @@ -21,7 +21,7 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op, VDevice from tvm.script import relax as R @@ -66,16 +66,16 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r (binary_arith_op, tir_arith_op) = tvm.testing.parameters( - (relax.op.add, tir.Add), - (relax.op.divide, tir.Div), - (relax.op.floor_divide, tir.FloorDiv), - (relax.op.multiply, tir.Mul), - (relax.op.power, tir.pow), - (relax.op.subtract, tir.Sub), - (relax.op.maximum, tir.Max), - (relax.op.minimum, tir.Min), - (relax.op.mod, tir.Mod), - (relax.op.floor_mod, tir.FloorMod), + (relax.op.add, tirx.Add), + (relax.op.divide, tirx.Div), + (relax.op.floor_divide, tirx.FloorDiv), + (relax.op.multiply, tirx.Mul), + (relax.op.power, tirx.pow), + (relax.op.subtract, tirx.Sub), + (relax.op.maximum, tirx.Max), + (relax.op.minimum, tirx.Min), + (relax.op.mod, tirx.Mod), + (relax.op.floor_mod, tirx.FloorMod), ) @@ -147,8 +147,8 @@ def test_infer_struct_info_binary_arith_known_prim_value_with_prim_value( ): bb = relax.BlockBuilder() - tir_x = tir.Var("tir_x", "float32") - tir_y = tir.Var("tir_y", "float32") + tir_x = tirx.Var("tir_x", "float32") + tir_y = tirx.Var("tir_y", "float32") x = relax.Var("x", R.Prim(value=tir_x)) y = relax.Var("y", R.Prim(value=tir_y)) @@ -158,12 +158,12 @@ def test_infer_struct_info_binary_arith_known_prim_value_with_prim_value( (binary_cmp_op, tir_cmp_op) = tvm.testing.parameters( - (relax.op.equal, tir.EQ), - (relax.op.greater, tir.GT), - (relax.op.greater_equal, tir.GE), - (relax.op.less, tir.LT), - (relax.op.less_equal, tir.LE), - (relax.op.not_equal, tir.NE), + (relax.op.equal, tirx.EQ), + (relax.op.greater, tirx.GT), + (relax.op.greater_equal, tirx.GE), + (relax.op.less, tirx.LT), + (relax.op.less_equal, tirx.LE), + (relax.op.not_equal, tirx.NE), ) @@ -205,8 +205,8 @@ def test_infer_struct_info_binary_cmp_known_prim_value_to_prim_value( ): bb = relax.BlockBuilder() - tir_x = tir.Var("tir_x", "float32") - tir_y = tir.Var("tir_y", "float32") + tir_x = tirx.Var("tir_x", "float32") + tir_y = tirx.Var("tir_y", "float32") x = relax.Var("x", R.Prim(value=tir_x)) y = relax.Var("y", R.Prim(value=tir_y)) @@ -217,9 +217,9 @@ def test_infer_struct_info_binary_cmp_known_prim_value_to_prim_value( def test_binary_infer_struct_info_shape_symbolic(binary_arith_op: Callable): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") - k = tir.Var("k", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") + k = tirx.Var("k", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((1, n), "float32")) x2 = relax.Var("x", R.Tensor((k, n, m), "float32")) diff --git a/tests/python/relax/test_op_ccl.py b/tests/python/relax/test_op_ccl.py index c114653d321e..71ee96b271e9 100644 --- a/tests/python/relax/test_op_ccl.py +++ b/tests/python/relax/test_op_ccl.py @@ -19,7 +19,7 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op from tvm.script import relax as R @@ -57,8 +57,8 @@ def test_allreduce_infer_struct_info(): def test_allreduce_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((4, n), "float32")) @@ -109,8 +109,8 @@ def test_allgather_infer_struct_info(): def test_allgather_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((4, n), "float32")) @@ -171,8 +171,8 @@ def test_broadcast_from_worker0_infer_struct_info(): def test_broadcast_from_worker0_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((4, n), "float32")) @@ -231,15 +231,15 @@ def test_scatter_from_worker0_infer_struct_info(): def test_scatter_from_worker0_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((4, n), "float32")) _check_inference( bb, relax.op.ccl.scatter_from_worker0(x0, 2), - relax.TensorStructInfo((tir.div(m, 2), n), "float32"), + relax.TensorStructInfo((tirx.div(m, 2), n), "float32"), ) _check_inference( bb, relax.op.ccl.scatter_from_worker0(x1, 2), relax.TensorStructInfo((2, n), "float32") diff --git a/tests/python/relax/test_op_create.py b/tests/python/relax/test_op_create.py index 7269dfdbcf47..81c5cd2143f1 100644 --- a/tests/python/relax/test_op_create.py +++ b/tests/python/relax/test_op_create.py @@ -18,10 +18,10 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op, VDevice from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_op_correctness(): @@ -131,7 +131,7 @@ def test_full_infer_struct_info(): def test_full_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") + a = tirx.Var("a", "int64") v = relax.Var("v", R.Tensor((), "float32")) s0 = relax.ShapeExpr((a, 3)) s1 = relax.Var("s", relax.ShapeStructInfo((a, 3))) @@ -206,7 +206,7 @@ def test_full_infer_struct_info_fill_value_not_scalar_tensor(): def test_full_shape_not_tuple(): - m = tir.Var("m", "int64") + m = tirx.Var("m", "int64") v = relax.Var("v", R.Tensor((), "float32")) with pytest.raises(TypeError): @@ -291,8 +291,8 @@ def test_full_like_infer_struct_info(): def test_full_like_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((m, n))) v = relax.Var("v", R.Tensor((), "float16")) @@ -416,8 +416,8 @@ def test_ones_zeros_infer_struct_info(): def test_ones_zeros_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") s0 = relax.ShapeExpr((m, n)) s1 = relax.Var("s", relax.ShapeStructInfo((m, n))) @@ -447,7 +447,7 @@ def test_ones_zeros_infer_struct_info_more_input_dtype(): def test_ones_zeros_shape_not_tuple(): - m = tir.Var("m", "int64") + m = tirx.Var("m", "int64") with pytest.raises(TypeError): relax.op.ones(10, "float32") @@ -502,8 +502,8 @@ def test_ones_like_zeros_like_infer_struct_info(): def test_ones_like_zeros_like_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((m, n))) @@ -557,9 +557,9 @@ def test_eye_infer_struct_info(): def test_eye_infer_struct_info_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - m = tir.Var("m", "int64") - k = tir.Var("k", "int64") + n = tirx.Var("n", "int64") + m = tirx.Var("m", "int64") + k = tirx.Var("k", "int64") _check_inference(bb, relax.op.eye(n), relax.TensorStructInfo((n, n), "float32")) _check_inference(bb, relax.op.eye(n, m), relax.TensorStructInfo((n, m), "float32")) @@ -583,10 +583,10 @@ def test_eye_like_infer_struct_info(): def test_eye_like_infer_struct_info_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - m = tir.Var("m", "int64") + n = tirx.Var("n", "int64") + m = tirx.Var("m", "int64") x = relax.Var("x", R.Tensor((n, m), "float32")) - k = tir.Var("k", "int64") + k = tirx.Var("k", "int64") _check_inference(bb, relax.op.eye_like(x), relax.TensorStructInfo((n, m), "float32")) _check_inference(bb, relax.op.eye_like(x, k=k), relax.TensorStructInfo((n, m), "float32")) @@ -619,9 +619,9 @@ def test_arange_infer_struct_info(): def test_arange_infer_struct_info_shape_var(): bb = relax.BlockBuilder() - start = tir.Var("start", "int64") - stop = tir.Var("stop", "int64") - step = tir.Var("step", "int64") + start = tirx.Var("start", "int64") + stop = tirx.Var("stop", "int64") + step = tirx.Var("step", "int64") _check_inference(bb, relax.op.arange(stop), relax.TensorStructInfo((stop,), "int64")) _check_inference(bb, relax.op.arange(1, stop), relax.TensorStructInfo((stop - 1,), "int64")) @@ -639,9 +639,9 @@ def test_arange_infer_struct_info_shape_var(): relax.TensorStructInfo(((stop + step - start - 1) // step,), "int64"), ) - start = tir.Var("start", "float32") - stop = tir.Var("stop", "float32") - step = tir.Var("step", "float32") + start = tirx.Var("start", "float32") + stop = tirx.Var("stop", "float32") + step = tirx.Var("step", "float32") _check_inference( bb, @@ -695,9 +695,9 @@ def test_tril_triu_infer_struct_info(): def test_tril_triu_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() vdev0 = VDevice("llvm") - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") x0 = relax.Var("x", R.Tensor((a, b, c), "float32")) x1 = relax.Var("x", R.Tensor((a, b, c))) x2 = relax.Var("x", R.Tensor((a, b, c), "float32", vdev0)) diff --git a/tests/python/relax/test_op_datatype.py b/tests/python/relax/test_op_datatype.py index cc11176bc036..bd7536b1c341 100644 --- a/tests/python/relax/test_op_datatype.py +++ b/tests/python/relax/test_op_datatype.py @@ -19,7 +19,7 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op from tvm.script import relax as R @@ -59,8 +59,8 @@ def test_astype_infer_struct_info(): def test_astype_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((m, n))) diff --git a/tests/python/relax/test_op_image.py b/tests/python/relax/test_op_image.py index 53b15b3ede69..43ebc7929892 100644 --- a/tests/python/relax/test_op_image.py +++ b/tests/python/relax/test_op_image.py @@ -18,7 +18,7 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op, VDevice from tvm.script import relax as R @@ -105,12 +105,12 @@ def test_resize2d_infer_struct_info(): def test_resize2d_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - ih = tir.Var("ih", "int64") - iw = tir.Var("iw", "int64") - oh = tir.Var("oh", "int64") - ow = tir.Var("ow", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + ih = tirx.Var("ih", "int64") + iw = tirx.Var("iw", "int64") + oh = tirx.Var("oh", "int64") + ow = tirx.Var("ow", "int64") x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) x1 = relax.Var("x", R.Tensor((n, c, ih, iw, 16), "float32")) diff --git a/tests/python/relax/test_op_index.py b/tests/python/relax/test_op_index.py index 7688cc9f4b98..21aa08945d42 100644 --- a/tests/python/relax/test_op_index.py +++ b/tests/python/relax/test_op_index.py @@ -20,11 +20,11 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op, VDevice from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_op_correctness(): @@ -218,11 +218,11 @@ def test_take_infer_struct_info_prim_value_index(): def test_take_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") - i = tir.Var("i", "int64") - j = tir.Var("j", "int64") - k = tir.Var("k", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") + i = tirx.Var("i", "int64") + j = tirx.Var("j", "int64") + k = tirx.Var("k", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((m, n))) y0 = relax.Var("y", R.Tensor((n,), "float32")) @@ -490,30 +490,30 @@ def test_strided_slice_infer_struct_info_shape_out_of_range(): def test_strided_slice_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((m, n))) _check_inference( bb, relax.op.strided_slice(x0, axes=[0], begin=[1], end=[3]), - relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m), n), "float32"), + relax.TensorStructInfo((tirx.min(3, m) - tirx.min(1, m), n), "float32"), ) _check_inference( bb, relax.op.strided_slice(x0, axes=[0], begin=[1], end=[8], strides=[3]), - relax.TensorStructInfo(((tir.min(8, m) + 2 - tir.min(1, m)) // 3, n), "float32"), + relax.TensorStructInfo(((tirx.min(8, m) + 2 - tirx.min(1, m)) // 3, n), "float32"), ) _check_inference( bb, relax.op.strided_slice(x1, axes=[0], begin=[1], end=[3]), - relax.TensorStructInfo((tir.min(3, m) - tir.min(1, m), n), dtype=""), + relax.TensorStructInfo((tirx.min(3, m) - tirx.min(1, m), n), dtype=""), ) _check_inference( bb, relax.op.strided_slice(x1, axes=[0], begin=[1], end=[8], strides=[3]), - relax.TensorStructInfo(((tir.min(8, m) + 2 - tir.min(1, m)) // 3, n), dtype=""), + relax.TensorStructInfo(((tirx.min(8, m) + 2 - tirx.min(1, m)) // 3, n), dtype=""), ) @@ -586,40 +586,40 @@ def test_strided_slice_infer_struct_info_more_input_dtype(): def test_strided_slice_infer_struct_info_symbolic_begin_end_strides(): bb = relax.BlockBuilder() - var = tir.Var("var", "int64") - size_var = tir.SizeVar("size_var", "int64") + var = tirx.Var("var", "int64") + size_var = tirx.SizeVar("size_var", "int64") x = relax.Var("x", R.Tensor((8, 9), "float32")) _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[var], end=[8]), relax.TensorStructInfo( - (tir.max(8 - tir.max(tir.if_then_else(var < 0, var + 8, var), 0), 0), 9), + (tirx.max(8 - tirx.max(tirx.if_then_else(var < 0, var + 8, var), 0), 0), 9), dtype="float32", ), ) _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[size_var], end=[8]), - relax.TensorStructInfo((tir.max(8 - size_var, 0), 9), dtype="float32"), + relax.TensorStructInfo((tirx.max(8 - size_var, 0), 9), dtype="float32"), ) _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[0], end=[var]), relax.TensorStructInfo( - (tir.min(tir.max(tir.if_then_else(var < 0, var + 8, var), 0), 8), 9), dtype="float32" + (tirx.min(tirx.max(tirx.if_then_else(var < 0, var + 8, var), 0), 8), 9), dtype="float32" ), ) _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[0], end=[size_var]), - relax.TensorStructInfo((tir.min(size_var, 8), 9), dtype="float32"), + relax.TensorStructInfo((tirx.min(size_var, 8), 9), dtype="float32"), ) _check_inference( bb, relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[var]), relax.TensorStructInfo( - [tir.if_then_else(var < 0, -8 // (0 - var) + 1, (var + 7) // var), 9], + [tirx.if_then_else(var < 0, -8 // (0 - var) + 1, (var + 7) // var), 9], dtype="float32", ), ) @@ -632,8 +632,8 @@ def test_strided_slice_infer_struct_info_symbolic_begin_end_strides(): def test_strided_slice_infer_struct_info_symbolic_begin_end_strides_inbound(): bb = relax.BlockBuilder() - var = tir.Var("var", "int64") - size_var = tir.SizeVar("size_var", "int64") + var = tirx.Var("var", "int64") + size_var = tirx.SizeVar("size_var", "int64") x = relax.Var("x", R.Tensor((8, 9), "float32")) _check_inference( @@ -673,8 +673,8 @@ def test_strided_slice_infer_struct_info_symbolic_begin_end_strides_inbound(): def test_strided_slice_infer_struct_info_no_axis(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") s0 = relax.Var("s", relax.ShapeStructInfo((m, n))) s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) s2 = relax.Var("s", relax.ShapeStructInfo()) @@ -861,10 +861,10 @@ def test_dynamic_strided_slice_infer_struct_info(): def test_dynamic_strided_slice_infer_struct_info_symbolic(): bb = relax.BlockBuilder() - i = tir.Var("i", "int64") - j = tir.Var("j", "int64") - k = tir.Var("k", "int64") - l = tir.Var("l", "int64") + i = tirx.Var("i", "int64") + j = tirx.Var("j", "int64") + k = tirx.Var("k", "int64") + l = tirx.Var("l", "int64") x0 = relax.Var("x", R.Tensor((i, j, k, l), "float32")) x1 = relax.Var("x", R.Tensor("float32", ndim=4)) x2 = relax.Var("x", R.Tensor("float32")) @@ -956,7 +956,7 @@ def test_dynamic_strided_slice_infer_struct_info_arg_wrong_dtype(): def test_dynamic_strided_slice_infer_struct_info_arg_wrong_shape_info(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) - m = tir.Var("m", "int64") + m = tirx.Var("m", "int64") # invalid arg b0 = relax.Var("begin", R.Tensor("int64", ndim=2)) b1 = relax.Var("begin", R.Tensor((1,), "int64")) @@ -1004,7 +1004,7 @@ def strided_slice( B: T.Buffer((T.int64(1), T.int64(16))), index: T.int64, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for iters in T.grid(*B.shape): with T.sblock("T_dynamic_strided_slice"): i, j = T.axis.remap("SS", iters) @@ -1031,7 +1031,7 @@ def main(A: R.Tensor((16, 16), "float32"), B: R.Shape(["index"])) -> R.Tensor((1 class expected: @T.prim_func(private=True) def strided_slice(A: T.Buffer((T.int64(16), T.int64(16)), "float32"), var_T_dynamic_strided_slice_with_axes: T.handle, index: T.int64): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) T_dynamic_strided_slice_with_axes = T.match_buffer(var_T_dynamic_strided_slice_with_axes, (T.max(T.int64(16) - T.max(T.if_then_else(index < T.int64(0), index + T.int64(16), index), T.int64(0)), T.int64(0)), T.int64(16))) # with T.sblock("root"): for ax0, ax1 in T.grid(T.max(T.int64(16) - T.max(T.if_then_else(index < T.int64(0), index + T.int64(16), index), T.int64(0)), T.int64(0)), T.int64(16)): diff --git a/tests/python/relax/test_op_linear_algebra.py b/tests/python/relax/test_op_linear_algebra.py index a08a27ff5f55..035220d21815 100644 --- a/tests/python/relax/test_op_linear_algebra.py +++ b/tests/python/relax/test_op_linear_algebra.py @@ -19,7 +19,7 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op, VDevice from tvm.script import relax as R @@ -80,14 +80,14 @@ def test_matmul_infer_struct_info(): def test_matmul_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") - k0 = tir.Var("k0", "int64") - k1 = tir.Var("k1", "int64") - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - b1 = tir.Var("b", "int64") - c = tir.Var("c", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") + k0 = tirx.Var("k0", "int64") + k1 = tirx.Var("k1", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + b1 = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") x0 = relax.Var("x", R.Tensor((m, k0), "float32")) x1 = relax.Var("x", R.Tensor((k0,), "float32")) x2 = relax.Var("x", R.Tensor((a, b, m, k0), "float32")) @@ -197,7 +197,7 @@ def test_matmul_infer_struct_info_not_broadcastable(): def test_matmul_infer_struct_info_unequal_reduction_length(): bb = relax.BlockBuilder() - k = tir.Var("k", "int64") + k = tirx.Var("k", "int64") x0 = relax.Var("x", R.Tensor((3, 4), "float32")) x1 = relax.Var("x", R.Tensor((3, k), "float32")) y0 = relax.Var("y", R.Tensor((6, 5), "float32")) @@ -314,9 +314,9 @@ def test_einsum_infer_struct_info(): def test_einsum_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") x = relax.Var("x", R.Tensor((a, b), "float32")) y = relax.Var("y", R.Tensor((b, c), "float32")) z = relax.Var("z", R.Tensor((a, a), "float32")) diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index a86a8127b1dc..537bc9c06c04 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -19,10 +19,10 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op, VDevice from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_op_correctness(): @@ -130,10 +130,10 @@ def test_reshape_infer_struct_info(): def test_reshape_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") - d = tir.Var("d", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") + d = tirx.Var("d", "int64") x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) s0 = relax.Var("s", R.Shape((c, a, d, b))) s1 = relax.Var("s", R.Shape()) @@ -155,7 +155,7 @@ def test_reshape_infer_struct_info_shape_symbolic(): _check_inference( bb, relax.op.reshape(x, (2, -1, a)), - relax.TensorStructInfo((2, tir.floordiv(b * c * d, 2), a), "float32"), + relax.TensorStructInfo((2, tirx.floordiv(b * c * d, 2), a), "float32"), ) _check_inference( bb, @@ -268,7 +268,7 @@ def test_reshape_infer_struct_info_inference_not_deducible(): def test_reshape_new_shape_not_tuple(): - m = tir.Var("m", "int64") + m = tirx.Var("m", "int64") x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) with pytest.raises(TypeError): @@ -391,10 +391,10 @@ def test_permute_dims_infer_struct_info(): def test_permute_dims_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") - d = tir.Var("d", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") + d = tirx.Var("d", "int64") x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) _check_inference( @@ -592,8 +592,8 @@ def test_expand_dims_infer_struct_info(): def test_expand_dims_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") x = relax.Var("x", R.Tensor((a, 4, b), "float32")) _check_inference( @@ -780,8 +780,8 @@ def test_layout_transform_infer_struct_info_unknown_shape(): def test_layout_transform_infer_struct_info_symbolic_shape(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") x0 = relax.Var("x", R.Tensor((a, b), "float32")) tiling_transform = lambda a, b: (a, b // 3, b % 3) @@ -820,8 +820,8 @@ def test_layout_transform_infer_struct_info_shape_var(): relax.TensorStructInfo(dtype="float32", ndim=3), ) - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") s_symbolic_shape = relax.Var("s", relax.ShapeStructInfo((a, b))) x_symbolic_shape = relax.Var("x", relax.TensorStructInfo(s_symbolic_shape, "float32")) _check_inference( @@ -875,8 +875,8 @@ def test_squeeze_infer_struct_info(): def test_squeeze_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") x0 = relax.Var("x", R.Tensor((a, 1, b), "float32")) x1 = relax.Var("x", R.Tensor((a, 1, b))) @@ -888,8 +888,8 @@ def test_squeeze_infer_struct_info_shape_symbolic(): def test_squeeze_infer_struct_info_shape_var(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") s0 = relax.Var("s", relax.ShapeStructInfo((2, 1, 3, 1, 1, 4))) s1 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) s2 = relax.Var("s", relax.ShapeStructInfo((a, 1, b))) @@ -988,7 +988,7 @@ def test_squeeze_infer_struct_info_repetitive_axes(): def test_squeeze_infer_struct_info_axis_length_not_one(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") + a = tirx.Var("a", "int64") s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) s1 = relax.Var("s", relax.ShapeStructInfo((a, 3, 4))) x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) @@ -1061,8 +1061,8 @@ def test_flatten_infer_struct_info(): def test_flatten_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") x0 = relax.Var("x", R.Tensor((a, b), "float32")) x1 = relax.Var("x", R.Tensor((a, b))) @@ -1255,12 +1255,12 @@ def test_concat_infer_struct_info_with_axis(): def test_concat_infer_struct_info_with_axis_shape_symbolic(): bb = relax.BlockBuilder() - a0 = tir.Var("a0", "int64") - a1 = tir.Var("a1", "int64") - b0 = tir.Var("b0", "int64") - b1 = tir.Var("b1", "int64") - b2 = tir.Var("b2", "int64") - c = tir.Var("c", "int64") + a0 = tirx.Var("a0", "int64") + a1 = tirx.Var("a1", "int64") + b0 = tirx.Var("b0", "int64") + b1 = tirx.Var("b1", "int64") + b2 = tirx.Var("b2", "int64") + c = tirx.Var("c", "int64") x0 = relax.Var("x", R.Tensor((a0, b0, c), "float32")) x1 = relax.Var("x", R.Tensor((a1, b0, c), "float32")) x2 = relax.Var("x", R.Tensor((a0, b0, c), "float32")) @@ -1294,12 +1294,12 @@ def test_concat_infer_struct_info_with_axis_shape_symbolic(): def test_concat_infer_struct_info_with_axis_shape_var(): bb = relax.BlockBuilder() - a0 = tir.Var("a0", "int64") - a1 = tir.Var("a1", "int64") - b0 = tir.Var("b0", "int64") - b1 = tir.Var("b1", "int64") - b2 = tir.Var("b2", "int64") - c = tir.Var("c", "int64") + a0 = tirx.Var("a0", "int64") + a1 = tirx.Var("a1", "int64") + b0 = tirx.Var("b0", "int64") + b1 = tirx.Var("b1", "int64") + b2 = tirx.Var("b2", "int64") + c = tirx.Var("c", "int64") sx0 = relax.Var("sx", relax.ShapeStructInfo((2, 3, 4))) sx1 = relax.Var("sx", relax.ShapeStructInfo((a0, b0, c))) sx2 = relax.Var("sx", relax.ShapeStructInfo((a1, b0, c))) @@ -1384,8 +1384,8 @@ def test_concat_infer_struct_info_without_axis(): def test_concat_infer_struct_info_without_axis_shape_symbolic(): bb = relax.BlockBuilder() - a0 = tir.Var("a0", "int64") - a1 = tir.Var("a1", "int64") + a0 = tirx.Var("a0", "int64") + a1 = tirx.Var("a1", "int64") x0 = relax.Var("x", R.Tensor((a0,), "float32")) x1 = relax.Var("x", R.Tensor((a0,), "")) y0 = relax.Var("y", R.Tensor((a1,), "float32")) @@ -1452,9 +1452,9 @@ def test_concat_infer_struct_info_more_input_dtype(): def test_concat_infer_struct_info_tuple_var(): bb = relax.BlockBuilder() - a = tir.Var("a0", "int64") - b0 = tir.Var("b0", "int64") - b1 = tir.Var("b1", "int64") + a = tirx.Var("a0", "int64") + b0 = tirx.Var("b0", "int64") + b1 = tirx.Var("b1", "int64") t0 = relax.Var( "t", relax.TupleStructInfo( @@ -1530,7 +1530,7 @@ def test_concat_infer_struct_info_tuple_var(): def test_concat_infer_struct_info_single_input_tensor(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") + a = tirx.Var("a", "int64") s0 = relax.Var("s", relax.ShapeStructInfo((3, a))) s1 = relax.Var("s", relax.ShapeStructInfo((a,))) s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) @@ -1676,7 +1676,7 @@ def test_concat_infer_struct_info_axis_out_of_range(): def test_concat_infer_struct_info_unequal_shape(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") + a = tirx.Var("a", "int64") s0 = relax.Var("s", relax.ShapeStructInfo((3, 4))) s1 = relax.Var("s", relax.ShapeStructInfo((3, a + 2))) x0 = relax.Var("x", R.Tensor((3, 4), "float32")) @@ -1834,8 +1834,8 @@ def test_split_infer_struct_info_by_indices(): def test_split_infer_struct_info_by_indices_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") x = relax.Var("x", R.Tensor((a, b), "float32")) _check_inference( @@ -1989,8 +1989,8 @@ def test_split_infer_struct_info_by_n_section(): def test_split_infer_struct_info_by_n_section_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") x = relax.Var("x", R.Tensor((a, b), "float32")) _check_inference( @@ -2103,8 +2103,8 @@ def test_split_infer_struct_info_more_input_dtype(): def test_split_infer_struct_info_single_output(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") s0 = relax.Var("s", relax.ShapeStructInfo((a, b))) s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) s2 = relax.Var("s", relax.ShapeStructInfo()) @@ -2189,7 +2189,7 @@ def test_split_indices_or_sections_int64(): def test_split_infer_struct_info(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") x = relax.Var("x", R.Tensor((16, 4))) y = relax.Var("y", R.Tensor((16, 4), "float32")) z = relax.Var("z", R.Tensor((n, 16))) @@ -2298,8 +2298,8 @@ def test_split_infer_struct_info(): def test_split_infer_struct_info_non_integer_indices(): bb = relax.BlockBuilder() - a = tir.Var("c", "int64") - b = tir.Var("d", "int64") + a = tirx.Var("c", "int64") + b = tirx.Var("d", "int64") x = relax.Var("x", R.Tensor((3, 4), "float32")) with pytest.raises(TypeError): @@ -2307,7 +2307,7 @@ def test_split_infer_struct_info_non_integer_indices(): def test_split_invalid_n_section(): - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") x = relax.Var("x", R.Tensor((3, 4), "float32")) with pytest.raises((TVMError, TypeError)): @@ -2393,10 +2393,10 @@ def test_broadcast_to_infer_struct_info(): def test_broadcast_to_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") - d = tir.Var("d", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") + d = tirx.Var("d", "int64") x0 = relax.Var("x", R.Tensor((b, 1, 1, d), "float32")) x1 = relax.Var("x", R.Tensor((b, 1, 1, d))) @@ -2434,10 +2434,10 @@ def test_broadcast_to_infer_struct_info_shape_var(): def test_broadcast_to_infer_struct_info_tgt_shape_var(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") - d = tir.Var("d", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") + d = tirx.Var("d", "int64") s0 = relax.Var("s", relax.ShapeStructInfo((b, 1, 1, d))) s1 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) s2 = relax.Var("s", relax.ShapeStructInfo()) @@ -2544,8 +2544,8 @@ def test_broadcast_to_infer_struct_info_not_broadcastable_static(): def test_broadcast_to_infer_struct_info_not_broadcastable_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") s = relax.Var("s", relax.ShapeStructInfo((2, a))) x0 = relax.Var("x", R.Tensor((2, a), "float32")) x1 = relax.Var("x", relax.TensorStructInfo(s, "float32")) @@ -2645,8 +2645,8 @@ def test_collapse_sum_like_infer_struct_info(): def test_collapse_sum_like_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") x0 = relax.Var("x", R.Tensor((3, 4, a), "float32")) y0 = relax.Var("y", R.Tensor((4, a), "float32")) x1 = relax.Var("x", R.Tensor((3, 4, b + a), "float32")) @@ -2710,8 +2710,8 @@ def test_collapse_sum_like_infer_struct_info_shape_mismatch(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) y0 = relax.Var("y", R.Tensor((3, 6, 5), "float32")) - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") x1 = relax.Var("z", R.Tensor((3, a, 5), "float32")) y1 = relax.Var("w", R.Tensor((3, b, 5), "float32")) @@ -2763,8 +2763,8 @@ def test_collapse_sum_to_infer_struct_info(): def test_collapse_sum_to_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") x0 = relax.Var("x", R.Tensor((3, 4, a), "float32")) x1 = relax.Var("x", R.Tensor((3, 4, b + a), "float32")) @@ -2827,8 +2827,8 @@ def test_collapse_sum_to_infer_struct_info_wrong_input_type(): def test_collapse_sum_to_infer_struct_info_shape_mismatch(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") x1 = relax.Var("x", R.Tensor((3, a, 5), "float32")) s0 = relax.Var("s0", relax.ShapeStructInfo((3, 4, 5))) @@ -2852,10 +2852,10 @@ def test_collapse_sum_to_infer_struct_info_shape_mismatch(): def test_collapse_sum_to_infer_struct_info_struct_info_tgt_shape_var(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") - d = tir.Var("d", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") + d = tirx.Var("d", "int64") s0 = relax.Var("s0", relax.ShapeStructInfo((3, a, b))) s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) s2 = relax.Var("s2", relax.ShapeStructInfo()) @@ -2973,9 +2973,9 @@ def test_repeat_infer_struct_info(): def test_repeat_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") x = relax.Var("x", R.Tensor((a, b, c), "float32")) _check_inference(bb, relax.op.repeat(x, 2, 0), relax.TensorStructInfo((a * 2, b, c), "float32")) @@ -3036,8 +3036,8 @@ def test_repeat_infer_struct_info_wrong_input_type(): x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) - r1 = tir.Var("r", "float32") - r2 = tir.StringImm("abc") + r1 = tirx.Var("r", "float32") + r2 = tirx.StringImm("abc") with pytest.raises((TypeError, TVMError)): bb.normalize(relax.op.repeat(x0, 2)) @@ -3115,9 +3115,9 @@ def test_tile_infer_struct_info(): def test_tile_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") x = relax.Var("x", R.Tensor((a, b, c), "float32")) _check_inference(bb, relax.op.tile(x, 2), relax.TensorStructInfo((a, b, c * 2), "float32")) @@ -3161,8 +3161,8 @@ def test_tile_infer_struct_info_wrong_input_type(): x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) x2 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) - r1 = tir.Var("a", "float32") - r2 = tir.StringImm("abc") + r1 = tirx.Var("a", "float32") + r2 = tirx.StringImm("abc") with pytest.raises((TypeError, TVMError)): bb.normalize(relax.op.tile(x0, 2)) @@ -3198,8 +3198,8 @@ def test_flip_infer_struct_info(): def test_flip_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") x = relax.Var("x", R.Tensor((a, b), "float32")) _check_inference(bb, relax.op.flip(x, axis=0), relax.TensorStructInfo((a, b), "float32")) @@ -3261,8 +3261,8 @@ def test_gather_elements_infer_struct_info(): def test_gather_elements_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") x = relax.Var("x", R.Tensor((a, b), "float32")) i = relax.Var("i", R.Tensor((a, b), "int64")) @@ -3319,9 +3319,9 @@ def test_gather_nd_infer_struct_info(): def test_gather_nd_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") x = relax.Var("x", R.Tensor((a, b, c), "float32")) i = relax.Var("i", R.Tensor((2, 2), "int64")) @@ -3437,12 +3437,12 @@ def test_scatter_elements_infer_struct_info(): def test_scatter_elements_infer_struct_info_symbolic_shape(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") - d = tir.Var("d", "int64") - e = tir.Var("e", "int64") - f = tir.Var("f", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") + d = tirx.Var("d", "int64") + e = tirx.Var("e", "int64") + f = tirx.Var("f", "int64") d0 = relax.Var("data", R.Tensor((a, b), "float32")) i0 = relax.Var("indices", R.Tensor((c, d), "int64")) @@ -3472,8 +3472,8 @@ def test_scatter_elements_infer_struct_info_wrong_indices_type(): def test_scatter_elements_infer_struct_info_rank_shape_mismatch(): - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") bb = relax.BlockBuilder() d0 = relax.Var("data", R.Tensor((4, 4), "float32")) @@ -3591,7 +3591,7 @@ def test_one_hot_infer_struct_info(): ) # Test case 3: With symbolic shape - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") i2 = relax.Var("indices", R.Tensor((n,), "int32")) _check_inference( bb, diff --git a/tests/python/relax/test_op_misc.py b/tests/python/relax/test_op_misc.py index 2e652b7c3899..42a055ce2eb5 100644 --- a/tests/python/relax/test_op_misc.py +++ b/tests/python/relax/test_op_misc.py @@ -19,7 +19,7 @@ import tvm.testing from tvm import relax as rx from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.register_global_func("test.op.identity", override=True) @@ -64,7 +64,7 @@ def test_call_tir_with_grad(): def test_implicit_op(): - m, n = tvm.tir.Var("m", "int64"), tvm.tir.Var("n", "int64") + m, n = tvm.tirx.Var("m", "int64"), tvm.tirx.Var("n", "int64") x = rx.Var("x", R.Tensor([m, n], "float32")) y = rx.Var("y", R.Tensor([m, n], "float32")) func = rx.Var( diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index 574c7ec3e4ce..417f502ed152 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -19,7 +19,7 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op, VDevice from tvm.script import relax as R @@ -90,8 +90,8 @@ def test_linear_unit_infer_struct_info(): def test_linear_unit_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((4, n), "float32")) @@ -188,8 +188,8 @@ def test_softmax_log_softmax_infer_struct_info(): def test_softmax_log_softmax_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((4, n), "float32")) @@ -406,11 +406,11 @@ def test_batch_norm_infer_struct_info(): def test_batch_norm_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c0 = tir.Var("c", "int64") - c1 = tir.Var("c", "int64") - h = tir.Var("h", "int64") - w = tir.Var("w", "int64") + n = tirx.Var("n", "int64") + c0 = tirx.Var("c", "int64") + c1 = tirx.Var("c", "int64") + h = tirx.Var("h", "int64") + w = tirx.Var("w", "int64") x0 = relax.Var("x", R.Tensor((n, c0, h, w), "float32")) x1 = relax.Var("x", R.Tensor((n, c1, h, w), "float32")) x2 = relax.Var("x", R.Tensor("float32", ndim=4)) @@ -630,7 +630,7 @@ def test_batch_norm_infer_struct_info_ndim_mismatch(): def test_batch_norm_infer_struct_info_shape_mismatch(): bb = relax.BlockBuilder() - c = tir.Var("c", "int64") + c = tirx.Var("c", "int64") x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32")) x1 = relax.Var("x", R.Tensor((2, c, 28, 28), "float32")) gamma0 = relax.Var("gamma", R.Tensor((3,), "float32")) @@ -714,11 +714,11 @@ def test_layer_norm_infer_struct_info(): def test_layer_norm_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c0 = tir.Var("c", "int64") - c1 = tir.Var("c", "int64") + n = tirx.Var("n", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c0 = tirx.Var("c", "int64") + c1 = tirx.Var("c", "int64") x0 = relax.Var("x", R.Tensor((n, a, b, c0), "float32")) x1 = relax.Var("x", R.Tensor((n, a, b, c1), "float32")) x2 = relax.Var("x", R.Tensor("float32", ndim=4)) @@ -854,7 +854,7 @@ def test_layer_norm_infer_struct_info_ndim_mismatch(): def test_layer_norm_infer_struct_info_shape_mismatch(): bb = relax.BlockBuilder() - c0 = tir.Var("c", "int64") + c0 = tirx.Var("c", "int64") x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) x1 = relax.Var("x", R.Tensor((2, 3, 4, c0), "float32")) gamma0 = relax.Var("gamma", R.Tensor((4, 6), "float32")) @@ -928,11 +928,11 @@ def test_group_norm_infer_struct_info(): def test_group_norm_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c0 = tir.Var("c", "int64") - c1 = tir.Var("c", "int64") + n = tirx.Var("n", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c0 = tirx.Var("c", "int64") + c1 = tirx.Var("c", "int64") x0 = relax.Var("x", R.Tensor((n, a, b, c0), "float32")) x1 = relax.Var("x", R.Tensor((n, a, b, c1), "float32")) x2 = relax.Var("x", R.Tensor("float32", ndim=4)) @@ -1084,7 +1084,7 @@ def test_group_norm_infer_struct_info_ndim_mismatch(): def test_group_norm_infer_struct_info_shape_mismatch(): bb = relax.BlockBuilder() - c0 = tir.Var("c", "int64") + c0 = tirx.Var("c", "int64") x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) x1 = relax.Var("x", R.Tensor((2, 3, 4, c0), "float32")) gamma0 = relax.Var("gamma", R.Tensor((4, 6), "float32")) @@ -1180,8 +1180,8 @@ def test_dropout_infer_struct_info(): def test_dropout_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x = relax.Var("x", R.Tensor((m, n), "float32")) _check_inference( @@ -1283,9 +1283,9 @@ def test_cross_entropy_infer_struct_info(): def test_cross_entropy_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m0 = tir.Var("m", "int64") - m1 = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m0 = tirx.Var("m", "int64") + m1 = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m0, n), "float32")) x1 = relax.Var("x", R.Tensor((m1, n), "float32")) y = relax.Var("y", R.Tensor((m0, n), "float32")) @@ -1346,7 +1346,7 @@ def test_cross_entropy_infer_struct_info_wrong_ndim(): def test_cross_entropy_infer_struct_info_shape_mismatch(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") + m = tirx.Var("m", "int64") x0 = relax.Var("x", R.Tensor((2, 3), "float32")) y0 = relax.Var("y", R.Tensor((2, 4), "float32")) @@ -1520,10 +1520,10 @@ def test_nll_loss_infer_struct_info(): def test_nll_loss_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - N = tir.Var("N", "int64") - C = tir.Var("C", "int64") - d1 = tir.Var("d", "int64") - d2 = tir.Var("d", "int64") + N = tirx.Var("N", "int64") + C = tirx.Var("C", "int64") + d1 = tirx.Var("d", "int64") + d2 = tirx.Var("d", "int64") x0 = relax.Var("x", R.Tensor((N, C, d1, d2), "float32")) x1 = relax.Var("x", R.Tensor((N, C), "float32")) x2 = relax.Var("x", R.Tensor((C,), "float32")) @@ -1636,10 +1636,10 @@ def test_nll_loss_infer_struct_info_no_weights(): def test_nll_loss_infer_struct_info_no_weights_symbolic(): - N = tir.Var("N", "int64") - C = tir.Var("C", "int64") - d1 = tir.Var("d", "int64") - d2 = tir.Var("d", "int64") + N = tirx.Var("N", "int64") + C = tirx.Var("C", "int64") + d1 = tirx.Var("d", "int64") + d2 = tirx.Var("d", "int64") bb = relax.BlockBuilder() x = relax.Var("x", R.Tensor((N, C, d1, d2), "float32")) y = relax.Var("y", R.Tensor((N, d1, d2), "int64")) @@ -1877,10 +1877,10 @@ def test_batch_flatten_infer_struct_info(): def test_batch_flatten_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") - h = tir.Var("h", "int64") - w = tir.Var("w", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") + h = tirx.Var("h", "int64") + w = tirx.Var("w", "int64") x0 = relax.Var("x", R.Tensor((m, n, h, w), "float32")) x1 = relax.Var("x", R.Tensor((4, n, 8, 8), "float32")) diff --git a/tests/python/relax/test_op_nn_convolution.py b/tests/python/relax/test_op_nn_convolution.py index 320ee0a80c35..07b9469abff2 100644 --- a/tests/python/relax/test_op_nn_convolution.py +++ b/tests/python/relax/test_op_nn_convolution.py @@ -18,7 +18,7 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op, VDevice from tvm.script import relax as R @@ -141,13 +141,13 @@ def test_conv1d_infer_struct_info(): def test_conv1d_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - c16 = tir.Var("c16", "int64") - iw = tir.Var("iw", "int64") - ki = tir.Var("ki", "int64") - ko = tir.Var("ko", "int64") - kw = tir.Var("kw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + c16 = tirx.Var("c16", "int64") + iw = tirx.Var("iw", "int64") + ki = tirx.Var("ki", "int64") + ko = tirx.Var("ko", "int64") + kw = tirx.Var("kw", "int64") x0 = relax.Var("x", R.Tensor((n, c, iw), "float32")) x1 = relax.Var("x", R.Tensor((n, c, iw, c16), "float32")) w0 = relax.Var("w", R.Tensor((ko, ki, kw), "float32")) @@ -173,7 +173,7 @@ def test_conv1d_infer_struct_info_shape_symbolic(): bb, relax.op.nn.conv1d(x0, w0, strides=2, padding=1, dilation=2), relax.TensorStructInfo( - (n, ko, tvm.tir.floordiv(iw + 3, 2) + 1 - kw), + (n, ko, tvm.tirx.floordiv(iw + 3, 2) + 1 - kw), "float32", ), ) @@ -232,9 +232,9 @@ def test_conv1d_infer_struct_info_groups(): def test_conv1d_infer_struct_info_symbolic_groups(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - ic = tir.Var("c", "int64") - oc = tir.Var("oc", "int64") + n = tirx.Var("n", "int64") + ic = tirx.Var("c", "int64") + oc = tirx.Var("oc", "int64") x = relax.Var("x", R.Tensor((n, ic * 4, 28), "float32")) w0 = relax.Var("w", R.Tensor((oc * 4, ic, 3), "float32")) w1 = relax.Var("w", R.Tensor((oc, ic, 3), "float32")) @@ -251,9 +251,9 @@ def test_conv1d_infer_struct_info_symbolic_groups(): def test_conv1d_infer_struct_info_input_channel_group_incompatible(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - ic = tir.Var("c", "int64") - oc = tir.Var("oc", "int64") + n = tirx.Var("n", "int64") + ic = tirx.Var("c", "int64") + oc = tirx.Var("oc", "int64") x0 = relax.Var("x", R.Tensor((2, 128, 28), "float32")) w0 = relax.Var("w", R.Tensor((48, 20, 3), "float32")) x1 = relax.Var("x", R.Tensor((n, ic * 6, 28), "float32")) @@ -267,9 +267,9 @@ def test_conv1d_infer_struct_info_input_channel_group_incompatible(): def test_conv1d_infer_struct_info_output_channel_group_incompatible(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - ic = tir.Var("c", "int64") - oc = tir.Var("oc", "int64") + n = tirx.Var("n", "int64") + ic = tirx.Var("c", "int64") + oc = tirx.Var("oc", "int64") x0 = relax.Var("x", R.Tensor((2, 120, 28), "float32")) w0 = relax.Var("w", R.Tensor((128, 20, 3), "float32")) x1 = relax.Var("x", R.Tensor((n, ic * 6, 28), "float32")) @@ -336,7 +336,7 @@ def test_conv1d_infer_struct_info_mixed_precision(): def test_conv1d_unequal_input_channel(): bb = relax.BlockBuilder() - ic = tir.Var("ic", "int64") + ic = tirx.Var("ic", "int64") x0 = relax.Var("x", R.Tensor([2, 3, 28], "float32")) w0 = relax.Var("w", R.Tensor([3, 4, 3], "float32")) x1 = relax.Var("x", R.Tensor([2, ic, 28], "float32")) @@ -525,13 +525,13 @@ def test_conv1d_transpose_infer_struct_info(): def test_conv1d_transpose_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - c16 = tir.Var("c16", "int64") - iw = tir.Var("iw", "int64") - ki = tir.Var("ki", "int64") - ko = tir.Var("ko", "int64") - kw = tir.Var("kw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + c16 = tirx.Var("c16", "int64") + iw = tirx.Var("iw", "int64") + ki = tirx.Var("ki", "int64") + ko = tirx.Var("ko", "int64") + kw = tirx.Var("kw", "int64") x0 = relax.Var("x", R.Tensor((n, c, iw), "float32")) x1 = relax.Var("x", R.Tensor((n, c, iw, c16), "float32")) w0 = relax.Var("w", R.Tensor((ki, ko, kw), "float32")) @@ -620,9 +620,9 @@ def test_conv1d_transpose_infer_struct_info_groups(): def test_conv1d_transpose_infer_struct_info_symbolic_groups(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - ic = tir.Var("c", "int64") - oc = tir.Var("oc", "int64") + n = tirx.Var("n", "int64") + ic = tirx.Var("c", "int64") + oc = tirx.Var("oc", "int64") x = relax.Var("x", R.Tensor((n, ic * 4, 28), "float32")) w0 = relax.Var("w", R.Tensor((ic, oc, 3), "float32")) @@ -635,9 +635,9 @@ def test_conv1d_transpose_infer_struct_info_symbolic_groups(): def test_conv1d_transpose_infer_struct_info_input_channel_group_incompatible(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - ic = tir.Var("c", "int64") - oc = tir.Var("oc", "int64") + n = tirx.Var("n", "int64") + ic = tirx.Var("c", "int64") + oc = tirx.Var("oc", "int64") x0 = relax.Var("x", R.Tensor((2, 128, 28), "float32")) w0 = relax.Var("w", R.Tensor((128, 20, 3), "float32")) x1 = relax.Var("x", R.Tensor((n, ic, 28), "float32")) @@ -686,7 +686,7 @@ def test_conv1d_transpose_infer_struct_info_more_input_dtype(): def test_conv1d_transpose_unequal_input_channel(): bb = relax.BlockBuilder() - ic = tir.Var("ic", "int64") + ic = tirx.Var("ic", "int64") x0 = relax.Var("x", R.Tensor([2, 3, 28], "float32")) w0 = relax.Var("w", R.Tensor([4, 3, 3], "float32")) x1 = relax.Var("x", R.Tensor([2, ic, 28], "float32")) @@ -901,15 +901,15 @@ def test_conv2d_infer_struct_info(): def test_conv2d_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - c16 = tir.Var("c16", "int64") - ih = tir.Var("ih", "int64") - iw = tir.Var("iw", "int64") - ki = tir.Var("ki", "int64") - ko = tir.Var("ko", "int64") - kh = tir.Var("kh", "int64") - kw = tir.Var("kw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + c16 = tirx.Var("c16", "int64") + ih = tirx.Var("ih", "int64") + iw = tirx.Var("iw", "int64") + ki = tirx.Var("ki", "int64") + ko = tirx.Var("ko", "int64") + kh = tirx.Var("kh", "int64") + kw = tirx.Var("kw", "int64") x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) w0 = relax.Var("w", R.Tensor((ko, ki, kh, kw), "float32")) @@ -937,7 +937,7 @@ def test_conv2d_infer_struct_info_shape_symbolic(): bb, relax.op.nn.conv2d(x0, w0, strides=(2, 2), padding=(1, 1), dilation=(2, 2)), relax.TensorStructInfo( - (n, ko, tvm.tir.floordiv(ih + 3, 2) + 1 - kh, tvm.tir.floordiv(iw + 3, 2) + 1 - kw), + (n, ko, tvm.tirx.floordiv(ih + 3, 2) + 1 - kh, tvm.tirx.floordiv(iw + 3, 2) + 1 - kw), "float32", ), ) @@ -996,9 +996,9 @@ def test_conv2d_infer_struct_info_groups(): def test_conv2d_infer_struct_info_symbolic_groups(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - ic = tir.Var("c", "int64") - oc = tir.Var("oc", "int64") + n = tirx.Var("n", "int64") + ic = tirx.Var("c", "int64") + oc = tirx.Var("oc", "int64") x = relax.Var("x", R.Tensor((n, ic * 4, 28, 28), "float32")) w0 = relax.Var("w", R.Tensor((oc * 4, ic, 3, 3), "float32")) w1 = relax.Var("w", R.Tensor((oc, ic, 3, 3), "float32")) @@ -1015,9 +1015,9 @@ def test_conv2d_infer_struct_info_symbolic_groups(): def test_conv2d_infer_struct_info_input_channel_group_incompatible(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - ic = tir.Var("c", "int64") - oc = tir.Var("oc", "int64") + n = tirx.Var("n", "int64") + ic = tirx.Var("c", "int64") + oc = tirx.Var("oc", "int64") x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) w0 = relax.Var("w", R.Tensor((48, 20, 3, 3), "float32")) x1 = relax.Var("x", R.Tensor((n, ic * 6, 28, 28), "float32")) @@ -1031,9 +1031,9 @@ def test_conv2d_infer_struct_info_input_channel_group_incompatible(): def test_conv2d_infer_struct_info_output_channel_group_incompatible(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - ic = tir.Var("c", "int64") - oc = tir.Var("oc", "int64") + n = tirx.Var("n", "int64") + ic = tirx.Var("c", "int64") + oc = tirx.Var("oc", "int64") x0 = relax.Var("x", R.Tensor((2, 120, 28, 28), "float32")) w0 = relax.Var("w", R.Tensor((128, 20, 3, 3), "float32")) x1 = relax.Var("x", R.Tensor((n, ic * 6, 28, 28), "float32")) @@ -1106,7 +1106,7 @@ def test_conv2d_infer_struct_info_mixed_precision(): def test_conv2d_unequal_input_channel(): bb = relax.BlockBuilder() - ic = tir.Var("ic", "int64") + ic = tirx.Var("ic", "int64") x0 = relax.Var("x", R.Tensor([2, 3, 28, 28], "float32")) w0 = relax.Var("w", R.Tensor([3, 4, 3, 3], "float32")) x1 = relax.Var("x", R.Tensor([2, ic, 28, 28], "float32")) @@ -1314,15 +1314,15 @@ def test_conv2d_transpose_infer_struct_info(): def test_conv2d_transpose_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - c16 = tir.Var("c16", "int64") - ih = tir.Var("ih", "int64") - iw = tir.Var("iw", "int64") - ki = tir.Var("ki", "int64") - ko = tir.Var("ko", "int64") - kh = tir.Var("kh", "int64") - kw = tir.Var("kw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + c16 = tirx.Var("c16", "int64") + ih = tirx.Var("ih", "int64") + iw = tirx.Var("iw", "int64") + ki = tirx.Var("ki", "int64") + ko = tirx.Var("ko", "int64") + kh = tirx.Var("kh", "int64") + kw = tirx.Var("kw", "int64") x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) w0 = relax.Var("w", R.Tensor((ki, ko, kh, kw), "float32")) @@ -1415,9 +1415,9 @@ def test_conv2d_transpose_infer_struct_info_groups(): def test_conv2d_transpose_infer_struct_info_symbolic_groups(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - ic = tir.Var("c", "int64") - oc = tir.Var("oc", "int64") + n = tirx.Var("n", "int64") + ic = tirx.Var("c", "int64") + oc = tirx.Var("oc", "int64") x = relax.Var("x", R.Tensor((n, ic * 4, 28, 28), "float32")) w0 = relax.Var("w", R.Tensor((ic, oc, 3, 3), "float32")) @@ -1430,9 +1430,9 @@ def test_conv2d_transpose_infer_struct_info_symbolic_groups(): def test_conv2d_transpose_infer_struct_info_input_channel_group_incompatible(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - ic = tir.Var("c", "int64") - oc = tir.Var("oc", "int64") + n = tirx.Var("n", "int64") + ic = tirx.Var("c", "int64") + oc = tirx.Var("oc", "int64") x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32")) w0 = relax.Var("w", R.Tensor((128, 20, 3, 3), "float32")) x1 = relax.Var("x", R.Tensor((n, ic, 28, 28), "float32")) @@ -1481,7 +1481,7 @@ def test_conv2d_transpose_infer_struct_info_more_input_dtype(): def test_conv2d_transpose_unequal_input_channel(): bb = relax.BlockBuilder() - ic = tir.Var("ic", "int64") + ic = tirx.Var("ic", "int64") x0 = relax.Var("x", R.Tensor([2, 3, 28, 28], "float32")) w0 = relax.Var("w", R.Tensor([4, 3, 3, 3], "float32")) x1 = relax.Var("x", R.Tensor([2, ic, 28, 28], "float32")) @@ -1711,17 +1711,17 @@ def test_conv3d_infer_struct_info(): def test_conv3d_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - c16 = tir.Var("c16", "int64") - id = tir.Var("id", "int64") - ih = tir.Var("ih", "int64") - iw = tir.Var("iw", "int64") - ki = tir.Var("ki", "int64") - ko = tir.Var("ko", "int64") - kd = tir.Var("kd", "int64") - kh = tir.Var("kh", "int64") - kw = tir.Var("kw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + c16 = tirx.Var("c16", "int64") + id = tirx.Var("id", "int64") + ih = tirx.Var("ih", "int64") + iw = tirx.Var("iw", "int64") + ki = tirx.Var("ki", "int64") + ko = tirx.Var("ko", "int64") + kd = tirx.Var("kd", "int64") + kh = tirx.Var("kh", "int64") + kw = tirx.Var("kw", "int64") x0 = relax.Var("x", R.Tensor((n, c, id, ih, iw), "float32")) x1 = relax.Var("x", R.Tensor((n, c, id, ih, iw, c16), "float32")) w0 = relax.Var("w", R.Tensor((ko, ki, kd, kh, kw), "float32")) @@ -1752,9 +1752,9 @@ def test_conv3d_infer_struct_info_shape_symbolic(): ( n, ko, - tvm.tir.floordiv(id + 3, 2) + 1 - kd, - tvm.tir.floordiv(ih + 3, 2) + 1 - kh, - tvm.tir.floordiv(iw + 3, 2) + 1 - kw, + tvm.tirx.floordiv(id + 3, 2) + 1 - kd, + tvm.tirx.floordiv(ih + 3, 2) + 1 - kh, + tvm.tirx.floordiv(iw + 3, 2) + 1 - kw, ), "float32", ), diff --git a/tests/python/relax/test_op_nn_pooling.py b/tests/python/relax/test_op_nn_pooling.py index 1e69817c0f17..8a7120af0b97 100644 --- a/tests/python/relax/test_op_nn_pooling.py +++ b/tests/python/relax/test_op_nn_pooling.py @@ -19,7 +19,7 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op, VDevice from tvm.script import relax as R @@ -87,10 +87,10 @@ def test_max_pool1d_infer_struct_info(): def test_max_pool1d_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - w = tir.Var("w", "int64") - c16 = tir.Var("c16", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + w = tirx.Var("w", "int64") + c16 = tirx.Var("c16", "int64") x0 = relax.Var("x", R.Tensor((n, c, w), "float32")) x1 = relax.Var("x", R.Tensor((n, c, w, c16), "float32")) @@ -102,7 +102,7 @@ def test_max_pool1d_infer_struct_info_shape_symbolic(): ( n, c, - tvm.tir.floordiv(w - 1, 3) + 1, + tvm.tirx.floordiv(w - 1, 3) + 1, ), "float32", ), @@ -157,15 +157,15 @@ def test_max_pool1d_infer_struct_info_ceil_mode(): def test_max_pool1d_infer_struct_info_ceil_mode_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - w = tir.Var("w", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + w = tirx.Var("w", "int64") x = relax.Var("x", R.Tensor((n, c, w), "float32")) _check_inference( bb, relax.op.nn.max_pool1d(x, pool_size=3, strides=2, padding=1, dilation=2, ceil_mode=True), - relax.TensorStructInfo((n, c, tvm.tir.floordiv(w, 2)), "float32"), + relax.TensorStructInfo((n, c, tvm.tirx.floordiv(w, 2)), "float32"), ) @@ -308,11 +308,11 @@ def test_max_pool2d_infer_struct_info(): def test_max_pool2d_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - c16 = tir.Var("c16", "int64") - ih = tir.Var("ih", "int64") - iw = tir.Var("iw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + c16 = tirx.Var("c16", "int64") + ih = tirx.Var("ih", "int64") + iw = tirx.Var("iw", "int64") x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) @@ -325,8 +325,8 @@ def test_max_pool2d_infer_struct_info_shape_symbolic(): ( n, c, - tvm.tir.floordiv(ih - 1, 3) + 1, - tvm.tir.floordiv(iw - 1, 3) + 1, + tvm.tirx.floordiv(ih - 1, 3) + 1, + tvm.tirx.floordiv(iw - 1, 3) + 1, ), "float32", ), @@ -380,10 +380,10 @@ def test_max_pool2d_infer_struct_info_ceil_mode(): def test_max_pool2d_infer_struct_info_ceil_mode_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - ih = tir.Var("ih", "int64") - iw = tir.Var("iw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + ih = tirx.Var("ih", "int64") + iw = tirx.Var("iw", "int64") x = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) _check_inference( @@ -391,7 +391,9 @@ def test_max_pool2d_infer_struct_info_ceil_mode_symbolic(): relax.op.nn.max_pool2d( x, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), dilation=(2, 2), ceil_mode=True ), - relax.TensorStructInfo((n, c, tvm.tir.floordiv(ih, 2), tvm.tir.floordiv(iw, 2)), "float32"), + relax.TensorStructInfo( + (n, c, tvm.tirx.floordiv(ih, 2), tvm.tirx.floordiv(iw, 2)), "float32" + ), ) @@ -540,12 +542,12 @@ def test_max_pool3d_infer_struct_info(): def test_max_pool3d_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - c16 = tir.Var("c16", "int64") - id = tir.Var("id", "int64") - ih = tir.Var("ih", "int64") - iw = tir.Var("iw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + c16 = tirx.Var("c16", "int64") + id = tirx.Var("id", "int64") + ih = tirx.Var("ih", "int64") + iw = tirx.Var("iw", "int64") x0 = relax.Var("x", R.Tensor((n, c, id, ih, iw), "float32")) x1 = relax.Var("x", R.Tensor((n, c, id, ih, iw, c16), "float32")) @@ -558,9 +560,9 @@ def test_max_pool3d_infer_struct_info_shape_symbolic(): ( n, c, - tvm.tir.floordiv(id - 1, 3) + 1, - tvm.tir.floordiv(ih - 1, 3) + 1, - tvm.tir.floordiv(iw - 1, 3) + 1, + tvm.tirx.floordiv(id - 1, 3) + 1, + tvm.tirx.floordiv(ih - 1, 3) + 1, + tvm.tirx.floordiv(iw - 1, 3) + 1, ), "float32", ), @@ -615,11 +617,11 @@ def test_max_pool3d_infer_struct_info_ceil_mode(): def test_max_pool3d_infer_struct_info_ceil_mode_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - id_ = tir.Var("id", "int64") - ih = tir.Var("ih", "int64") - iw = tir.Var("iw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + id_ = tirx.Var("id", "int64") + ih = tirx.Var("ih", "int64") + iw = tirx.Var("iw", "int64") x = relax.Var("x", R.Tensor((n, c, id_, ih, iw), "float32")) _check_inference( @@ -633,7 +635,7 @@ def test_max_pool3d_infer_struct_info_ceil_mode_symbolic(): ceil_mode=True, ), relax.TensorStructInfo( - (n, c, tvm.tir.floordiv(id_, 2), tvm.tir.floordiv(ih, 2), tvm.tir.floordiv(iw, 2)), + (n, c, tvm.tirx.floordiv(id_, 2), tvm.tirx.floordiv(ih, 2), tvm.tirx.floordiv(iw, 2)), "float32", ), ) @@ -779,10 +781,10 @@ def test_avg_pool1d_infer_struct_info(): def test_avg_pool1d_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - c16 = tir.Var("c16", "int64") - iw = tir.Var("iw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + c16 = tirx.Var("c16", "int64") + iw = tirx.Var("iw", "int64") x0 = relax.Var("x", R.Tensor((n, c, iw), "float32")) x1 = relax.Var("x", R.Tensor((n, c, iw, c16), "float32")) @@ -793,7 +795,7 @@ def test_avg_pool1d_infer_struct_info_shape_symbolic(): ( n, c, - tvm.tir.floordiv(iw - 1, 3) + 1, + tvm.tirx.floordiv(iw - 1, 3) + 1, ), "float32", ), @@ -847,16 +849,16 @@ def test_avg_pool1d_infer_struct_info_ceil_mode(): def test_avg_pool1d_infer_struct_info_ceil_mode_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - iw = tir.Var("iw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + iw = tirx.Var("iw", "int64") x = relax.Var("x", R.Tensor((n, c, iw), "float32")) _check_inference( bb, relax.op.nn.avg_pool1d(x, pool_size=3, strides=2, padding=1, dilation=2, ceil_mode=True), relax.TensorStructInfo( - (n, c, tvm.tir.floordiv(iw, 2)), + (n, c, tvm.tirx.floordiv(iw, 2)), "float32", ), ) @@ -997,11 +999,11 @@ def test_avg_pool2d_infer_struct_info(): def test_avg_pool2d_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - c16 = tir.Var("c16", "int64") - ih = tir.Var("ih", "int64") - iw = tir.Var("iw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + c16 = tirx.Var("c16", "int64") + ih = tirx.Var("ih", "int64") + iw = tirx.Var("iw", "int64") x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) @@ -1014,8 +1016,8 @@ def test_avg_pool2d_infer_struct_info_shape_symbolic(): ( n, c, - tvm.tir.floordiv(ih - 1, 3) + 1, - tvm.tir.floordiv(iw - 1, 3) + 1, + tvm.tirx.floordiv(ih - 1, 3) + 1, + tvm.tirx.floordiv(iw - 1, 3) + 1, ), "float32", ), @@ -1069,10 +1071,10 @@ def test_avg_pool2d_infer_struct_info_ceil_mode(): def test_avg_pool2d_infer_struct_info_ceil_mode_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - ih = tir.Var("ih", "int64") - iw = tir.Var("iw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + ih = tirx.Var("ih", "int64") + iw = tirx.Var("iw", "int64") x = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) _check_inference( @@ -1080,7 +1082,9 @@ def test_avg_pool2d_infer_struct_info_ceil_mode_symbolic(): relax.op.nn.avg_pool2d( x, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), dilation=(2, 2), ceil_mode=True ), - relax.TensorStructInfo((n, c, tvm.tir.floordiv(ih, 2), tvm.tir.floordiv(iw, 2)), "float32"), + relax.TensorStructInfo( + (n, c, tvm.tirx.floordiv(ih, 2), tvm.tirx.floordiv(iw, 2)), "float32" + ), ) @@ -1230,12 +1234,12 @@ def test_avg_pool3d_infer_struct_info(): def test_avg_pool3d_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - c16 = tir.Var("c16", "int64") - id_ = tir.Var("id", "int64") - ih = tir.Var("ih", "int64") - iw = tir.Var("iw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + c16 = tirx.Var("c16", "int64") + id_ = tirx.Var("id", "int64") + ih = tirx.Var("ih", "int64") + iw = tirx.Var("iw", "int64") x0 = relax.Var("x", R.Tensor((n, c, id_, ih, iw), "float32")) x1 = relax.Var("x", R.Tensor((n, c, id_, ih, iw, c16), "float32")) @@ -1248,9 +1252,9 @@ def test_avg_pool3d_infer_struct_info_shape_symbolic(): ( n, c, - tvm.tir.floordiv(id_ - 1, 3) + 1, - tvm.tir.floordiv(ih - 1, 3) + 1, - tvm.tir.floordiv(iw - 1, 3) + 1, + tvm.tirx.floordiv(id_ - 1, 3) + 1, + tvm.tirx.floordiv(ih - 1, 3) + 1, + tvm.tirx.floordiv(iw - 1, 3) + 1, ), "float32", ), @@ -1304,11 +1308,11 @@ def test_avg_pool3d_infer_struct_info_ceil_mode(): def test_avg_pool3d_infer_struct_info_ceil_mode_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - id_ = tir.Var("id", "int64") - ih = tir.Var("ih", "int64") - iw = tir.Var("iw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + id_ = tirx.Var("id", "int64") + ih = tirx.Var("ih", "int64") + iw = tirx.Var("iw", "int64") x = relax.Var("x", R.Tensor((n, c, id_, ih, iw), "float32")) _check_inference( @@ -1325,9 +1329,9 @@ def test_avg_pool3d_infer_struct_info_ceil_mode_symbolic(): ( n, c, - tvm.tir.floordiv(id_, 2), - tvm.tir.floordiv(ih, 2), - tvm.tir.floordiv(iw, 2), + tvm.tirx.floordiv(id_, 2), + tvm.tirx.floordiv(ih, 2), + tvm.tirx.floordiv(iw, 2), ), "float32", ), @@ -1461,9 +1465,9 @@ def test_adaptive_avg_pool1d_infer_struct_info(): def test_adaptive_avg_pool1d_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - l = tir.Var("l", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + l = tirx.Var("l", "int64") x0 = relax.Var("x", R.Tensor((n, c, l), "float32")) @@ -1619,11 +1623,11 @@ def test_adaptive_avg_pool2d_infer_struct_info(): def test_adaptive_avg_pool2d_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - c16 = tir.Var("c16", "int64") - ih = tir.Var("ih", "int64") - iw = tir.Var("iw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + c16 = tirx.Var("c16", "int64") + ih = tirx.Var("ih", "int64") + iw = tirx.Var("iw", "int64") x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32")) x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32")) @@ -1798,12 +1802,12 @@ def test_adaptive_avg_pool3d_infer_struct_info(): def test_adaptive_avg_pool3d_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") - c = tir.Var("c", "int64") - c16 = tir.Var("c16", "int64") - d = tir.Var("d", "int64") - ih = tir.Var("ih", "int64") - iw = tir.Var("iw", "int64") + n = tirx.Var("n", "int64") + c = tirx.Var("c", "int64") + c16 = tirx.Var("c16", "int64") + d = tirx.Var("d", "int64") + ih = tirx.Var("ih", "int64") + iw = tirx.Var("iw", "int64") x0 = relax.Var("x", R.Tensor((n, c, d, ih, iw), "float32")) x1 = relax.Var("x", R.Tensor((n, c, d, ih, iw, c16), "float32")) diff --git a/tests/python/relax/test_op_qdq.py b/tests/python/relax/test_op_qdq.py index d773a6c7d28a..2c876eb4a34b 100644 --- a/tests/python/relax/test_op_qdq.py +++ b/tests/python/relax/test_op_qdq.py @@ -16,7 +16,7 @@ # under the License. import tvm import tvm.testing -from tvm import relax, tir +from tvm import relax, tirx from tvm.ir import Op from tvm.script import relax as R @@ -53,7 +53,7 @@ def test_qdq_op_infer_struct_info(): def test_qdq_op_infer_struct_info_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") x = relax.Var("x", R.Tensor((n, 3), "float32")) dx = relax.Var("dx", R.Tensor((n, 3), "int8")) s = relax.Var("s", R.Tensor([3], "float32")) @@ -70,7 +70,7 @@ def test_qdq_op_infer_struct_info_symbolic(): def test_qdq_float8_e4m3fn_op_infer_struct_info_symbolic(): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") x = relax.Var("x", R.Tensor((n, 3), "float32")) dx = relax.Var("dx", R.Tensor((n, 3), "float8_e4m3fn")) s = relax.Var("s", R.Tensor([3], "float32")) @@ -90,7 +90,7 @@ def test_qdq_float8_e4m3fn_op_infer_struct_info_symbolic(): def test_qdq_float8_e5m2_op_infer_struct_info_symbolic(): dtype = "float8_e5m2" bb = relax.BlockBuilder() - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") x = relax.Var("x", R.Tensor((n, 3), "float32")) dx = relax.Var("dx", R.Tensor((n, 3), dtype)) s = relax.Var("s", R.Tensor([3], "float32")) diff --git a/tests/python/relax/test_op_search.py b/tests/python/relax/test_op_search.py index 53d2c3fa683b..6648d8cf3a93 100644 --- a/tests/python/relax/test_op_search.py +++ b/tests/python/relax/test_op_search.py @@ -20,7 +20,7 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op, VDevice from tvm.script import relax as R @@ -104,12 +104,12 @@ def test_where_infer_struct_info(): def test_where_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") - d0 = tir.Var("d", "int64") - d1 = tir.Var("d", "int64") - e = tir.Var("e", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") + d0 = tirx.Var("d", "int64") + d1 = tirx.Var("d", "int64") + e = tirx.Var("e", "int64") cond = relax.Var("cond", R.Tensor((a, b, 1, d0, 1), "bool")) x0 = relax.Var("x", R.Tensor((b, 1, d0, e), "float32")) x1 = relax.Var("x", R.Tensor((b, 1, d1, e), "float32")) @@ -362,10 +362,10 @@ def test_argmax_argmin_infer_struct_info(argmax_argmin_op: Callable): def test_argmax_argmin_infer_struct_info_shape_symbolic(argmax_argmin_op: Callable): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") - d = tir.Var("d", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") + d = tirx.Var("d", "int64") x = relax.Var("x", R.Tensor((a, b, c, d), "int64")) _check_inference(bb, argmax_argmin_op(x, axis=1), relax.TensorStructInfo((a, c, d), "int64")) diff --git a/tests/python/relax/test_op_set.py b/tests/python/relax/test_op_set.py index 87521a5296ac..a8a21db92607 100644 --- a/tests/python/relax/test_op_set.py +++ b/tests/python/relax/test_op_set.py @@ -18,7 +18,7 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op, VDevice from tvm.script import relax as R @@ -498,9 +498,9 @@ def test_unique_infer_struct_info(): def test_unique_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") x = relax.Var("x", R.Tensor((a, b, c), "float32")) _check_inference( diff --git a/tests/python/relax/test_op_sort.py b/tests/python/relax/test_op_sort.py index e3e2765de855..0a29633daf4f 100644 --- a/tests/python/relax/test_op_sort.py +++ b/tests/python/relax/test_op_sort.py @@ -18,7 +18,7 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op, VDevice from tvm.script import relax as R @@ -65,9 +65,9 @@ def test_sort_infer_struct_info(): def test_sort_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") x = relax.Var("x", R.Tensor((a, b, c), "float32")) _check_inference(bb, relax.op.sort(x, axis=1), relax.TensorStructInfo((a, b, c), "float32")) @@ -142,9 +142,9 @@ def test_argsort_infer_struct_info(): def test_argsort_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") x = relax.Var("x", R.Tensor((a, b, c), "float32")) _check_inference(bb, relax.op.argsort(x, axis=1), relax.TensorStructInfo((a, b, c), "int32")) @@ -263,9 +263,9 @@ def test_topk_infer_struct_info(): def test_topk_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") x = relax.Var("x", R.Tensor((a, b, c), "float32")) _check_inference( diff --git a/tests/python/relax/test_op_statistical.py b/tests/python/relax/test_op_statistical.py index 37793ecc73c7..af5de101eec5 100644 --- a/tests/python/relax/test_op_statistical.py +++ b/tests/python/relax/test_op_statistical.py @@ -20,7 +20,7 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op, VDevice from tvm.script import relax as R @@ -124,10 +124,10 @@ def test_statistical_infer_struct_info(): def test_statistical_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") - d = tir.Var("d", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") + d = tirx.Var("d", "int64") x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) _check_inference(bb, relax.op.min(x, axis=[1, 2]), relax.TensorStructInfo((a, d), "float32")) @@ -240,9 +240,9 @@ def test_scan_op_infer_struct_info(scan_op: Callable): def test_scan_op_infer_struct_info_shape_symbolic(scan_op: Callable): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") x = relax.Var("x", R.Tensor((a, b, c), "float32")) _check_inference(bb, scan_op(x, axis=1), relax.TensorStructInfo((a, b, c), "float32")) @@ -382,10 +382,10 @@ def test_statistical_ext_infer_struct_info(): def test_statistical_ext_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - a = tir.Var("a", "int64") - b = tir.Var("b", "int64") - c = tir.Var("c", "int64") - d = tir.Var("d", "int64") + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + c = tirx.Var("c", "int64") + d = tirx.Var("d", "int64") x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) _check_inference( diff --git a/tests/python/relax/test_op_take.py b/tests/python/relax/test_op_take.py index 6e2615c92141..38dd5574917e 100644 --- a/tests/python/relax/test_op_take.py +++ b/tests/python/relax/test_op_take.py @@ -21,7 +21,7 @@ import tvm.testing from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T axis = tvm.testing.parameter(0, 1) diff --git a/tests/python/relax/test_op_ternary.py b/tests/python/relax/test_op_ternary.py index 541f90a2a0c3..11098082f404 100644 --- a/tests/python/relax/test_op_ternary.py +++ b/tests/python/relax/test_op_ternary.py @@ -18,7 +18,7 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op, VDevice from tvm.script import relax as R @@ -63,8 +63,8 @@ def test_ewise_fma_infer_struct_info(): def test_ewise_fma_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) y0 = relax.Var("y", R.Tensor((m, n), "float32")) y1 = relax.Var("y", R.Tensor(dtype="float32", ndim=2)) diff --git a/tests/python/relax/test_op_unary.py b/tests/python/relax/test_op_unary.py index ca4e65f89c2f..17da8903b236 100644 --- a/tests/python/relax/test_op_unary.py +++ b/tests/python/relax/test_op_unary.py @@ -20,7 +20,7 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.ir import Op, VDevice from tvm.script import relax as R @@ -116,8 +116,8 @@ def test_unary_arith_infer_struct_info(unary_arith_op: Callable): def test_unary_arith_infer_struct_info_shape_symbolic(unary_arith_op: Callable): bb = relax.BlockBuilder() - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x0 = relax.Var("x", R.Tensor((m, n), "float32")) x1 = relax.Var("x", R.Tensor((4, n), "float32")) @@ -206,8 +206,8 @@ def test_clip_infer_struct_info(): _check_inference(bb, relax.op.clip(x4, 0, 6), relax.TensorStructInfo(dtype="")) # Symbolic - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x5 = relax.Var("x", R.Tensor((m, n), "float32")) x6 = relax.Var("x", R.Tensor((4, n), "float32")) diff --git a/tests/python/relax/test_op_view.py b/tests/python/relax/test_op_view.py index 4d309e3f72a0..b634eac65fcf 100644 --- a/tests/python/relax/test_op_view.py +++ b/tests/python/relax/test_op_view.py @@ -22,7 +22,7 @@ import tvm.testing from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_infer_shape_of_1d_static_view(): diff --git a/tests/python/relax/test_op_vision.py b/tests/python/relax/test_op_vision.py index 660b5d27720b..753ee14140bf 100644 --- a/tests/python/relax/test_op_vision.py +++ b/tests/python/relax/test_op_vision.py @@ -20,7 +20,7 @@ import tvm import tvm.testing -from tvm import TVMError, relax, tir +from tvm import TVMError, relax, tirx from tvm.relax.transform import LegalizeOps from tvm.script import relax as R @@ -63,9 +63,9 @@ def test_all_class_non_max_suppression_wrong_input_number(): def test_all_class_non_max_suppression_infer_struct_info_shape_var(): bb = relax.BlockBuilder() - batch_size = tir.Var("batch_size", "int64") - num_classes = tir.Var("num_classes", "int64") - num_boxes = tir.Var("num_boxes", "int64") + batch_size = tirx.Var("batch_size", "int64") + num_classes = tirx.Var("num_classes", "int64") + num_boxes = tirx.Var("num_boxes", "int64") boxes = relax.Var("boxes", R.Tensor((batch_size, num_boxes, 4), "float32")) scores = relax.Var("scores", R.Tensor((batch_size, num_classes, num_boxes), "float32")) max_output_boxes_per_class = relax.const(10, "int64") diff --git a/tests/python/relax/test_optimize_layout_transform.py b/tests/python/relax/test_optimize_layout_transform.py index 046fb789a6fd..cd60ce1d2bc9 100644 --- a/tests/python/relax/test_optimize_layout_transform.py +++ b/tests/python/relax/test_optimize_layout_transform.py @@ -26,7 +26,7 @@ from tvm.relax.transform import DeadCodeElimination, FuseTIR, OptimizeLayoutTransform from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def _run_pass_compare_output(Before, Expected): @@ -286,7 +286,7 @@ def relax_relu_replacement( @T.prim_func(private=True) def remove_pad(var_input: T.handle, var_output: T.handle): - T.func_attr({"operator_name": "remove_pad", "tir.noalias": True}) + T.func_attr({"operator_name": "remove_pad", "tirx.noalias": True}) p0 = T.int64() input = T.match_buffer(var_input, (p0,)) i0 = T.int64() @@ -363,7 +363,7 @@ def relax_relu_replacement( @T.prim_func(private=True) def remove_pad(var_input: T.handle, var_output: T.handle): - T.func_attr({"operator_name": "remove_pad", "tir.noalias": True}) + T.func_attr({"operator_name": "remove_pad", "tirx.noalias": True}) p0 = T.int64() input = T.match_buffer(var_input, (p0,)) i0 = T.int64() diff --git a/tests/python/relax/test_pipeline.py b/tests/python/relax/test_pipeline.py index 482c45fbdd85..91a9228e8c27 100644 --- a/tests/python/relax/test_pipeline.py +++ b/tests/python/relax/test_pipeline.py @@ -21,7 +21,7 @@ import tvm.testing from tvm import relax from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_pipeline_compile(): diff --git a/tests/python/relax/test_pytorch_integration.py b/tests/python/relax/test_pytorch_integration.py index 2eb6dfc9b9c8..681c66e45267 100644 --- a/tests/python/relax/test_pytorch_integration.py +++ b/tests/python/relax/test_pytorch_integration.py @@ -32,11 +32,11 @@ import torch.nn.functional as F import tvm -from tvm import relax, tir +from tvm import relax, tirx from tvm.relax import BasePyModule from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @I.ir_module diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index 192f89e399cb..056833bdc2a5 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -28,7 +28,7 @@ from tvm.base import TVMError from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T exec_mode = tvm.testing.parameter("bytecode", "compiled") diff --git a/tests/python/relax/test_relax_to_pyfunc_converter.py b/tests/python/relax/test_relax_to_pyfunc_converter.py index a223ef15c216..6a14a10b9a08 100644 --- a/tests/python/relax/test_relax_to_pyfunc_converter.py +++ b/tests/python/relax/test_relax_to_pyfunc_converter.py @@ -30,7 +30,7 @@ from tvm.relax.relax_to_pyfunc_converter import RelaxToPyFuncConverter from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @I.ir_module diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py index 8f75804cbba6..6658c52581f0 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py @@ -142,7 +142,7 @@ def set_global_func(head_dim, dtype): mod = tvm.IRModule({"main": tir_func}) with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) - f = tvm.tir.build(mod["main"], target=target) + f = tvm.tirx.build(mod["main"], target=target) builts.append(f.main) ( @@ -188,13 +188,13 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): tvm.runtime.empty((), dtype, device=device), ftranspose_append, None, # f_transpose_append_mla - ["tir", fattn_prefill_ragged], - ["tir", fattn_prefill], - ["tir", fattn_decode], - ["tir", fattn_prefill_sliding_window], - ["tir", fattn_decode_sliding_window], - ["tir", fattn_prefill_with_tree_mask_paged_kv_cache], - ["tir", fattn_prefill_with_tree_mask], + ["tirx", fattn_prefill_ragged], + ["tirx", fattn_prefill], + ["tirx", fattn_decode], + ["tirx", fattn_prefill_sliding_window], + ["tirx", fattn_decode_sliding_window], + ["tirx", fattn_prefill_with_tree_mask_paged_kv_cache], + ["tirx", fattn_prefill_with_tree_mask], [], # f_mla_prefill [fmerge_state], fsplit_rotary, diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index fdc706d4d457..da9071c1e43d 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -136,7 +136,7 @@ def set_global_func(rope_mode: RopeMode): mod = tvm.IRModule({"main": tir_func}) with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) - f = tvm.tir.build(mod["main"], target=target) + f = tvm.tirx.build(mod["main"], target=target) builts.append(f.main) ( diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py index 07da9da8f7ba..4683f6fdde0b 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py @@ -146,7 +146,7 @@ def set_global_func(dtype): mod = tvm.IRModule({"main": tir_func}) with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) - f = tvm.tir.build(mod["main"], target=target) + f = tvm.tirx.build(mod["main"], target=target) builts.append(f.main) ( diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py index 38c5cda9b1d3..7b13112cac00 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py @@ -132,7 +132,7 @@ def set_global_func(dtype): mod = tvm.IRModule({"main": tir_func}) with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) - f = tvm.tir.build(mod["main"], target=target) + f = tvm.tirx.build(mod["main"], target=target) builts.append(f.main) ( @@ -187,14 +187,14 @@ def create_kv_cache(dtype): tvm.runtime.empty((), dtype, device=device), None, # f_transpose_append_mha ftranspose_append, - ["tir", fmla_prefill_ragged], # fattn_prefill_ragged + ["tirx", fmla_prefill_ragged], # fattn_prefill_ragged [], # fattn_prefill [], # fattn_decode [], # fattn_prefill_sliding_window [], # fattn_decode_sliding_window [], # fattn_prefill_with_tree_mask_paged_kv_cache [], # fattn_prefill_with_tree_mask - ["tir", fmla_prefill], + ["tirx", fmla_prefill], [fmerge_state, fmerge_state_additional], fdumb, # fsplit_rotary fcopy_single_page, diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index f385e53db083..4d04f01fed83 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -144,7 +144,7 @@ def set_global_func(head_dim, dtype): mod = tvm.IRModule({"main": tir_func}) with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) - f = tvm.tir.build(mod["main"], target=target) + f = tvm.tirx.build(mod["main"], target=target) builts.append(f.main) ( @@ -190,13 +190,13 @@ def create_kv_cache(head_dim, dtype, rope_mode, support_sliding_window): tvm.runtime.empty((), dtype, device=device), ftranspose_append, None, # f_transpose_append_mla - ["tir", fattn_prefill_ragged], - ["tir", fattn_prefill], - ["tir", fattn_decode], - ["tir", fattn_prefill_sliding_window], - ["tir", fattn_decode_sliding_window], - ["tir", fattn_prefill_with_tree_mask_paged_kv_cache], - ["tir", fattn_prefill_with_tree_mask], + ["tirx", fattn_prefill_ragged], + ["tirx", fattn_prefill], + ["tirx", fattn_decode], + ["tirx", fattn_prefill_sliding_window], + ["tirx", fattn_decode_sliding_window], + ["tirx", fattn_prefill_with_tree_mask_paged_kv_cache], + ["tirx", fattn_prefill_with_tree_mask], [], # f_mla_prefill [fmerge_state], fsplit_rotary, diff --git a/tests/python/relax/test_runtime_builtin_rnn_state.py b/tests/python/relax/test_runtime_builtin_rnn_state.py index f67b42a7ee2c..18cb9c15c60b 100644 --- a/tests/python/relax/test_runtime_builtin_rnn_state.py +++ b/tests/python/relax/test_runtime_builtin_rnn_state.py @@ -22,10 +22,10 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.runtime import ShapeTuple from tvm.s_tir import dlight as dl -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=invalid-name @@ -80,7 +80,7 @@ def _build(tir_func): mod = tvm.IRModule({"main": tir_func}) with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) # pylint: disable=not-callable - f = tvm.tir.build(mod["main"], target=target) + f = tvm.tirx.build(mod["main"], target=target) return f.main _f_tir_gets, _f_tir_sets = [], [] @@ -218,7 +218,7 @@ def _rnn_state_get( def rnn_state_set( - shape: Sequence[int | tir.Var], + shape: Sequence[int | tirx.Var], dtype: str, ): # fmt: off diff --git a/tests/python/relax/test_struct_info.py b/tests/python/relax/test_struct_info.py index cf3202ddc915..31060f11b365 100644 --- a/tests/python/relax/test_struct_info.py +++ b/tests/python/relax/test_struct_info.py @@ -19,7 +19,7 @@ import tvm import tvm.testing -from tvm import TVMError, tir +from tvm import TVMError, tirx from tvm import relax as rx @@ -91,7 +91,7 @@ def test_prim_struct_info(): def test_prim_struct_info_with_expr(): - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") sinfo = rx.PrimStructInfo(value=n + 1) _check_equal(sinfo, rx.PrimStructInfo(value=n + 1)) @@ -107,7 +107,7 @@ def test_prim_struct_info_with_expr(): def test_shape_struct_info(): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") s0 = rx.ShapeStructInfo([1, n + 1, m]) s1 = rx.ShapeStructInfo([1, n + 1, m]) @@ -148,7 +148,7 @@ def test_shape_struct_info(): def test_tensor_struct_info(): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") s0 = rx.TensorStructInfo([1, n + 1, m], "float32") s1 = rx.TensorStructInfo(rx.ShapeExpr([1, n + 1, m]), "float32") @@ -193,7 +193,7 @@ def test_tensor_struct_info(): def test_tuple_struct_info(): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") s0 = rx.TensorStructInfo([1, 2, m + n], "float32") s1 = rx.ObjectStructInfo() @@ -221,7 +221,7 @@ def test_tuple_struct_info(): def test_func_struct_info(): def fn_info(c): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = rx.TensorStructInfo([c, n, m], "float32") y = rx.TensorStructInfo([c, n, 1], "float32") z = rx.TensorStructInfo([c, n, m], "float32") diff --git a/tests/python/relax/test_testing_nn.py b/tests/python/relax/test_testing_nn.py index 036d272031c4..f9a508b8863c 100644 --- a/tests/python/relax/test_testing_nn.py +++ b/tests/python/relax/test_testing_nn.py @@ -21,7 +21,7 @@ from tvm.relax.testing import nn from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_emit(): @@ -108,7 +108,7 @@ def activation(state: R.Tensor(("batch_size", 32), dtype="float32")) -> R.Tensor return state model = Layer(64, 32) - batch_size = tvm.tir.Var("batch_size", "int64") + batch_size = tvm.tirx.Var("batch_size", "int64") input = nn.Placeholder((batch_size, 64), dtype="float32", name="input") bb = relax.BlockBuilder() diff --git a/tests/python/relax/test_tir_call_source_kernel.py b/tests/python/relax/test_tir_call_source_kernel.py index 42ce4ea4f5e6..e17e63f4f805 100644 --- a/tests/python/relax/test_tir_call_source_kernel.py +++ b/tests/python/relax/test_tir_call_source_kernel.py @@ -22,7 +22,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T add_cuda_source = """ extern "C" __global__ void add_kernel(float* x, float* y, float* output, int n_elements) { diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 8beca9ef0cde..7f331f439f0a 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -24,7 +24,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_to_non_dataflow(): @@ -281,7 +281,7 @@ class Input: @T.prim_func def zeros(A: T.Buffer((2, 3), "int32")): # just overwrites A with 0s - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -299,7 +299,7 @@ def foo(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): class Expected: @T.prim_func def zeros(A: T.Buffer((2, 3), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -325,7 +325,7 @@ def copy( A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32"), C: T.Buffer((2, 3), "int32") ): # copies the contents of C into A and B - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -354,7 +354,7 @@ def copy( A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32"), C: T.Buffer((2, 3), "int32") ): # copies the contents of C into A and B - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -388,7 +388,7 @@ def copy( out2: T.Buffer((2, 3), "int32"), ): # copies the contents of C into A, out1, and out2 - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -427,7 +427,7 @@ def copy( out1: T.Buffer((2, 3), "int32"), out2: T.Buffer((2, 3), "int32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) diff --git a/tests/python/relax/test_transform_adjust_matmul_order.py b/tests/python/relax/test_transform_adjust_matmul_order.py index 17cf4ee9a5e0..a086f3abdb8d 100644 --- a/tests/python/relax/test_transform_adjust_matmul_order.py +++ b/tests/python/relax/test_transform_adjust_matmul_order.py @@ -23,7 +23,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T class Base: diff --git a/tests/python/relax/test_transform_alter_op_impl.py b/tests/python/relax/test_transform_alter_op_impl.py index 39c278575ae4..b0d911a5d4eb 100644 --- a/tests/python/relax/test_transform_alter_op_impl.py +++ b/tests/python/relax/test_transform_alter_op_impl.py @@ -22,8 +22,8 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T -from tvm.tir import IndexMap +from tvm.script import tirx as T +from tvm.tirx import IndexMap kOperatorName = "operator_name" @@ -378,7 +378,7 @@ def relax_relu_replacement( @T.prim_func(private=True) def remove_pad(var_input: T.handle, var_output: T.handle): - T.func_attr({"operator_name": "remove_pad", "tir.noalias": True}) + T.func_attr({"operator_name": "remove_pad", "tirx.noalias": True}) p0 = T.int64() input = T.match_buffer(var_input, (p0,)) i0 = T.int64() diff --git a/tests/python/relax/test_transform_annotate_tir_op_pattern.py b/tests/python/relax/test_transform_annotate_tir_op_pattern.py index f457a3ae5e22..9590adb9d20d 100644 --- a/tests/python/relax/test_transform_annotate_tir_op_pattern.py +++ b/tests/python/relax/test_transform_annotate_tir_op_pattern.py @@ -22,7 +22,7 @@ import tvm.script import tvm.testing from tvm import relax -from tvm.script import tir as T +from tvm.script import tirx as T class OpPatternKind(enum.IntEnum): @@ -203,7 +203,7 @@ def tir_bias_add( C: T.Buffer((1, 1000), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "tir_bias_add", "tir.noalias": True}) + T.func_attr({"global_symbol": "tir_bias_add", "tirx.noalias": True}) # body # with T.sblock("root") for i0, i1 in T.grid(1, 1000): @@ -227,7 +227,7 @@ def add_with_unit_dim_len_broadcast( B: T.Buffer((64, 1, 1), "float32"), C: T.Buffer((1, 64, 112, 112), "float32"), ) -> None: - T.func_attr({"global_symbol": "add5", "tir.noalias": True}) + T.func_attr({"global_symbol": "add5", "tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(1, 64, 112, 112): with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -249,7 +249,7 @@ def add_zero_dim( B: T.Buffer((), "float32"), C: T.Buffer((128,), "float32"), ) -> None: - T.func_attr({"global_symbol": "add8", "tir.noalias": True}) + T.func_attr({"global_symbol": "add8", "tirx.noalias": True}) for i0 in T.serial(128): with T.sblock("T_add"): ax0 = T.axis.spatial(128, i0) diff --git a/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py b/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py index 36be33f1595b..f6801a2bd5d5 100644 --- a/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py +++ b/tests/python/relax/test_transform_attach_attr_layout_free_buffers.py @@ -18,11 +18,11 @@ import numpy as np import tvm.testing -from tvm import relax, tir +from tvm import relax, tirx from tvm.relax.transform import CombineParallelMatmul from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import relax as relax_builder diff --git a/tests/python/relax/test_transform_attach_global_symbol.py b/tests/python/relax/test_transform_attach_global_symbol.py index ad6800d55f30..4d57cc8a9661 100644 --- a/tests/python/relax/test_transform_attach_global_symbol.py +++ b/tests/python/relax/test_transform_attach_global_symbol.py @@ -20,11 +20,11 @@ import tvm import tvm.script -from tvm import relax, tir +from tvm import relax, tirx from tvm.ir import assert_structural_equal from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_basic(): diff --git a/tests/python/relax/test_transform_bind_params.py b/tests/python/relax/test_transform_bind_params.py index 6742a017492c..da4796f0172b 100644 --- a/tests/python/relax/test_transform_bind_params.py +++ b/tests/python/relax/test_transform_bind_params.py @@ -23,7 +23,7 @@ import tvm.testing from tvm import relax from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T use_np_array = tvm.testing.parameter(False, True) diff --git a/tests/python/relax/test_transform_bind_symbolic_vars.py b/tests/python/relax/test_transform_bind_symbolic_vars.py index 74b809418ae6..d7deae250f9e 100644 --- a/tests/python/relax/test_transform_bind_symbolic_vars.py +++ b/tests/python/relax/test_transform_bind_symbolic_vars.py @@ -23,7 +23,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_bind_tensors(): diff --git a/tests/python/relax/test_transform_bundle_model_params.py b/tests/python/relax/test_transform_bundle_model_params.py index 942c99976ccf..1fa9ecff5649 100644 --- a/tests/python/relax/test_transform_bundle_model_params.py +++ b/tests/python/relax/test_transform_bundle_model_params.py @@ -22,7 +22,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_basic(): diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index d341625e9b29..ccbb011bb61b 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -26,7 +26,7 @@ from tvm.relax.transform.transform import CanonicalizeBindings from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def verify(input, expected): diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py index 6e5bbdd72996..c9a8497efd57 100644 --- a/tests/python/relax/test_transform_codegen_pass.py +++ b/tests/python/relax/test_transform_codegen_pass.py @@ -24,13 +24,13 @@ import tvm import tvm.testing -from tvm import relax, s_tir, tir +from tvm import relax, s_tir, tirx from tvm.contrib import utils from tvm.relax.dpl import is_op, wildcard from tvm.relax.testing import transform from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T env_checker_codegen = tvm.get_global_func("relax.ext.tensorrt", True) env_checker_runtime = tvm.get_global_func("relax.is_tensorrt_runtime_enabled", True) diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py b/tests/python/relax/test_transform_combine_parallel_matmul.py index 2b14aa335beb..11bcd8c2ba26 100644 --- a/tests/python/relax/test_transform_combine_parallel_matmul.py +++ b/tests/python/relax/test_transform_combine_parallel_matmul.py @@ -16,10 +16,10 @@ # under the License. # ruff: noqa: E731, F401, F841 import tvm.testing -from tvm import relax, tir +from tvm import relax, tirx from tvm.relax.transform import CombineParallelMatmul from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import relax as relax_builder diff --git a/tests/python/relax/test_transform_compute_prim_value.py b/tests/python/relax/test_transform_compute_prim_value.py index 9fe892f5249a..1a1a283f6888 100644 --- a/tests/python/relax/test_transform_compute_prim_value.py +++ b/tests/python/relax/test_transform_compute_prim_value.py @@ -19,7 +19,7 @@ import tvm.testing from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_prim_value_in_assert_condition(): @@ -42,7 +42,7 @@ def main(A: R.Tensor(["N"])): @T.prim_func(private=True) def compute_symbolic_expr(N: T.int64) -> T.bool: - T.func_attr({"tir.is_host_func": True}) + T.func_attr({"tirx.is_host_func": True}) T.ret(N % 16 == 0) After = tvm.relax.transform.ComputePrimValue()(Before) @@ -75,7 +75,7 @@ def main(A: R.Tensor(["N"])): @T.prim_func(private=True) def compute_symbolic_expr(N: T.int64) -> T.bool: - T.func_attr({"tir.is_host_func": True}) + T.func_attr({"tirx.is_host_func": True}) T.ret(N % 16 == 0) After = tvm.relax.transform.ComputePrimValue()(Before) @@ -103,7 +103,7 @@ def main(_N: R.Prim(value="N"), _M: R.Prim(value="M")) -> R.Prim(value="N*M"): @T.prim_func(private=True) def compute_symbolic_expr(N: T.int64, M: T.int64) -> T.int64: - T.func_attr({"tir.is_host_func": True}) + T.func_attr({"tirx.is_host_func": True}) T.ret(N * M) After = tvm.relax.transform.ComputePrimValue()(Before) diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index 29871b550e50..32d42f9478f2 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -21,7 +21,7 @@ from tvm.relax.transform import ConvertLayout, Normalize from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def verify(input, expected, extra_ops={}, cb=None): diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py index ac3894e00e9e..76d34e6c9dc5 100644 --- a/tests/python/relax/test_transform_cse.py +++ b/tests/python/relax/test_transform_cse.py @@ -24,7 +24,7 @@ from tvm.relax.transform import EliminateCommonSubexpr from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def verify(input, expected, call_only=False): diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index db9f194bd2ef..25ba006b3999 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -23,7 +23,7 @@ from tvm.relax.transform import DeadCodeElimination from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def verify(input, expected): diff --git a/tests/python/relax/test_transform_decompose_ops.py b/tests/python/relax/test_transform_decompose_ops.py index 599a1d28f26b..137ea0750205 100644 --- a/tests/python/relax/test_transform_decompose_ops.py +++ b/tests/python/relax/test_transform_decompose_ops.py @@ -25,7 +25,7 @@ from tvm.relax import Function from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_batch_norm_inference(): diff --git a/tests/python/relax/test_transform_expand_tuple_args.py b/tests/python/relax/test_transform_expand_tuple_args.py index f5389cacb6fe..11be7b6f7e6f 100644 --- a/tests/python/relax/test_transform_expand_tuple_args.py +++ b/tests/python/relax/test_transform_expand_tuple_args.py @@ -20,7 +20,7 @@ import tvm.testing from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_simple(): diff --git a/tests/python/relax/test_transform_fold_constant.py b/tests/python/relax/test_transform_fold_constant.py index d529fd5ff612..3fdf8335f76e 100644 --- a/tests/python/relax/test_transform_fold_constant.py +++ b/tests/python/relax/test_transform_fold_constant.py @@ -23,7 +23,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def gen_mod(mod, name, binding): diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 90a859843899..892c578b3c02 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -21,7 +21,7 @@ from tvm import relax, topi from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def _check(mod_actual, mod_expected): @@ -1010,7 +1010,7 @@ def main( class Expected: @T.prim_func(private=True) def add(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(1), T.int64(320), T.int64(1), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32")): - T.func_attr({"op_pattern": 0, "tir.noalias": True}) + T.func_attr({"op_pattern": 0, "tirx.noalias": True}) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(320), T.int64(64), T.int64(64)): with T.sblock("T_add"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -1020,7 +1020,7 @@ def add(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64( @T.prim_func(private=True) def add1(rxplaceholder: T.Buffer((T.int64(2), T.int64(320)), "float32"), rxplaceholder_1: T.Buffer((T.int64(320),), "float32"), T_add: T.Buffer((T.int64(2), T.int64(320)), "float32")): - T.func_attr({"op_pattern": 0, "tir.noalias": True}) + T.func_attr({"op_pattern": 0, "tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(2), T.int64(320)): with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -1030,7 +1030,7 @@ def add1(rxplaceholder: T.Buffer((T.int64(2), T.int64(320)), "float32"), rxplace @T.prim_func(private=True) def add2(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(320), T.int64(1), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32")): - T.func_attr({"op_pattern": 0, "tir.noalias": True}) + T.func_attr({"op_pattern": 0, "tirx.noalias": True}) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(320), T.int64(64), T.int64(64)): with T.sblock("T_add"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -1040,7 +1040,7 @@ def add2(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64 @T.prim_func(private=True) def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32"), rxplaceholder_1: T.Buffer((T.int64(320), T.int64(320), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int64(64)), "float32")): - T.func_attr({"op_pattern": 4, "tir.noalias": True}) + T.func_attr({"op_pattern": 4, "tirx.noalias": True}) pad_temp = T.sblock_alloc_buffer((T.int64(2), T.int64(320), T.int64(66), T.int64(66))) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(320), T.int64(66), T.int64(66)): with T.sblock("pad_temp"): @@ -1059,7 +1059,7 @@ def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(320), T.int64(64), T.int @T.prim_func(private=True) def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(1280)), "float32"), rxplaceholder_1: T.Buffer((T.int64(1280), T.int64(320)), "float32"), matmul: T.Buffer((T.int64(2), T.int64(320)), "float32")): - T.func_attr({"op_pattern": 4, "tir.noalias": True}) + T.func_attr({"op_pattern": 4, "tirx.noalias": True}) for i0, i1, k in T.grid(T.int64(2), T.int64(320), T.int64(1280)): with T.sblock("matmul"): v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) @@ -1071,7 +1071,7 @@ def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(1280)), "float32"), rxpl @T.prim_func(private=True) def reshape(rxplaceholder: T.Buffer((T.int64(320),), "float32"), T_reshape: T.Buffer((T.int64(1), T.int64(320), T.int64(1), T.int64(1)), "float32")): - T.func_attr({"op_pattern": 2, "tir.noalias": True}) + T.func_attr({"op_pattern": 2, "tirx.noalias": True}) for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(320), T.int64(1), T.int64(1)): with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -1081,7 +1081,7 @@ def reshape(rxplaceholder: T.Buffer((T.int64(320),), "float32"), T_reshape: T.Bu @T.prim_func(private=True) def reshape1(rxplaceholder: T.Buffer((T.int64(2), T.int64(320)), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(320), T.int64(1), T.int64(1)), "float32")): - T.func_attr({"op_pattern": 2, "tir.noalias": True}) + T.func_attr({"op_pattern": 2, "tirx.noalias": True}) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(320), T.int64(1), T.int64(1)): with T.sblock("T_reshape"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -1091,7 +1091,7 @@ def reshape1(rxplaceholder: T.Buffer((T.int64(2), T.int64(320)), "float32"), T_r @T.prim_func(private=True) def transpose(rxplaceholder: T.Buffer((T.int64(320), T.int64(1280)), "float32"), T_transpose: T.Buffer((T.int64(1280), T.int64(320)), "float32")): - T.func_attr({"op_pattern": 2, "tir.noalias": True}) + T.func_attr({"op_pattern": 2, "tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(1280), T.int64(320)): with T.sblock("T_transpose"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -1165,7 +1165,7 @@ def main(inp_0: R.Tensor((1, 784), dtype="float32"), inp_1: R.Tensor((1, 128), d class Expected: @T.prim_func(private=True) def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(128)), "float32")): - T.func_attr({"op_pattern": 0, "tir.noalias": True}) + T.func_attr({"op_pattern": 0, "tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(1), T.int64(128)): with T.sblock("T_add"): @@ -1176,7 +1176,7 @@ def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceh @T.prim_func(private=True) def add1(rxplaceholder: T.Buffer((T.int64(1), T.int64(10)), "float32"), rxplaceholder_1: T.Buffer((T.int64(10),), "float32"), T_add: T.Buffer((T.int64(1), T.int64(10)), "float32")): - T.func_attr({"op_pattern": 0, "tir.noalias": True}) + T.func_attr({"op_pattern": 0, "tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(1), T.int64(10)): with T.sblock("T_add"): @@ -1187,7 +1187,7 @@ def add1(rxplaceholder: T.Buffer((T.int64(1), T.int64(10)), "float32"), rxplaceh @T.prim_func(private=True) def matmul(rxplaceholder: T.Buffer((T.int64(1), T.int64(784)), "float32"), rxplaceholder_1: T.Buffer((T.int64(784), T.int64(128)), "float32"), matmul_1: T.Buffer((T.int64(1), T.int64(128)), "float32")): - T.func_attr({"op_pattern": 4, "tir.noalias": True}) + T.func_attr({"op_pattern": 4, "tirx.noalias": True}) # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(1), T.int64(128), T.int64(784)): with T.sblock("matmul"): @@ -1200,7 +1200,7 @@ def matmul(rxplaceholder: T.Buffer((T.int64(1), T.int64(784)), "float32"), rxpla @T.prim_func(private=True) def matmul1(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(10)), "float32"), matmul: T.Buffer((T.int64(1), T.int64(10)), "float32")): - T.func_attr({"op_pattern": 4, "tir.noalias": True}) + T.func_attr({"op_pattern": 4, "tirx.noalias": True}) # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(1), T.int64(10), T.int64(128)): with T.sblock("matmul"): @@ -1213,7 +1213,7 @@ def matmul1(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), rxpl @T.prim_func(private=True) def relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), compute: T.Buffer((T.int64(1), T.int64(128)), "float32")): - T.func_attr({"op_pattern": 0, "tir.noalias": True}) + T.func_attr({"op_pattern": 0, "tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(1), T.int64(128)): with T.sblock("compute"): @@ -1224,7 +1224,7 @@ def relu(rxplaceholder: T.Buffer((T.int64(1), T.int64(128)), "float32"), compute @T.prim_func(private=True) def transpose(rxplaceholder: T.Buffer((T.int64(128), T.int64(784)), "float32"), T_transpose: T.Buffer((T.int64(784), T.int64(128)), "float32")): - T.func_attr({"op_pattern": 2, "tir.noalias": True}) + T.func_attr({"op_pattern": 2, "tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(784), T.int64(128)): with T.sblock("T_transpose"): @@ -1235,7 +1235,7 @@ def transpose(rxplaceholder: T.Buffer((T.int64(128), T.int64(784)), "float32"), @T.prim_func(private=True) def transpose1(rxplaceholder: T.Buffer((T.int64(10), T.int64(128)), "float32"), T_transpose: T.Buffer((T.int64(128), T.int64(10)), "float32")): - T.func_attr({"op_pattern": 2, "tir.noalias": True}) + T.func_attr({"op_pattern": 2, "tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(128), T.int64(10)): with T.sblock("T_transpose"): @@ -1317,7 +1317,7 @@ def main(s: R.Shape(["n"])): n = T.int64() with R.dataflow(): lv0 = R.emit_te(topi.full, [n, n], "float32", 0) - lv1 = R.emit_te(topi.trilu, lv0, tvm.tir.const(1, "int32"), upper=True) + lv1 = R.emit_te(topi.trilu, lv0, tvm.tirx.const(1, "int32"), upper=True) gv = R.emit_te(topi.broadcast_to, lv1, [1, 1, n, n]) R.output(gv) return gv @@ -1332,7 +1332,7 @@ def fused_full_trilu_broadcast_to( n = T.int64() with R.dataflow(): lv0 = R.emit_te(topi.full, [n, n], "float32", 0) - lv1 = R.emit_te(topi.trilu, lv0, tvm.tir.const(1, "int32"), upper=True) + lv1 = R.emit_te(topi.trilu, lv0, tvm.tirx.const(1, "int32"), upper=True) gv = R.emit_te(topi.broadcast_to, lv1, [1, 1, n, n]) R.output(gv) return gv @@ -1359,7 +1359,7 @@ def main(s: R.Shape(["n"]), kv_cache: R.Object): n = T.int64() with R.dataflow(): lv0 = R.emit_te(topi.full, [n, n], "float32", 0) - lv1 = R.emit_te(topi.trilu, lv0, tvm.tir.const(1, "int32"), upper=True) + lv1 = R.emit_te(topi.trilu, lv0, tvm.tirx.const(1, "int32"), upper=True) lv2 = R.emit_te(topi.broadcast_to, lv1, [1, 1, n, n]) gv = R.call_pure_packed( "vm.builtin.attention_kv_cache_view", @@ -1380,7 +1380,7 @@ def fused_full_trilu_broadcast_to( n = T.int64() with R.dataflow(): lv0 = R.emit_te(topi.full, [n, n], "float32", 0) - lv1 = R.emit_te(topi.trilu, lv0, tvm.tir.const(1, "int32"), upper=True) + lv1 = R.emit_te(topi.trilu, lv0, tvm.tirx.const(1, "int32"), upper=True) gv = R.emit_te(topi.broadcast_to, lv1, [1, 1, n, n]) R.output(gv) return gv @@ -1517,7 +1517,7 @@ def add( B: T.Buffer((), "float32"), Out: T.Buffer((T.int64(10), T.int64(20)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -1527,7 +1527,7 @@ def add( @T.prim_func(private=True) def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(10), T.int64(20)): with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) @@ -1537,7 +1537,7 @@ def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): @T.prim_func(private=True) def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.sblock("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -1579,7 +1579,7 @@ def add( B: T.Buffer((), "float32"), Out: T.Buffer((T.int64(10), T.int64(20)), "float32"), ): - T.func_attr({"tir.noalias": True, "op_pattern": 0}) + T.func_attr({"tirx.noalias": True, "op_pattern": 0}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -1589,7 +1589,7 @@ def add( @T.prim_func(private=True) def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): - T.func_attr({"tir.noalias": True, "op_pattern": 0}) + T.func_attr({"tirx.noalias": True, "op_pattern": 0}) for i0, i1 in T.grid(T.int64(10), T.int64(20)): with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) @@ -1599,7 +1599,7 @@ def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): @T.prim_func(private=True) def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): - T.func_attr({"tir.noalias": True, "op_pattern": 0}) + T.func_attr({"tirx.noalias": True, "op_pattern": 0}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.sblock("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -1655,7 +1655,7 @@ def test_packed_params(): class Before: @T.prim_func(private=True) def cast(lv: T.Buffer((T.int64(16), T.int64(16)), "float16"), compute: T.Buffer((T.int64(16), T.int64(16)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(16), T.int64(16)): with T.sblock("compute"): @@ -1666,7 +1666,7 @@ def cast(lv: T.Buffer((T.int64(16), T.int64(16)), "float16"), compute: T.Buffer( @T.prim_func(private=True) def matmul(x: T.Buffer((T.int64(16), T.int64(16)), "float32"), lv2: T.Buffer((T.int64(16), T.int64(16)), "float32"), T_matmul: T.Buffer((T.int64(16), T.int64(16)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, k in T.grid(T.int64(16), T.int64(16), T.int64(16)): with T.sblock("T_matmul"): diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index b684e43e6b0e..0d0842637a61 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -32,7 +32,7 @@ from tvm.relax.transform import PatternCheckContext from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.script.ir_module diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index a77dc17954cf..9b3ac325409a 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -22,7 +22,7 @@ from tvm import relax, topi from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def _check(mod_before, mod_expected): @@ -509,12 +509,12 @@ def te_argmax_idx_val(val): from tvm import te def f_combine(x, y): - lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) - rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1]) + lhs = tvm.tirx.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.tirx.Select((x[1] >= y[1]), x[1], y[1]) return lhs, rhs def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType): - return tvm.tir.const(-1, dtype0), tvm.te.min_value(dtype1) + return tvm.tirx.const(-1, dtype0), tvm.te.min_value(dtype1) argmax = te.comm_reducer(f_combine, f_identity, name="argmax") m, n = val.shape @@ -637,7 +637,7 @@ def fused_add1_exp1_squeeze1( p0: T.Buffer((), "float32"), T_squeeze: T.Buffer((T.int64(20), T.int64(10)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) T_add = T.sblock_alloc_buffer((T.int64(20), T.int64(10))) compute = T.sblock_alloc_buffer((T.int64(20), T.int64(10))) for ax0, ax1 in T.grid(T.int64(20), T.int64(10)): @@ -665,7 +665,7 @@ def fused_add_exp_squeeze( p0: T.Buffer((), "float32"), T_squeeze: T.Buffer((T.int64(10), T.int64(20)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) T_add = T.sblock_alloc_buffer((T.int64(10), T.int64(20))) compute = T.sblock_alloc_buffer((T.int64(10), T.int64(20))) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): @@ -782,7 +782,7 @@ def fused_function( X: T.Buffer([T.int64(16), T.int64(32)], "float32"), Z: T.Buffer([T.int64(16), T.int64(32)], "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) Y = T.sblock_alloc_buffer(X.shape, "float32") for iters in T.grid(*X.shape): with T.sblock("compute_Y"): @@ -866,7 +866,7 @@ def fused_function( C: T.Buffer(T.int64(32), "float32"), Z: T.Buffer(T.int64(512), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) Y = T.sblock_alloc_buffer((T.int64(512),)) for i, j in T.grid(T.int64(16), T.int64(32)): with T.sblock("compute"): @@ -1008,7 +1008,7 @@ def fused( rotary: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float32"), m: T.int64, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) T_add = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128))) for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)): with T.sblock("T_add"): @@ -1053,7 +1053,7 @@ def concatenate( ), T_concat: T.Buffer((T.int64(2), T.int64(4), T.int64(64), T.int64(64)), "float32"), ): - T.func_attr({"op_pattern": 2, "tir.noalias": True}) + T.func_attr({"op_pattern": 2, "tirx.noalias": True}) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4), T.int64(64), T.int64(64)): with T.sblock("T_concat"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -1073,7 +1073,7 @@ def transpose2( rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(64), T.int64(64)), "float32"), T_transpose: T.Buffer((T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32"), ): - T.func_attr({"op_pattern": 2, "tir.noalias": True}) + T.func_attr({"op_pattern": 2, "tirx.noalias": True}) for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(64), T.int64(64), T.int64(4)): with T.sblock("T_transpose"): v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) @@ -1121,7 +1121,7 @@ def fused_concatenate_transpose2( (T.int64(2), T.int64(64), T.int64(64), T.int64(4)), "float32" ), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) T_concat_handle_intermediate = T.sblock_alloc_buffer( (T.int64(2), T.int64(4), T.int64(64), T.int64(64)) ) @@ -1199,7 +1199,7 @@ def fused_transpose_matmul( p_output0: T.handle, n: T.int64, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) y = T.match_buffer(p_y, (n - T.int64(1), T.int64(4))) var_T_matmul_intermediate = T.match_buffer(p_output0, (n - T.int64(1), T.int64(3))) var_T_transpose_intermediate = T.sblock_alloc_buffer((T.int64(4), T.int64(3))) @@ -1246,7 +1246,7 @@ def reshape( A: T.Buffer((T.int64(4), T.int64(8), T.int64(2048)), "float32"), T_reshape: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(64)), "float32"), ): - T.func_attr({"op_pattern": 2, "tir.noalias": True}) + T.func_attr({"op_pattern": 2, "tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(8), T.int64(32), T.int64(64)): with T.sblock("T_reshape"): @@ -1311,7 +1311,7 @@ def fused_reshape( (T.int64(4), T.int64(8), T.int64(32), T.int64(64)), "float32" ), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(8), T.int64(32), T.int64(64)): with T.sblock("T_reshape"): @@ -1411,7 +1411,7 @@ def fused_func( input_embeds: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), Out_intermediate_1: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) Out_intermediate = T.sblock_alloc_buffer((T.int64(4096), T.int64(4096)), "float16") for i, j in T.grid(T.int64(4096), T.int64(4096)): with T.sblock("add"): @@ -1525,7 +1525,7 @@ def fused( rotary_handle: T.handle, m: T.int64, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) sequence_length = T.int64() @@ -1625,7 +1625,7 @@ def fused( X: T.Buffer([T.int64(64)], "float32"), Y: T.Buffer([T.int64(1)], "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in range(T.int64(64)): with T.sblock("sum"): @@ -1724,7 +1724,7 @@ def fused( Y: T.Buffer([T.int64(16)], "float32"), Out: T.Buffer([T.int64(1)], "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) XSum = T.sblock_alloc_buffer([T.int64(1)], "float32") YSum = T.sblock_alloc_buffer([T.int64(1)], "float32") @@ -1825,7 +1825,7 @@ def fused( X: T.Buffer([T.int64(64)], "float32"), Y: T.Buffer([T.int64(1)], "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in range(T.int64(64)): with T.sblock("sum"): @@ -1907,7 +1907,7 @@ def fused_func( input_embeds: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), T_take: T.Buffer((T.int64(1), T.int64(4096)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) Out_handle_intermediate = T.sblock_alloc_buffer( (T.int64(4096), T.int64(4096)), "float16" ) @@ -1947,7 +1947,7 @@ class Module: def add_inplace( A: T.Buffer((T.int64(10), T.int64(20)), "float32"), B: T.Buffer((), "float32") ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -1957,7 +1957,7 @@ def add_inplace( @T.prim_func(private=True) def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(10), T.int64(20)): with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) @@ -1967,7 +1967,7 @@ def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): @T.prim_func(private=True) def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.sblock("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -2026,7 +2026,7 @@ class Expected: def fused_add_exp_squeeze( x: T.Buffer((T.int64(10), T.int64(20)), "float32"), p0: T.Buffer((), "float32") ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -2070,7 +2070,7 @@ def add( B: T.Buffer((), "float32"), Out: T.Buffer((T.int64(10), T.int64(20)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -2078,7 +2078,7 @@ def add( @T.prim_func(private=True) def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(10), T.int64(20)): with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) @@ -2086,7 +2086,7 @@ def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): @T.prim_func(private=True) def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.sblock("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -2139,7 +2139,7 @@ def fused_add_exp_squeeze( p0: T.Buffer((), "float32"), p_output0: T.Buffer((T.int64(10), T.int64(20)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -2180,7 +2180,7 @@ def add( B: T.Buffer((), "float32"), Out: T.Buffer((T.int64(10), T.int64(20)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -2231,7 +2231,7 @@ def fused_sums( p0: T.Buffer((), "float32"), p_output0: T.Buffer((T.int64(10), T.int64(20)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.sblock("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -2370,7 +2370,7 @@ def main( class Expected: @T.prim_func(private=True) def fused_function(x: T.handle, y: T.handle, z: T.handle, c: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) X = T.match_buffer(x, [T.int64(16), T.int64(32)], "float32", axis_separators=[1]) Y = T.match_buffer(y, [T.int64(16), T.int64(32)], "float32", axis_separators=[1]) Z = T.match_buffer(z, [T.int64(16), T.int64(32)], "float32", axis_separators=[1]) @@ -2453,7 +2453,7 @@ def test_block_name_numeric_suffix_deduplication(): class Before: @T.prim_func(private=True) def add1(x: T.Buffer((10,), "float32"), y: T.Buffer((10,), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in range(10): with T.sblock("compute1"): vi = T.axis.spatial(10, i) @@ -2461,7 +2461,7 @@ def add1(x: T.Buffer((10,), "float32"), y: T.Buffer((10,), "float32")): @T.prim_func(private=True) def mul1(x: T.Buffer((10,), "float32"), y: T.Buffer((10,), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in range(10): with T.sblock("compute1"): vi = T.axis.spatial(10, i) @@ -2489,7 +2489,7 @@ def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32" class Expected: @T.prim_func(private=True) def fused_add_mul(p_x: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) x = T.match_buffer(p_x, (T.int64(10),)) y_intermediate_1 = T.match_buffer(p_output0, (T.int64(10),), elem_offset=T.int32(0)) with T.sblock("root"): diff --git a/tests/python/relax/test_transform_fuse_transpose_matmul.py b/tests/python/relax/test_transform_fuse_transpose_matmul.py index f51cc794064d..3117d56ff3b9 100644 --- a/tests/python/relax/test_transform_fuse_transpose_matmul.py +++ b/tests/python/relax/test_transform_fuse_transpose_matmul.py @@ -23,7 +23,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_transform_fuse_transpose_matmul(): @@ -48,7 +48,7 @@ def NT_matmul( w: T.Buffer((T.int64(128), T.int64(256)), "float32"), NT_matmul: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(256)): with T.sblock("NT_matmul"): @@ -103,7 +103,7 @@ def NT_matmul( w: T.Buffer((T.int64(128), T.int64(256)), "float32"), NT_matmul: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, k in T.grid(T.int64(128), T.int64(128), T.int64(256)): with T.sblock("NT_matmul"): diff --git a/tests/python/relax/test_transform_gradient.py b/tests/python/relax/test_transform_gradient.py index 8b440666d08f..b5ad7a998115 100644 --- a/tests/python/relax/test_transform_gradient.py +++ b/tests/python/relax/test_transform_gradient.py @@ -25,7 +25,7 @@ from tvm.ir.base import assert_structural_equal from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def test_simple(): @@ -1212,7 +1212,7 @@ def sum( rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), rxplaceholder_red: T.Buffer((), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for k0, k1 in T.grid(T.int64(3), T.int64(3)): with T.sblock("rxplaceholder_red"): v_k0, v_k1 = T.axis.remap("RR", [k0, k1]) diff --git a/tests/python/relax/test_transform_gradient_te_register.py b/tests/python/relax/test_transform_gradient_te_register.py index db42ddbbf3d5..8621c99ab5ac 100644 --- a/tests/python/relax/test_transform_gradient_te_register.py +++ b/tests/python/relax/test_transform_gradient_te_register.py @@ -15,19 +15,19 @@ # specific language governing permissions and limitations # under the License. # ruff: noqa: E501 -"""Unit tests for registering tir gradient functions in the gradient pass.""" +"""Unit tests for registering tirx gradient functions in the gradient pass.""" import pytest import tvm import tvm.testing -from tvm import relax, tir +from tvm import relax, tirx from tvm.ir.base import assert_structural_equal from tvm.relax.training.utils import register_te_gradient from tvm.relax.transform import Gradient from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T # Only run once in the whole test session @@ -64,7 +64,7 @@ def get_expected_1(): class Expected: @T.prim_func(private=True) def f_mul(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), B: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mul_1: T.Buffer((T.int64(5), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(5), T.int64(5)): with T.sblock("f_mul"): @@ -75,7 +75,7 @@ def f_mul(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), B: T.Buffer((T.int64 @T.prim_func(private=True) def f_mul_grad(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), B: T.Buffer((T.int64(5), T.int64(5)), "float32"), C: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mul_grad_1: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mul_grad_2: T.Buffer((T.int64(5), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(5), T.int64(5)): with T.sblock("f_mul_grad_1"): @@ -151,7 +151,7 @@ def test_call_tir(register_te_grads): class Before: @T.prim_func(private=True) def f_mul(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), B: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mul_1: T.Buffer((T.int64(5), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(5), T.int64(5)): with T.sblock("f_mul"): @@ -180,7 +180,7 @@ def get_expected_2(): class Expected: @T.prim_func(private=True) def f_mul(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mul2: T.Buffer((T.int64(5), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(5), T.int64(5)): with T.sblock("f_mul2"): @@ -191,7 +191,7 @@ def f_mul(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mul2: T.Buffer((T. @T.prim_func(private=True) def f_mulk_grad(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), B: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mulk_grad_1: T.Buffer((T.int64(5), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(5), T.int64(5)): with T.sblock("f_mulk_grad"): @@ -260,7 +260,7 @@ def test_call_tir_kwargs(register_te_grads): class Before: @T.prim_func(private=True) def f_mul(A: T.Buffer((T.int64(5), T.int64(5)), "float32"), f_mul2: T.Buffer((T.int64(5), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(5), T.int64(5)): with T.sblock("f_mul2"): @@ -289,7 +289,7 @@ def get_expected_3(): class Expected: @T.prim_func(private=True) def f_mul(var_A: T.handle, var_B: T.handle, var_f_mul: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() A = T.match_buffer(var_A, (n, n)) B = T.match_buffer(var_B, (n, n)) @@ -304,7 +304,7 @@ def f_mul(var_A: T.handle, var_B: T.handle, var_f_mul: T.handle): @T.prim_func(private=True) def f_mul_grad(var_A: T.handle, var_B: T.handle, var_C: T.handle, var_f_mul_grad_1: T.handle, var_f_mul_grad_2: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() A = T.match_buffer(var_A, (n, n)) B = T.match_buffer(var_B, (n, n)) @@ -362,7 +362,7 @@ def mul(*idx): return tvm.te.compute(src1.shape, mul, name="f_mul") - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") a = relax.Var("a", relax.TensorStructInfo([n, n], "float32")) b = relax.Var("b", relax.TensorStructInfo([n, n], "float32")) diff --git a/tests/python/relax/test_transform_inline_private_functions.py b/tests/python/relax/test_transform_inline_private_functions.py index 41302d53c9a6..61d09cb3af3e 100644 --- a/tests/python/relax/test_transform_inline_private_functions.py +++ b/tests/python/relax/test_transform_inline_private_functions.py @@ -24,7 +24,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_inline_simple(): diff --git a/tests/python/relax/test_transform_ipc_allreduce_rewrite.py b/tests/python/relax/test_transform_ipc_allreduce_rewrite.py index fa68c16e691d..2d3aef17e23b 100644 --- a/tests/python/relax/test_transform_ipc_allreduce_rewrite.py +++ b/tests/python/relax/test_transform_ipc_allreduce_rewrite.py @@ -20,7 +20,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_ipc_allreduce_rewrite(): diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py index 7fdca378d6c6..e0b08c5f2baf 100644 --- a/tests/python/relax/test_transform_lambda_lift.py +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -24,7 +24,7 @@ from tvm.relax import transform from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def _check_equal(x, y): diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index 51fe14618a51..f792c51930fb 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -23,7 +23,7 @@ from tvm.relax.transform import LazyTransformParams from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_lazy_transform_params(): diff --git a/tests/python/relax/test_transform_legalize_ops.py b/tests/python/relax/test_transform_legalize_ops.py index d238bc8e0a58..cd6da2fc7fa7 100644 --- a/tests/python/relax/test_transform_legalize_ops.py +++ b/tests/python/relax/test_transform_legalize_ops.py @@ -25,7 +25,7 @@ from tvm.relax.transform.legalize_ops.common import register_legalize from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_customize_legalize(): @@ -48,7 +48,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") @T.prim_func(private=True) def add(rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -111,7 +111,7 @@ def identity(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_id: @T.prim_func(private=True) def multiply(rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(3), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): with T.sblock("T_multiply"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -195,7 +195,7 @@ def multiply( rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float16"), T_multiply: T.Buffer((T.int64(3), T.int64(3)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): with T.sblock("T_multiply"): @@ -219,7 +219,7 @@ def multiply( rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "uint8"), T_multiply: T.Buffer((T.int64(3), T.int64(3)), "uint8"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): with T.sblock("T_multiply"): @@ -241,14 +241,14 @@ def equal( rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "bool"), T_equal: T.Buffer((T.int64(3), T.int64(3)), "bool"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(3), T.int64(3)): with T.sblock("T_equal"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) T.reads(rxplaceholder[v_ax0, v_ax1]) T.writes(T_equal[v_ax0, v_ax1]) - T_equal[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] == tvm.tir.const(True, "bool") + T_equal[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] == tvm.tirx.const(True, "bool") @R.function def main(x: R.Tensor((3, 3), dtype="bool")) -> R.Tensor((3, 3), dtype="bool"): @@ -401,7 +401,7 @@ def add( B: T.Buffer((T.int64(32), T.int64(32)), "float32"), C: T.Buffer((T.int64(32), T.int64(32)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for iters in T.grid(T.int64(32), T.int64(32)): with T.sblock("T_add"): ax0, ax1 = T.axis.remap("SS", iters) @@ -426,7 +426,7 @@ def add_llvm( B: T.Buffer((T.int64(32), T.int64(32)), "float32"), C: T.Buffer((T.int64(32), T.int64(32)), "float32"), ): - T.func_attr({"target": T.target("llvm"), "tir.noalias": True}) + T.func_attr({"target": T.target("llvm"), "tirx.noalias": True}) for iters in T.grid(T.int64(32), T.int64(32)): with T.sblock("T_add"): ax0, ax1 = T.axis.remap("SS", iters) diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py b/tests/python/relax/test_transform_legalize_ops_binary.py index eea040928086..f9b2074eab4d 100644 --- a/tests/python/relax/test_transform_legalize_ops_binary.py +++ b/tests/python/relax/test_transform_legalize_ops_binary.py @@ -21,7 +21,7 @@ from tvm.relax.transform import LegalizeOps from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T ##################### Binary arithmetic ##################### @@ -45,7 +45,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") @T.prim_func(private=True) def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -76,7 +76,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def add(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_add"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -107,7 +107,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def add(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_add"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -146,7 +146,7 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), @T.prim_func(private=True) def add(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_add: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -194,7 +194,7 @@ def add( rhs: T.float32, output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j, k in T.grid(*lhs.shape): with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) @@ -222,7 +222,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") @T.prim_func(private=True) def divide(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_divide: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_divide"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -253,7 +253,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def divide(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_divide: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_divide"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -284,7 +284,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def divide(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_divide: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_divide"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -323,7 +323,7 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), @T.prim_func(private=True) def divide(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_divide: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -371,7 +371,7 @@ def divide( rhs: T.float32, output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j, k in T.grid(*lhs.shape): with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) @@ -399,7 +399,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") @T.prim_func(private=True) def floor_divide(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_floor_divide: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_floor_divide"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -430,7 +430,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def floor_divide(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_floor_divide: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_floor_divide"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -461,7 +461,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def floor_divide(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_floor_divide: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_floor_divide"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -500,7 +500,7 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), @T.prim_func(private=True) def floor_divide(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_floor_divide: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -548,7 +548,7 @@ def floor_divide( rhs: T.float32, output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j, k in T.grid(*lhs.shape): with T.sblock("T_floordiv"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) @@ -576,7 +576,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") @T.prim_func(private=True) def multiply(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_multiply: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_multiply"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -615,7 +615,7 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), @T.prim_func(private=True) def multiply(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_multiply: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -663,7 +663,7 @@ def multiply( rhs: T.float32, output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j, k in T.grid(*lhs.shape): with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) @@ -686,7 +686,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") class Expected: @T.prim_func(private=True) def power(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_power: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_power"): @@ -723,7 +723,7 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), class Expected: @T.prim_func(private=True) def power(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_power: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) c = T.int64() d = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), c, d)) @@ -781,7 +781,7 @@ def power( rhs: T.float32, output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j, k in T.grid(*lhs.shape): with T.sblock("T_power"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) @@ -809,7 +809,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") @T.prim_func(private=True) def subtract(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_subtract: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_subtract"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -848,7 +848,7 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), @T.prim_func(private=True) def subtract(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_subtract: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -896,7 +896,7 @@ def subtract( rhs: T.float32, output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j, k in T.grid(*lhs.shape): with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) @@ -927,7 +927,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") @T.prim_func(private=True) def equal(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_equal: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_equal"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -958,7 +958,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): @T.prim_func(private=True) def equal(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_equal: T.Buffer((T.int64(2), T.int64(3)), "bool")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_equal"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -989,7 +989,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): @T.prim_func(private=True) def equal(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_equal: T.Buffer((T.int64(2), T.int64(3)), "bool")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_equal"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -1028,7 +1028,7 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), @T.prim_func(private=True) def equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_equal: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -1076,7 +1076,7 @@ def equal( rhs: T.float32, output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j, k in T.grid(*lhs.shape): with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) @@ -1104,7 +1104,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") @T.prim_func(private=True) def greater(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_greater: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_greater"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -1135,7 +1135,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): @T.prim_func(private=True) def greater(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_greater: T.Buffer((T.int64(2), T.int64(3)), "bool")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_greater"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -1166,7 +1166,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): @T.prim_func(private=True) def greater(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_greater: T.Buffer((T.int64(2), T.int64(3)), "bool")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_greater"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -1205,7 +1205,7 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), @T.prim_func(private=True) def greater(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_greater: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -1253,7 +1253,7 @@ def greater( rhs: T.float32, output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j, k in T.grid(*lhs.shape): with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) @@ -1281,7 +1281,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") @T.prim_func(private=True) def greater_equal(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_greater_equal: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_greater_equal"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -1320,7 +1320,7 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), @T.prim_func(private=True) def greater_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_greater_equal: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -1368,7 +1368,7 @@ def greater_equal( rhs: T.float32, output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j, k in T.grid(*lhs.shape): with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) @@ -1396,7 +1396,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") @T.prim_func(private=True) def less(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_less: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_less"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -1435,7 +1435,7 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), @T.prim_func(private=True) def less(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_less: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -1483,7 +1483,7 @@ def less( rhs: T.float32, output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j, k in T.grid(*lhs.shape): with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) @@ -1511,7 +1511,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") @T.prim_func(private=True) def less_equal(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_less_equal: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_less_equal"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -1542,7 +1542,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): @T.prim_func(private=True) def less_equal(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_less_equal: T.Buffer((T.int64(2), T.int64(3)), "bool")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_less_equal"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -1573,7 +1573,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): @T.prim_func(private=True) def less_equal(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_less_equal: T.Buffer((T.int64(2), T.int64(3)), "bool")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_less_equal"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -1612,7 +1612,7 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), @T.prim_func(private=True) def less_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_less_equal: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -1660,7 +1660,7 @@ def less_equal( rhs: T.float32, output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j, k in T.grid(*lhs.shape): with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) @@ -1688,7 +1688,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") @T.prim_func(private=True) def not_equal(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_not_equal: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "bool")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_not_equal"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -1727,7 +1727,7 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), @T.prim_func(private=True) def not_equal(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_not_equal: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -1775,7 +1775,7 @@ def not_equal( rhs: T.float32, output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "bool"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j, k in T.grid(*lhs.shape): with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) @@ -1804,7 +1804,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") @T.prim_func(private=True) def maximum(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_maximum: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_maximum"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -1835,7 +1835,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def maximum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_maximum: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_maximum"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -1866,7 +1866,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def maximum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_maximum: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_maximum"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -1905,7 +1905,7 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), @T.prim_func(private=True) def maximum(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_maximum: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -1953,7 +1953,7 @@ def maximum( rhs: T.float32, output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j, k in T.grid(*lhs.shape): with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) @@ -1982,7 +1982,7 @@ def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32") @T.prim_func(private=True) def minimum(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_minimum: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_minimum"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -2013,7 +2013,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def minimum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_minimum: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_minimum"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -2044,7 +2044,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def minimum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_minimum: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_minimum"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -2083,7 +2083,7 @@ def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), @T.prim_func(private=True) def minimum(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_minimum: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -2131,7 +2131,7 @@ def minimum( rhs: T.float32, output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j, k in T.grid(*lhs.shape): with T.sblock("T_add"): vi, vj, vk = T.axis.remap("SSS", [i, j, k]) diff --git a/tests/python/relax/test_transform_legalize_ops_ccl.py b/tests/python/relax/test_transform_legalize_ops_ccl.py index 8009d847757e..47192c02e900 100644 --- a/tests/python/relax/test_transform_legalize_ops_ccl.py +++ b/tests/python/relax/test_transform_legalize_ops_ccl.py @@ -21,7 +21,7 @@ from tvm.relax.transform import LegalizeOps from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_allreduce(): @@ -110,7 +110,7 @@ def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10,5), "float32"): class Expected: @T.prim_func(private=True) def reshape(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), T_reshape: T.Buffer((T.int64(10), T.int64(2), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(10), T.int64(2), T.int64(5)): with T.sblock("T_reshape"): @@ -121,7 +121,7 @@ def reshape(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), T_reshape: T.Buf @T.prim_func(private=True) def transpose(A: T.Buffer((T.int64(10), T.int64(2), T.int64(5)), "float32"), T_transpose: T.Buffer((T.int64(2), T.int64(10), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(10), T.int64(5)): with T.sblock("T_transpose"): diff --git a/tests/python/relax/test_transform_legalize_ops_create_datatype.py b/tests/python/relax/test_transform_legalize_ops_create_datatype.py index 5f367e5f98d1..c1c289825aae 100644 --- a/tests/python/relax/test_transform_legalize_ops_create_datatype.py +++ b/tests/python/relax/test_transform_legalize_ops_create_datatype.py @@ -20,7 +20,7 @@ import tvm.testing from tvm.relax.transform import LegalizeOps from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T ##################### Creation ##################### @@ -43,7 +43,7 @@ def main(v: R.Tensor((), "int32")) -> R.Tensor((2, 3), "int32"): @T.prim_func(private=True) def full(rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -74,7 +74,7 @@ def main() -> R.Tensor((2, 3), "int32"): @T.prim_func(private=True) def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -105,7 +105,7 @@ def main(v: R.Tensor((), "int32")) -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def full(rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -140,7 +140,7 @@ def main(dumb_param: R.Tensor(("m", "n")), v: R.Tensor((), "int32")) -> R.Tensor @T.prim_func(private=True) def full(rxplaceholder: T.Buffer((), "int32"), var_T_full: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int64() n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="int32") @@ -174,7 +174,7 @@ def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(( @T.prim_func(private=True) def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -205,7 +205,7 @@ def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"): @T.prim_func(private=True) def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -236,7 +236,7 @@ def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(( @T.prim_func(private=True) def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "float64")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -271,7 +271,7 @@ def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tens @T.prim_func(private=True) def full(rxplaceholder: T.Buffer((), "float32"), var_T_full: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int64() n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="int32") @@ -305,7 +305,7 @@ def main() -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def ones(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -340,7 +340,7 @@ def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), "float32"): @T.prim_func(private=True) def ones(var_T_full: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int64() n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") @@ -374,7 +374,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "int32"): @T.prim_func(private=True) def ones(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -409,7 +409,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): @T.prim_func(private=True) def ones(var_T_full: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int64() n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") @@ -443,7 +443,7 @@ def main() -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def zeros(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -478,7 +478,7 @@ def main(dumb_param: R.Tensor(("m", "n"))) -> R.Tensor(("m", "n"), "float32"): @T.prim_func(private=True) def zeros(var_T_full: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int64() n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") @@ -512,7 +512,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "int32"): @T.prim_func(private=True) def zeros(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -547,7 +547,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): @T.prim_func(private=True) def zeros(var_T_full: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int64() n = T.int64() T_full = T.match_buffer(var_T_full, [m, n], dtype="float32") @@ -605,7 +605,7 @@ def main(x: R.Tensor(["n"], "float32")): @T.prim_func(private=True) def arange(var_T_arange: T.handle, n: T.int64): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) T_arange = T.match_buffer(var_T_arange, (n // T.int64(2),), "int64") for ax0 in range(n // T.int64(2)): with T.sblock("T_arange"): @@ -635,7 +635,7 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): @T.prim_func(private=True) def tril(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), trilu: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): with T.sblock("trilu"): i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) @@ -672,7 +672,7 @@ def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", "k"), "int @T.prim_func(private=True) def tril(var_rxplaceholder: T.handle, var_trilu: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) k = T.int64() m = T.int64() n = T.int64() @@ -708,7 +708,7 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): @T.prim_func(private=True) def triu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), trilu: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): with T.sblock("trilu"): i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) @@ -745,7 +745,7 @@ def main(x: R.Tensor(("m", "n", "k"), "int8")) -> R.Tensor(("m", "n", "k"), "int @T.prim_func(private=True) def triu(var_rxplaceholder: T.handle, var_trilu: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) k = T.int64() m = T.int64() n = T.int64() @@ -784,7 +784,7 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "int32"): @T.prim_func(private=True) def cast(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): with T.sblock("compute"): i0_1, i1_1, i2_1 = T.axis.remap("SSS", [i0, i1, i2]) @@ -840,7 +840,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "int32"): @T.prim_func(private=True) def cast(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int64() n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") diff --git a/tests/python/relax/test_transform_legalize_ops_distributed.py b/tests/python/relax/test_transform_legalize_ops_distributed.py index 04f3ce36bc06..61255e10b38f 100644 --- a/tests/python/relax/test_transform_legalize_ops_distributed.py +++ b/tests/python/relax/test_transform_legalize_ops_distributed.py @@ -22,7 +22,7 @@ from tvm.relax.transform import LegalizeOps from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_redistribute_replica_to_shard(): @@ -38,7 +38,7 @@ def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10, 5), "float32"): class Expected: @T.prim_func(private=True) def strided_slice(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), redistribute_replica_to_shard: T.Buffer((T.int64(10), T.int64(5)), "float32"), worker_id: T.int64): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(10), T.int64(5)): with T.sblock("redistribute_replica_to_shard"): diff --git a/tests/python/relax/test_transform_legalize_ops_grad.py b/tests/python/relax/test_transform_legalize_ops_grad.py index 30d74fa5cc5f..4855a8c26bcc 100644 --- a/tests/python/relax/test_transform_legalize_ops_grad.py +++ b/tests/python/relax/test_transform_legalize_ops_grad.py @@ -20,7 +20,7 @@ from tvm.relax.transform import LegalizeOps from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_nll_loss_backward(): @@ -36,7 +36,7 @@ def main(output_grad: R.Tensor((), "float32"), predictions: R.Tensor((2, 3, 4, 5 class Expected: @T.prim_func(private=True) def nll_loss_backward(rxplaceholder: T.Buffer((), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_2: T.Buffer((T.int64(2), T.int64(4), T.int64(5)), "int64"), rxplaceholder_3: T.Buffer((T.int64(4),), "float32"), pred_grad: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): all_weights = T.sblock_alloc_buffer((T.int64(2), T.int64(4), T.int64(5))) T_broadcast_to = T.sblock_alloc_buffer((T.int64(2), T.int64(4), T.int64(5))) @@ -99,7 +99,7 @@ def main(output_grad: R.Tensor((), "float32"), predictions: R.Tensor((2, 3, 4, 5 class Expected: @T.prim_func(private=True) def te_nll_loss_backward_no_weight(rxplaceholder: T.Buffer((), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_2: T.Buffer((T.int64(2), T.int64(4), T.int64(5)), "int64"), pred_grad: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): T_full = T.sblock_alloc_buffer((T.int64(3),)) all_weights = T.sblock_alloc_buffer((T.int64(2), T.int64(4), T.int64(5))) @@ -175,7 +175,7 @@ def main(output_grad: R.Tensor((), dtype="float32"), predictions: R.Tensor((4,), @T.prim_func(private=True) def nll_loss_backward(rxplaceholder: T.Buffer((), "float32"), rxplaceholder_1: T.Buffer((T.int64(4),), "float32"), rxplaceholder_2: T.Buffer((), "int64"), rxplaceholder_3: T.Buffer((T.int64(4),), "float32"), pred_grad: T.Buffer((T.int64(4),), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): all_weights = T.sblock_alloc_buffer(()) T_broadcast_to = T.sblock_alloc_buffer(()) @@ -220,7 +220,7 @@ def main(output_grad: R.Tensor((3, 2, 6, 5), "float32"), data: R.Tensor((3, 2, 1 class Expected: @T.prim_func(private=True) def max_pool2d_backward(A: T.Buffer((T.int64(3), T.int64(2), T.int64(6), T.int64(5)), "float32"), B: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32"), T_pool_grad: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): pad_temp = T.sblock_alloc_buffer((T.int64(3), T.int64(2), T.int64(15), T.int64(13))) maxpool_grad_argmax_v0 = T.sblock_alloc_buffer((T.int64(3), T.int64(2), T.int64(6), T.int64(5)), "int64") @@ -276,7 +276,7 @@ def main(output_grad: R.Tensor((3, 2, 6, 5), "float32"), data: R.Tensor((3, 2, 1 class Expected: @T.prim_func(private=True) def avg_pool2d_backward(output_grad: T.Buffer((T.int64(3), T.int64(2), T.int64(6), T.int64(5)), "float32"), data: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32"), T_pool_grad: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3, wh, ww in T.grid(T.int64(3), T.int64(2), T.int64(10), T.int64(10), T.int64(3), T.int64(3)): with T.sblock("T_pool_grad"): @@ -311,7 +311,7 @@ def main(output_grad: R.Tensor((3, 2, 5), "float32"), x: R.Tensor((3, 4, 5), "fl class Expected: @T.prim_func(private=True) def take_backward(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, out_buf: T.Buffer((T.int64(3), T.int64(4), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(3), T.int64(2), T.int64(5)), offset_factor=1) rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(3), T.int64(4), T.int64(5)), offset_factor=1) rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (T.int64(2),), "int32", offset_factor=1) @@ -348,7 +348,7 @@ def main(output_grad: R.Tensor(("m", "i"), "float32"), x: R.Tensor(("m", "n"), " class Expected: @T.prim_func(private=True) def take_backward(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_take_backward: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m, i = T.int64(), T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, (m, i), offset_factor=1) n = T.int64() diff --git a/tests/python/relax/test_transform_legalize_ops_image.py b/tests/python/relax/test_transform_legalize_ops_image.py index 19cc290f3b1d..23c128eea565 100644 --- a/tests/python/relax/test_transform_legalize_ops_image.py +++ b/tests/python/relax/test_transform_legalize_ops_image.py @@ -20,7 +20,7 @@ import tvm.testing from tvm.relax.transform import LegalizeOps from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_image_resize2d(): @@ -41,7 +41,7 @@ def main(x: R.Tensor((2, 8, 8, 3), "float32")) -> R.Tensor((2, 16, 16, 3), "floa @T.prim_func(private=True) def resize2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(8), T.int64(8), T.int64(3)), "float32"), resize: T.Buffer((T.int64(2), T.int64(16), T.int64(16), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(16), T.int64(16), T.int64(3)): with T.sblock("resize"): i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -80,7 +80,7 @@ def main(dumb_param: R.Tensor(("oh", "ow")), x: R.Tensor(("n", "c", "h", "w", 16 @T.prim_func(private=True) def resize2d(var_rxplaceholder: T.handle, var_resize: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) c = T.int64() h = T.int64() n = T.int64() diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index b24106faa1e9..1be7a397812b 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -22,7 +22,7 @@ from tvm.relax.transform import LegalizeOps from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T ##################### Indexing ##################### @@ -45,7 +45,7 @@ def main(x: R.Tensor((2, 3, 4), "float32"), indices: R.Tensor((4,), "int64")) -> @T.prim_func(private=True) def take(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), rxplaceholder_1: T.Buffer(T.int64(4), "int64"), T_take: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)): with T.sblock("T_take"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) @@ -76,7 +76,7 @@ def main(x: R.Tensor((2, 3, 4), "float32"), index: R.Prim("int64")) -> R.Tensor( @T.prim_func(private=True) def take(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), index: T.int64, T_take: T.Buffer((T.int64(2), T.int64(4)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i2 in T.grid(T.int64(2), T.int64(4)): with T.sblock("T_take"): ax0, ax2 = T.axis.remap("SS", [i0, i2]) @@ -107,7 +107,7 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 4), "float32"): @T.prim_func(private=True) def take(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), T_take: T.Buffer((T.int64(2), T.int64(4)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i2 in T.grid(T.int64(2), T.int64(4)): with T.sblock("T_take"): ax0, ax2 = T.axis.remap("SS", [i0, i2]) @@ -142,7 +142,7 @@ def main(x: R.Tensor(("m", "n"), "float32"), indices: R.Tensor(("i",), "int64")) @T.prim_func(private=True) def take(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_take: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) i = T.int64() m = T.int64() n = T.int64() @@ -183,7 +183,7 @@ def take(x_handle: T.handle, T_take: T.Buffer((T.int64(2), T.int64(4)), "float32 n = T.int64() rxplaceholder = T.match_buffer(x_handle, (T.int64(2), n, T.int64(4)), "float32") - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i2 in T.grid(T.int64(2), T.int64(4)): with T.sblock("T_take"): ax0, ax2 = T.axis.remap("SS", [i0, i2]) @@ -214,7 +214,7 @@ def main(x: R.Tensor((8, 9, 10, 10), dtype="float32")) -> R.Tensor((4, 9, 10, 3) @T.prim_func(private=True) def strided_slice(rxplaceholder: T.Buffer((T.int64(8), T.int64(9), T.int64(10), T.int64(10)), "float32"), T_strided_slice_with_axes: T.Buffer((T.int64(4), T.int64(9), T.int64(10), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(9), T.int64(10), T.int64(3)): with T.sblock("T_strided_slice_with_axes"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -245,7 +245,7 @@ def main(x: R.Tensor((8, 9, 10, 10), dtype="float32")): @T.prim_func(private=True) def strided_slice(rxplaceholder: T.Buffer((T.int64(8), T.int64(9), T.int64(10), T.int64(10)), "float32"), T_strided_slice_with_axes: T.Buffer((T.int64(7), T.int64(9), T.int64(10), T.int64(2)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(7), T.int64(9), T.int64(10), T.int64(2)): with T.sblock("T_strided_slice_with_axes"): @@ -273,7 +273,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor((2, "n"), "float32"): class Expected: @T.prim_func(private=True) def strided_slice(var_A: T.handle, var_T_dynamic_strided_slice_with_axes: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m, n = T.int64(), T.int64() A = T.match_buffer(var_A, (m, n)) T_dynamic_strided_slice_with_axes = T.match_buffer(var_T_dynamic_strided_slice_with_axes, (T.int64(3), n)) @@ -318,7 +318,7 @@ def main(x: R.Tensor((10, "n"), dtype="float32")) -> R.Tensor((3, "n"), dtype="f @T.prim_func(private=True) def strided_slice(var_rxplaceholder: T.handle, var_T_strided_slice_with_axes: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(10), n], dtype="float32") T_strided_slice_with_axes = T.match_buffer(var_T_strided_slice_with_axes, [T.int64(3), n], dtype="float32") @@ -354,7 +354,7 @@ def main(x: R.Tensor((10, "n"), dtype="float32")) -> R.Tensor((3, "n"), dtype="f @T.prim_func(private=True) def strided_slice(var_rxplaceholder: T.handle, var_T_strided_slice_with_axes: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(10), n], dtype="float32") T_strided_slice_with_axes = T.match_buffer(var_T_strided_slice_with_axes, [T.int64(3), n], dtype="float32") @@ -386,7 +386,7 @@ def main(x: R.Tensor((10, "n"), dtype="float32")) -> R.Tensor((3, "n"), dtype="f @T.prim_func(private=True) def strided_slice(var_rxplaceholder: T.handle, var_T_strided_slice_with_axes: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [T.int64(10), n], dtype="float32") T_strided_slice_with_axes = T.match_buffer(var_T_strided_slice_with_axes, [T.int64(3), n], dtype="float32") @@ -418,7 +418,7 @@ def dynamic_strided_slice( rxplaceholder_3: T.Buffer((T.int64(4),), "int64"), var_T_strided_slice_dynamic: T.handle, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) s, s_1, s_2, s_3 = T.int64(), T.int64(), T.int64(), T.int64() T_strided_slice_dynamic = T.match_buffer( var_T_strided_slice_dynamic, (s, s_1, s_2, s_3) @@ -463,7 +463,7 @@ def shape_func( rxplaceholder_3: T.Buffer((T.int64(4),), "int64"), T_shape_func_strided_slice_dynamic: T.Buffer((T.int64(4),), "int64"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i in range(T.int64(4)): with T.sblock("T_shape_func_strided_slice_dynamic"): @@ -706,7 +706,7 @@ def dynamic_strided_slice( rxplaceholder_2: T.Buffer((T.int64(2),), "int64"), var_T_strided_slice_dynamic: T.handle, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() rxplaceholder_3 = T.match_buffer(var_rxplaceholder, (T.int64(10), n)) s, s_1 = T.int64(), T.int64() @@ -741,7 +741,7 @@ def shape_func( rxplaceholder_2: T.Buffer((T.int64(2),), "int64"), T_shape_func_strided_slice_dynamic: T.Buffer((T.int64(2),), "int64"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() rxplaceholder_3 = T.match_buffer(var_rxplaceholder, (T.int64(10), n)) # with T.sblock("root"): @@ -904,7 +904,7 @@ def main(x: R.Tensor((4,), "float32"), y: R.Tensor((2, 3, 4, 5), "float32")) -> @T.prim_func(private=True) def matmul(rxplaceholder: T.Buffer(T.int64(4), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), matmul: T.Buffer((T.int64(2), T.int64(3), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(5), T.int64(4)): with T.sblock("matmul"): i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) @@ -937,7 +937,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32"), y: R.Tensor((5,), "float32")) -> @T.prim_func(private=True) def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer(T.int64(5), "float32"), matmul: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.sblock("matmul"): i0_1, i1_1, i2_1, k = T.axis.remap("SSSR", [i0, i1, i2, i3]) @@ -970,7 +970,7 @@ def main(x: R.Tensor((4,), "float32"), y: R.Tensor((4,), "float32")) -> R.Tensor @T.prim_func(private=True) def matmul(rxplaceholder: T.Buffer(T.int64(4), "float32"), rxplaceholder_1: T.Buffer(T.int64(4), "float32"), matmul: T.Buffer((), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0 in T.serial(T.int64(4)): with T.sblock("matmul"): k = T.axis.reduce(T.int64(4), i0) @@ -1003,7 +1003,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float16"), y: R.Tensor((6, 2, 3, 5, 7), "flo @T.prim_func(private=True) def matmul(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"), rxplaceholder_1: T.Buffer((T.int64(6), T.int64(2), T.int64(3), T.int64(5), T.int64(7)), "float16"), matmul: T.Buffer((T.int64(6), T.int64(2), T.int64(3), T.int64(4), T.int64(7)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(6), T.int64(2), T.int64(3), T.int64(4), T.int64(7), T.int64(5)): with T.sblock("matmul"): i0_1, i1_1, i2_1, i3_1, i4_1, k = T.axis.remap("SSSSSR", [i0, i1, i2, i3, i4, i5]) @@ -1046,7 +1046,7 @@ def main(x: R.Tensor(("b", 1, "m", "k"), "float32"), y: R.Tensor(("a", 1, "c", " @T.prim_func(private=True) def matmul(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_matmul: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -1083,7 +1083,7 @@ def main(x: R.Tensor((1, 1, 4, 5), "float32"), y: R.Tensor((1, 1, 5, 7), "float3 class Expected: @T.prim_func(private=True) def matmul(A: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(5)), "float32"), B: T.Buffer((T.int64(1), T.int64(1), T.int64(5), T.int64(7)), "float32"), matmul_1: T.Buffer((T.int64(1), T.int64(1), T.int64(4), T.int64(7)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(1), T.int64(4), T.int64(7), T.int64(5)): with T.sblock("matmul"): @@ -1130,7 +1130,7 @@ def einsum( rxplaceholder_1: T.Buffer((T.int64(3), T.int64(4)), "float32"), T_einsum: T.Buffer((T.int64(2), T.int64(4)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1, j in T.grid(T.int64(2), T.int64(4), T.int64(3)): with T.sblock("T_einsum"): v_ax0, v_ax1, v_j = T.axis.remap("SSR", [ax0, ax1, j]) @@ -1177,7 +1177,7 @@ def einsum( var_rxplaceholder_1: T.handle, var_T_einsum: T.handle, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a, b = T.int64(), T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b)) c = T.int64() diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index cb7311c51a0d..05b6c50c923b 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -21,7 +21,7 @@ from tvm.relax.transform import LegalizeOps from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T ##################### Manipulation ##################### @@ -44,7 +44,7 @@ def main(x: R.Tensor((2, 1, 3), "float32")) -> R.Tensor((4, 2, 5, 3), "float32") @T.prim_func(private=True) def broadcast_to(rxplaceholder: T.Buffer((T.int64(2), T.int64(1), T.int64(3)), "float32"), T_broadcast_to: T.Buffer((T.int64(4), T.int64(2), T.int64(5), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(2), T.int64(5), T.int64(3)): with T.sblock("T_broadcast_to"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -83,7 +83,7 @@ def main(dumb_param: R.Tensor(("a", "c")), x: R.Tensor(("b", 1, "d"), "float32") @T.prim_func(private=True) def broadcast_to(var_rxplaceholder: T.handle, var_T_broadcast_to: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -120,7 +120,7 @@ def main(x1: R.Tensor((1, 2, 3), "float32"), x2: R.Tensor((1, 3, 3), "float32"), @T.prim_func(private=True) def concatenate(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(1), T.int64(3), T.int64(3)), "float32"), rxplaceholder_2: T.Buffer((T.int64(1), T.int64(4), T.int64(3)), "float32"), T_concat: T.Buffer((T.int64(1), T.int64(9), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2 in T.grid(T.int64(1), T.int64(9), T.int64(3)): with T.sblock("T_concat"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) @@ -153,7 +153,7 @@ def main(t: R.Tuple(R.Tensor((3, 4), "float32"), R.Tensor((3, 5), "float32"))) - @T.prim_func(private=True) def concatenate(rxplaceholder: T.Buffer((T.int64(3), T.int64(4)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3), T.int64(5)), "float32"), T_concat: T.Buffer((T.int64(3), T.int64(9)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(3), T.int64(9)): with T.sblock("T_concat"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -195,7 +195,7 @@ def main(t: R.Tuple(R.Tensor(("a", "b0"), "float32"), R.Tensor(("a", "b1"), "flo @T.prim_func(private=True) def concatenate(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_concat: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b0 = T.int64() b1 = T.int64() @@ -234,7 +234,7 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 1, 1, 1, 3, 1, 4, 1) @T.prim_func(private=True) def expand_dims(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), expand_dims: T.Buffer((T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(3), T.int64(1), T.int64(4), T.int64(1)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(3), T.int64(1), T.int64(4), T.int64(1)): with T.sblock("expand_dims"): i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1 = T.axis.remap("SSSSSSSS", [i0, i1, i2, i3, i4, i5, i6, i7]) @@ -271,7 +271,7 @@ def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", 1, "b", 1, " @T.prim_func(private=True) def expand_dims(var_rxplaceholder: T.handle, var_expand_dims: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -307,7 +307,7 @@ def main(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((24,), "float32"): @T.prim_func(private=True) def reshape(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), T_reshape: T.Buffer(T.int64(24), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0 in T.serial(T.int64(24)): with T.sblock("T_reshape"): ax0 = T.axis.spatial(T.int64(24), i0) @@ -338,7 +338,7 @@ def main(x: R.Tensor((), "float32")) -> R.Tensor((1,), "float32"): @T.prim_func(private=True) def reshape(rxplaceholder: T.Buffer((), "float32"), T_reshape: T.Buffer(T.int64(1), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0 in T.serial(T.int64(1)): with T.sblock("T_reshape"): ax0 = T.axis.spatial(T.int64(1), i0) @@ -375,7 +375,7 @@ def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a * b * c",), "f @T.prim_func(private=True) def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -411,7 +411,7 @@ def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((2, 4, 3, 1), "float3 @T.prim_func(private=True) def transpose(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3), T.int64(4)), "float32"), T_transpose: T.Buffer((T.int64(2), T.int64(4), T.int64(3), T.int64(1)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(4), T.int64(3), T.int64(1)): with T.sblock("T_transpose"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -450,7 +450,7 @@ def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) -> R.Tensor(("b", " @T.prim_func(private=True) def transpose(var_rxplaceholder: T.handle, var_T_transpose: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -487,7 +487,7 @@ def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((8, 3), "float32"): @T.prim_func(private=True) def reshape(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3), T.int64(4)), "float32"), T_reshape: T.Buffer((T.int64(8), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(8), T.int64(3)): with T.sblock("T_reshape"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -517,7 +517,7 @@ def reshape( rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3), T.int64(4)), "float32"), T_reshape: T.Buffer((T.int64(8), T.int64(3)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(8), T.int64(3)): with T.sblock("T_reshape"): @@ -571,7 +571,7 @@ def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b * 2"), "f @T.prim_func(private=True) def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b], dtype="float32") @@ -611,7 +611,7 @@ def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b * 2"), "f @T.prim_func(private=True) def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b], dtype="float32") @@ -651,7 +651,7 @@ def main(x: R.Tensor((10, "b"), "float32")) -> R.Tensor((5, "b * 2"), "float32") class Expected3: @T.prim_func(private=True) def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) b = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(10), b)) T_reshape = T.match_buffer(var_T_reshape, (T.int64(5), b * T.int64(2))) @@ -725,7 +725,7 @@ def reshape( rxplaceholder: T.Buffer(T.int64(16), "float32"), var_T_reshape: T.handle, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) M = T.int64() N = T.int64() T_reshape = T.match_buffer(var_T_reshape, [M,N], "float32") @@ -758,7 +758,7 @@ def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 3, 4), "fl @T.prim_func(private=True) def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32"), T_split_1: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32"), T_split_2: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): with T.sblock("T_split"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) @@ -801,7 +801,7 @@ def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 4, 4), "fl @T.prim_func(private=True) def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split_sections: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32"), T_split_sections_1: T.Buffer((T.int64(2), T.int64(4), T.int64(4)), "float32"), T_split_sections_2: T.Buffer((T.int64(2), T.int64(2), T.int64(4)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(4)): with T.sblock("T_split_sections"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) @@ -845,7 +845,7 @@ def main(x: R.Tensor((2, 10, 4), "float32")) -> R.Tuple([R.Tensor((2, 5, 4), "fl @T.prim_func(private=True) def split(rxplaceholder: T.Buffer((T.int64(2), T.int64(10), T.int64(4)), "float32"), T_split_sections: T.Buffer((T.int64(2), T.int64(5), T.int64(4)), "float32"), T_split_sections_1: T.Buffer((T.int64(2), T.int64(5), T.int64(4)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2 in T.grid(T.int64(2), T.int64(5), T.int64(4)): with T.sblock("T_split_sections"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) @@ -886,7 +886,7 @@ def main(dumb_param: R.Tensor(("n",)), x: R.Tensor(("m", "(n * 3)"), "float32")) @T.prim_func(private=True) def split(var_rxplaceholder: T.handle, var_T_split_sections: T.handle, var_T_split_sections_1: T.handle, var_T_split_sections_2: T.handle, n: T.int64): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n * T.int64(3)], dtype="float32") T_split_sections = T.match_buffer(var_T_split_sections, [m, (n * T.int64(3) + T.int64(3) - T.int64(1)) // T.int64(3)], dtype="float32") @@ -934,7 +934,7 @@ def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 1, 4), " @T.prim_func(private=True) def squeeze(rxplaceholder: T.Buffer((T.int64(2), T.int64(1), T.int64(3), T.int64(1), T.int64(1), T.int64(4)), "float32"), T_squeeze: T.Buffer((T.int64(2), T.int64(3), T.int64(1), T.int64(4)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(1), T.int64(4)): with T.sblock("T_squeeze"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -965,7 +965,7 @@ def main(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) : @T.prim_func(private=True) def squeeze(rxplaceholder: T.Buffer((T.int64(2), T.int64(1), T.int64(3), T.int64(1), T.int64(1), T.int64(4)), "float32"), T_squeeze: T.Buffer((T.int64(2), T.int64(3), T.int64(4)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(4)): with T.sblock("T_squeeze"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) @@ -1000,7 +1000,7 @@ def main(x: R.Tensor(("a", 1, "b", 1), "float32")) -> R.Tensor(("a", "b", 1), "f @T.prim_func(private=True) def squeeze(var_rxplaceholder: T.handle, var_T_squeeze: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [a, T.int64(1), b, T.int64(1)], dtype="float32") @@ -1035,7 +1035,7 @@ def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1, 3), "float32")) -> R.Te @T.prim_func(private=True) def collapse_sum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), rxplaceholder_red: T.Buffer((T.int64(1), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2 in T.grid(T.int64(1), T.int64(3), T.int64(2)): with T.sblock("rxplaceholder_red"): ax0, ax1, k0 = T.axis.remap("SSR", [i0, i1, i2]) @@ -1071,7 +1071,7 @@ def main( @T.prim_func(private=True) def collapse_sum(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32"), rxplaceholder_red: T.Buffer((T.int64(2), T.int64(1)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1, k0, k2 in T.grid(T.int64(2), T.int64(1), T.int64(3), T.int64(3)): with T.sblock("rxplaceholder_red"): v_ax0, v_ax1, v_k0, v_k2 = T.axis.remap("SSRR", [ax0, ax1, k0, k2]) @@ -1104,7 +1104,7 @@ def main(x: R.Tensor((3, 2, 3), dtype="float32")) -> R.Tensor((6, 2, 3), dtype=" @T.prim_func(private=True) def repeat(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32"), T_repeat: T.Buffer((T.int64(6), T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2 in T.grid(T.int64(6), T.int64(2), T.int64(3)): with T.sblock("T_repeat"): @@ -1141,7 +1141,7 @@ def repeat( rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32"), T_repeat: T.Buffer((T.int64(36),), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): T_reshape = T.sblock_alloc_buffer((T.int64(18),)) for ax0 in range(T.int64(18)): @@ -1185,7 +1185,7 @@ def main(x: R.Tensor(("a", "b", "c"), "float32")): class Expected: @T.prim_func(private=True) def repeat(var_rxplaceholder: T.handle, var_T_repeat: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -1225,7 +1225,7 @@ def main(x: R.Tensor((3, 2, 3), "float32")): class Expected: @T.prim_func(private=True) def tile(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32"), T_tile: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(9)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(9)): with T.sblock("T_tile"): @@ -1257,7 +1257,7 @@ def main(x: R.Tensor(("a", "b", "c"), "float32")): class Expected: @T.prim_func(private=True) def tile(var_rxplaceholder: T.handle, var_T_tile: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -1305,7 +1305,7 @@ def flip( rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_reverse_sequence: T.Buffer((T.int64(2), T.int64(3)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_reverse_sequence"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -1344,7 +1344,7 @@ def main( @T.prim_func(private=True) def flip(var_rxplaceholder: T.handle, var_T_reverse_sequence: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a, b = T.int64(), T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b)) T_reverse_sequence = T.match_buffer(var_T_reverse_sequence, (a, b)) @@ -1380,7 +1380,7 @@ def scatter_elements( var_rxplaceholder_2: T.handle, out_buf: T.Buffer((T.int64(4), T.int64(4)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) rxplaceholder = T.match_buffer( var_rxplaceholder, (T.int64(4), T.int64(4)), offset_factor=1 ) @@ -1477,7 +1477,7 @@ def scatter_elements( var_rxplaceholder_2: T.handle, var_scatter_elements_generic: T.handle, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a, b = T.int64(), T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b), offset_factor=1) m, n = T.int64(), T.int64() @@ -1569,7 +1569,7 @@ def main(x: R.Tensor((10, 21, 30), "float32")): class Expected: @T.prim_func(private=True) def te_layout_transform(A: T.Buffer((T.int64(10), T.int64(21), T.int64(30)), "float32"), te_layout_transform_1: T.Buffer((T.int64(10), T.int64(30), T.int64(7), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2 in T.grid(T.int64(10), T.int64(21), T.int64(30)): with T.sblock("te_layout_transform"): @@ -1607,7 +1607,7 @@ def main(x: R.Tensor((10, 20, 30), "float32")): class Expected: @T.prim_func(private=True) def te_layout_transform_with_pad(A: T.Buffer((T.int64(10), T.int64(20), T.int64(30)), "float32"), te_layout_transform_with_pad_1: T.Buffer((T.int64(10), T.int64(30), T.int64(7), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for axis0, axis1, axis2, axis3 in T.grid(T.int64(10), T.int64(30), T.int64(7), T.int64(3)): with T.sblock("te_layout_transform_with_pad"): @@ -1645,7 +1645,7 @@ def main(x: R.Tensor(("a", "b", "c"), "float32")): class Expected: @T.prim_func(private=True) def te_layout_transform_with_pad(var_A: T.handle, var_te_layout_transform_with_pad: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a, b, c = T.int64(), T.int64(), T.int64() A = T.match_buffer(var_A, (a, b, c)) te_layout_transform_with_pad_1 = T.match_buffer(var_te_layout_transform_with_pad, (a, c, (b - b % T.int64(-3)) // T.int64(3), T.int64(3))) @@ -1690,7 +1690,7 @@ def main(x: R.Tensor((10, 20, 30), "float32")): class Expected: @T.prim_func(private=True) def te_layout_transform_with_pad_axis_separator(A: T.Buffer((T.int64(10), T.int64(20), T.int64(30)), "float32"), var_te_layout_transform_with_pad_axis_separator: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) te_layout_transform_with_pad_axis_separator_1 = T.match_buffer(var_te_layout_transform_with_pad_axis_separator, (T.int64(10), T.int64(30), T.int64(7), T.int64(3)), axis_separators=[3]) # with T.sblock("root"): for axis0, axis1, axis2, axis3 in T.grid(T.int64(10), T.int64(30), T.int64(7), T.int64(3)): @@ -1768,7 +1768,7 @@ def te_layout_transform( A: T.Buffer((T.int64(16),), "float32"), te_layout_transform: T.Buffer((T.int64(4), T.int64(4)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in range(T.int64(16)): with T.sblock("te_layout_transform"): vi = T.axis.spatial(T.int64(16), i) @@ -1807,7 +1807,7 @@ def main( @T.prim_func(private=True) def scatter_nd(var_data: T.handle, var_indices: T.handle, var_updates: T.handle, var_scatter_nd_generic: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) data = T.match_buffer(var_data, (T.int64(8),), offset_factor=1) indices = T.match_buffer(var_indices, (T.int64(4), T.int64(1)), "int64") updates = T.match_buffer(var_updates, (T.int64(4),), offset_factor=1) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 5ce24a60c851..53a4fa7b1c8c 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -23,7 +23,7 @@ from tvm.relax.transform import LegalizeOps from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T ##################### Neural network ##################### @@ -46,7 +46,7 @@ def main(x: R.Tensor((2, 128, 28), dtype="float32"), w: R.Tensor((64, 16, 3), dt @T.prim_func(private=True) def conv1d(A: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"), B: T.Buffer((T.int64(64), T.int64(16), T.int64(3)), "float32"), group_conv1d_ncw: T.Buffer((T.int64(2), T.int64(64), T.int64(13)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) pad_temp = T.sblock_alloc_buffer((T.int64(2), T.int64(128), T.int64(30))) for i0, i1, i2 in T.grid(T.int64(2), T.int64(128), T.int64(30)): with T.sblock("pad_temp"): @@ -86,7 +86,7 @@ def main(x: R.Tensor((2, 3, 28), dtype="float32"), w: R.Tensor((4, 3, 3), dtype= @T.prim_func(private=True) def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(3)), "float32"), conv1d_ncw: T.Buffer((T.int64(2), T.int64(4), T.int64(26)), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): pad_temp = T.sblock_alloc_buffer((T.int64(2), T.int64(3), T.int64(28))) for i0, i1, i2 in T.grid(T.int64(2), T.int64(3), T.int64(28)): @@ -127,7 +127,7 @@ def main(x: R.Tensor((2, 28, 128), dtype="float32"), w: R.Tensor((64, 128, 3), d @T.prim_func(private=True) def conv1d(rxplaceholder: T.Buffer((T.int64(2), T.int64(28), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(128), T.int64(3)), "float32"), conv1d_nwc: T.Buffer((T.int64(2), T.int64(26), T.int64(64)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): pad_temp = T.sblock_alloc_buffer((T.int64(2), T.int64(28), T.int64(128))) for i0, i1, i2 in T.grid(T.int64(2), T.int64(28), T.int64(128)): @@ -177,7 +177,7 @@ def main(x: R.Tensor(("n", "c", "w"), dtype="float32"), kernel: R.Tensor(("f", " @T.prim_func(private=True) def conv1d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv1d_ncw: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n, c, w = T.int64(), T.int64(), T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, (n, c, w)) f, kw = T.int64(), T.int64() @@ -218,7 +218,7 @@ def main(x: R.Tensor((2, 128, 28), "float32"), w: R.Tensor((128, 16, 3), "float3 class Expected: @T.prim_func(private=True) def conv1d_transpose(x: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"), w: T.Buffer((T.int64(128), T.int64(16), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(128), T.int64(56)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) data_dilate = T.sblock_alloc_buffer((T.int64(2), T.int64(128), T.int64(55))) data_pad = T.sblock_alloc_buffer((T.int64(2), T.int64(128), T.int64(58))) kernel = T.sblock_alloc_buffer((T.int64(16), T.int64(128), T.int64(3))) @@ -270,7 +270,7 @@ def main(x: R.Tensor((2, 128, 28, 28), "float32"), w: R.Tensor((64, 16, 3, 3), " @T.prim_func(private=True) def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(128), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(16), T.int64(3), T.int64(3)), "float32"), group_conv2d_nchw: T.Buffer((T.int64(2), T.int64(64), T.int64(13), T.int64(13)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) pad_temp = T.sblock_alloc_buffer([T.int64(2), T.int64(128), T.int64(30), T.int64(30)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(128), T.int64(30), T.int64(30)): with T.sblock("pad_temp"): @@ -310,7 +310,7 @@ def main(x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "floa @T.prim_func(private=True) def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(3), T.int64(3)), "float32"), conv2d_nchw: T.Buffer((T.int64(2), T.int64(4), T.int64(26), T.int64(26)), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) pad_temp = T.sblock_alloc_buffer([T.int64(2), T.int64(3), T.int64(28), T.int64(28)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(28), T.int64(28)): with T.sblock("pad_temp"): @@ -350,7 +350,7 @@ def main(x: R.Tensor((2, 28, 28, 128), "float32"), w: R.Tensor((64, 128, 3, 3), @T.prim_func(private=True) def conv2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(28), T.int64(28), T.int64(128)), "float32"), rxplaceholder_1: T.Buffer((T.int64(64), T.int64(128), T.int64(3), T.int64(3)), "float32"), conv2d_nhwc: T.Buffer((T.int64(2), T.int64(26), T.int64(26), T.int64(64)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) pad_temp = T.sblock_alloc_buffer([T.int64(2), T.int64(28), T.int64(28), T.int64(128)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(28), T.int64(28), T.int64(128)): with T.sblock("pad_temp"): @@ -402,7 +402,7 @@ def main(x: R.Tensor(("n", "c", "h", "w"), "float32"), kernel: R.Tensor(("f", "c @T.prim_func(private=True) def conv2d(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_conv2d_nchw: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) c = T.int64() f = T.int64() h = T.int64() @@ -452,7 +452,7 @@ def main(x: R.Tensor((2, 128, 28, 28), dtype="float32"), w: R.Tensor((128, 16, 3 @T.prim_func(private=True) def conv2d_transpose(rxplaceholder: T.Buffer((T.int64(2), T.int64(128), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(16), T.int64(3), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(128), T.int64(56), T.int64(84)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): data_dilate = T.sblock_alloc_buffer((T.int64(2), T.int64(128), T.int64(55), T.int64(82))) data_pad = T.sblock_alloc_buffer((T.int64(2), T.int64(128), T.int64(58), T.int64(86))) @@ -507,7 +507,7 @@ def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((3, 4, 3, 3), @T.prim_func(private=True) def conv2d_transpose(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3), T.int64(4), T.int64(3), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4), T.int64(30), T.int64(30)), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): data_dilate = T.sblock_alloc_buffer((T.int64(2), T.int64(3), T.int64(28), T.int64(28))) data_pad = T.sblock_alloc_buffer((T.int64(2), T.int64(3), T.int64(32), T.int64(32))) @@ -569,7 +569,7 @@ def main(x: R.Tensor(("n", "c", "h", "w"), dtype="float32"), kernel: R.Tensor((" @T.prim_func(private=True) def conv2d_transpose(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() c = T.int64() h = T.int64() @@ -634,7 +634,7 @@ def main(x: R.Tensor((4, 112, 112, 6), "float32")) -> R.Tensor((4, 56, 56, 6), " @T.prim_func(private=True) def max_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(112), T.int64(112), T.int64(6)), "float32"), pool_max: T.Buffer((T.int64(4), T.int64(56), T.int64(56), T.int64(6)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) pad_temp = T.sblock_alloc_buffer([T.int64(4), T.int64(114), T.int64(114), T.int64(6)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(114), T.int64(114), T.int64(6)): with T.sblock("pad_temp"): @@ -675,7 +675,7 @@ def main(x: R.Tensor((4, 4, 112, 112, 16), "float32")) -> R.Tensor((4, 4, 110, 1 @T.prim_func(private=True) def max_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(4), T.int64(112), T.int64(112), T.int64(16)), "float32"), pool_max: T.Buffer((T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16), T.int64(3), T.int64(3)): with T.sblock("pool_max"): ax0, ax1, ax2, ax3, ax4, rv0, rv1 = T.axis.remap("SSSSSRR", [i0, i1, i2, i3, i4, i5, i6]) @@ -709,7 +709,7 @@ def main(x: R.Tensor((4, 6, 112, 112), dtype="float32")) -> R.Tensor((4, 6, 38, @T.prim_func(private=True) def max_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(6), T.int64(112), T.int64(112)), "float32"), pool_max: T.Buffer((T.int64(4), T.int64(6), T.int64(38), T.int64(38)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) pad_temp = T.sblock_alloc_buffer([T.int64(4), T.int64(6), T.int64(116), T.int64(116)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(6), T.int64(116), T.int64(116)): with T.sblock("pad_temp"): @@ -767,7 +767,7 @@ def main(x: R.Tensor((4, 112, 112, 6), "float32")) -> R.Tensor((4, 56, 56, 6), " class Expected: @T.prim_func(private=True) def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(112), T.int64(112), T.int64(6)), "float32"), pool_avg: T.Buffer((T.int64(4), T.int64(56), T.int64(56), T.int64(6)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): pad_temp = T.sblock_alloc_buffer((T.int64(4), T.int64(114), T.int64(114), T.int64(6))) pool_sum = T.sblock_alloc_buffer((T.int64(4), T.int64(56), T.int64(56), T.int64(6))) @@ -816,7 +816,7 @@ def main(x: R.Tensor((4, 4, 112, 112, 16), "float32")) -> R.Tensor((4, 4, 110, 1 class Expected: @T.prim_func(private=True) def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(4), T.int64(112), T.int64(112), T.int64(16)), "float32"), pool_avg: T.Buffer((T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): pool_sum = T.sblock_alloc_buffer((T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16))) for ax0, ax1, ax2, ax3, ax4, rv0, rv1 in T.grid(T.int64(4), T.int64(4), T.int64(110), T.int64(110), T.int64(16), T.int64(3), T.int64(3)): @@ -857,7 +857,7 @@ def main(x: R.Tensor((4, 6, 112, 112), "float32")) -> R.Tensor((4, 6, 38, 38), " class Expected: @T.prim_func(private=True) def avg_pool2d(rxplaceholder: T.Buffer((T.int64(4), T.int64(6), T.int64(112), T.int64(112)), "float32"), pool_avg: T.Buffer((T.int64(4), T.int64(6), T.int64(38), T.int64(38)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): pad_temp = T.sblock_alloc_buffer((T.int64(4), T.int64(6), T.int64(116), T.int64(116))) pool_sum = T.sblock_alloc_buffer((T.int64(4), T.int64(6), T.int64(38), T.int64(38))) @@ -934,7 +934,7 @@ def main(x: R.Tensor((2, 4, 7, 7, 16), "float32")) -> R.Tensor((2, 4, 1, 1, 16), @T.prim_func(private=True) def adaptive_avg_pool2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(7), T.int64(7), T.int64(16)), "float32"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) adaptive_pool_sum = T.sblock_alloc_buffer([T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16)], dtype="float32") for i0, i1, i2, i3, i4, i5, i6 in T.grid(T.int64(2), T.int64(4), T.int64(1), T.int64(1), T.int64(16), T.int64(7), T.int64(7)): with T.sblock("adaptive_pool_sum"): @@ -975,7 +975,7 @@ def main(x: R.Tensor((2, 16, 7, 7), "float32")) -> R.Tensor((2, 16, 7, 7), "floa @T.prim_func(private=True) def adaptive_avg_pool2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(16), T.int64(7), T.int64(7)), "float32"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(16), T.int64(7), T.int64(7)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) adaptive_pool_sum = T.sblock_alloc_buffer([T.int64(2), T.int64(16), T.int64(7), T.int64(7)], dtype="float32") for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(2), T.int64(16), T.int64(7), T.int64(7), T.int64(1), T.int64(1)): with T.sblock("adaptive_pool_sum"): @@ -1035,7 +1035,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def relu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("compute"): i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) @@ -1070,7 +1070,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): @T.prim_func(private=True) def relu(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int64() n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") @@ -1106,7 +1106,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def leaky_relu(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) @@ -1141,7 +1141,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): @T.prim_func(private=True) def leaky_relu(var_x: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m, n = T.int64(), T.int64() x = T.match_buffer(var_x, (m, n)) compute = T.match_buffer(var_compute, (m, n)) @@ -1175,7 +1175,7 @@ def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((1,), dtype="float32" @T.prim_func(private=True) def prelu(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), y: T.Buffer((T.int64(1),), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): slope_broadcasted = T.sblock_alloc_buffer((T.int64(3),)) for c in range(T.int64(3)): @@ -1216,7 +1216,7 @@ def main(x: R.Tensor(("m", 7), dtype="float32"), y: R.Tensor((1,), dtype="float3 @T.prim_func(private=True) def prelu(var_x: T.handle, y: T.Buffer((T.int64(1),), "float32"), var_compute: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int64() x = T.match_buffer(var_x, (m, T.int64(7))) compute = T.match_buffer(var_compute, (m, T.int64(7))) @@ -1258,7 +1258,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def gelu(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) T_multiply_1 = T.sblock_alloc_buffer((T.int64(2), T.int64(3))) compute = T.sblock_alloc_buffer((T.int64(2), T.int64(3))) T_multiply_2 = T.sblock_alloc_buffer((T.int64(2), T.int64(3))) @@ -1321,7 +1321,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): @T.prim_func(private=True) def gelu(var_x: T.handle, var_T_multiply: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m, n = T.int64(), T.int64() x = T.match_buffer(var_x, (m, n)) T_multiply = T.match_buffer(var_T_multiply, (m, n)) @@ -1383,7 +1383,7 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3 @T.prim_func(private=True) def gelu_tanh(A: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) T_multiply_1 = T.sblock_alloc_buffer((T.int64(2), T.int64(3))) T_multiply_2 = T.sblock_alloc_buffer((T.int64(2), T.int64(3))) T_multiply_3 = T.sblock_alloc_buffer((T.int64(2), T.int64(3))) @@ -1473,7 +1473,7 @@ def main(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m", "n"), dtype @T.prim_func(private=True) def gelu_tanh(var_A: T.handle, var_T_multiply: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m, n = T.int64(), T.int64() A = T.match_buffer(var_A, (m, n)) T_multiply = T.match_buffer(var_T_multiply, (m, n)) @@ -1564,7 +1564,7 @@ def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): @T.prim_func(private=True) def silu(rxplaceholder: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_multiply: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) compute = T.sblock_alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("compute"): @@ -1606,7 +1606,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): @T.prim_func(private=True) def silu(var_rxplaceholder: T.handle, var_T_multiply: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int64() n = T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, [m, n], dtype="float32") @@ -1648,7 +1648,7 @@ def main(x: R.Tensor((2, 3, 16, 32), "float32")) -> R.Tensor((2, 3, 16, 32), "fl @T.prim_func(private=True) def softmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"), T_softmax_norm: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) T_softmax_maxelem = T.sblock_alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") T_softmax_exp = T.sblock_alloc_buffer([T.int64(2), T.int64(3), T.int64(16), T.int64(32)], dtype="float32") T_softmax_expsum = T.sblock_alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") @@ -1711,7 +1711,7 @@ def main(x: R.Tensor(("a", "b", "c"), "float32")) -> R.Tensor(("a", "b", "c"), " @T.prim_func(private=True) def softmax(var_rxplaceholder: T.handle, var_T_softmax_norm: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -1773,7 +1773,7 @@ def main(x: R.Tensor((2, 3, 16, 32), dtype="float32")) -> R.Tensor((2, 3, 16, 32 @T.prim_func(private=True) def log_softmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"), compute: T.Buffer((T.int64(2), T.int64(3), T.int64(16), T.int64(32)), "float32"),): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) T_softmax_maxelem = T.sblock_alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") compute_1 = T.sblock_alloc_buffer([T.int64(2), T.int64(3), T.int64(32)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(32), T.int64(16)): @@ -1830,7 +1830,7 @@ def main(x: R.Tensor(("a", "b", "c"), dtype="float32")) -> R.Tensor(("a", "b", " @T.prim_func(private=True) def log_softmax(var_rxplaceholder: T.handle, var_compute: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -1885,7 +1885,7 @@ def main(x: R.Tensor((3,), dtype="float32"), y: R.Tensor((3,), dtype="float32")) @T.prim_func(private=True) def cross_entropy_with_logits(x: T.Buffer((T.int64(3),), "float32"), y: T.Buffer((T.int64(3),), "float32"), T_multiply: T.Buffer((), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) T_multiply_1 = T.sblock_alloc_buffer((T.int64(3),)) T_multiply_red = T.sblock_alloc_buffer(()) for ax0 in range(T.int64(3)): @@ -1931,7 +1931,7 @@ def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float3 @T.prim_func(private=True) def cross_entropy_with_logits(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), y: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_divide: T.Buffer((), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) T_multiply = T.sblock_alloc_buffer((T.int64(2), T.int64(3))) T_multiply_red = T.sblock_alloc_buffer(()) T_multiply_1 = T.sblock_alloc_buffer(()) @@ -1985,7 +1985,7 @@ def main(x: R.Tensor(("n", "m"), dtype="float32"), y: R.Tensor(("n", "m"), dtype @T.prim_func(private=True) def cross_entropy_with_logits(var_x: T.handle, var_y: T.handle, T_divide: T.Buffer((), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m, n = T.int64(), T.int64() x = T.match_buffer(var_x, (n, m)) y = T.match_buffer(var_y, (n, m)) @@ -2035,7 +2035,7 @@ def main(x: R.Tensor((2, 3, 28, 28), "float32"), gamma: R.Tensor((3,), "float32" class Expected: @T.prim_func(private=True) def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_moving_mean: T.handle, var_moving_var: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) x = T.match_buffer(var_x, (T.int64(2), T.int64(3), T.int64(28), T.int64(28))) gamma = T.match_buffer(var_gamma, (T.int64(3),)) beta = T.match_buffer(var_beta, (T.int64(3),)) @@ -2328,7 +2328,7 @@ def main(x: R.Tensor(("n", "h", "w", "c"), "float32"), gamma: R.Tensor(("c",), " class Expected: @T.prim_func(private=True) def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta: T.handle, var_moving_mean: T.handle, var_moving_var: T.handle, var_T_add: T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n, h, w, c = T.int64(), T.int64(), T.int64(), T.int64() x = T.match_buffer(var_x, (n, h, w, c)) gamma = T.match_buffer(var_gamma, (c,)) @@ -2626,7 +2626,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32"), gamma: R.Tensor((4, 5), "float32" @T.prim_func(private=True) def layer_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(5)), "float32"), rxplaceholder_2: T.Buffer((T.int64(4), T.int64(5)), "float32"), T_layer_norm: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) rxplaceholder_red_temp_v0 = T.sblock_alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") rxplaceholder_red_temp_v1 = T.sblock_alloc_buffer([T.int64(2), T.int64(3)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): @@ -2669,7 +2669,7 @@ def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight: R.Tensor((3,) class LayerNorm_1D_Expected: @T.prim_func(private=True) def layer_norm(x: T.Buffer((T.int64(3),), "float32"), layer_norm_weight: T.Buffer((T.int64(3),), "float32"), layer_norm_bias: T.Buffer((T.int64(3),), "float32"), T_layer_norm: T.Buffer((T.int64(3),), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): x_red_temp_v0 = T.sblock_alloc_buffer(()) x_red_temp_v1 = T.sblock_alloc_buffer(()) @@ -2719,7 +2719,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float16"), gamma: R.Tensor((4, 5), "float16" class Expected: @T.prim_func(private=True) def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_layer_norm: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16") rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(4), T.int64(5)), "float16") rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (T.int64(4), T.int64(5)), "float16") @@ -2793,7 +2793,7 @@ def main(x: R.Tensor(("n", "s", "f"), "float32"), gamma: R.Tensor(("s", "f"), "f @T.prim_func(private=True) def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_layer_norm: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) f = T.int64() n = T.int64() s = T.int64() @@ -2839,7 +2839,7 @@ def main(x: R.Tensor((2, 4, 4, 5), "float32"), gamma: R.Tensor((4,), "float32"), class Expected: @T.prim_func(private=True) def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4),), "float32"), rxplaceholder_2: T.Buffer((T.int64(4),), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) T_reshape_1 = T.sblock_alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5))) rxplaceholder_red_temp_v0 = T.sblock_alloc_buffer((T.int64(2), T.int64(2))) rxplaceholder_red_temp_v1 = T.sblock_alloc_buffer((T.int64(2), T.int64(2))) @@ -2916,7 +2916,7 @@ def main(x: R.Tensor((2, 4, 4, 5), dtype="float16"), gamma: R.Tensor((4,), dtype @T.prim_func(private=True) def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float16"), rxplaceholder_1: T.Buffer((T.int64(4),), "float16"), rxplaceholder_2: T.Buffer((T.int64(4),), "float16"), T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): T_reshape_1 = T.sblock_alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5)), "float16") T_cast = T.sblock_alloc_buffer((T.int64(2), T.int64(2), T.int64(2), T.int64(4), T.int64(5))) @@ -2996,7 +2996,7 @@ def main(s: R.Shape(["c"]), x: R.Tensor(("n", "4 * c", "h", "w"), "float32"), ga class Expected: @T.prim_func(private=True) def group_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_reshape: T.handle, c: T.int64): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() h = T.int64() w = T.int64() @@ -3080,7 +3080,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32"), weight: R.Tensor((4, 5), "float32 class Expected: @T.prim_func(private=True) def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), B: T.Buffer((T.int64(4), T.int64(5)), "float32"), T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): T_cast_1 = T.sblock_alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply = T.sblock_alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) @@ -3156,7 +3156,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float16"), weight: R.Tensor((4, 5), "float16 class Expected: @T.prim_func(private=True) def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16"), B: T.Buffer((T.int64(4), T.int64(5)), "float16"), T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): T_cast_1 = T.sblock_alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply = T.sblock_alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) @@ -3235,7 +3235,7 @@ def main(x: R.Tensor(("n", "s", "f"), "float32"), weight: R.Tensor(("s", "f"), " class Expected: @T.prim_func(private=True) def rms_norm(var_A: T.handle, var_B: T.handle, var_T_cast: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n, s, f = T.int64(), T.int64(), T.int64() A = T.match_buffer(var_A, (n, s, f)) B = T.match_buffer(var_B, (s, f)) @@ -3318,7 +3318,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32"), weight: R.Tensor((4, 5), "float32 class Expected: @T.prim_func(private=True) def rms_norm(A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), B: T.Buffer((T.int64(4), T.int64(5)), "float32"), T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): T_cast_1 = T.sblock_alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) T_multiply = T.sblock_alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) @@ -3395,7 +3395,7 @@ def main(q: R.Tensor((4, 16, 32, 8), "float32"), k: R.Tensor((4, 8, 32, 8), "flo class Expected: @T.prim_func(private=True) def attention_bias(q: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(8)), "float32"), k: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(8)), "float32"), v: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(16)), "float32"), bias: T.Buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8)), "float32"), T_transpose: T.Buffer((T.int64(4), T.int64(16), T.int64(32), T.int64(16)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): T_transpose_1 = T.sblock_alloc_buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8))) T_reshape = T.sblock_alloc_buffer((T.int64(128), T.int64(16), T.int64(8))) @@ -3620,7 +3620,7 @@ def nll_loss( output: T.Buffer((), "float32"), ): # function attr dict - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # body # with T.sblock("root") nll_loss = T.sblock_alloc_buffer([T.int64(2), T.int64(4), T.int64(5)], dtype="float32") @@ -3685,7 +3685,7 @@ def main(predictions: R.Tensor((2, 3, 4, 5), dtype="float32"), targets: R.Tensor @T.prim_func(private=True) def nll_loss_without_weight(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(4), T.int64(5)), "int64"), T_divide: T.Buffer((), "float32"),): # function attr dict - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # body # with T.sblock("root") T_full = T.sblock_alloc_buffer([T.int64(3)], dtype="float32") @@ -3757,7 +3757,7 @@ def main(predictions: R.Tensor(("C",), dtype="float32"), targets: R.Tensor((), d @T.prim_func(private=True) def nll_loss(var_rxplaceholder: T.handle, rxplaceholder: T.Buffer((), "int64"), var_rxplaceholder_1: T.handle, T_divide: T.Buffer((), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) C = T.int64() rxplaceholder_1 = T.match_buffer(var_rxplaceholder, (C,)) rxplaceholder_2 = T.match_buffer(var_rxplaceholder_1, (C,)) @@ -3805,7 +3805,7 @@ def main(predictions: R.Tensor(("N", "C", "d1", "d2"), dtype="float32"), targets @T.prim_func(private=True) def nll_loss(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, T_divide: T.Buffer((), "float32"),): # function attr dict - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) C = T.int64() N = T.int64() d1 = T.int64() @@ -3879,7 +3879,7 @@ def pad( A: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"), PadInput: T.Buffer((T.int64(2), T.int64(130), T.int64(30)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2 in T.grid(T.int64(2), T.int64(130), T.int64(30)): with T.sblock("PadInput"): @@ -3917,7 +3917,7 @@ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((2, 60), dtype= @T.prim_func(private=True) def reshape(x: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(60)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(2), T.int64(60)): with T.sblock("T_reshape"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) diff --git a/tests/python/relax/test_transform_legalize_ops_qdq.py b/tests/python/relax/test_transform_legalize_ops_qdq.py index ecfded76a4c1..251d7db8c981 100644 --- a/tests/python/relax/test_transform_legalize_ops_qdq.py +++ b/tests/python/relax/test_transform_legalize_ops_qdq.py @@ -19,7 +19,7 @@ import tvm.testing from tvm.relax.transform import LegalizeOps from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_quantize_fp32_to_int8(): @@ -43,7 +43,7 @@ def quantize( C: T.Buffer((T.int64(2),), "int8"), quantized: T.Buffer((T.int64(2), T.int64(4)), "int8"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): with T.sblock("quantized"): @@ -97,7 +97,7 @@ def quantize( C: T.Buffer((T.int64(2),), "int8"), quantized: T.Buffer((T.int64(2), T.int64(4)), "uint8"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): with T.sblock("quantized"): @@ -146,7 +146,7 @@ def main( class Expected: @T.prim_func(private=True) def quantize(var_A: T.handle, var_B: T.handle, var_C: T.handle, var_quantized: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() A = T.match_buffer(var_A, (T.int64(4), n)) B = T.match_buffer(var_B, (n,)) @@ -202,7 +202,7 @@ def quantize( A: T.Buffer((T.int64(2), T.int64(4)), "float32"), quantized: T.Buffer((T.int64(2), T.int64(4)), "int8"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): with T.sblock("quantized"): @@ -252,7 +252,7 @@ def quantize( C: T.Buffer((T.int64(2),), "int8"), quantized: T.Buffer((T.int64(2), T.int64(4)), "int8"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): with T.sblock("quantized"): @@ -301,7 +301,7 @@ def quantize( A: T.Buffer((T.int64(2), T.int64(4)), "float16"), quantized: T.Buffer((T.int64(2), T.int64(4)), "int8"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): with T.sblock("quantized"): @@ -349,7 +349,7 @@ def dequantize( C: T.Buffer((T.int64(2),), "int8"), dequantized: T.Buffer((T.int64(2), T.int64(4)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): with T.sblock("dequantized"): @@ -393,7 +393,7 @@ def dequantize( A: T.Buffer((T.int64(2), T.int64(4)), "int8"), dequantized: T.Buffer((T.int64(2), T.int64(4)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): with T.sblock("dequantized"): @@ -432,7 +432,7 @@ class Expected: def dequantize( var_A: T.handle, var_B: T.handle, var_C: T.handle, var_dequantized: T.handle ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() A = T.match_buffer(var_A, (T.int64(2), n), "int8") B = T.match_buffer(var_B, (n,)) @@ -486,7 +486,7 @@ def dequantize( C: T.Buffer((T.int64(2),), "int8"), dequantized: T.Buffer((T.int64(2), T.int64(4)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): with T.sblock("dequantized"): @@ -540,7 +540,7 @@ def dequantize( A: T.Buffer((T.int64(2), T.int64(4)), "int8"), dequantized: T.Buffer((T.int64(2), T.int64(4)), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i0, i1 in T.grid(T.int64(2), T.int64(4)): with T.sblock("dequantized"): diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py index 8165cd715a9f..304227d30d82 100644 --- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -21,7 +21,7 @@ from tvm.relax.transform import LegalizeOps from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T ##################### Search ##################### @@ -44,7 +44,7 @@ def main(condition: R.Tensor((3, 2, 1), "bool"), x: R.Tensor((2, 3), "float32"), @T.prim_func(private=True) def where(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(1)), "bool"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3)), "float32"), rxplaceholder_2: T.Buffer((T.int64(2), T.int64(1)), "float32"), T_where: T.Buffer((T.int64(3), T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2 in T.grid(T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_where"): ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2]) @@ -81,7 +81,7 @@ def main(condition: R.Tensor(("a", "b", 1), "bool"), x: R.Tensor(("b", "c"), "fl @T.prim_func(private=True) def where(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_rxplaceholder_2: T.handle, var_T_where: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -119,7 +119,7 @@ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((2, 4, 5), dtyp @T.prim_func(private=True) def argmax(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((T.int64(2), T.int64(4), T.int64(5)), "int64")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) rxplaceholder_red_temp_v0 = T.sblock_alloc_buffer((T.int64(2), T.int64(4), T.int64(5)), "int64") rxplaceholder_red_temp_v1 = T.sblock_alloc_buffer((T.int64(2), T.int64(4), T.int64(5))) for ax0, ax1, ax2, k1 in T.grid(T.int64(2), T.int64(4), T.int64(5), T.int64(3)): @@ -170,7 +170,7 @@ def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) -> R.Tensor(("a", 1 @T.prim_func(private=True) def argmax(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -217,7 +217,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((), "int64"): class Expected: @T.prim_func(private=True) def argmin(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((), "int64")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) rxplaceholder_red_temp_v0 = T.sblock_alloc_buffer((), "int64") rxplaceholder_red_temp_v1 = T.sblock_alloc_buffer(()) for k0, k1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): @@ -261,7 +261,7 @@ def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, 1, 1, 1), class Expected: @T.prim_func(private=True) def argmin(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1)), "int64")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -319,7 +319,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 5), "float32"): @T.prim_func(private=True) def max(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((T.int64(2), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(5), T.int64(3), T.int64(4)): with T.sblock("rxplaceholder_red"): ax0, ax1, k1, k2 = T.axis.remap("SSRR", [i0, i1, i2, i3]) @@ -356,7 +356,7 @@ def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", "d"), " @T.prim_func(private=True) def max(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -395,7 +395,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 1, 1, 5), "float3 @T.prim_func(private=True) def min(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((T.int64(2), T.int64(1), T.int64(1), T.int64(5)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3, i4, i5 in T.grid(T.int64(2), T.int64(1), T.int64(1), T.int64(5), T.int64(3), T.int64(4)): with T.sblock("rxplaceholder_red"): ax0, ax1, ax2, ax3, k1, k2 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) @@ -432,7 +432,7 @@ def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor(("a", 1, 1, " @T.prim_func(private=True) def min(var_rxplaceholder: T.handle, var_rxplaceholder_red: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -471,7 +471,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((), "float32"): @T.prim_func(private=True) def sum(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.sblock("rxplaceholder_red"): k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3]) @@ -504,7 +504,7 @@ def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((), "float32" @T.prim_func(private=True) def sum(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -542,7 +542,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((1, 1, 1, 1), "float3 @T.prim_func(private=True) def prod(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), rxplaceholder_red: T.Buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(3), T.int64(4), T.int64(5)): with T.sblock("rxplaceholder_red"): ax0, ax1, ax2, ax3, k0, k1, k2, k3 = T.axis.remap("SSSSRRRR", [i0, i1, i2, i3, i4, i5, i6, i7]) @@ -575,7 +575,7 @@ def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, 1, 1, 1), @T.prim_func(private=True) def prod(var_rxplaceholder: T.handle, rxplaceholder_red: T.Buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -613,7 +613,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((3, 4), "float32"): @T.prim_func(private=True) def mean(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), T_divide: T.Buffer((T.int64(3), T.int64(4)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) rxplaceholder_red = T.sblock_alloc_buffer([T.int64(3), T.int64(4)], dtype="float32") for i0, i1, i2, i3 in T.grid(T.int64(3), T.int64(4), T.int64(2), T.int64(5)): with T.sblock("rxplaceholder_red"): @@ -657,7 +657,7 @@ def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) -> R.Tensor(("b", " @T.prim_func(private=True) def mean(var_rxplaceholder: T.handle, var_T_divide: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -703,7 +703,7 @@ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tuple(R.Tensor((3, 4, @T.prim_func(private=True) def median(var_x: T.handle, T_squeeze: T.Buffer((T.int64(3), T.int64(4), T.int64(5)), "float32"), T_squeeze_1: T.Buffer((T.int64(3), T.int64(4), T.int64(5)), "int64")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) data_buf = T.match_buffer(var_x, (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), align=8) # with T.sblock("root"): T_full = T.sblock_alloc_buffer((T.int64(1), T.int64(3), T.int64(4), T.int64(5)), "int64") @@ -776,7 +776,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((), "float32"): class Expected: @T.prim_func(private=True) def std(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), compute: T.Buffer((), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): rxplaceholder_red = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1))) T_divide = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(1), T.int64(1))) @@ -853,7 +853,7 @@ def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((), "float32" class Expected: @T.prim_func(private=True) def std(var_rxplaceholder: T.handle, compute: T.Buffer((), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a, b, c, d = T.int64(), T.int64(), T.int64(), T.int64() rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c, d)) # with T.sblock("root"): @@ -941,7 +941,7 @@ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((1, 3, 4, 1), d @T.prim_func(private=True) def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), T_divide: T.Buffer((T.int64(1), T.int64(3), T.int64(4), T.int64(1)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) rxplaceholder_red = T.sblock_alloc_buffer([T.int64(1), T.int64(3), T.int64(4), T.int64(1)], dtype="float32") T_divide_1 = T.sblock_alloc_buffer([T.int64(1), T.int64(3), T.int64(4), T.int64(1)], dtype="float32") T_subtract = T.sblock_alloc_buffer([T.int64(2), T.int64(3), T.int64(4), T.int64(5)], dtype="float32") @@ -1015,7 +1015,7 @@ def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((1, "b", "c", @T.prim_func(private=True) def variance(var_rxplaceholder: T.handle, var_T_divide: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) a = T.int64() b = T.int64() c = T.int64() @@ -1086,7 +1086,7 @@ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((3, 4), "float32"): class Expected: @T.prim_func(private=True) def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"), T_divide: T.Buffer((T.int64(3), T.int64(4)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): rxplaceholder_red = T.sblock_alloc_buffer((T.int64(1), T.int64(3), T.int64(4), T.int64(1))) T_divide_1 = T.sblock_alloc_buffer((T.int64(1), T.int64(3), T.int64(4), T.int64(1))) diff --git a/tests/python/relax/test_transform_legalize_ops_unary.py b/tests/python/relax/test_transform_legalize_ops_unary.py index 55fa34ab2204..e32f95a3c4d2 100644 --- a/tests/python/relax/test_transform_legalize_ops_unary.py +++ b/tests/python/relax/test_transform_legalize_ops_unary.py @@ -27,7 +27,7 @@ from tvm.relax.transform import LegalizeOps from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def _test_static_shape(name: str, relax_op: Callable, te_func: Callable, dtype: str): diff --git a/tests/python/relax/test_transform_lift_transform_params.py b/tests/python/relax/test_transform_lift_transform_params.py index 4d99912aeac9..48b3b6357dcb 100644 --- a/tests/python/relax/test_transform_lift_transform_params.py +++ b/tests/python/relax/test_transform_lift_transform_params.py @@ -25,7 +25,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @pytest.mark.parametrize("consume_params", [True, False]) @@ -1438,7 +1438,7 @@ def test_symbolic_var_2(): class Before: @T.prim_func def zeros(var_T_full: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() T_full = T.match_buffer(var_T_full, (n, n)) for ax0, ax1 in T.grid(n, n): @@ -1464,7 +1464,7 @@ def main(shape: R.Shape(["n"])) -> R.Shape(["n"]): class Expected: @T.prim_func def zeros(var_T_full: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() T_full = T.match_buffer(var_T_full, (n, n)) # with T.sblock("root"): @@ -1532,7 +1532,7 @@ def slice( Output_Slice: T.Buffer(shape=[16], dtype="int32"), slice_index: T.int64, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for j in range(16): with T.sblock("T_full"): vj = T.axis.remap("S", [j]) @@ -1586,7 +1586,7 @@ def slice( Output_Slice: T.Buffer(shape=[16], dtype="int32"), slice_index: T.int64, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for j in range(16): with T.sblock("T_full"): vj = T.axis.remap("S", [j]) diff --git a/tests/python/relax/test_transform_lower_gpu_ipc_alloc_storage.py b/tests/python/relax/test_transform_lower_gpu_ipc_alloc_storage.py index 16cfed0f79bd..379d52f262e5 100644 --- a/tests/python/relax/test_transform_lower_gpu_ipc_alloc_storage.py +++ b/tests/python/relax/test_transform_lower_gpu_ipc_alloc_storage.py @@ -20,7 +20,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_alloc_storage(): diff --git a/tests/python/relax/test_transform_merge_composite_functions.py b/tests/python/relax/test_transform_merge_composite_functions.py index bdf6efebf3ca..b896244ec9ed 100644 --- a/tests/python/relax/test_transform_merge_composite_functions.py +++ b/tests/python/relax/test_transform_merge_composite_functions.py @@ -20,7 +20,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.script.ir_module @@ -1140,7 +1140,7 @@ def relu( Input: T.Buffer(T.int64(10), "float32"), Output: T.Buffer(T.int64(10), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in range(T.int64(10)): with T.sblock("compute"): vi = T.axis.remap("S", [i]) @@ -1192,7 +1192,7 @@ def relu( Input: T.Buffer(T.int64(10), "float32"), Output: T.Buffer(T.int64(10), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i in range(T.int64(10)): with T.sblock("compute"): vi = T.axis.remap("S", [i]) diff --git a/tests/python/relax/test_transform_meta_schedule_apply_database.py b/tests/python/relax/test_transform_meta_schedule_apply_database.py index 857a09b58746..dd34726cf20d 100644 --- a/tests/python/relax/test_transform_meta_schedule_apply_database.py +++ b/tests/python/relax/test_transform_meta_schedule_apply_database.py @@ -18,10 +18,10 @@ import tvm import tvm.testing -from tvm import relax, tir +from tvm import relax, tirx from tvm.s_tir import meta_schedule as ms from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T target = tvm.target.Target({"kind": "llvm", "num-cores": 16}) @@ -31,7 +31,7 @@ def test_apply_to_func_with_different_block_name(): class RecordModule: @T.prim_func def main(A: T.Buffer((2,), "float32"), B: T.Buffer((2,), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i in T.serial(2): with T.sblock("block"): vi = T.axis.spatial(2, i) @@ -41,7 +41,7 @@ def main(A: T.Buffer((2,), "float32"), B: T.Buffer((2,), "float32")): class BlockRenamedModule: @T.prim_func def main(A: T.Buffer((2,), "float32"), B: T.Buffer((2,), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i in T.serial(2): with T.sblock("renamed_block"): vi = T.axis.spatial(2, i) @@ -53,9 +53,9 @@ class Expected: def main(A: T.Buffer((2,), "float32"), B: T.Buffer((2,), "float32")): T.func_attr( { - "tir.is_scheduled": True, + "tirx.is_scheduled": True, "global_symbol": "main", - "tir.noalias": True, + "tirx.noalias": True, } ) for i in T.serial(2): diff --git a/tests/python/relax/test_transform_meta_schedule_tuning.py b/tests/python/relax/test_transform_meta_schedule_tuning.py index 7e36260ab02d..a9baae65ed76 100644 --- a/tests/python/relax/test_transform_meta_schedule_tuning.py +++ b/tests/python/relax/test_transform_meta_schedule_tuning.py @@ -1,4 +1,20 @@ # Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file @@ -26,7 +42,7 @@ from tvm.ir.module import IRModule from tvm.ir.transform import PassContext from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T target = tvm.target.Target({"kind": "llvm", "num-cores": 16}) @@ -156,7 +172,7 @@ def tir_matmul( B: T.Buffer((32, 32), "float32"), C: T.Buffer((32, 32), "float32"), ): - T.func_attr({"global_symbol": "tir_matmul", "tir.is_scheduled": True}) + T.func_attr({"global_symbol": "tir_matmul", "tirx.is_scheduled": True}) # with T.sblock("root"): for i0_j0_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i0_j0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): @@ -173,7 +189,7 @@ def tir_matmul( @T.prim_func def tir_relu(A: T.Buffer((32, 32), "float32"), B: T.Buffer((32, 32), "float32")): - T.func_attr({"global_symbol": "tir_relu", "tir.is_scheduled": True}) + T.func_attr({"global_symbol": "tir_relu", "tirx.is_scheduled": True}) # with T.sblock("root"): for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): diff --git a/tests/python/relax/test_transform_normalize.py b/tests/python/relax/test_transform_normalize.py index 0fb390826977..468d0c3381a4 100644 --- a/tests/python/relax/test_transform_normalize.py +++ b/tests/python/relax/test_transform_normalize.py @@ -20,15 +20,15 @@ import tvm import tvm.script import tvm.testing -from tvm import relax, tir +from tvm import relax, tirx from tvm.ir.base import assert_structural_equal from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_normalize_function(): - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x = relax.Var("x", R.Tensor([m, n], "float16")) # Note: the parser automatically normalize the IR written in TVMScript, diff --git a/tests/python/relax/test_transform_normalize_global_var.py b/tests/python/relax/test_transform_normalize_global_var.py index 34f6c3acd6c4..71c1832bf03f 100644 --- a/tests/python/relax/test_transform_normalize_global_var.py +++ b/tests/python/relax/test_transform_normalize_global_var.py @@ -20,11 +20,11 @@ import tvm import tvm.script import tvm.testing -from tvm import relax, tir +from tvm import relax, tirx from tvm.ir.base import assert_structural_equal from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @pytest.mark.skip_well_formed_check_before_transform diff --git a/tests/python/relax/test_transform_operator_specific_normalization.py b/tests/python/relax/test_transform_operator_specific_normalization.py index 2b5d9d4ee221..9ceb9c424b79 100644 --- a/tests/python/relax/test_transform_operator_specific_normalization.py +++ b/tests/python/relax/test_transform_operator_specific_normalization.py @@ -26,7 +26,7 @@ from tvm import relax from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T define_normalization = tvm.testing.parameter(True) diff --git a/tests/python/relax/test_transform_realize_vdevice.py b/tests/python/relax/test_transform_realize_vdevice.py index ba90213e3524..fbf4a8a26af8 100644 --- a/tests/python/relax/test_transform_realize_vdevice.py +++ b/tests/python/relax/test_transform_realize_vdevice.py @@ -23,7 +23,7 @@ from tvm.relax.transform import RealizeVDevice from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def verify(input, expected): diff --git a/tests/python/relax/test_transform_remove_unused_outputs.py b/tests/python/relax/test_transform_remove_unused_outputs.py index ab3ea89a94f2..1a2f305bfebf 100644 --- a/tests/python/relax/test_transform_remove_unused_outputs.py +++ b/tests/python/relax/test_transform_remove_unused_outputs.py @@ -20,7 +20,7 @@ import tvm.testing from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_simple(): diff --git a/tests/python/relax/test_transform_remove_unused_parameters.py b/tests/python/relax/test_transform_remove_unused_parameters.py index 52f4f1cbf0ab..1be2a5ac2f9d 100644 --- a/tests/python/relax/test_transform_remove_unused_parameters.py +++ b/tests/python/relax/test_transform_remove_unused_parameters.py @@ -19,7 +19,7 @@ import tvm.testing from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_remove_unused_relax_parameter(): diff --git a/tests/python/relax/test_transform_reorder_take_after_matmul.py b/tests/python/relax/test_transform_reorder_take_after_matmul.py index da36c39ffdfd..8e7243d02ee0 100644 --- a/tests/python/relax/test_transform_reorder_take_after_matmul.py +++ b/tests/python/relax/test_transform_reorder_take_after_matmul.py @@ -24,7 +24,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T class Base: diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py b/tests/python/relax/test_transform_rewrite_cuda_graph.py index fd022a828e17..3e4759eeb3f2 100644 --- a/tests/python/relax/test_transform_rewrite_cuda_graph.py +++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py @@ -23,7 +23,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @pytest.fixture(autouse=True) @@ -40,7 +40,7 @@ class Before: @T.prim_func def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")): # function attr dict - T.func_attr({"tir.noalias": True, "global_symbol": "exp"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "exp"}) for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): with T.sblock("compute"): @@ -82,7 +82,7 @@ class Expected: @T.prim_func def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")): # function attr dict - T.func_attr({"tir.noalias": True, "global_symbol": "exp"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "exp"}) # body # with T.sblock("root") for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): @@ -152,7 +152,7 @@ class Before: @T.prim_func def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")): # function attr dict - T.func_attr({"tir.noalias": True, "global_symbol": "exp"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "exp"}) # body # with T.sblock("root") for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): @@ -194,7 +194,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4), dtype="float3 class Expected: @T.prim_func def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")): - T.func_attr({"global_symbol": "exp", "tir.noalias": True}) + T.func_attr({"global_symbol": "exp", "tirx.noalias": True}) # with T.sblock("root"): for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): @@ -260,7 +260,7 @@ class Before: @T.prim_func def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")): # function attr dict - T.func_attr({"tir.noalias": True, "global_symbol": "exp"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "exp"}) for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): with T.sblock("compute"): @@ -295,7 +295,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2,4), dtype="float32 class Expected: @T.prim_func def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")): - T.func_attr({"global_symbol": "exp", "tir.noalias": True}) + T.func_attr({"global_symbol": "exp", "tirx.noalias": True}) # with T.sblock("root"): for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): @@ -400,7 +400,7 @@ def fused_conv2d_relu( (T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16" ), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): pad_temp = T.sblock_alloc_buffer( (T.int64(16), T.int64(34), T.int64(34), T.int64(16)), "float16" @@ -462,7 +462,7 @@ def layer_norm( C: T.Buffer((T.int64(16),), "float16"), T_layer_norm: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16"), ): - T.func_attr({"op_pattern": 4, "tir.noalias": True}) + T.func_attr({"op_pattern": 4, "tirx.noalias": True}) # with T.sblock("root"): A_red_temp_v0 = T.sblock_alloc_buffer((T.int64(16), T.int64(32), T.int64(32))) A_red_temp_v1 = T.sblock_alloc_buffer((T.int64(16), T.int64(32), T.int64(32))) diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py index 7d77796f7465..7b6299991916 100644 --- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py +++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py @@ -21,7 +21,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_reshape_expand_dims(): @@ -229,7 +229,7 @@ def test_reshape_dynamic_shape(): class Module: @T.prim_func(private=True) def reshape(var_A: T.handle, var_T_reshape: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) n = T.int32() A = T.match_buffer(var_A, (n, 16, 128), "float16") T_reshape = T.match_buffer(var_T_reshape, (1, n, 16, 128), "float16") @@ -271,7 +271,7 @@ def main(x: R.Tensor((8, 16, 128), dtype="float16")) -> R.Tensor( class Expected: @T.prim_func(private=True) def reshape(var_A: T.handle, var_T_reshape: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) n = T.int32() A = T.match_buffer(var_A, (n, 16, 128), "float16") T_reshape = T.match_buffer(var_T_reshape, (1, n, 16, 128), "float16") @@ -360,7 +360,7 @@ def fused_reshape5( (T.int64(2), T.int64(4096), T.int64(8), T.int64(40)), "float16" ), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4096), T.int64(8), T.int64(40)): with T.sblock("T_reshape"): @@ -421,7 +421,7 @@ def fused_reshape5( (T.int64(2), T.int64(4096), T.int64(8), T.int64(40)), "float16" ), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4096), T.int64(8), T.int64(40)): with T.sblock("T_reshape"): @@ -483,7 +483,7 @@ def strided_slice( A: T.Buffer((T.int64(1), T.int64(1024)), "int32"), T_strided_slice: T.Buffer((T.int64(1), T.int64(1000)), "int32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(1), T.int64(1000)): with T.sblock("T_strided_slice"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -557,7 +557,7 @@ def add( B: T.Buffer((T.int64(1),), "float32"), T_add: T.Buffer((T.int64(1),), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0 in range(T.int64(1)): with T.sblock("T_add"): @@ -568,7 +568,7 @@ def add( @T.prim_func(private=True) def reshape(A: T.Buffer((), "float32"), T_reshape: T.Buffer((T.int64(1),), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for ax0 in range(T.int64(1)): with T.sblock("T_reshape"): @@ -621,7 +621,7 @@ def add( y2: T.Buffer((T.int64(64), T.int64(4)), "float32"), z: T.Buffer((T.int64(64), T.int64(4)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for iters in T.grid(T.int64(64), T.int64(4)): with T.sblock("T_add"): @@ -686,7 +686,7 @@ def add( # y2 = T.match_buffer(y2_handle, [N // 4, 4], "float32") # z = T.match_buffer(z_handle, [N // 4, 4], "float32") -# T.func_attr({"tir.noalias": True}) +# T.func_attr({"tirx.noalias": True}) # for iters in T.grid(T.int64(64), T.int64(4)): # with T.sblock("T_add"): @@ -750,7 +750,7 @@ def add( y2 = T.match_buffer(y2_handle, [N * 4, T.int64(4)], "float32") z = T.match_buffer(z_handle, [N * 4, T.int64(4)], "float32") - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for iters in T.grid(N * 4, T.int64(4)): with T.sblock("T_add"): diff --git a/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py b/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py index 68f33e6c419e..995a2a3dc951 100644 --- a/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py +++ b/tests/python/relax/test_transform_specialize_primfunc_based_on_callsite.py @@ -24,7 +24,7 @@ from tvm.relax.transform.legalize_ops import adreno as legalize_adreno from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T @visitor diff --git a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py index 46a5f361e959..d3222c7d6683 100644 --- a/tests/python/relax/test_transform_split_layout_rewrite_preproc.py +++ b/tests/python/relax/test_transform_split_layout_rewrite_preproc.py @@ -19,7 +19,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_single_buffer(): @@ -225,7 +225,7 @@ def tir_func( W: T.Buffer((224, 224), "float32"), Out: T.Buffer((224, 224), "float32"), ): - T.func_attr({"layout_free_buffers": [1], "tir.noalias": True}) + T.func_attr({"layout_free_buffers": [1], "tirx.noalias": True}) W_rewrite = T.sblock_alloc_buffer((4, 4, 56, 56)) for i, j in T.grid(224, 224): with T.sblock("W_rewrite"): @@ -260,7 +260,7 @@ def tir_func_prepacked( W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), Out: T.Buffer((224, 224), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, j0, i1, j1 in T.grid(4, 4, 56, 56): with T.sblock("Out"): vi = T.axis.spatial(224, i0 * 56 + i1) @@ -272,7 +272,7 @@ def tir_func_weight_prepack( W: T.Buffer((224, 224), "float32"), W_rewrite: T.Buffer((4, 4, 56, 56), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j in T.grid(224, 224): with T.sblock("W_rewrite"): vi, vj = T.axis.remap("SS", [i, j]) diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index d0e7c966494c..e6d6a7071b30 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -23,7 +23,7 @@ from tvm import TVMError, relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_basic(): diff --git a/tests/python/relax/test_transform_to_mixed_precision.py b/tests/python/relax/test_transform_to_mixed_precision.py index 3d30a92695d8..204d06bf9454 100644 --- a/tests/python/relax/test_transform_to_mixed_precision.py +++ b/tests/python/relax/test_transform_to_mixed_precision.py @@ -23,7 +23,7 @@ from tvm.relax.transform import ToMixedPrecision from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def _assert_test(input, expected=None, expected2=None): diff --git a/tests/python/relax/test_transform_update_vdevice.py b/tests/python/relax/test_transform_update_vdevice.py index b240536fc1e1..618a2861159b 100644 --- a/tests/python/relax/test_transform_update_vdevice.py +++ b/tests/python/relax/test_transform_update_vdevice.py @@ -22,7 +22,7 @@ from tvm.relax.transform import UpdateVDevice from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def verify(input, new_vdevice, vdevice_index, expected): diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py index 0f92534eaa32..ade2378b7937 100644 --- a/tests/python/relax/test_tvmscript_ir_builder.py +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -17,7 +17,7 @@ # ruff: noqa: F401 import tvm import tvm.testing -from tvm import relax, tir, topi +from tvm import relax, tirx, topi from tvm.script.ir_builder import relax as R from tvm.script.ir_builder.base import IRBuilder @@ -90,8 +90,8 @@ def foo(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")) -> R.Shape(n R.func_name("foo") x = R.arg("x", relax.TensorStructInfo(ndim=-1, dtype="float32")) y = R.arg("y", relax.TensorStructInfo(ndim=-1, dtype="float32")) - m = tir.Var("m", dtype="int64") - n = tir.Var("n", dtype="int64") + m = tirx.Var("m", dtype="int64") + n = tirx.Var("n", dtype="int64") _ = R.emit_match_cast(x, relax.TensorStructInfo((m,), "float32")) y1 = R.emit_match_cast(y, relax.TensorStructInfo((n,), "float32")) v = relax.Var("v", relax.TensorStructInfo((n,), "float32")) @@ -104,8 +104,8 @@ def foo(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")) -> R.Shape(n func = ir_builder.get() # create with BlockBuilder - m = tir.Var("m", dtype="int64") - n = tir.Var("n", dtype="int64") + m = tirx.Var("m", dtype="int64") + n = tirx.Var("n", dtype="int64") x = relax.Var("x", relax.TensorStructInfo(dtype="float32", ndim=-1)) y = relax.Var("y", relax.TensorStructInfo(dtype="float32", ndim=-1)) v = relax.Var("v", relax.TensorStructInfo((n,), "float32")) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 720b8290e385..4716c64f0401 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -23,11 +23,11 @@ import tvm import tvm.script import tvm.testing -from tvm import IRModule, relax, tir, topi +from tvm import IRModule, relax, tirx, topi from tvm.ir import DummyGlobalInfo, VDevice from tvm.script.parser import ir as I from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def _check( @@ -115,7 +115,7 @@ def test_unexpected_tir_cast_args(): @R.function def f(x: R.Tensor(("m",), "float32")): m = T.int64() - # tir.cast expects 2 arguments, but got 3 + # tirx.cast expects 2 arguments, but got 3 return R.call_tir("foo", (x,), R.Tensor((T.cast("int32", m, 1),), dtype="float32")) @@ -135,7 +135,7 @@ def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) - @R.function def foo(x: R.Tensor(("m", "m"), "float32")): m = T.int64() - # tir.max expects 2 arguments, but got 1 + # tirx.max expects 2 arguments, but got 1 gv = R.call_tir(tir_addone, (x,), R.Tensor((T.max(16),), dtype="float32")) return gv @@ -144,7 +144,7 @@ def foo(x: R.Tensor(("m", "m"), "float32")): @R.function def f(x: R.Tensor(("m", "n"), "float32")): m = T.int64() - # call_tir expected a tir prim_func + # call_tir expected a tirx prim_func return relax.call_tir("extern_func", (x,), R.Tensor((T.max(m),), dtype="float32")) @@ -198,7 +198,7 @@ def tir_func( x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) @@ -227,7 +227,7 @@ def plus_one( x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): - T.func_attr({"some_attr": "foo", "another_attr": True, "tir.noalias": True}) + T.func_attr({"some_attr": "foo", "another_attr": True, "tirx.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) @@ -289,7 +289,7 @@ def tir_func( x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) @@ -339,7 +339,7 @@ def tir_func( x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i, j in T.grid(T.int64(128), T.int64(128)): with T.sblock(): vi, vj = T.axis.remap("SS", [i, j]) @@ -436,7 +436,7 @@ def mismatch_dtype(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float3 return gv0 def _expected(name: str): - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = relax.Var("x", R.Tensor([m, n], "float32")) bb = relax.BlockBuilder() with bb.function(name, (x,)): @@ -489,8 +489,8 @@ def foo(x: R.Tensor("float32"), y: R.Tensor("float32")): x = relax.Var("x", R.Tensor("float32")) y = relax.Var("y", R.Tensor("float32")) - m = tir.Var("m", dtype="int64") - n = tir.Var("n", dtype="int64") + m = tirx.Var("m", dtype="int64") + n = tirx.Var("n", dtype="int64") y2 = relax.Var("y", R.Tensor([n], "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x, y)): @@ -528,7 +528,7 @@ def foo(x: R.Tensor("float32", ndim=2)): return (x0, R.shape([n + 1, m, 1])) x = relax.Var("x", R.Tensor("float32", ndim=2)) - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") bb = relax.BlockBuilder() with bb.function("foo", (x,)): x0 = bb.match_cast(x, R.Tensor((n, m), "float32")) @@ -547,7 +547,7 @@ def foo(x: R.Tensor("float32", ndim=2)): return t1 x = relax.Var("x", R.Tensor("float32", ndim=2)) - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") bb = relax.BlockBuilder() with bb.function("foo", (x,)): x0 = bb.match_cast(x, R.Tensor((n, m), "float32")) @@ -631,8 +631,8 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) x = relax.Var("x", R.Tensor((128, 128), "float32")) bb = relax.BlockBuilder() - m = tir.Var("m", dtype="int64") - n = tir.Var("n", dtype="int64") + m = tirx.Var("m", dtype="int64") + n = tirx.Var("n", dtype="int64") with bb.function("foo", (x,)): gv0 = bb.emit( relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) @@ -1027,7 +1027,7 @@ def copy( out1: T.Buffer((2, 3), "int32"), ): # copies the contents of B into A and out1 - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -1078,7 +1078,7 @@ def copy( out1: T.Buffer((2, 3), "int32"), ): # copies the contents of B into A and out1 - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for iters in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_zeros"): i, j = T.axis.remap("SS", iters) @@ -1560,7 +1560,7 @@ def foo(x: R.Tensor(("m + 1",), "float32"), y: R.Tensor(("m", 1), "float32")): z = R.add(x, y) return z - m = tir.Var("m", "int64") + m = tirx.Var("m", "int64") x = relax.Var("x", relax.TensorStructInfo([m + 1], "float32")) y = relax.Var("y", relax.TensorStructInfo([m, 1], "float32")) bb = relax.BlockBuilder() @@ -1582,14 +1582,14 @@ def bar(x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32" z = R.call_dps_packed("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,), dtype="float32")) return z - m = tir.Var("m", "int64") + m = tirx.Var("m", "int64") x = relax.Var("x", relax.TensorStructInfo([m], "float32")) - y = relax.Var("y", relax.TensorStructInfo([tir.max(m, 20)], "float32")) + y = relax.Var("y", relax.TensorStructInfo([tirx.max(m, 20)], "float32")) bb = relax.BlockBuilder() with bb.function("bar", (x, y)): z = bb.emit( relax.call_dps_packed( - "test_intrin", (x, y), R.Tensor((tir.max(m, 20) + 1,), dtype="float32") + "test_intrin", (x, y), R.Tensor((tirx.max(m, 20) + 1,), dtype="float32") ) ) bb.emit_func_output(z) @@ -1606,7 +1606,7 @@ def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")): z = R.call_dps_packed("test_intrin", y, R.Tensor((m * 2,), dtype="float32")) return z - m = tir.Var("m", "int64") + m = tirx.Var("m", "int64") x = relax.Var("x", relax.ShapeStructInfo([m])) y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32")) bb = relax.BlockBuilder() @@ -1626,7 +1626,7 @@ def baz(x: R.Prim(value="m"), y: R.Tensor(("m * 2",), "float32")): z = R.call_dps_packed("test_intrin", y, R.Tensor((m * 2,), dtype="float32")) return z - m = tir.Var("m", "int64") + m = tirx.Var("m", "int64") x = relax.Var("x", relax.PrimStructInfo(value=m)) y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32")) bb = relax.BlockBuilder() @@ -1676,8 +1676,8 @@ def foo(x: R.Tensor(("m", "n"), "float32"), y: R.Tensor(("m", "n"), "float32")): t2 = tuple_expr[0][0] # <= Will normalize to two bindings return (a0, a1, a2, a3, a4, a5, a6, c0, c1, c2, c3, t0, t1, t2) - m = tir.Var("m", "int64") - n = tir.Var("n", "int64") + m = tirx.Var("m", "int64") + n = tirx.Var("n", "int64") x = relax.Var("x", relax.TensorStructInfo([m, n], "float32")) y = relax.Var("y", relax.TensorStructInfo([m, n], "float32")) bb = relax.BlockBuilder() diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 6b6a4a90be80..c50c1fcb254d 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -19,10 +19,10 @@ import tvm import tvm.testing -from tvm import IRModule, relax, tir +from tvm import IRModule, relax, tirx from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def _assert_print(obj, expected): @@ -156,7 +156,7 @@ def test_shape_struct_info_1(): def test_shape_struct_info_2(): - obj = relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]) + obj = relax.ShapeStructInfo([1, tirx.Var("a", "int64"), 3]) _assert_print( obj, """ @@ -167,7 +167,7 @@ def test_shape_struct_info_2(): def test_tensor_struct_info(): obj = relax.TensorStructInfo( - shape=relax.ShapeExpr([1, tir.Var("a", "int64"), 3]), + shape=relax.ShapeExpr([1, tirx.Var("a", "int64"), 3]), dtype="float32", ) _assert_print( @@ -189,7 +189,7 @@ def test_tuple_struct_info(): [ relax.PrimStructInfo("float32"), relax.ObjectStructInfo(), - relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]), + relax.ShapeStructInfo([1, tirx.Var("a", "int64"), 3]), ] ) _assert_print( @@ -206,8 +206,8 @@ def test_func_struct_info(): params=[ relax.PrimStructInfo("float32"), relax.ObjectStructInfo(), - relax.ShapeStructInfo([1, tir.Var("a", "int64"), 3]), - relax.PrimStructInfo(value=tir.Var("b", "int64")), + relax.ShapeStructInfo([1, tirx.Var("a", "int64"), 3]), + relax.PrimStructInfo(value=tirx.Var("b", "int64")), ], ret=relax.TensorStructInfo( shape=relax.ShapeExpr([1, 2, 3]), @@ -284,7 +284,7 @@ def test_data_type_imm(): def test_var(): - obj = relax.Var("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3], "float32")) + obj = relax.Var("a", relax.TensorStructInfo([1, tirx.Var("x", "int64"), 3], "float32")) _assert_print( obj, """ @@ -295,7 +295,7 @@ def test_var(): def test_dataflow_var(): - obj = relax.DataflowVar("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3], "float32")) + obj = relax.DataflowVar("a", relax.TensorStructInfo([1, tirx.Var("x", "int64"), 3], "float32")) _assert_print( obj, """ @@ -308,9 +308,9 @@ def test_dataflow_var(): def test_tuple(): obj = relax.Tuple( [ - relax.Var("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3], "float32")), - relax.Var("b", relax.TensorStructInfo([1, tir.Var("y", "int64"), 3], "float32")), - relax.Var("c", relax.TensorStructInfo([1, tir.Var("z", "int64"), 3], "float32")), + relax.Var("a", relax.TensorStructInfo([1, tirx.Var("x", "int64"), 3], "float32")), + relax.Var("b", relax.TensorStructInfo([1, tirx.Var("y", "int64"), 3], "float32")), + relax.Var("c", relax.TensorStructInfo([1, tirx.Var("z", "int64"), 3], "float32")), ] ) _assert_print( @@ -331,9 +331,9 @@ def test_tuple_get_item(): obj = relax.TupleGetItem( relax.Tuple( [ - relax.Var("a", relax.TensorStructInfo([1, tir.Var("x", "int64"), 3], "float32")), - relax.Var("b", relax.TensorStructInfo([1, tir.Var("y", "int64"), 3], "float32")), - relax.Var("c", relax.TensorStructInfo([1, tir.Var("z", "int64"), 3], "float32")), + relax.Var("a", relax.TensorStructInfo([1, tirx.Var("x", "int64"), 3], "float32")), + relax.Var("b", relax.TensorStructInfo([1, tirx.Var("y", "int64"), 3], "float32")), + relax.Var("c", relax.TensorStructInfo([1, tirx.Var("z", "int64"), 3], "float32")), ] ), 0, @@ -358,7 +358,7 @@ def test_shape_expr(): def test_call(): - x = tir.Var("x", "int64") + x = tirx.Var("x", "int64") a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) o0 = relax.call_tir(relax.GlobalVar("tir_func"), args=a, out_sinfo=a.struct_info, tir_vars=[x]) o1 = relax.call_dps_packed("my_dps_func", args=a, out_sinfo=a.struct_info) @@ -381,7 +381,7 @@ def test_call(): def test_call_tir_with_grad(): - x = tir.Var("x", "int64") + x = tirx.Var("x", "int64") v0 = relax.Var("v0", R.Tensor([54, 96], "float32")) v1 = relax.call_tir_with_grad( relax.GlobalVar("tir_func"), @@ -404,7 +404,7 @@ def test_call_tir_with_grad(): def test_call_tir_inplace(): x = relax.Var("x", R.Tensor((32, 32), dtype="int32")) y = relax.Var("y", R.Tensor((32, 32), dtype="int32")) - t = tir.Var("t", dtype="int64") + t = tirx.Var("t", dtype="int64") call = relax.call_tir_inplace( relax.GlobalVar("tir_func"), ( @@ -427,7 +427,7 @@ def test_call_tir_inplace(): def test_seq_expr(): - x = tir.Var("x", "int64") + x = tirx.Var("x", "int64") a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) b = relax.DataflowVar("b", relax.TensorStructInfo([1, x, 3], "float32")) c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32")) @@ -458,7 +458,7 @@ def test_seq_expr(): def test_binding_block(): - x = tir.Var("x", "int64") + x = tirx.Var("x", "int64") a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) b = relax.Var("b", relax.TensorStructInfo([1, x, 3], "float32")) c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32")) @@ -480,7 +480,7 @@ def test_binding_block(): def test_dataflow_block(): - x = tir.Var("x", "int64") + x = tirx.Var("x", "int64") a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) b = relax.DataflowVar("b", relax.TensorStructInfo([1, x, 3], "float32")) c = relax.Var("c", relax.TensorStructInfo([1, x, 3], "float32")) @@ -504,7 +504,7 @@ def test_dataflow_block(): def test_match_cast(): - x = tir.Var("x", "int64") + x = tirx.Var("x", "int64") a = relax.Var("a", relax.TensorStructInfo([1, x, 3])) b = relax.Var("b", relax.TensorStructInfo([1, 5, 3])) obj = relax.MatchCast( @@ -523,7 +523,7 @@ def test_match_cast(): def test_var_binding(): - x = tir.Var("x", "int64") + x = tirx.Var("x", "int64") a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) b = relax.Var("b", relax.TensorStructInfo([1, x, 3], "float32")) obj = relax.VarBinding(b, relax.op.sin(a)) @@ -561,7 +561,7 @@ def test_if(): def test_builtin_keywords(): - x = tir.Var("x", "int64") + x = tirx.Var("x", "int64") a = relax.Var("R", relax.TensorStructInfo([1, x, 3], "float32")) b = relax.Var("T", relax.TensorStructInfo([1, x, 3], "float32")) obj = relax.VarBinding(b, relax.op.sin(a)) @@ -595,7 +595,7 @@ def foo(x: R.Tensor((128,), "float32")) -> R.Tensor((128,), "float32"): TestModule, """ # from tvm.script import ir as I -# from tvm.script import tir as T +# from tvm.script import tirx as T # from tvm.script import relax as R @I.ir_module @@ -618,7 +618,7 @@ def foo(x: R.Tensor((128,), dtype="float32")) -> R.Tensor((128,), dtype="float32 module_str, """ # from tvm.script import ir as I -# from tvm.script import tir as T +# from tvm.script import tirx as T # from tvm.script import relax as R @I.ir_module diff --git a/tests/python/relax/test_tvmscript_pyfunc.py b/tests/python/relax/test_tvmscript_pyfunc.py index 601c8dd4f46e..2c8f84db4ba0 100644 --- a/tests/python/relax/test_tvmscript_pyfunc.py +++ b/tests/python/relax/test_tvmscript_pyfunc.py @@ -34,7 +34,7 @@ from tvm.relax import BasePyModule from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @I.ir_module @@ -63,7 +63,7 @@ def simple_tir_func( var_A: T.handle, var_B: T.handle, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int32() A = T.match_buffer(var_A, (n,), "float32") B = T.match_buffer(var_B, (n,), "float32") diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index 064c9f04747c..8eb3961ef8dd 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -24,7 +24,7 @@ from tvm import relax from tvm.ir.base import assert_structural_equal from tvm.script.parser import relax as R -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def test_copy_with_new_vars(): diff --git a/tests/python/relax/test_vm_alloc_storage_with_scope.py b/tests/python/relax/test_vm_alloc_storage_with_scope.py index d9355cd6648e..3db64b13e9ed 100644 --- a/tests/python/relax/test_vm_alloc_storage_with_scope.py +++ b/tests/python/relax/test_vm_alloc_storage_with_scope.py @@ -23,7 +23,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T @I.ir_module diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index d2ab2a7c41ad..ad716a1e71a9 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -25,14 +25,14 @@ import tvm import tvm.script import tvm.testing -from tvm import relax, rpc, te, tir, topi +from tvm import relax, rpc, te, tirx, topi from tvm.contrib import cc, popen_pool, utils from tvm.relax.testing import nn from tvm.relax.testing.vm import check_saved_func from tvm.runtime import ShapeTuple from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T EXEC_MODE = ["bytecode", "compiled"] @@ -238,7 +238,7 @@ def copy( out1: T.Buffer((2, 3), "int32"), ): # copies the contents of C into A, B, and out1 - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -292,7 +292,7 @@ class TestCallTIRInplaceE2ERW: @T.prim_func def inplace_add(A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32")): # sums A and B, storing the result in A - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_add"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -331,7 +331,7 @@ def test_vm_emit_te_extern(exec_mode): print("skip because extern function is not available") return bb = relax.BlockBuilder() - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = relax.Var("x", R.Tensor([n, m], "float32")) y = relax.Var("y", R.Tensor([m, n], "float32")) @@ -355,12 +355,12 @@ def test_vm_emit_te_extern(exec_mode): def test_vm_emit_te_concat(exec_mode): # concatenate of two vectors of size (n,) and (m,) bb = relax.BlockBuilder() - n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + n, m = tirx.Var("n", "int64"), tirx.Var("m", "int64") x = relax.Var("x", R.Tensor([n], "float32")) y = relax.Var("y", R.Tensor([m], "float32")) def te_func(A, B): - C = te.compute((n + m), lambda i: tvm.tir.if_then_else(i < n, A[i], B[i - n])) + C = te.compute((n + m), lambda i: tvm.tirx.if_then_else(i < n, A[i], B[i - n])) return C with bb.function("rx_func", [x, y]): @@ -391,7 +391,7 @@ def te_func(A, B): def test_vm_emit_te_dtype_change(exec_mode): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") x = relax.Var("x", R.Tensor([n], "float32")) # convert a tensor with dtype of float32 to int16 @@ -420,11 +420,11 @@ def te_func(A): def test_vm_emit_te_floor_symbolic_shape(exec_mode): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") x = relax.Var("x", R.Tensor([n], "float32")) def te_func(A): - C = te.compute((tir.floordiv(n, 2),), lambda i: A[i] + 1) + C = te.compute((tirx.floordiv(n, 2),), lambda i: A[i] + 1) return C with bb.function("rx_func", [x]): @@ -499,7 +499,7 @@ def test_vm_emit_te_constant_param_gpu(exec_mode): def test_vm_relax_symbolic_shape(exec_mode): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") x = relax.Var("x", R.Tensor([n], "float32")) y = relax.Var("y", R.Tensor([(n // 2) + 1], "float32")) @@ -648,7 +648,7 @@ def main( def test_vm_relax_dyn_tir_shape(exec_mode): # case where TIR variables are unbound in generated PrimFunc bb = relax.BlockBuilder() - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") def te_func(A): C = te.compute((n + 1), lambda i: A[i]) @@ -680,7 +680,7 @@ def te_func(A): def test_vm_tuple(exec_mode): bb = relax.BlockBuilder() - n = tir.Var("n", "int64") + n = tirx.Var("n", "int64") with bb.function("rx_func"): x = nn.Placeholder((n,), dtype="float32", name="x") diff --git a/tests/python/relax/test_vm_builtin_lower.py b/tests/python/relax/test_vm_builtin_lower.py index e4b957b0e109..50bbc4fb7246 100644 --- a/tests/python/relax/test_vm_builtin_lower.py +++ b/tests/python/relax/test_vm_builtin_lower.py @@ -22,7 +22,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def test_vm_builtin_lower_mem_alloc_storage(): diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index f94ba675b341..810847193474 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -30,7 +30,7 @@ from tvm.relax.testing.vm import check_saved_func from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T EXEC_MODE = ["bytecode", "compiled"] @@ -367,7 +367,7 @@ def test_vm_kill_object(exec_mode): class TestKillObject: @T.prim_func def full(T_full: T.Buffer((T.int64(4),), "float32")): - T.func_attr({"global_symbol": "full", "tir.noalias": True}) + T.func_attr({"global_symbol": "full", "tirx.noalias": True}) for ax0 in range(T.int64(4)): with T.sblock("T_full"): v_ax0 = T.axis.spatial(T.int64(4), ax0) @@ -377,7 +377,7 @@ def full(T_full: T.Buffer((T.int64(4),), "float32")): @T.prim_func def full1(T_full: T.Buffer((T.int64(4),), "float32")): - T.func_attr({"global_symbol": "full1", "tir.noalias": True}) + T.func_attr({"global_symbol": "full1", "tirx.noalias": True}) for ax0 in range(T.int64(4)): with T.sblock("T_full"): v_ax0 = T.axis.spatial(T.int64(4), ax0) diff --git a/tests/python/relax/test_vm_codegen_tir.py b/tests/python/relax/test_vm_codegen_tir.py index 756e14a66c50..5e0e61e8a2c1 100644 --- a/tests/python/relax/test_vm_codegen_tir.py +++ b/tests/python/relax/test_vm_codegen_tir.py @@ -24,7 +24,7 @@ from tvm import relax from tvm.ir import assert_structural_equal from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def get_tir_mod(mod): @@ -119,7 +119,7 @@ def __vmtir__ife(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): T.func_attr({"global_symbol": "__vmtir__ife"}) if T.Call( "bool", - tvm.ir.Op.get("tir.tvm_call_packed"), + tvm.ir.Op.get("tirx.tvm_call_packed"), ["vm.builtin.read_if_cond", T.anylist_getitem(r, T.int32(0))], ): T.anylist_setitem_call_packed( diff --git a/tests/python/relax/test_vm_cuda_graph.py b/tests/python/relax/test_vm_cuda_graph.py index 91e7adb8aa11..b2eccb2fa88b 100644 --- a/tests/python/relax/test_vm_cuda_graph.py +++ b/tests/python/relax/test_vm_cuda_graph.py @@ -24,7 +24,7 @@ from tvm import relax from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T # fmt: off diff --git a/tests/python/relax/texture/test_texture_nd.py b/tests/python/relax/texture/test_texture_nd.py index ce1a28f2de03..cf725e208606 100644 --- a/tests/python/relax/texture/test_texture_nd.py +++ b/tests/python/relax/texture/test_texture_nd.py @@ -28,13 +28,13 @@ DataType, IRModule, relax, - tir, + tirx, ) from tvm.contrib import ndk from tvm.relax.transform.legalize_ops import adreno as legalize_adreno from tvm.rpc import connect_tracker from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -62,11 +62,11 @@ def preprocess_pipeline(mod: IRModule) -> IRModule: desired_layouts = {"relax.nn.conv2d": ["NCHW16c", "OIHW16o", "NCHW16c"]} seq = tvm.transform.Sequential( [ - tvm.tir.transform.BindTarget(Target.current(allow_none=False)), + tvm.tirx.transform.BindTarget(Target.current(allow_none=False)), tvm.relax.transform.FoldConstant(), tvm.relax.transform.DecomposeOpsForInference(), tvm.relax.transform.FoldConstant(), - tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False)), + tvm.tirx.transform.BindTarget(tvm.target.Target.current(allow_none=False)), tvm.relax.transform.ConvertLayout(desired_layouts), tvm.relax.transform.Normalize(), tvm.relax.transform.FoldConstant(), diff --git a/tests/python/runtime/test_evaluator_with_preproc.py b/tests/python/runtime/test_evaluator_with_preproc.py index 55630828e819..14462a50d454 100644 --- a/tests/python/runtime/test_evaluator_with_preproc.py +++ b/tests/python/runtime/test_evaluator_with_preproc.py @@ -20,7 +20,7 @@ import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func @@ -45,7 +45,7 @@ def test_time_evalutor_with_preproc(f_preproc: str): i, j, k = sch.get_loops(blk) sch.bind(i, "blockIdx.x") sch.bind(j, "threadIdx.x") - f = tvm.tir.build(sch.mod["main"], target="cuda") + f = tvm.tirx.build(sch.mod["main"], target="cuda") dev = tvm.cuda(0) evaluator = f.time_evaluator(f.entry_name, dev, repeat=1000, number=1, f_preproc=f_preproc) diff --git a/tests/python/runtime/test_executable.py b/tests/python/runtime/test_executable.py index 4d6830b8b6a4..b4ccfcdb4026 100644 --- a/tests/python/runtime/test_executable.py +++ b/tests/python/runtime/test_executable.py @@ -24,7 +24,7 @@ import tvm import tvm.testing from tvm.runtime import Executable -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.script.ir_module @@ -41,7 +41,7 @@ def add( def test_executable_init(): """Test initialization of Executable class.""" - lib = tvm.tir.build(MyModule, target="llvm") + lib = tvm.tirx.build(MyModule, target="llvm") executable = Executable(lib) assert executable.mod is lib @@ -50,7 +50,7 @@ def test_executable_init(): def test_executable_getitem(): """Test __getitem__ method of Executable class.""" - lib = tvm.tir.build(MyModule, target="llvm") + lib = tvm.tirx.build(MyModule, target="llvm") executable = Executable(lib) # Jit the module first @@ -72,7 +72,7 @@ def test_executable_getitem(): def test_executable_jit_already_jitted(): """Test jit method when module is already jitted.""" - lib = tvm.tir.build(MyModule, target="llvm") + lib = tvm.tirx.build(MyModule, target="llvm") executable = Executable(lib) # First jit call @@ -101,7 +101,7 @@ def test_executable_jit_already_jitted(): def test_executable_export_library(): """Test export_library method.""" - lib = tvm.tir.build(MyModule, target="llvm") + lib = tvm.tirx.build(MyModule, target="llvm") executable = Executable(lib) # Create a temporary directory for the library @@ -136,7 +136,7 @@ def test_executable_export_library(): def test_executable_export_library_with_workspace(): """Test export_library method with workspace_dir.""" - lib = tvm.tir.build(MyModule, target="llvm") + lib = tvm.tirx.build(MyModule, target="llvm") executable = Executable(lib) # Create temporary directories @@ -176,7 +176,7 @@ def test_executable_integration(): """Integration test for Executable with a simple TVM module.""" # Create target and build target = tvm.target.Target("llvm") - lib = tvm.tir.build(MyModule, target=target) + lib = tvm.tirx.build(MyModule, target=target) # Create an executable executable = Executable(lib) @@ -232,7 +232,7 @@ def test_executable_jit_force_recompile(): """Test jit method with force_recompile=True.""" # Create target and build target = tvm.target.Target("c") - lib = tvm.tir.build(MyModule, target=target) + lib = tvm.tirx.build(MyModule, target=target) # Create an executable executable = Executable(lib) diff --git a/tests/python/runtime/test_runtime_container.py b/tests/python/runtime/test_runtime_container.py index 83f718383116..d603b25f72e5 100644 --- a/tests/python/runtime/test_runtime_container.py +++ b/tests/python/runtime/test_runtime_container.py @@ -125,11 +125,11 @@ def test_conversion_of_arg(): func = tvm.get_global_func("testing.AcceptsPrimExpr") res = func(1) - assert isinstance(res, tvm.tir.IntImm) + assert isinstance(res, tvm.tirx.IntImm) assert res.dtype == "int32" res = func(True) - assert isinstance(res, tvm.tir.IntImm) + assert isinstance(res, tvm.tirx.IntImm) assert res.dtype == "bool" @@ -146,9 +146,9 @@ def test_conversion_of_array_elements(): func = tvm.get_global_func("testing.AcceptsArrayOfPrimExpr") res = func([1, False]) - assert isinstance(res[0], tvm.tir.IntImm) + assert isinstance(res[0], tvm.tirx.IntImm) assert res[0].dtype == "int32" - assert isinstance(res[1], tvm.tir.IntImm) + assert isinstance(res[1], tvm.tirx.IntImm) assert res[1].dtype == "bool" @@ -165,9 +165,9 @@ def test_conversion_of_map_values(): func = tvm.get_global_func("testing.AcceptsMapOfPrimExpr") res = func({"a": 1, "b": False}) - assert isinstance(res["a"], tvm.tir.IntImm) + assert isinstance(res["a"], tvm.tirx.IntImm) assert res["a"].dtype == "int32" - assert isinstance(res["b"], tvm.tir.IntImm) + assert isinstance(res["b"], tvm.tirx.IntImm) assert res["b"].dtype == "bool" diff --git a/tests/python/runtime/test_runtime_extension.py b/tests/python/runtime/test_runtime_extension.py index e2500c9c91bc..65d9afd9cee2 100644 --- a/tests/python/runtime/test_runtime_extension.py +++ b/tests/python/runtime/test_runtime_extension.py @@ -18,7 +18,7 @@ import tvm from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_dltensor_compatible(): diff --git a/tests/python/runtime/test_runtime_measure.py b/tests/python/runtime/test_runtime_measure.py index 9e52e48d59d2..2559709072b0 100644 --- a/tests/python/runtime/test_runtime_measure.py +++ b/tests/python/runtime/test_runtime_measure.py @@ -35,8 +35,8 @@ def my_debug(filename): with open(filename, "a") as fout: fout.write("c") - X = te.compute((), lambda: tvm.tir.call_packed("my_debug", filename)) - func = tvm.tir.build(te.create_prim_func([X])) + X = te.compute((), lambda: tvm.tirx.call_packed("my_debug", filename)) + func = tvm.tirx.build(te.create_prim_func([X])) x = tvm.runtime.empty((), dtype="int32") ftimer = func.time_evaluator(func.entry_name, tvm.cpu(), number=1, repeat=1) diff --git a/tests/python/runtime/test_runtime_module_export.py b/tests/python/runtime/test_runtime_module_export.py index 30f311677d7b..47a1ffd41f2e 100644 --- a/tests/python/runtime/test_runtime_module_export.py +++ b/tests/python/runtime/test_runtime_module_export.py @@ -34,8 +34,8 @@ def test_import_static_library(): te.create_prim_func([A, B]).with_attr("global_symbol", "myadd1") ) - mod0 = tvm.tir.build(irmod0, target="llvm") - mod1 = tvm.tir.build(irmod1, target="llvm") + mod0 = tvm.tirx.build(irmod0, target="llvm") + mod1 = tvm.tirx.build(irmod1, target="llvm") assert mod0.implements_function("myadd0") assert mod1.implements_function("myadd1") diff --git a/tests/python/runtime/test_runtime_module_load.py b/tests/python/runtime/test_runtime_module_load.py index 6910ad8d4b4d..38ac7e36e8b8 100644 --- a/tests/python/runtime/test_runtime_module_load.py +++ b/tests/python/runtime/test_runtime_module_load.py @@ -51,20 +51,20 @@ def test_dso_module_load(target): def save_object(names): n = te.size_var("n") - Ab = tvm.tir.decl_buffer((n,), dtype) + Ab = tvm.tirx.decl_buffer((n,), dtype) i = te.var("i") # for i in 0 to n-1: - stmt = tvm.tir.For( + stmt = tvm.tirx.For( i, 0, n - 1, - tvm.tir.ForKind.SERIAL, - tvm.tir.BufferStore(Ab, tvm.tir.BufferLoad(Ab, [i]) + 1, [i + 1]), + tvm.tirx.ForKind.SERIAL, + tvm.tirx.BufferStore(Ab, tvm.tirx.BufferLoad(Ab, [i]) + 1, [i + 1]), ) mod = tvm.IRModule.from_expr( - tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "main") + tvm.tirx.PrimFunc([Ab], stmt).with_attr("global_symbol", "main") ) - m = tvm.tir.build(mod, target=target) + m = tvm.tirx.build(mod, target=target) for name in names: m.write_to_file(name) @@ -166,8 +166,8 @@ def test_combine_module_llvm(): def check_llvm(): dev = tvm.cpu(0) temp = utils.tempdir() - fadd1 = tvm.tir.build(mod1, "llvm") - fadd2 = tvm.tir.build(mod2, "llvm") + fadd1 = tvm.tirx.build(mod1, "llvm") + fadd2 = tvm.tirx.build(mod2, "llvm") path1 = temp.relpath("myadd1.o") path2 = temp.relpath("myadd2.o") path_dso = temp.relpath("mylib.so") @@ -192,8 +192,8 @@ def check_system_lib(): return temp = utils.tempdir() print("Running popen check") - fadd1 = tvm.tir.build(mod1.with_attr("system_lib_prefix", ""), "llvm") - fadd2 = tvm.tir.build(mod2.with_attr("system_lib_prefix", ""), "llvm") + fadd1 = tvm.tirx.build(mod1.with_attr("system_lib_prefix", ""), "llvm") + fadd2 = tvm.tirx.build(mod2.with_attr("system_lib_prefix", ""), "llvm") path1 = temp.relpath("myadd1.o") path2 = temp.relpath("myadd2.o") path_dso = temp.relpath("mylib.so") diff --git a/tests/python/runtime/test_runtime_module_property.py b/tests/python/runtime/test_runtime_module_property.py index a8a875421914..89bac851e9d2 100644 --- a/tests/python/runtime/test_runtime_module_property.py +++ b/tests/python/runtime/test_runtime_module_property.py @@ -33,7 +33,7 @@ def create_csource_module(): def create_llvm_module(): A = te.placeholder((1024,), name="A") B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B") - return tvm.tir.build(te.create_prim_func([A, B]), target="llvm") + return tvm.tirx.build(te.create_prim_func([A, B]), target="llvm") def test_property(): diff --git a/tests/python/runtime/test_runtime_rpc.py b/tests/python/runtime/test_runtime_rpc.py index e2eec960c287..5600c7f88756 100644 --- a/tests/python/runtime/test_runtime_rpc.py +++ b/tests/python/runtime/test_runtime_rpc.py @@ -33,7 +33,7 @@ from tvm.rpc.proxy import Proxy from tvm.rpc.tracker import Tracker from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T if __name__ == "__main__": # NOTE: must live here to avoid registering PackedFunc with libtvm.so twice. @@ -71,7 +71,7 @@ def test_bigendian_rpc(): def verify_rpc(remote, target, shape, dtype): A = te.placeholder(shape, dtype=dtype) - B = te.compute(A.shape, lambda i: A[i] + tvm.tir.const(1, A.dtype)) + B = te.compute(A.shape, lambda i: A[i] + tvm.tirx.const(1, A.dtype)) f = tvm.compile(te.create_prim_func([A, B]), target=target) dev = remote.cpu(0) diff --git a/tests/python/runtime/test_runtime_trace.py b/tests/python/runtime/test_runtime_trace.py index f0806e88c8e7..2ab8dbeb7d59 100644 --- a/tests/python/runtime/test_runtime_trace.py +++ b/tests/python/runtime/test_runtime_trace.py @@ -24,7 +24,7 @@ def test_trace_default_action(): n = 2 x = te.placeholder((n, n, n), name="X", dtype="float32") - y = te.compute(x.shape, lambda i, j, k: tvm.tir.trace([i, j, k, x[i][j][k]])) + y = te.compute(x.shape, lambda i, j, k: tvm.tirx.trace([i, j, k, x[i][j][k]])) f = tvm.compile(te.create_prim_func([x, y]), target="llvm") xnd = tvm.runtime.tensor(np.ones((n, n, n), dtype=x.dtype)) ynd = tvm.runtime.tensor(np.zeros((n, n, n), dtype=y.dtype)) @@ -32,7 +32,7 @@ def test_trace_default_action(): def test_trace_expr_assign(): - @tvm.register_global_func("tvm.tir.trace_callback2") + @tvm.register_global_func("tvm.tirx.trace_callback2") def trace_buffer(x): return @@ -40,10 +40,10 @@ def check_assign(dtype): n = 4 x = te.placeholder((n, n, n), name="X", dtype=dtype) y = te.compute( - x.shape, lambda i, j, k: tvm.tir.trace([x[i][j][k]], "tvm.tir.trace_callback2") + x.shape, lambda i, j, k: tvm.tirx.trace([x[i][j][k]], "tvm.tirx.trace_callback2") ) z = te.compute( - x.shape, lambda i, j, k: tvm.tir.trace([y[i][j][k]], "tvm.tir.trace_callback2") + x.shape, lambda i, j, k: tvm.tirx.trace([y[i][j][k]], "tvm.tirx.trace_callback2") ) f = tvm.compile(te.create_prim_func([x, y, z]), "llvm") @@ -61,7 +61,7 @@ def check_assign(dtype): def test_trace_expr_sum_generated(): - @tvm.register_global_func("tvm.tir.trace_callback3") + @tvm.register_global_func("tvm.tirx.trace_callback3") def trace_buffer(x): return @@ -72,8 +72,8 @@ def check_expr_sum(dtype): c = te.compute( a.shape, lambda i, j, k: ( - tvm.tir.trace([a[i][j][k]], "tvm.tir.trace_callback3") - + tvm.tir.trace([b[i][j][k]], "tvm.tir.trace_callback3") + tvm.tirx.trace([a[i][j][k]], "tvm.tirx.trace_callback3") + + tvm.tirx.trace([b[i][j][k]], "tvm.tirx.trace_callback3") ), ) f = tvm.compile(te.create_prim_func([a, b, c])) @@ -88,7 +88,7 @@ def check_expr_sum(dtype): def test_trace_expr_sum_args(): - @tvm.register_global_func("tvm.tir.trace_silent") + @tvm.register_global_func("tvm.tirx.trace_silent") def silent(*args): return @@ -102,10 +102,10 @@ def check_expr_sum(dtype): c = te.compute( a.shape, lambda i, j, k: ( - tvm.tir.trace([i, j, k, a[i][j][k]], "tvm.tir.trace_silent") - + tvm.tir.trace([i, j, k, b[i][j][k]], "tvm.tir.trace_silent") - + tvm.tir.trace([i, j, k, d[i][j][k]], "tvm.tir.trace_silent") - + tvm.tir.trace([i, j, k, e[i][j][k]], "tvm.tir.trace_silent") + tvm.tirx.trace([i, j, k, a[i][j][k]], "tvm.tirx.trace_silent") + + tvm.tirx.trace([i, j, k, b[i][j][k]], "tvm.tirx.trace_silent") + + tvm.tirx.trace([i, j, k, d[i][j][k]], "tvm.tirx.trace_silent") + + tvm.tirx.trace([i, j, k, e[i][j][k]], "tvm.tirx.trace_silent") ), ) f = tvm.compile(te.create_prim_func([a, b, d, e, c])) @@ -124,7 +124,7 @@ def check_expr_sum(dtype): def test_trace_expr_sum_custom(): - @tvm.register_global_func("tvm.tir.trace_callback4") + @tvm.register_global_func("tvm.tirx.trace_callback4") def trace_buffer(x): return @@ -135,8 +135,8 @@ def check_expr_sum_custom(dtype): c = te.compute( a.shape, lambda i, j: ( - tvm.tir.trace([a[i][j]], "tvm.tir.trace_callback4") - + tvm.tir.trace([b[i][j]], "tvm.tir.trace_callback4") + tvm.tirx.trace([a[i][j]], "tvm.tirx.trace_callback4") + + tvm.tirx.trace([b[i][j]], "tvm.tirx.trace_callback4") ), ) f = tvm.compile(te.create_prim_func([a, b, c])) @@ -153,19 +153,21 @@ def check_expr_sum_custom(dtype): def test_trace_can_change_traced_value_int(): - @tvm.register_global_func("tvm.tir.trace_change_int_first") + @tvm.register_global_func("tvm.tirx.trace_change_int_first") def trace_buffer(x): return 13 - @tvm.register_global_func("tvm.tir.trace_change_int_second") + @tvm.register_global_func("tvm.tirx.trace_change_int_second") def trace_buffer(x): return 14 def check_assign(dtype): n = 4 x = te.placeholder((n,), name="X", dtype=dtype) - y = te.compute(x.shape, lambda i: tvm.tir.trace([x[i]], "tvm.tir.trace_change_int_first")) - z = te.compute(x.shape, lambda i: tvm.tir.trace([y[i]], "tvm.tir.trace_change_int_second")) + y = te.compute(x.shape, lambda i: tvm.tirx.trace([x[i]], "tvm.tirx.trace_change_int_first")) + z = te.compute( + x.shape, lambda i: tvm.tirx.trace([y[i]], "tvm.tirx.trace_change_int_second") + ) f = tvm.compile(te.create_prim_func([x, y, z])) xnd = tvm.runtime.tensor(np.ones((n,), dtype=x.dtype)) @@ -182,20 +184,22 @@ def check_assign(dtype): def test_trace_can_change_traced_value_float(): - @tvm.register_global_func("tvm.tir.trace_change_float_first") + @tvm.register_global_func("tvm.tirx.trace_change_float_first") def trace_buffer(x): return 13.0 - @tvm.register_global_func("tvm.tir.trace_change_float_second") + @tvm.register_global_func("tvm.tirx.trace_change_float_second") def trace_buffer(x): return 14.0 def check_assign(dtype): n = 4 x = te.placeholder((n,), name="X", dtype=dtype) - y = te.compute(x.shape, lambda i: tvm.tir.trace([x[i]], "tvm.tir.trace_change_float_first")) + y = te.compute( + x.shape, lambda i: tvm.tirx.trace([x[i]], "tvm.tirx.trace_change_float_first") + ) z = te.compute( - x.shape, lambda i: tvm.tir.trace([y[i]], "tvm.tir.trace_change_float_second") + x.shape, lambda i: tvm.tirx.trace([y[i]], "tvm.tirx.trace_change_float_second") ) f = tvm.compile(te.create_prim_func([x, y, z]), target="llvm") diff --git a/tests/python/s_tir/analysis/test_s_tir_analysis_calculate_allocated_memory.py b/tests/python/s_tir/analysis/test_s_tir_analysis_calculate_allocated_memory.py index 5b8145274395..769527b4da64 100644 --- a/tests/python/s_tir/analysis/test_s_tir_analysis_calculate_allocated_memory.py +++ b/tests/python/s_tir/analysis/test_s_tir_analysis_calculate_allocated_memory.py @@ -19,8 +19,8 @@ import pytest import tvm -from tvm import tir -from tvm.script import tir as T +from tvm import tirx +from tvm.script import tirx as T # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks diff --git a/tests/python/s_tir/analysis/test_s_tir_analysis_estimate_tir_flops.py b/tests/python/s_tir/analysis/test_s_tir_analysis_estimate_tir_flops.py index ce1891d67928..a59faedf3698 100644 --- a/tests/python/s_tir/analysis/test_s_tir_analysis_estimate_tir_flops.py +++ b/tests/python/s_tir/analysis/test_s_tir_analysis_estimate_tir_flops.py @@ -24,7 +24,7 @@ from tvm.ir import IRModule from tvm.s_tir.analysis import estimate_tir_flops from tvm.s_tir.meta_schedule.testing.te_workload import create_te_workload -from tvm.script import tir as T +from tvm.script import tirx as T @pytest.mark.parametrize( diff --git a/tests/python/s_tir/analysis/test_s_tir_analysis_identify_memcpy.py b/tests/python/s_tir/analysis/test_s_tir_analysis_identify_memcpy.py index 47003ad140ad..e22c2ceebea1 100644 --- a/tests/python/s_tir/analysis/test_s_tir_analysis_identify_memcpy.py +++ b/tests/python/s_tir/analysis/test_s_tir_analysis_identify_memcpy.py @@ -23,8 +23,8 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T -from tvm.tir import BufferRegion, StringImm +from tvm.script import tirx as T +from tvm.tirx import BufferRegion, StringImm identify_memcpy = tvm.s_tir.analysis._ffi_api._identify_memcpy diff --git a/tests/python/s_tir/analysis/test_s_tir_analysis_is_pure_function.py b/tests/python/s_tir/analysis/test_s_tir_analysis_is_pure_function.py index 862983f92a44..57e7d4dbdf4c 100644 --- a/tests/python/s_tir/analysis/test_s_tir_analysis_is_pure_function.py +++ b/tests/python/s_tir/analysis/test_s_tir_analysis_is_pure_function.py @@ -19,7 +19,7 @@ import tvm.testing from tvm.s_tir.analysis import assert_pure_function, is_pure_function -from tvm.script import tir as T +from tvm.script import tirx as T class CheckPureFunction: diff --git a/tests/python/s_tir/analysis/test_s_tir_analysis_oob.py b/tests/python/s_tir/analysis/test_s_tir_analysis_oob.py index a99b3253ef90..252f2f0fb80f 100644 --- a/tests/python/s_tir/analysis/test_s_tir_analysis_oob.py +++ b/tests/python/s_tir/analysis/test_s_tir_analysis_oob.py @@ -17,7 +17,7 @@ import pytest import tvm -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func diff --git a/tests/python/s_tir/analysis/test_sblock_access_region.py b/tests/python/s_tir/analysis/test_sblock_access_region.py index b008488c70e2..039363644110 100644 --- a/tests/python/s_tir/analysis/test_sblock_access_region.py +++ b/tests/python/s_tir/analysis/test_sblock_access_region.py @@ -20,7 +20,7 @@ import tvm.testing from tvm import s_tir from tvm.ir import Range -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func @@ -216,7 +216,7 @@ def test_block_access_region_detector(): tvm.ir.assert_structural_equal(block.writes, ret[1]) D = alloc_buffers[-1] tvm.ir.assert_structural_equal( - [tvm.tir.BufferRegion(D, [Range(0, 128), Range(0, 128)])], ret[2] + [tvm.tirx.BufferRegion(D, [Range(0, 128), Range(0, 128)])], ret[2] ) diff --git a/tests/python/s_tir/analysis/test_sblock_buffer_access_lca.py b/tests/python/s_tir/analysis/test_sblock_buffer_access_lca.py index 2593d09a31fa..91c1bb366052 100644 --- a/tests/python/s_tir/analysis/test_sblock_buffer_access_lca.py +++ b/tests/python/s_tir/analysis/test_sblock_buffer_access_lca.py @@ -17,7 +17,7 @@ # ruff: noqa: F401 import tvm from tvm import s_tir -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func diff --git a/tests/python/s_tir/base/test_sblock_dependence_info.py b/tests/python/s_tir/base/test_sblock_dependence_info.py index 5a48bfd72466..6385ac13e418 100644 --- a/tests/python/s_tir/base/test_sblock_dependence_info.py +++ b/tests/python/s_tir/base/test_sblock_dependence_info.py @@ -23,13 +23,13 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.ir import IRModule from tvm.s_tir import SBlockDependenceInfo from tvm.s_tir.sblock_scope import DepKind -from tvm.script import tir as T -from tvm.tir import PrimFunc -from tvm.tir.stmt_functor import post_order_visit +from tvm.script import tirx as T +from tvm.tirx import PrimFunc +from tvm.tirx.stmt_functor import post_order_visit # pylint: disable=no-member,invalid-name,unused-variable @@ -90,10 +90,10 @@ def get_sblocks(func: PrimFunc): blocks = {} def update_blocks(node): - if isinstance(node, tvm.tir.SBlock): + if isinstance(node, tvm.tirx.SBlock): blocks[node.name_hint] = node - # post_order_visit(func.body, lambda node: blocks[node.name_hint] = node if isinstance(node, tvm.tir.SBlock) else None) + # post_order_visit(func.body, lambda node: blocks[node.name_hint] = node if isinstance(node, tvm.tirx.SBlock) else None) post_order_visit(func.body, update_blocks) return blocks diff --git a/tests/python/s_tir/base/test_tir_te_extern_primfunc.py b/tests/python/s_tir/base/test_tir_te_extern_primfunc.py index 7ea7728da57e..cc8ea82e887f 100644 --- a/tests/python/s_tir/base/test_tir_te_extern_primfunc.py +++ b/tests/python/s_tir/base/test_tir_te_extern_primfunc.py @@ -24,7 +24,7 @@ import tvm import tvm.testing from tvm import te -from tvm.script import tir as T +from tvm.script import tirx as T # TODO(csullivan): Additional tests cases needed: # - PrimFunc with 1 arg, inplace update diff --git a/tests/python/s_tir/dlight/test_benchmark.py b/tests/python/s_tir/dlight/test_benchmark.py index 6196227be0bc..c80440b63e34 100644 --- a/tests/python/s_tir/dlight/test_benchmark.py +++ b/tests/python/s_tir/dlight/test_benchmark.py @@ -34,7 +34,7 @@ from tvm.s_tir.meta_schedule.testing.local_rpc import LocalRPC from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T # The test function uses an undefined symbolic var in Relax. @@ -45,7 +45,7 @@ class Module: @T.prim_func def full1(var_T_full: T.handle): - T.func_attr({"op_pattern": 0, "tir.noalias": True}) + T.func_attr({"op_pattern": 0, "tirx.noalias": True}) n = T.int64() T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(32), T.int64(1), n), "float16") # with T.sblock("root"): @@ -58,7 +58,7 @@ def full1(var_T_full: T.handle): @T.prim_func def full2(var_T_full: T.handle): - T.func_attr({"op_pattern": 0, "tir.noalias": True}) + T.func_attr({"op_pattern": 0, "tirx.noalias": True}) n = T.int64() T_full = T.match_buffer(var_T_full, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") # with T.sblock("root"): @@ -71,7 +71,7 @@ def full2(var_T_full: T.handle): @T.prim_func def matmul1(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")): - T.func_attr({"op_pattern": 4, "tir.noalias": True}) + T.func_attr({"op_pattern": 4, "tirx.noalias": True}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16") B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") @@ -102,7 +102,7 @@ def test(): @T.prim_func def cuda_workload(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_matmul: T.handle): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) m = T.int64() inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096))) matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096))) diff --git a/tests/python/s_tir/dlight/test_cpu_gemv.py b/tests/python/s_tir/dlight/test_cpu_gemv.py index 76da6163f312..610a1acd9d7d 100644 --- a/tests/python/s_tir/dlight/test_cpu_gemv.py +++ b/tests/python/s_tir/dlight/test_cpu_gemv.py @@ -20,7 +20,7 @@ import tvm.testing from tvm.s_tir import dlight as dl -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -28,7 +28,7 @@ def test_gemv_basic(): # fmt: off @T.prim_func(private=True) def before(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_lv1614: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int32() lv1638 = T.match_buffer(p_lv1638, (1, 32, n, 128), "float16") lv1614 = T.match_buffer(p_lv1614, (1, 1, 1, n), "float16") @@ -73,7 +73,7 @@ def before(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_l @T.prim_func(private=True) def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_lv1614: T.handle, p_output0: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) n = T.int32() lv1638 = T.match_buffer(p_lv1638, (1, 32, n, 128), "float16") lv1614 = T.match_buffer(p_lv1614, (1, 1, 1, n), "float16") @@ -114,7 +114,7 @@ def test_decode_gemv_256_threads(): # fmt: off @T.prim_func(private=True) def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): p_output0_intermediate = T.sblock_alloc_buffer((22016, 4096), "float16") for i, j in T.grid(22016, 4096): @@ -134,7 +134,7 @@ def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128) @T.prim_func(private=True) def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): for u_fused in range(1): for ax0_fused_0 in T.parallel(172): @@ -162,7 +162,7 @@ def test_decode_gemv1(): @T.prim_func(private=True) def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): p_output0_intermediate = T.sblock_alloc_buffer((22016, 4096), "float16") for i, j in T.grid(22016, 4096): @@ -182,7 +182,7 @@ def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128) @T.prim_func(private=True) def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): for u_fused in range(1): for ax0_fused_0 in T.parallel(172): @@ -210,7 +210,7 @@ def test_decode_gemv2(): @T.prim_func(private=True) def before(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): p_output0_intermediate_1 = T.sblock_alloc_buffer((32000, 4096), "float16") var_NT_matmul_intermediate = T.sblock_alloc_buffer((1, 1, 32000), "float16") @@ -237,7 +237,7 @@ def before(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128) @T.prim_func(private=True) def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): var_NT_matmul_intermediate = T.sblock_alloc_buffer((1, 1, 32000), "float16") for u_fused in range(1): @@ -272,7 +272,7 @@ def test_decode_gemv3(): @T.prim_func(private=True) def before(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): p_output0_intermediate_1 = T.sblock_alloc_buffer((T.int64(4096), T.int64(11008)), "float16") var_NT_matmul_intermediate = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") @@ -299,7 +299,7 @@ def before(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.B @T.prim_func(private=True) def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): var_NT_matmul_intermediate = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") for u_fused in range(1): @@ -334,7 +334,7 @@ def test_autogptq_decode_gemv(): # fmt: off @T.prim_func(private=True) def func(lv9: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv10: T.Buffer((T.int64(32), T.int64(512)), "uint32"), lv11: T.Buffer((T.int64(32), T.int64(4096)), "float16"), lv12: T.Buffer((T.int64(4096),), "uint32"), lv8: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv1613: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): decode_intermediate = T.sblock_alloc_buffer((T.int64(4096), T.int64(4096)), "float16") var_matmul_intermediate = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") @@ -378,7 +378,7 @@ def before( lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): p_output0_intermediate_1 = T.sblock_alloc_buffer((11008, 4096), "float16") var_matmul_intermediate = T.sblock_alloc_buffer((1, 1, 4096), "float16") @@ -399,7 +399,7 @@ def before( @T.prim_func(private=True) def expected(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): p_output0_intermediate_1 = T.sblock_alloc_buffer((11008, 4096), "float16") var_matmul_intermediate = T.sblock_alloc_buffer((1, 1, 4096), "float16") @@ -434,7 +434,7 @@ def test_outer_reduction_adreno_dynamic(): # fmt: off @T.prim_func(private=True) def before(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) v = T.int64() lv612 = T.match_buffer(p_lv612, (T.int64(512), v), "uint32") lv613 = T.match_buffer(p_lv613, (T.int64(128), v), "float16") @@ -465,7 +465,7 @@ def before(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), T @T.prim_func(private=True) def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) v = T.int64() lv612 = T.match_buffer(p_lv612, (T.int64(512), v), "uint32") lv613 = T.match_buffer(p_lv613, (T.int64(128), v), "float16") @@ -524,7 +524,7 @@ def before(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "flo @T.prim_func(private=True) def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "float16"), indptr: T.Buffer((2,), "int32"), o: T.Buffer((2, 16384), "float16")): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) # with T.sblock("root"): for expert_id in T.thread_binding(2, thread="blockIdx.y"): with T.sblock("gemv_o"): diff --git a/tests/python/s_tir/dlight/test_gpu_conv.py b/tests/python/s_tir/dlight/test_gpu_conv.py index c393650f044f..aad1cf374980 100644 --- a/tests/python/s_tir/dlight/test_gpu_conv.py +++ b/tests/python/s_tir/dlight/test_gpu_conv.py @@ -19,7 +19,7 @@ import tvm import tvm.testing from tvm.s_tir import dlight as dl -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -45,7 +45,7 @@ def before( @T.prim_func(private=True) def expected(A: T.Buffer((14308, 3, 2, 14, 14), "float16"), W: T.Buffer((1280, 3, 2, 14, 14), "float16"), C: T.Buffer((14308, 1280, 1, 1, 1), "float16")): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) # with T.sblock("root"): C_reindex_pad_local = T.sblock_alloc_buffer((1, 14336, 1280), "float16", scope="local") pad_A_reindex_pad_shared = T.sblock_alloc_buffer((1, 14336, 1184), "float16", scope="shared") diff --git a/tests/python/s_tir/dlight/test_gpu_fallback.py b/tests/python/s_tir/dlight/test_gpu_fallback.py index bc49b92f9fd0..58f9cc4ad1c2 100644 --- a/tests/python/s_tir/dlight/test_gpu_fallback.py +++ b/tests/python/s_tir/dlight/test_gpu_fallback.py @@ -20,7 +20,7 @@ from tvm.ir import assert_structural_equal from tvm.s_tir import dlight as dl from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -49,7 +49,7 @@ def main( A: T.Buffer((1, 32, 1, 128), "float16"), C: T.Buffer((1, 1, 4096), "float16"), ): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) for ax0_fused_0 in T.thread_binding(4, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): with T.sblock("T_reshape"): @@ -85,7 +85,7 @@ def main(A: T.Buffer((1, 6144), "float32"), B: T.Buffer((1,), "float32")): class Expected: @T.prim_func def main(A: T.Buffer((1, 6144), "float32"), B: T.Buffer((1,), "float32")): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) for ax0_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): with T.sblock("block_init"): @@ -145,7 +145,7 @@ def func( # fmt: off @T.prim_func(private=True) def expected(var_pages: T.handle, var_page_table_indptr: T.handle, var_page_table_values: T.handle, var_values: T.handle, seq_id: T.int32): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) nhead = T.int32() nlayer = T.int32() seqlen = T.int32() @@ -226,7 +226,7 @@ def gpu_func( A: T.Buffer((1, 32, 1, 128), "float16"), C: T.Buffer((1, 1, 4096), "float16"), ): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) for ax0_fused_0 in T.thread_binding(4, thread="blockIdx.x"): for ax0_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): with T.sblock("T_reshape"): diff --git a/tests/python/s_tir/dlight/test_gpu_gemv.py b/tests/python/s_tir/dlight/test_gpu_gemv.py index ee07e7a0db2a..68aabbd09566 100644 --- a/tests/python/s_tir/dlight/test_gpu_gemv.py +++ b/tests/python/s_tir/dlight/test_gpu_gemv.py @@ -19,7 +19,7 @@ import tvm import tvm.testing from tvm.s_tir import dlight as dl -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -27,7 +27,7 @@ def test_gemv_basic(): # fmt: off @T.prim_func(private=True) def before(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_lv1614: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int32() lv1638 = T.match_buffer(p_lv1638, (1, 32, n, 128), "float16") lv1614 = T.match_buffer(p_lv1614, (1, 1, 1, n), "float16") @@ -72,7 +72,7 @@ def before(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_l @T.prim_func(private=True) def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_lv1614: T.handle, p_output0: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) n = T.int32() lv1638 = T.match_buffer(p_lv1638, (1, 32, n, 128), "float16") lv1614 = T.match_buffer(p_lv1614, (1, 1, 1, n), "float16") @@ -181,7 +181,7 @@ def test_decode_gemv_256_threads(): # fmt: off @T.prim_func(private=True) def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): p_output0_intermediate = T.sblock_alloc_buffer((22016, 4096), "float16") for i, j in T.grid(22016, 4096): @@ -201,7 +201,7 @@ def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128) @T.prim_func(private=True) def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): var_NT_matmul_intermediate_rf_local = T.sblock_alloc_buffer((16, 1, 1, 22016), "float16", scope="local") var_NT_matmul_intermediate_rf_local_1 = T.sblock_alloc_buffer((16, 1, 1, 22016), "float16", scope="local") @@ -277,7 +277,7 @@ def test_decode_gemv1(): @T.prim_func(private=True) def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): p_output0_intermediate = T.sblock_alloc_buffer((22016, 4096), "float16") for i, j in T.grid(22016, 4096): @@ -297,7 +297,7 @@ def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128) @T.prim_func(private=True) def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): var_NT_matmul_intermediate_rf_local = T.sblock_alloc_buffer((128, 1, 1, 22016), "float16", scope="local") var_NT_matmul_intermediate_rf_local_1 = T.sblock_alloc_buffer((32, 1, 1, 22016), "float16", scope="local") @@ -385,7 +385,7 @@ def test_decode_gemv2(): @T.prim_func(private=True) def before(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): p_output0_intermediate_1 = T.sblock_alloc_buffer((32000, 4096), "float16") var_NT_matmul_intermediate = T.sblock_alloc_buffer((1, 1, 32000), "float16") @@ -412,7 +412,7 @@ def before(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128) @T.prim_func(private=True) def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): var_NT_matmul_intermediate_local = T.sblock_alloc_buffer((1, 1, 32000), "float16", scope="local") var_NT_matmul_intermediate_rf_local = T.sblock_alloc_buffer((128, 1, 1, 32000), "float16", scope="local") @@ -508,7 +508,7 @@ def test_decode_gemv3(): @T.prim_func(private=True) def before(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): p_output0_intermediate_1 = T.sblock_alloc_buffer((T.int64(4096), T.int64(11008)), "float16") var_NT_matmul_intermediate = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") @@ -535,7 +535,7 @@ def before(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.B @T.prim_func(private=True) def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): var_NT_matmul_intermediate_local = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="local") var_NT_matmul_intermediate_rf_local = T.sblock_alloc_buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="local") @@ -631,7 +631,7 @@ def test_autogptq_decode_gemv(): # fmt: off @T.prim_func(private=True) def func(lv9: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv10: T.Buffer((T.int64(32), T.int64(512)), "uint32"), lv11: T.Buffer((T.int64(32), T.int64(4096)), "float16"), lv12: T.Buffer((T.int64(4096),), "uint32"), lv8: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv1613: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): decode_intermediate = T.sblock_alloc_buffer((T.int64(4096), T.int64(4096)), "float16") var_matmul_intermediate = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") @@ -675,7 +675,7 @@ def before( lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16"), ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): p_output0_intermediate_1 = T.sblock_alloc_buffer((11008, 4096), "float16") var_matmul_intermediate = T.sblock_alloc_buffer((1, 1, 4096), "float16") @@ -696,7 +696,7 @@ def before( @T.prim_func(private=True) def expected(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): var_matmul_intermediate_local = T.sblock_alloc_buffer((1, 1, 4096), "float16", scope="local") var_matmul_intermediate_rf_local = T.sblock_alloc_buffer((32, 1, 1, 4096), "float16", scope="local") @@ -781,7 +781,7 @@ def test_outer_reduction_adreno_dynamic(): # fmt: off @T.prim_func(private=True) def before(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) v = T.int64() lv612 = T.match_buffer(p_lv612, (T.int64(512), v), "uint32") lv613 = T.match_buffer(p_lv613, (T.int64(128), v), "float16") @@ -812,7 +812,7 @@ def before(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), T @T.prim_func(private=True) def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) v = T.int64() lv612 = T.match_buffer(p_lv612, (T.int64(512), v), "uint32") lv613 = T.match_buffer(p_lv613, (T.int64(128), v), "float16") @@ -942,7 +942,7 @@ def before(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "flo @T.prim_func(private=True) def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "float16"), indptr: T.Buffer((2,), "int32"), o: T.Buffer((2, 16384), "float16")): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) # with T.sblock("root"): for expert_id in T.thread_binding(2, thread="blockIdx.y"): with T.sblock("gemv_o"): diff --git a/tests/python/s_tir/dlight/test_gpu_general_reduction.py b/tests/python/s_tir/dlight/test_gpu_general_reduction.py index 3df3d2b1df2f..7c6b995b377a 100644 --- a/tests/python/s_tir/dlight/test_gpu_general_reduction.py +++ b/tests/python/s_tir/dlight/test_gpu_general_reduction.py @@ -21,7 +21,7 @@ from tvm.ir import IRModule, assert_structural_equal from tvm.s_tir import dlight as dl from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -40,7 +40,7 @@ def test_softmax_1(): class Before: @T.prim_func def main(p_lv44: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n, m = T.int64(), T.int64() lv44 = T.match_buffer(p_lv44, (T.int64(1), T.int64(32), n, m)) var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") @@ -89,7 +89,7 @@ def main(p_lv44: T.handle, p_output0: T.handle): class After: @T.prim_func def main(p_lv44: T.handle, p_output0: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) n, m = T.int64(), T.int64() lv44 = T.match_buffer(p_lv44, (T.int64(1), T.int64(32), n, m)) var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") @@ -182,7 +182,7 @@ def main(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), T_sof class After: @T.prim_func def main(A: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) # with T.sblock("root"): T_softmax_maxelem_shared = T.sblock_alloc_buffer((T.int64(1), T.int64(1)), scope="shared") T_softmax_expsum_shared = T.sblock_alloc_buffer((T.int64(1), T.int64(1)), scope="shared") @@ -268,7 +268,7 @@ def main(input: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), " class After: @T.prim_func def main(input: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), "float32"), T_softmax_norm: T.Buffer((T.int64(1), T.int64(4), T.int64(32), T.int64(8192)), "float32")): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) # with T.sblock("root"): T_softmax_maxelem_shared = T.sblock_alloc_buffer((T.int64(1), T.int64(4), T.int64(8192)), scope="shared") T_softmax_expsum_shared = T.sblock_alloc_buffer((T.int64(1), T.int64(4), T.int64(8192)), scope="shared") @@ -320,7 +320,7 @@ def test_layer_norm(): class Before: @T.prim_func def main(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") @@ -357,7 +357,7 @@ def main(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: class After: @T.prim_func def main(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) n = T.int64() lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") @@ -398,7 +398,7 @@ def test_rms_norm(): class Before: @T.prim_func def main(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle): - T.func_attr({"op_pattern": 4, "tir.noalias": True}) + T.func_attr({"op_pattern": 4, "tirx.noalias": True}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16") @@ -423,7 +423,7 @@ def main(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm class After: @T.prim_func def main(var_A: T.handle, B: T.Buffer((T.int64(4096),), "float16"), var_rms_norm: T.handle): - T.func_attr({"op_pattern": 4, "tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"op_pattern": 4, "tirx.is_scheduled": True, "tirx.noalias": True}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), n, T.int64(4096)), "float16") rms_norm_1 = T.match_buffer(var_rms_norm, (T.int64(1), n, T.int64(4096)), "float16") @@ -459,7 +459,7 @@ def test_group_norm(): class Before: @T.prim_func def main(A: T.Buffer((1, 2048), "float32"), B: T.Buffer((2048,), "float32"), C: T.Buffer((2048,), "float32"), T_reshape: T.Buffer((1, 2048), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) T_reshape_1 = T.sblock_alloc_buffer((1, 32, 64)) A_red_temp_v0 = T.sblock_alloc_buffer((1, 32)) A_red_temp_v1 = T.sblock_alloc_buffer((1, 32)) @@ -513,7 +513,7 @@ def main(A: T.Buffer((1, 2048), "float32"), B: T.Buffer((2048,), "float32"), C: class After: @T.prim_func def main(A: T.Buffer((1, 2048), "float32"), B: T.Buffer((2048,), "float32"), C: T.Buffer((2048,), "float32"), T_reshape: T.Buffer((1, 2048), "float32")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): A_red_temp_v0_shared = T.sblock_alloc_buffer((1, 32), scope="shared") A_red_temp_v1_shared = T.sblock_alloc_buffer((1, 32), scope="shared") @@ -550,7 +550,7 @@ def test_logsumexp(): class Before: @T.prim_func def compute_lse(var_A: T.handle, var_blocked_lse: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) batch_size = T.int64(is_size_var=True) vocab_size = T.int64(is_size_var=True) num_chunks = T.int64(is_size_var=True) @@ -596,7 +596,7 @@ def compute_lse(var_A: T.handle, var_blocked_lse: T.handle): class After: @T.prim_func def compute_lse(var_A: T.handle, var_blocked_lse: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) batch_size, vocab_size = T.int64(is_size_var=True), T.int64(is_size_var=True) A = T.match_buffer(var_A, (batch_size, vocab_size)) num_chunks = T.int64(is_size_var=True) diff --git a/tests/python/s_tir/dlight/test_gpu_low_batch_gemv.py b/tests/python/s_tir/dlight/test_gpu_low_batch_gemv.py index 1d1cddb82dd8..83099682b036 100644 --- a/tests/python/s_tir/dlight/test_gpu_low_batch_gemv.py +++ b/tests/python/s_tir/dlight/test_gpu_low_batch_gemv.py @@ -19,7 +19,7 @@ import tvm.testing from tvm.s_tir import dlight as dl -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -28,7 +28,7 @@ def test_batch_decode_gemv(): @T.prim_func(private=True) def before(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T.Buffer((T.int64(4096), T.int64(896)), "float16"), p_lv807: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": True, "tir.HoistIfThenElseExprWithBlock": 1}) + T.func_attr({"tirx.noalias": True, "tirx.HoistIfThenElseExprWithBlock": 1}) batch_size = T.int64() lv807 = T.match_buffer(p_lv807, (batch_size, T.int64(1), T.int64(28672)), "float16") NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(4096)), "float16") @@ -58,7 +58,7 @@ def before(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T.B @T.prim_func(private=True) def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"), lv430: T.Buffer((T.int64(4096), T.int64(896)), "float16"), p_lv807: T.handle, p_output0: T.handle): - T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.HoistIfThenElseExprWithBlock": 1, "tirx.is_scheduled": True, "tirx.noalias": True}) batch_size = T.int64() lv807 = T.match_buffer(p_lv807, (batch_size, T.int64(1), T.int64(28672)), "float16") NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size, T.int64(1), T.int64(4096)), "float16") @@ -156,7 +156,7 @@ def test_batch_gemv(): # fmt: off @T.prim_func(private=True) def before(var_A: T.handle, B: T.Buffer((T.int64(N), T.int64(K)), "float16"), var_NT_matmul: T.handle): - T.func_attr({"tir.noalias": True, "tir.HoistIfThenElseExprWithBlock": 1}) + T.func_attr({"tirx.noalias": True, "tirx.HoistIfThenElseExprWithBlock": 1}) batch_size = T.int64() A = T.match_buffer(var_A, (batch_size, T.int64(1), T.int64(K)), "float16") NT_matmul = T.match_buffer(var_NT_matmul, (batch_size, T.int64(1), T.int64(N)), "float16") @@ -172,7 +172,7 @@ def before(var_A: T.handle, B: T.Buffer((T.int64(N), T.int64(K)), "float16"), va @T.prim_func(private=True) def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), var_NT_matmul: T.handle): - T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1, "tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.HoistIfThenElseExprWithBlock": 1, "tirx.is_scheduled": True, "tirx.noalias": True}) batch_size = T.int64() A = T.match_buffer(var_A, (batch_size, T.int64(1), T.int64(4096)), "float16") NT_matmul = T.match_buffer(var_NT_matmul, (batch_size, T.int64(1), T.int64(4096)), "float16") @@ -257,7 +257,7 @@ def test_reduction_symbolic_var(): # fmt: off @T.prim_func(private=True) def before(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) kv_seq_len = T.int64() A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), kv_seq_len)) B = T.match_buffer(var_B, (T.int64(1), T.int64(32), kv_seq_len, T.int64(128))) @@ -280,7 +280,7 @@ def before(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int def test_small_spatial_axis(): @T.prim_func(private=True) def func(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16"), var_C: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) batch_size = T.int64() A = T.match_buffer(var_A, (batch_size, T.int64(4096)), "float16") C = T.match_buffer(var_C, (batch_size, T.int64(8)), "float16") @@ -296,7 +296,7 @@ def func(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16"), v # fmt: off @T.prim_func(private=True) def expected(var_A: T.handle, B: T.Buffer((T.int64(8), T.int64(4096)), "float16"), var_C: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) batch_size = T.int64() A = T.match_buffer(var_A, (batch_size, T.int64(4096)), "float16") C = T.match_buffer(var_C, (batch_size, T.int64(8)), "float16") @@ -414,7 +414,7 @@ def before( @T.prim_func(private=True) def expected(B0: T.Buffer((512, 6144), "uint32"), B1: T.Buffer((128, 6144), "float16"), var_A: T.handle, var_C: T.handle): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) batch_size = T.int32() A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") C = T.match_buffer(var_C, (batch_size, 1, 6144), "float16") diff --git a/tests/python/s_tir/dlight/test_gpu_matmul.py b/tests/python/s_tir/dlight/test_gpu_matmul.py index b8262a1f039f..0c1aefd4c8d0 100644 --- a/tests/python/s_tir/dlight/test_gpu_matmul.py +++ b/tests/python/s_tir/dlight/test_gpu_matmul.py @@ -19,7 +19,7 @@ import tvm import tvm.testing from tvm.s_tir import dlight as dl -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -39,7 +39,7 @@ def before(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "f @T.prim_func(private=True) def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_matmul: T.handle): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) m = T.int64() inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096))) matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096))) @@ -131,7 +131,7 @@ def func(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_matmul @T.prim_func(private=True) def expected(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_matmul: T.handle): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) m = T.int32() inp0 = T.match_buffer(var_inp0, (1, m, 4096)) matmul = T.match_buffer(var_matmul, (1, m, 4096)) @@ -236,7 +236,7 @@ def before(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T. @T.prim_func(private=True) def expected(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(32), T.int64(4096)), "float32")): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) # with T.sblock("root"): var_matmul_intermediate_reindex_local = T.sblock_alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="local") A_reindex_shared = T.sblock_alloc_buffer((T.int64(1), T.int64(32), T.int64(4096)), scope="shared") @@ -313,7 +313,7 @@ def test_skip_gemv(): # fmt: off @T.prim_func(private=True) def before(W: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), S: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), A: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), C: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), Out: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) var_decode_intermediate = T.sblock_alloc_buffer((T.int64(4096), T.int64(4096))) var_matmul_intermediate = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(4096))) for i, j in T.grid(T.int64(4096), T.int64(4096)): @@ -351,7 +351,7 @@ def test_output_fp32(): # fmt: off @T.prim_func(private=True) def before(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Buffer((T.int64(4096), T.int64(128)), "float16"), p_lv48: T.handle, lv13_1: T.Buffer((T.int64(4096),), "float16"), p_lv3: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(4096)), "float16") lv3 = T.match_buffer(p_lv3, (T.int64(1), n, T.int64(4096)), "float16") @@ -403,7 +403,7 @@ def before(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Buff @T.prim_func(private=True) def expected(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Buffer((T.int64(4096), T.int64(128)), "float16"), p_lv48: T.handle, lv13_1: T.Buffer((T.int64(4096),), "float16"), p_lv3: T.handle, p_output0: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) n = T.int64() lv48 = T.match_buffer(p_lv48, (T.int64(1), n, T.int64(4096)), "float16") lv3 = T.match_buffer(p_lv3, (T.int64(1), n, T.int64(4096)), "float16") @@ -485,7 +485,7 @@ def test_inline_consumer_chain(): # fmt: off @T.prim_func(private=True) def before(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "float16"), p_lv52: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() lv26 = T.match_buffer(p_lv26, (n, T.int64(2048)), "float16") lv52 = T.match_buffer(p_lv52, (T.int64(1), n, T.int64(2048))) @@ -537,7 +537,7 @@ def before(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "floa @T.prim_func(private=True) def expected(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "float16"), p_lv52: T.handle, p_output0: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) n = T.int64() lv26 = T.match_buffer(p_lv26, (n, T.int64(2048)), "float16") lv52 = T.match_buffer(p_lv52, (T.int64(1), n, T.int64(2048))) @@ -631,7 +631,7 @@ def before(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "f @T.prim_func(private=True) def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), "float32"), var_matmul: T.handle): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) m = T.int64() inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096))) matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096))) @@ -712,7 +712,7 @@ def test_fused_dequant_matmul_android(): # fmt: off @T.prim_func(private=True) def before(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm130: T.handle, transformer_h_0_attn_c_attn_bias3: T.Buffer((T.int64(12288),), "float16"), p_output0: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) seq_len = T.int64() rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, T.int64(4096)), "float16") T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16") @@ -749,7 +749,7 @@ def before(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T.B @T.prim_func(private=True) def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"), lv453: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm130: T.handle, transformer_h_0_attn_c_attn_bias3: T.Buffer((T.int64(12288),), "float16"), p_output0: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) seq_len = T.int64() rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len, T.int64(4096)), "float16") T_add_intermediate_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len, T.int64(12288)), "float16") diff --git a/tests/python/s_tir/dlight/test_gpu_matmul_tensorize.py b/tests/python/s_tir/dlight/test_gpu_matmul_tensorize.py index 1fb8ba9c125b..2c6d780c69ce 100644 --- a/tests/python/s_tir/dlight/test_gpu_matmul_tensorize.py +++ b/tests/python/s_tir/dlight/test_gpu_matmul_tensorize.py @@ -19,7 +19,7 @@ import tvm import tvm.testing from tvm.s_tir import dlight as dl -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -27,7 +27,7 @@ def test_matmul_tensorize(): # fmt: off @T.prim_func(private=True) def before(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float16"), compute: T.Buffer((256, 256), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i, j, k in T.grid(256, 256, 256): with T.sblock("compute"): @@ -40,7 +40,7 @@ def before(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float16" @T.prim_func(private=True) def expected(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float16"), compute: T.Buffer((256, 256), "float16")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): X_reindex_shared_dyn = T.sblock_alloc_buffer((1, 256, 256), "float16", scope="shared.dyn") W_reindex_shared_dyn = T.sblock_alloc_buffer((1, 256, 256), "float16", scope="shared.dyn") @@ -77,7 +77,7 @@ def expected(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float1 v2 = T.axis.spatial(256, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) T.reads(X[v1, v2]) T.writes(X_reindex_shared_dyn[v0, v1, v2]) - T.sblock_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tirx.manifest_shared_memory_local_stage": 1}) X_reindex_shared_dyn[v0, v1, v2] = X[v1, v2] for ax0_ax1_fused_0 in range(4): for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): @@ -89,7 +89,7 @@ def expected(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float1 v2 = T.axis.spatial(256, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) T.reads(W[v1, v2]) T.writes(W_reindex_shared_dyn[v0, v1, v2]) - T.sblock_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tirx.manifest_shared_memory_local_stage": 1}) W_reindex_shared_dyn[v0, v1, v2] = W[v1, v2] for ax3_0_1 in range(4, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): for ax0_0 in T.unroll(2): @@ -166,7 +166,7 @@ def test_matmul_tensorize_too_small(): # fmt: off @T.prim_func(private=True) def before(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) m = T.int32() X = T.match_buffer(var_X, (m, 256), "float16") compute = T.match_buffer(var_compute, (m, 15)) @@ -182,7 +182,7 @@ def before(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T.ha @T.prim_func(private=True) def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) m = T.int32() X = T.match_buffer(var_X, (m, 256), "float16") compute = T.match_buffer(var_compute, (m, 15)) @@ -262,7 +262,7 @@ def test_matmul_tensorize_epilogue(): # fmt: off @T.prim_func(private=True) def before(lv686: T.Buffer((T.int32(4096), T.int32(256)), "uint32"), lv687: T.Buffer((T.int32(4096), T.int32(64)), "float16"), p_lv42: T.handle, p_lv3: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int32() lv42 = T.match_buffer(p_lv42, (T.int32(1), n, T.int32(2048)), "float16") lv3 = T.match_buffer(p_lv3, (T.int32(1), n, T.int32(4096)), "float16") @@ -300,7 +300,7 @@ def before(lv686: T.Buffer((T.int32(4096), T.int32(256)), "uint32"), lv687: T.Bu @T.prim_func(private=True) def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64), "float16"), p_lv42: T.handle, p_lv3: T.handle, p_output0: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) n = T.int32() lv42 = T.match_buffer(p_lv42, (1, n, 2048), "float16") lv3 = T.match_buffer(p_lv3, (1, n, 4096), "float16") @@ -341,7 +341,7 @@ def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64), v2 = T.axis.spatial(2048, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) T.reads(lv42[v0, v1, v2]) T.writes(lv42_reindex_pad_shared_dyn[v0, v1, v2]) - T.sblock_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tirx.manifest_shared_memory_local_stage": 1}) lv42_reindex_pad_shared_dyn[v0, v1, v2] = T.if_then_else(v1 < n, lv42[v0, v1, v2], T.float16(0)) for ax0_ax1_fused_0 in range(4): for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): @@ -353,7 +353,7 @@ def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64), v2 = T.axis.spatial(2048, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) T.reads(lv686[v1, v2 // 8], lv687[v1, v2 // 32]) T.writes(p_output0_intermediate_1_reindex_shared_dyn[v0, v1, v2]) - T.sblock_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tirx.manifest_shared_memory_local_stage": 1}) p_output0_intermediate_1_reindex_shared_dyn[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv686[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv687[v1, v2 // 32] for ax3_0_1 in range(4, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): for ax0_0 in T.unroll(2): @@ -430,7 +430,7 @@ def test_matmul_int8_tensorize(): # fmt: off @T.prim_func(private=True) def before(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), compute: T.Buffer((256, 256), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): for i, j, r in T.grid(256, 256, 256): with T.sblock("compute"): @@ -443,7 +443,7 @@ def before(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), com @T.prim_func(private=True) def expected(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), compute: T.Buffer((256, 256), "int32")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): X_reindex_shared_dyn = T.sblock_alloc_buffer((1, 256, 256), "int8", scope="shared.dyn") W_reindex_shared_dyn = T.sblock_alloc_buffer((1, 256, 256), "int8", scope="shared.dyn") @@ -480,7 +480,7 @@ def expected(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), c v2 = T.axis.spatial(256, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) T.reads(X[v1, v2]) T.writes(X_reindex_shared_dyn[v0, v1, v2]) - T.sblock_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tirx.manifest_shared_memory_local_stage": 1}) X_reindex_shared_dyn[v0, v1, v2] = X[v1, v2] for ax0_ax1_fused_0 in range(1): for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): @@ -492,7 +492,7 @@ def expected(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), c v2 = T.axis.spatial(256, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) T.reads(W[v1, v2]) T.writes(W_reindex_shared_dyn[v0, v1, v2]) - T.sblock_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tirx.manifest_shared_memory_local_stage": 1}) W_reindex_shared_dyn[v0, v1, v2] = W[v1, v2] for ax3_0_1 in T.serial(1, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): for ax0_0 in T.unroll(2): @@ -568,7 +568,7 @@ def test_matmul_int8_tensorize_3d2d_dyn(): # fmt: off @T.prim_func(private=True) def before(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T.handle): - T.func_attr({"op_pattern": 4, "tir.noalias": True}) + T.func_attr({"op_pattern": 4, "tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (1, m, 22016), "int8") matmul_1 = T.match_buffer(var_matmul, (1, m, 4096), "int32") @@ -584,7 +584,7 @@ def before(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T.ha @T.prim_func(private=True) def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T.handle): - T.func_attr({"op_pattern": 4, "tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"op_pattern": 4, "tirx.is_scheduled": True, "tirx.noalias": True}) m = T.int32() A = T.match_buffer(var_A, (1, m, 22016), "int8") matmul_1 = T.match_buffer(var_matmul, (1, m, 4096), "int32") @@ -624,7 +624,7 @@ def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T. v2 = T.axis.spatial(22016, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) T.reads(A[v0, v1, v2]) T.writes(A_reindex_pad_shared_dyn[v0, v1, v2]) - T.sblock_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tirx.manifest_shared_memory_local_stage": 1}) A_reindex_pad_shared_dyn[v0, v1, v2] = T.if_then_else(v1 < m, A[v0, v1, v2], T.int8(0)) for ax0_ax1_fused_0 in range(1): for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): @@ -636,7 +636,7 @@ def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T. v2 = T.axis.spatial(22016, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) T.reads(B[v1, v2]) T.writes(B_reindex_shared_dyn[v0, v1, v2]) - T.sblock_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tirx.manifest_shared_memory_local_stage": 1}) B_reindex_shared_dyn[v0, v1, v2] = B[v1, v2] for ax3_0_1 in T.serial(1, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): for ax0_0 in T.unroll(2): @@ -730,7 +730,7 @@ def before( @T.prim_func(private=True) def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.handle): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) batch_size = T.int32() A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16") @@ -875,7 +875,7 @@ def before( @T.prim_func(private=True) def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "float16"), var_A: T.handle, var_C: T.handle): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) batch_size = T.int32() A = T.match_buffer(var_A, (batch_size, 1, 4096), "float16") C = T.match_buffer(var_C, (batch_size, 1, 28672), "float16") diff --git a/tests/python/s_tir/dlight/test_gpu_reduction.py b/tests/python/s_tir/dlight/test_gpu_reduction.py index 39ce3e26874f..5b00f733b071 100644 --- a/tests/python/s_tir/dlight/test_gpu_reduction.py +++ b/tests/python/s_tir/dlight/test_gpu_reduction.py @@ -21,7 +21,7 @@ from tvm.ir import assert_structural_equal from tvm.s_tir import dlight as dl from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -32,7 +32,7 @@ def test_decode_gemv_1(): class Before: @T.prim_func def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): B = T.sblock_alloc_buffer((4096, 4096), "float16") for i, j in T.grid(4096, 4096): @@ -55,7 +55,7 @@ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16") class After: @T.prim_func def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T.handle): - T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.is_scheduled": True, "tirx.noalias": True}) W = T.match_buffer(W_handle, (4096, 512), "uint32") S = T.match_buffer(S_handle, (4096, 128), "float16") V = T.match_buffer(V_handle, (1, 1, 4096), "float16") @@ -107,7 +107,7 @@ def test_decode_gemv_2(): class Before: @T.prim_func def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): B = T.sblock_alloc_buffer((4096, 4096), "float16") for i, j in T.grid(4096, 4096): @@ -130,7 +130,7 @@ def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), "float16") class After: @T.prim_func def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): - T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): C_rf_local = T.sblock_alloc_buffer((16, 1, 1, 4096), "float16", scope="local") for i2_i0_i1_fused_0 in T.thread_binding(256, thread="blockIdx.x"): @@ -170,7 +170,7 @@ def test_decode_gemv_3(): class Before: @T.prim_func def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): B = T.sblock_alloc_buffer((4096, 4096), "float16") for i, j in T.grid(4096, 4096): @@ -192,7 +192,7 @@ def func(W: T.Buffer((512, 4096), "uint32"), S: T.Buffer((128, 4096), "float16") class After: @T.prim_func def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T.handle): - T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.is_scheduled": True, "tirx.noalias": True}) W = T.match_buffer(W_handle, (512, 4096), "uint32") S = T.match_buffer(S_handle, (128, 4096), "float16") V = T.match_buffer(V_handle, (1, 1, 4096), "float16") @@ -246,7 +246,7 @@ def test_decode_gemv_4(): class Before: @T.prim_func def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): B = T.sblock_alloc_buffer((4096, 4096), "float16") for i, j in T.grid(4096, 4096): @@ -269,7 +269,7 @@ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16") class After: @T.prim_func def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): - T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): C_rf_local = T.sblock_alloc_buffer((16, 1, 1, 4096), "float16", scope="local") for i2_0_i0_i1_fused_0 in T.thread_binding(32, thread="blockIdx.x"): @@ -311,7 +311,7 @@ def test_decode_gemv_sigmoid(): class Before: @T.prim_func def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), D: T.Buffer((1, 1, 4096), "float16")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): B = T.sblock_alloc_buffer((4096, 4096), "float16") C = T.sblock_alloc_buffer((1, 1, 4096), "float16") @@ -340,7 +340,7 @@ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16") class After: @T.prim_func def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, D_handle: T.handle): - T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.is_scheduled": True, "tirx.noalias": True}) W = T.match_buffer(W_handle, (4096, 512), "uint32") S = T.match_buffer(S_handle, (4096, 128), "float16") V = T.match_buffer(V_handle, (1, 1, 4096), "float16") @@ -400,7 +400,7 @@ def test_decode_gemv_1_fp32(): class Before: @T.prim_func def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16"), V: T.Buffer((1, 1, 4096), "float16"), C: T.Buffer((1, 1, 4096), "float16")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): B = T.sblock_alloc_buffer((4096, 4096), "float16") C_fp32 = T.sblock_alloc_buffer((1, 1, 4096), "float32") @@ -429,7 +429,7 @@ def func(W: T.Buffer((4096, 512), "uint32"), S: T.Buffer((4096, 128), "float16") class After: @T.prim_func def func(W_handle: T.handle, S_handle: T.handle, V_handle: T.handle, C_handle: T.handle): - T.func_attr({"global_symbol": "main", "tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.is_scheduled": True, "tirx.noalias": True}) W = T.match_buffer(W_handle, (4096, 512), "uint32") S = T.match_buffer(S_handle, (4096, 128), "float16") V = T.match_buffer(V_handle, (1, 1, 4096), "float16") @@ -488,7 +488,7 @@ def test_reduction_no_spatial(): class Before: @T.prim_func def main(A: T.Buffer((1, 1, 4096), "float16"), B: T.Buffer((4096,), "float16"), rms_norm: T.Buffer((1, 4096), "float16")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) Ared_temp = T.sblock_alloc_buffer((1, 1)) for ax0 in range(4096): with T.sblock("Ared_temp"): @@ -505,7 +505,7 @@ def main(A: T.Buffer((1, 1, 4096), "float16"), B: T.Buffer((4096,), "float16"), class After: @T.prim_func def main(A_handle: T.handle, B_handle: T.handle, rms_norm_handle: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) A = T.match_buffer(A_handle, (1, 1, 4096), "float16") B = T.match_buffer(B_handle, (4096,), "float16") rms_norm = T.match_buffer(rms_norm_handle, (1, 4096), "float16") @@ -561,7 +561,7 @@ def test_spatial_inner_no_broadcasting(): class Module: @T.prim_func def main(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) p_output0_intermediate_1 = T.sblock_alloc_buffer((11008, 4096), "float16") var_matmul_intermediate = T.sblock_alloc_buffer((1, 1, 4096), "float16") for i, j in T.grid(11008, 4096): @@ -589,7 +589,7 @@ def main(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), " class Expected: @T.prim_func def main(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) var_matmul_intermediate_local = T.sblock_alloc_buffer((1, 1, 4096), "float16", scope="local") var_matmul_intermediate_rf_local = T.sblock_alloc_buffer((16, 1, 1, 4096), "float16", scope="local") for ax0_fused_0 in T.thread_binding(256, thread="blockIdx.x"): @@ -640,7 +640,7 @@ def test_spatial_inner_broadcasting(): class Module: @T.prim_func def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) temp_local = T.sblock_alloc_buffer((256,)) for j in T.serial(256): for k in T.serial(256): @@ -662,7 +662,7 @@ def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")) class Expected: @T.prim_func def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) temp_local_shared = T.sblock_alloc_buffer((256,), scope="shared") temp_local_rf_local = T.sblock_alloc_buffer((16, 256), scope="local") for ax0_fused_0 in T.thread_binding(16, thread="blockIdx.x"): @@ -715,7 +715,7 @@ def test_reduction_inner_no_broadcasting(): class Module: @T.prim_func def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) temp_local = T.sblock_alloc_buffer((256,)) for i in T.serial(256): for k in T.serial(256): @@ -737,7 +737,7 @@ def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): class Expected: @T.prim_func def main(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256,), "float32")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): temp_local_local = T.sblock_alloc_buffer((256,), scope="local") temp_local_rf_local = T.sblock_alloc_buffer((256, 256), scope="local") @@ -783,7 +783,7 @@ def test_reduction_inner_no_broadcasting2(): class Module: @T.prim_func def main(lv9: T.Buffer((2560, 320), "uint32"), lv10: T.Buffer((2560, 80), "float16"), lv1: T.Buffer((1, 2560), "float16"), p_output0_intermediate: T.Buffer((1, 2560), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): p_output0_intermediate_1 = T.sblock_alloc_buffer((2560, 2560), "float16") var_matmul_intermediate = T.sblock_alloc_buffer((1, 2560), "float16") @@ -812,7 +812,7 @@ def main(lv9: T.Buffer((2560, 320), "uint32"), lv10: T.Buffer((2560, 80), "float class Expected: @T.prim_func def main(lv9: T.Buffer((2560, 320), "uint32"), lv10: T.Buffer((2560, 80), "float16"), lv1: T.Buffer((1, 2560), "float16"), p_output0_intermediate: T.Buffer((1, 2560), "float32")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): var_matmul_intermediate_local = T.sblock_alloc_buffer((1, 2560), "float16", scope="local") var_matmul_intermediate_rf_local = T.sblock_alloc_buffer((16, 1, 2560), "float16", scope="local") @@ -865,7 +865,7 @@ def test_reduction_inner_spatial_choose_perfect_factor(): class Module: @T.prim_func def main(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(100)), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16") B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(100)), "float16") @@ -882,7 +882,7 @@ def main(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64 class Expected: @T.prim_func def main(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(100)), "float16")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) n = T.int64() A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), n), "float16") B = T.match_buffer(var_B, (T.int64(1), T.int64(32), n, T.int64(100)), "float16") @@ -933,7 +933,7 @@ def test_repeat_transpose_gemv(): class Before: @T.prim_func(private=True) def fused_relax_repeat_relax_permute_dims_relax_matmul1(p_lv716: T.handle, p_astype66: T.handle, var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) kv_seq_len = T.int64() lv716 = T.match_buffer(p_lv716, (T.int64(1), kv_seq_len, T.int64(8), T.int64(128)), "float16") astype66 = T.match_buffer(p_astype66, (T.int64(1), T.int64(32), T.int64(1), kv_seq_len), "float16") @@ -964,7 +964,7 @@ def fused_relax_repeat_relax_permute_dims_relax_matmul1(p_lv716: T.handle, p_ast class Expected: @T.prim_func(private=True) def fused_relax_repeat_relax_permute_dims_relax_matmul1(p_lv716: T.handle, p_astype66: T.handle, var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) kv_seq_len = T.int64() lv716 = T.match_buffer(p_lv716, (T.int64(1), kv_seq_len, T.int64(8), T.int64(128)), "float16") astype66 = T.match_buffer(p_astype66, (T.int64(1), T.int64(32), T.int64(1), kv_seq_len), "float16") @@ -1019,7 +1019,7 @@ def main( B: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_C: T.handle, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) vocab_size = T.int64() A = T.match_buffer(var_A, (T.int64(4096), vocab_size), "float16") C = T.match_buffer(var_C, (T.int64(1), T.int64(1), vocab_size)) @@ -1046,7 +1046,7 @@ def main( class Expected: @T.prim_func(private=True) def main(var_A: T.handle, B: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), var_C: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) vocab_size = T.int64() A = T.match_buffer(var_A, (T.int64(4096), vocab_size), "float16") C = T.match_buffer(var_C, (T.int64(1), T.int64(1), vocab_size)) @@ -1099,7 +1099,7 @@ def test_gemv_output_one_element(): class Before: @T.prim_func(private=True) def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight: T.Buffer((T.int64(1), T.int64(2048)), "float16"), out: T.Buffer((T.int64(1), T.int64(1)), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) NT_matmul_intermediate = T.sblock_alloc_buffer((T.int64(1), T.int64(1)), "float16") for i0, i1, k in T.grid(T.int64(1), T.int64(1), T.int64(2048)): with T.sblock("NT_matmul"): @@ -1117,7 +1117,7 @@ def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight: T.Buffer(( class Expected: @T.prim_func(private=True) def main(A: T.Buffer((T.int64(1), T.int64(2048)), "float16"), weight: T.Buffer((T.int64(1), T.int64(2048)), "float16"), out: T.Buffer((T.int64(1), T.int64(1)), "float16")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) NT_matmul_intermediate_shared = T.sblock_alloc_buffer((T.int64(1), T.int64(1)), "float16", scope="shared") NT_matmul_intermediate_rf_local = T.sblock_alloc_buffer((T.int64(1024), T.int64(1), T.int64(1)), "float16", scope="local") for ax0_fused in T.thread_binding(T.int64(1), thread="blockIdx.x"): @@ -1161,7 +1161,7 @@ def test_no_reduction_loop_check(): class Before: @T.prim_func(private=True) def matmul(lv43: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16"), lv44: T.Buffer((T.int64(1), T.int64(1), T.int64(1)), "float16"), matmul: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), "float16")): - T.func_attr({"op_pattern": 4, "tir.noalias": True}) + T.func_attr({"op_pattern": 4, "tirx.noalias": True}) # with T.sblock("root"): for i0, i1, i2, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(1)): with T.sblock("matmul"): diff --git a/tests/python/s_tir/dlight/test_gpu_rmsnorm.py b/tests/python/s_tir/dlight/test_gpu_rmsnorm.py index 3d3825730c62..e565a672f60d 100644 --- a/tests/python/s_tir/dlight/test_gpu_rmsnorm.py +++ b/tests/python/s_tir/dlight/test_gpu_rmsnorm.py @@ -20,7 +20,7 @@ from tvm.ir import IRModule, assert_structural_equal from tvm.s_tir import dlight as dl from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -39,7 +39,7 @@ def test_rms_norm_with_casting(): class Before: @T.prim_func def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int32() data = T.match_buffer(var_data, (1, n, 4096), "float16") T_cast = T.match_buffer(var_T_cast, (1, n, 4096), "float16") @@ -99,7 +99,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T class After: @T.prim_func def main(var_data: T.handle, weight: T.Buffer((4096,), "float16"), var_T_cast: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) n = T.int32() data = T.match_buffer(var_data, (1, n, 4096), "float16") T_cast = T.match_buffer(var_T_cast, (1, n, 4096), "float16") @@ -171,7 +171,7 @@ def test_rms_norm_without_casting(): class Before: @T.prim_func def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int32() data = T.match_buffer(var_data, (1, n, 4096)) T_cast = T.match_buffer(var_T_cast, (1, n, 4096)) @@ -217,7 +217,7 @@ def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T class After: @T.prim_func def main(var_data: T.handle, weight: T.Buffer((4096,), "float32"), var_T_cast: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) n = T.int32() data = T.match_buffer(var_data, (1, n, 4096)) T_cast = T.match_buffer(var_T_cast, (1, n, 4096)) diff --git a/tests/python/s_tir/dlight/test_gpu_transpose.py b/tests/python/s_tir/dlight/test_gpu_transpose.py index 7463ef5fdabb..38f9bd34478c 100644 --- a/tests/python/s_tir/dlight/test_gpu_transpose.py +++ b/tests/python/s_tir/dlight/test_gpu_transpose.py @@ -20,7 +20,7 @@ from tvm.ir import IRModule, assert_structural_equal from tvm.s_tir import dlight as dl from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -39,7 +39,7 @@ def test_transpose(): class Before: @T.prim_func def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "float32"), T_transpose: T.Buffer((T.int64(4096), T.int64(512)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(4096), T.int64(512)): with T.sblock("T_transpose"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) @@ -49,7 +49,7 @@ def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "float32"), T_tr class After: @T.prim_func def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "float32"), T_transpose: T.Buffer((T.int64(4096), T.int64(512)), "float32")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): rxplaceholder_shared = T.sblock_alloc_buffer((T.int64(512), T.int64(4096)), scope="shared") for ax0_0_0 in T.thread_binding(T.int64(512), thread="blockIdx.y", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): @@ -85,7 +85,7 @@ def test_decode_transpose(): class Before: @T.prim_func def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) decode = T.sblock_alloc_buffer((T.int64(4096), T.int64(4096))) for i, j in T.grid(T.int64(4096), T.int64(4096)): with T.sblock("decode"): @@ -104,7 +104,7 @@ def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), rxpla class After: @T.prim_func def main(rxplaceholder: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), rxplaceholder_1: T.Buffer((T.int64(128), T.int64(4096)), "uint32"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float32")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) decode_shared = T.sblock_alloc_buffer((T.int64(4096), T.int64(4096)), scope="shared") for ax0_0_0 in T.thread_binding(T.int64(64), thread="blockIdx.y", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): for ax1_0 in T.thread_binding(T.int64(256), thread="blockIdx.x"): @@ -139,7 +139,7 @@ def test_decode_int3_transpose(): class Before: @T.prim_func def main(A: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), B: T.Buffer((T.int64(103), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) decode_1 = T.sblock_alloc_buffer((T.int64(4096), T.int64(4096)), "float16") for i, j in T.grid(T.int64(4096), T.int64(4096)): with T.sblock("decode"): @@ -158,7 +158,7 @@ def main(A: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), B: T.Buffer((T.in class After: @T.prim_func def main(A: T.Buffer((T.int64(412), T.int64(4096)), "uint32"), B: T.Buffer((T.int64(103), T.int64(4096)), "float16"), T_transpose: T.Buffer((T.int64(4096), T.int64(4096)), "float16")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): decode_1_shared = T.sblock_alloc_buffer((T.int64(4096), T.int64(4096)), "float16", scope="shared") for ax0_0_0 in T.thread_binding(T.int64(52), thread="blockIdx.y", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): diff --git a/tests/python/s_tir/dlight/test_primitives.py b/tests/python/s_tir/dlight/test_primitives.py index 5426d15a9611..a1cdb1936a62 100644 --- a/tests/python/s_tir/dlight/test_primitives.py +++ b/tests/python/s_tir/dlight/test_primitives.py @@ -19,12 +19,12 @@ import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func def main(p0: T.Buffer((), "int32"), T_stack: T.Buffer((T.int64(3),), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): compile_engine_const = T.sblock_alloc_buffer((), "int32") compile_engine_const_1 = T.sblock_alloc_buffer((), "int32") diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_arg_info.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_arg_info.py index 2d37640cd616..86a3757c8985 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_arg_info.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_arg_info.py @@ -17,7 +17,7 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring from tvm.s_tir.meta_schedule.arg_info import ArgInfo, TensorInfo -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument # fmt: off diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_builder.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_builder.py index 9586f684cda1..1421e775cc71 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_builder.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_builder.py @@ -33,7 +33,7 @@ LocalBuilder, PyBuilder, ) -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @@ -43,7 +43,7 @@ class MatmulModule: @T.prim_func def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument - T.func_attr({"global_symbol": "matmul", "tir.noalias": True}) + T.func_attr({"global_symbol": "matmul", "tirx.noalias": True}) A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") @@ -61,7 +61,7 @@ class MatmulReluModule: def matmul_relu( # pylint: disable=no-self-argument a: T.handle, b: T.handle, d: T.handle ) -> None: - T.func_attr({"global_symbol": "matmul_relu", "tir.noalias": True}) + T.func_attr({"global_symbol": "matmul_relu", "tirx.noalias": True}) A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") D = T.match_buffer(d, (1024, 1024), "float32") @@ -84,7 +84,7 @@ class BatchMatmulModule: def batch_matmul( # pylint: disable=no-self-argument a: T.handle, b: T.handle, c: T.handle ) -> None: - T.func_attr({"global_symbol": "batch_matmul", "tir.noalias": True}) + T.func_attr({"global_symbol": "batch_matmul", "tirx.noalias": True}) A = T.match_buffer(a, [16, 128, 128]) B = T.match_buffer(b, [16, 128, 128]) C = T.match_buffer(c, [16, 128, 128]) diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_cost_model.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_cost_model.py index 57ee5db6f07d..0f1c91ad0d88 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_cost_model.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_cost_model.py @@ -35,7 +35,7 @@ from tvm.s_tir.meta_schedule.tune_context import TuneContext from tvm.s_tir.meta_schedule.utils import derived_object from tvm.s_tir.schedule.schedule import Schedule -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @@ -43,7 +43,7 @@ class Matmul: @T.prim_func def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") @@ -59,7 +59,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s class FullModule: @T.prim_func def main(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for ax0, ax1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_full"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_database.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_database.py index 4eaea25b3488..f8421600a59f 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_database.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_database.py @@ -27,12 +27,12 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.ir.module import IRModule from tvm.s_tir import Schedule from tvm.s_tir import meta_schedule as ms from tvm.s_tir.meta_schedule.database import TuningRecord, Workload -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -58,7 +58,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: class MatmulRelu: @T.prim_func def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-self-argument - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (16, 16), "float32") D = T.match_buffer(d, (16, 16), "float32") diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py index 24657f5ba8d5..fb897b99a735 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py @@ -24,9 +24,9 @@ import tvm import tvm.testing -from tvm import s_tir, te, tir +from tvm import s_tir, te, tirx from tvm.s_tir import meta_schedule as ms -from tvm.script import tir as T +from tvm.script import tirx as T N_FEATURES = 164 @@ -38,7 +38,7 @@ def matmul( C: T.Buffer((512, 512), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") for i0, i1, i2 in T.grid(512, 512, 512): @@ -54,13 +54,13 @@ def matmul( # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument # fmt: off -# from tvm.script import tir as T +# from tvm.script import tirx as T @tvm.script.ir_module class LayoutTransform: @T.prim_func def main(placeholder: T.Buffer((1, 16, 7, 7, 32), "float32"), placeholder_1: T.Buffer((25088,), "float32"), T_layout_trans: T.Buffer((1, 1, 7, 7, 512), "float32")) -> None: # function attr dict - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) # body # with T.sblock("root") for i0_i1_i2_i3_i4_fused in T.parallel(25088, annotations={"pragma_auto_unroll_max_step":64, "pragma_unroll_explicit":1}): @@ -783,8 +783,8 @@ def _create_schedule(): _, b_j = sch.split(b_ij, factors=[None, 16]) # outer: 8 sch.bind(b_j, "threadIdx.x") # auto unroll - sch.annotate(i0_j0, "pragma_auto_unroll_max_step", tir.IntImm("int32", 1024)) - sch.annotate(i0_j0, "pragma_unroll_explicit", tir.IntImm("int32", 1)) + sch.annotate(i0_j0, "pragma_auto_unroll_max_step", tirx.IntImm("int32", 1024)) + sch.annotate(i0_j0, "pragma_unroll_explicit", tirx.IntImm("int32", 1)) return sch extractor = ms.feature_extractor.PerStoreFeature() diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_measure_callback.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_measure_callback.py index b6fa306ce06d..d386bbad4fc2 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_measure_callback.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_measure_callback.py @@ -23,7 +23,7 @@ import tvm from tvm.s_tir import meta_schedule as ms from tvm.s_tir.schedule import Schedule -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, # fmt: off diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_mma_tensorize.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_mma_tensorize.py index db53f7b3571b..15487893d0d9 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_mma_tensorize.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_mma_tensorize.py @@ -23,7 +23,7 @@ import tvm.s_tir.tensor_intrin # pylint: disable=unused-import import tvm.testing from tvm.s_tir.schedule import Schedule -from tvm.script import tir as T +from tvm.script import tirx as T torch = pytest.importorskip("torch") diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_compute_location.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_compute_location.py index 96a16bd71c14..ce0b4b8bb312 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_compute_location.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_compute_location.py @@ -17,7 +17,7 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring from tvm.s_tir import Schedule from tvm.s_tir import meta_schedule as ms -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target # pylint: disable=invalid-name, no-member diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_parallel.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_parallel.py index 7678ee7a418b..fe367f414788 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_parallel.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_parallel.py @@ -18,7 +18,7 @@ from tvm.s_tir import Schedule from tvm.s_tir import meta_schedule as ms -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target # pylint: disable=invalid-name, no-member diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_thread_binding.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_thread_binding.py index 69d687476833..11fc2a9abf82 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_thread_binding.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_thread_binding.py @@ -17,7 +17,7 @@ # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring from tvm.s_tir import Schedule from tvm.s_tir import meta_schedule as ms -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target # pylint: disable=invalid-name, no-member diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_tile_size.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_tile_size.py index bfb0ff3eb2cb..c9aa7d9e666b 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_tile_size.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_tile_size.py @@ -20,7 +20,7 @@ from tvm.s_tir import Schedule from tvm.s_tir import meta_schedule as ms -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target # pylint: disable=invalid-name, no-member diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_unroll.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_unroll.py index bec385163652..c29e190afb1c 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_unroll.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_mutator_mutate_unroll.py @@ -18,7 +18,7 @@ from tvm.s_tir import Schedule from tvm.s_tir import meta_schedule as ms -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target # pylint: disable=invalid-name, no-member diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py index c6b477a7275f..fdc47532f1e4 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py @@ -33,7 +33,7 @@ from tvm.s_tir.meta_schedule.space_generator import PostOrderApply from tvm.s_tir.meta_schedule.utils import derived_object from tvm.s_tir.schedule import SBlockRV, Schedule -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, @@ -124,7 +124,7 @@ def main(a: T.handle, d: T.handle) -> None: A = T.match_buffer(a, [1024, 1024], dtype="float32") D = T.match_buffer(d, [1024, 1024], dtype="float32") # body - # with tir.block("root") + # with tirx.block("root") B = T.sblock_alloc_buffer([1024, 1024], dtype="float32") for i0_0, i1_0, i0_1, i1_1 in T.grid(16, 64, 64, 16): with T.sblock("A"): diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_disallow_async_strided_mem_copy.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_disallow_async_strided_mem_copy.py index a9330454473c..4f6a5abfe96f 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_disallow_async_strided_mem_copy.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_disallow_async_strided_mem_copy.py @@ -18,9 +18,9 @@ # ruff: noqa: F401 import tvm -from tvm import tir +from tvm import tirx from tvm.s_tir import meta_schedule as ms -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_disallow_dynamic_loop.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_disallow_dynamic_loop.py index 4b36c3fdeb73..b125c926295a 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_disallow_dynamic_loop.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_disallow_dynamic_loop.py @@ -18,9 +18,9 @@ # ruff: noqa: F401 import tvm -from tvm import tir +from tvm import tirx from tvm.s_tir import meta_schedule as ms -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_cooperative_fetch.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_cooperative_fetch.py index a1c6c3cd8bef..d0af40adb7ec 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_cooperative_fetch.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_cooperative_fetch.py @@ -19,10 +19,10 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir import meta_schedule as ms from tvm.s_tir.meta_schedule.testing import te_workload -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target from tvm.te import create_prim_func @@ -55,7 +55,7 @@ class AfterRewrite0: @T.prim_func def main(var_A: T.handle, var_B: T.handle, var_C: T.handle) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(var_A, [512, 512], dtype="float32") B = T.match_buffer(var_B, [512, 512], dtype="float32") C = T.match_buffer(var_C, [512, 512], dtype="float32") @@ -113,7 +113,7 @@ def main( C: T.Buffer((512, 512), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") C_local = T.sblock_alloc_buffer([512, 512], dtype="float32", scope="local") diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py index 5897b617799f..ef2444eac1e3 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_layout.py @@ -22,7 +22,7 @@ import tvm.testing from tvm.s_tir import meta_schedule as ms from tvm.s_tir.schedule.testing import assert_structural_equal_ignore_global_symbol -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -226,7 +226,7 @@ def test_layout_rewrite(): class Conv2dCacheRead: @T.prim_func def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), "float32"), conv2d_nhwc: T.Buffer((1, 56, 56, 64), "float32")): - T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"layout_free_buffers": [1], "tirx.noalias": True, "global_symbol": "main"}) pad_temp = T.sblock_alloc_buffer([1, 58, 58, 64], dtype="float32") conv2d_nhwc_global = T.sblock_alloc_buffer([1, 56, 56, 64], dtype="float32") pad_temp_global = T.sblock_alloc_buffer([1, 58, 58, 64], dtype="float32") @@ -303,7 +303,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), class Conv2dCacheReadRewritten: @T.prim_func def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), "float32"), conv2d_nhwc: T.Buffer((1, 56, 56, 64), "float32")): - T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"layout_free_buffers": [1], "tirx.noalias": True, "global_symbol": "main"}) pad_temp = T.sblock_alloc_buffer([1, 58, 58, 64], dtype="float32") conv2d_nhwc_global = T.sblock_alloc_buffer([1, 56, 56, 64], dtype="float32") pad_temp_global = T.sblock_alloc_buffer([1, 58, 58, 64], dtype="float32") @@ -388,7 +388,7 @@ def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), class Conv2dCacheReadMultipleRewritten: @T.prim_func def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((3, 3, 64, 64), "float32"), conv2d_nhwc: T.Buffer((1, 56, 56, 64), "float32")): - T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"layout_free_buffers": [1], "tirx.noalias": True, "global_symbol": "main"}) pad_temp = T.sblock_alloc_buffer([1, 58, 58, 64], dtype="float32") conv2d_nhwc_global = T.sblock_alloc_buffer([1, 56, 56, 64], dtype="float32") pad_temp_global = T.sblock_alloc_buffer([1, 58, 58, 64], dtype="float32") @@ -504,7 +504,7 @@ def before( p1: T.Buffer((T.int64(12), T.int64(197), T.int64(64)), "int8"), T_batch_matmul_NT: T.Buffer((T.int64(12), T.int64(197), T.int64(197)), "int32"), ): - T.func_attr({"layout_free_buffers": [1], "tir.noalias": True}) + T.func_attr({"layout_free_buffers": [1], "tirx.noalias": True}) for b_0_i_0_fused in T.parallel(T.int64(394)): for j_0 in T.serial(T.int64(1)): for b_1, i_1, j_1 in T.grid(T.int64(1), T.int64(1), T.int64(1)): @@ -565,7 +565,7 @@ def expected( p1: T.Buffer((T.int64(12), T.int64(197), T.int64(64)), "int8"), T_batch_matmul_NT: T.Buffer((T.int64(12), T.int64(197), T.int64(197)), "int32"), ): - T.func_attr({"tir.noalias": True, "layout_free_buffers": [1]}) + T.func_attr({"tirx.noalias": True, "layout_free_buffers": [1]}) p1_global = T.sblock_alloc_buffer( [T.int64(2), T.int64(64), T.int64(6), T.int64(197)], dtype="int8" ) diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py index 17c363963540..f70f16ea6c45 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -21,7 +21,7 @@ from tvm.s_tir.meta_schedule.postproc import RewriteParallelVectorizeUnroll from tvm.s_tir.schedule import Schedule from tvm.s_tir.schedule.testing import assert_structural_equal_ignore_global_symbol -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant # fmt: off @@ -189,7 +189,7 @@ def test_meta_schedule_postproc_rewrite_parallel_unroll_vectorize(): postproc = RewriteParallelVectorizeUnroll() sch = Schedule(Move_PUV) assert postproc.apply(sch) - mod = tvm.tir.transform.Simplify()(sch.mod) + mod = tvm.tirx.transform.Simplify()(sch.mod) tvm.ir.assert_structural_equal(mod["main"], Move_PUV0) @@ -265,7 +265,7 @@ def expected(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "flo postproc = RewriteParallelVectorizeUnroll() sch = Schedule(layer_norm) assert postproc.apply(sch) - mod = tvm.tir.transform.Simplify()(sch.mod) + mod = tvm.tirx.transform.Simplify()(sch.mod) assert_structural_equal_ignore_global_symbol(mod["main"], expected) diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_reduction_block.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_reduction_block.py index b8eeb6991ee3..b9271f70e1e3 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_reduction_block.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_reduction_block.py @@ -18,9 +18,9 @@ # ruff: noqa: E501, F401 import tvm -from tvm import tir +from tvm import tirx from tvm.s_tir import meta_schedule as ms -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py index f3563beb6a63..37eb421dde84 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_tensorize.py @@ -19,7 +19,7 @@ import tvm from tvm.s_tir import meta_schedule as ms from tvm.s_tir.tensor_intrin import cuda, rocm, x86 -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.script.ir_module @@ -30,7 +30,7 @@ def main( placeholder_1: T.Buffer((16, 4, 1, 1, 4, 16, 4), "int8"), conv2d_NCHWc_int8: T.Buffer((1, 16, 56, 56, 16), "int32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for ( i0_0, i1_0, @@ -151,7 +151,7 @@ def main( conv2d_NCHWc_int8: T.Buffer((1, 16, 56, 56, 16), "int32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0 in T.grid( @@ -252,7 +252,7 @@ def main( W: T.Buffer((128, 128), "int8"), compute: T.Buffer((128, 128), "int32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) compute_local = T.sblock_alloc_buffer([128, 128], dtype="int32", scope="local") X_shared = T.sblock_alloc_buffer([128, 128], dtype="int8", scope="shared") W_shared = T.sblock_alloc_buffer([128, 128], dtype="int8", scope="shared") @@ -341,7 +341,7 @@ def main( compute: T.Buffer((128, 128), "int32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") compute_local = T.sblock_alloc_buffer([128, 128], dtype="int32", scope="local") diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_unbound_block.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_unbound_block.py index 532c87a4c6a0..1f438034b150 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_unbound_block.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_unbound_block.py @@ -18,9 +18,9 @@ # ruff: noqa: F401 import tvm -from tvm import tir +from tvm import tirx from tvm.s_tir import meta_schedule as ms -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -235,7 +235,7 @@ def before_unrolled_loop( placeholder: T.Buffer((1, 56, 56, 64), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) bgemm = T.sblock_alloc_buffer([6, 6, 196, 64], dtype="float32") inverse = T.sblock_alloc_buffer([4, 4, 196, 64], dtype="float32") for i2_0, i3_0, i2_1, i3_1 in T.grid(98, 4, 2, 16): @@ -259,7 +259,7 @@ def before_unrolled_loop( def after_unrolled_loop( placeholder: T.Buffer((1, 56, 56, 64), "float32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") bgemm = T.sblock_alloc_buffer([6, 6, 196, 64], dtype="float32") diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_verify_gpu_code.py index 0e680fae7c34..6101bd34a218 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_verify_gpu_code.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_verify_gpu_code.py @@ -20,9 +20,9 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir import meta_schedule as ms -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -400,7 +400,7 @@ def GMMCUDATensorCore( Z: T.Buffer((1024, 1024), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) s0 = T.int32() s0_1 = T.int32() s0_2 = T.int32() diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_verify_vtcm_limit.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_verify_vtcm_limit.py index 2c5d74b9ca4d..eaf1fba2881f 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_verify_vtcm_limit.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_verify_vtcm_limit.py @@ -18,9 +18,9 @@ # ruff: noqa: E501, F401, F841 import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir import meta_schedule as ms -from tvm.script import tir as T +from tvm.script import tirx as T def _create_context(mod, target) -> ms.TuneContext: @@ -44,7 +44,7 @@ def _create_context(mod, target) -> ms.TuneContext: class Conv2dNCHWcVTCM: @T.prim_func def main(p0: T.Buffer((T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64(32)), "uint8"), p1: T.Buffer((T.int64(2), T.int64(2), T.int64(3), T.int64(3), T.int64(8), T.int64(32), T.int64(4)), "uint8"), conv2d_NCHWc_int8: T.Buffer((T.int64(1), T.int64(2), T.int64(54), T.int64(54), T.int64(32)), "int32")): - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) p0_global_vtcm = T.sblock_alloc_buffer([T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64(32)], dtype="uint8", scope="global.vtcm") p1_global_vtcm = T.sblock_alloc_buffer([T.int64(2), T.int64(2), T.int64(3), T.int64(3), T.int64(8), T.int64(32), T.int64(4)], dtype="uint8", scope="global.vtcm") for n_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step":16, "pragma_unroll_explicit":1}): diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_runner.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_runner.py index 0ad8ddb67d84..b58be23698da 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_runner.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_runner.py @@ -56,9 +56,9 @@ derived_object, get_global_func_with_default_on_worker, ) -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target -from tvm.tir import FloatImm +from tvm.tirx import FloatImm MATMUL_N = 16 MATMUL_M = 32 @@ -70,7 +70,7 @@ class MatmulModule: @T.prim_func def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") @@ -86,7 +86,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s class MatmulReluModule: @T.prim_func def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-self-argument - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, (16, 16), "float32") B = T.match_buffer(b, (16, 16), "float32") D = T.match_buffer(d, (16, 16), "float32") @@ -107,7 +107,7 @@ def main(a: T.handle, b: T.handle, d: T.handle) -> None: # pylint: disable=no-s class BatchMatmulModule: @T.prim_func def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, [16, 32, 32]) B = T.match_buffer(b, [16, 32, 32]) C = T.match_buffer(c, [16, 32, 32]) @@ -123,7 +123,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s class AddModule: @T.prim_func def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, [32], "float32") B = T.match_buffer(b, [32], "float32") C = T.match_buffer(c, [32], "float32") @@ -138,7 +138,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-s class MatmulHugeModule: @T.prim_func def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, (4096, 4096), "float32") B = T.match_buffer(b, (4096, 4096), "float32") C = T.match_buffer(c, (4096, 4096), "float32") diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py index 247c66de97ee..82100fee9f68 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py @@ -21,7 +21,7 @@ check_sketches, generate_design_space, ) -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target from tvm.te import create_prim_func @@ -33,7 +33,7 @@ def cpu_matmul_0( B: T.Buffer((512, 4), "float32"), C: T.Buffer((4, 4), "float32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i0, i1, i2 in T.grid(4, 4, 512): with T.sblock("C"): i, j, k = T.axis.remap("SSR", [i0, i1, i2]) @@ -49,7 +49,7 @@ def cpu_matmul_1( B: T.Buffer((512, 4), "float32"), C: T.Buffer((4, 4), "float32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) C_rf = T.sblock_alloc_buffer([4, 4, 128], dtype="float32") for i0, i1, i2_0, i2_1 in T.grid(4, 4, 4, 128): with T.sblock("C_rf"): @@ -77,7 +77,7 @@ def cpu_matmul_2( B: T.Buffer((512, 4), "float32"), C: T.Buffer((4, 4), "float32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) C_rf = T.sblock_alloc_buffer([4, 4, 4], dtype="float32") for i0, i1, i2_0, i2_1 in T.grid(4, 4, 4, 128): with T.sblock("C_rf"): diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_apply_custom_rule.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_apply_custom_rule.py index 607327c0e58f..d2f3ecbf9af1 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_apply_custom_rule.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_apply_custom_rule.py @@ -22,7 +22,7 @@ import tvm from tvm.s_tir import meta_schedule as ms from tvm.s_tir.meta_schedule.schedule_rule import ApplyCustomRule -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.script.ir_module @@ -43,7 +43,7 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: @tvm.register_global_func("s_tir.meta_schedule.cpu.test_apply_custom_rule") -def sch_fn(sch: tvm.s_tir.Schedule, block: tvm.tir.SBlock) -> list[tvm.s_tir.Schedule]: +def sch_fn(sch: tvm.s_tir.Schedule, block: tvm.tirx.SBlock) -> list[tvm.s_tir.Schedule]: raise ValueError("Intended for s_tir.meta_schedule.cpu.test_apply_custom_rule") diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_auto_bind.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_auto_bind.py index c5097f9f0d72..335a391bf14d 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_auto_bind.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_auto_bind.py @@ -21,7 +21,7 @@ check_sketches, generate_design_space, ) -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_auto_inline.py index 3b2613c1cc33..84231f3469bd 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_auto_inline.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_auto_inline.py @@ -24,7 +24,7 @@ from tvm.s_tir import Schedule from tvm.s_tir import meta_schedule as ms from tvm.s_tir.meta_schedule.testing.space_generation import generate_design_space -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target # fmt: off @@ -254,7 +254,7 @@ def main( placeholder_2: T.Buffer((1, 384, 768), "float32"), T_add: T.Buffer((1, 384, 768), "float32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) compile_engine_const = T.sblock_alloc_buffer([], dtype="int64") T_less = T.sblock_alloc_buffer([1, 384], dtype="bool") compile_engine_const_1 = T.sblock_alloc_buffer([], dtype="int64") @@ -315,7 +315,7 @@ class AfterPureSpatial: @T.prim_func def main(placeholder: T.Buffer((1, 384), "int64"), placeholder_1: T.Buffer((30522, 768), "float32"), placeholder_2: T.Buffer((1, 384, 768), "float32"), T_add: T.Buffer((1, 384, 768), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") for i0, i1, i2 in T.grid(1, 384, 768): @@ -330,7 +330,7 @@ class ConstConsumer: @T.prim_func def main(T_full: T.Buffer((1, 12, 4096), "int64")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") for i0, i1, i2 in T.grid(1, 12, 4096): @@ -346,7 +346,7 @@ class Conv2dInt8: @T.prim_func def main(p0: T.Buffer((16, 14, 14, 256), "int8"), p1: T.Buffer((1024, 1, 1, 256), "int8"), p2: T.Buffer((1, 1, 1, 1024), "int32"), p3: T.Buffer((1, 1, 1, 1024), "int32"), p4: T.Buffer(1024, "int32"), p5: T.Buffer(1024, "int32"), p6: T.Buffer(1024, "int32"), p7: T.Buffer(1, "int32"), p8: T.Buffer((16, 14, 14, 1024), "int32"), compute: T.Buffer((16, 14, 14, 1024), "int32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") compile_engine_const = T.sblock_alloc_buffer([], dtype="int32") diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py index 6a3183822178..b7aea90a298d 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py @@ -23,7 +23,7 @@ check_sketches, generate_design_space, ) -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target from tvm.te import create_prim_func @@ -67,7 +67,7 @@ def softmax_mn_0( T_softmax_norm: T.Buffer((256, 256), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") T_softmax_maxelem = T.sblock_alloc_buffer([256], dtype="float32") @@ -110,7 +110,7 @@ def softmax_mn_1( A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32") ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") T_softmax_maxelem_shared = T.sblock_alloc_buffer([256], dtype="float32", scope="shared") @@ -162,7 +162,7 @@ def softmax_mn_2( A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32") ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") T_softmax_maxelem = T.sblock_alloc_buffer([256], dtype="float32") @@ -214,7 +214,7 @@ def softmax_mn_3( A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32") ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") T_softmax_maxelem_shared = T.sblock_alloc_buffer([256], dtype="float32", scope="shared") @@ -500,7 +500,7 @@ def test_gpu_batch_norm_bmn(): @T.prim_func def batch_norm_bmn_0(A: T.Buffer((1, 512, 512), "float32"), D: T.Buffer(1, "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") C = T.sblock_alloc_buffer([1], dtype="float32") @@ -522,7 +522,7 @@ def batch_norm_bmn_0(A: T.Buffer((1, 512, 512), "float32"), D: T.Buffer(1, "floa @T.prim_func def batch_norm_bmn_1(A: T.Buffer((1, 512, 512), "float32"), D: T.Buffer(1, "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") C_shared = T.sblock_alloc_buffer([1], dtype="float32", scope="shared") diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_mlt.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_mlt.py index 639a45603b8b..09765eee70c7 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_mlt.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_mlt.py @@ -25,7 +25,7 @@ generate_design_space, print_sketches, ) -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -37,7 +37,7 @@ def cpu_matmul_0( C: T.Buffer((512, 512), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") C_global = T.sblock_alloc_buffer([512, 512], dtype="float32") @@ -68,7 +68,7 @@ def cpu_matmul_1( C: T.Buffer((512, 512), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") C_global = T.sblock_alloc_buffer([512, 512], dtype="float32") @@ -99,7 +99,7 @@ def cpu_matmul_2( C: T.Buffer((512, 512), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") for i0_0, i1_0, i0_1, i1_1, i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid( @@ -155,7 +155,7 @@ def cpu_matmul_relu_0( compute: T.Buffer((512, 512), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") C = T.sblock_alloc_buffer([512, 512], dtype="float32") @@ -186,7 +186,7 @@ def cpu_matmul_relu_1( compute: T.Buffer((512, 512), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") C = T.sblock_alloc_buffer([512, 512], dtype="float32") @@ -217,7 +217,7 @@ def cpu_matmul_relu_2( compute: T.Buffer((512, 512), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") C = T.sblock_alloc_buffer([512, 512], dtype="float32") @@ -279,7 +279,7 @@ def cuda_matmul_0( C: T.Buffer((512, 512), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") C_local = T.sblock_alloc_buffer([512, 512], dtype="float32", scope="local") @@ -385,7 +385,7 @@ def cuda_matmul_relu_0( compute: T.Buffer((512, 512), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") C = T.sblock_alloc_buffer([512, 512], dtype="float32") @@ -528,7 +528,7 @@ def cpu_conv2d_nhwc( weight: T.Buffer((3, 3, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 56, 56, 64), "float16"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) PadInput = T.sblock_alloc_buffer((1, 58, 58, 64), "float16") for i0, i1, i2, i3 in T.grid(1, 58, 58, 64): with T.sblock("PadInput"): @@ -633,7 +633,7 @@ def cache_read_specify_consumer_0( B: T.Buffer((512, 512), "float32"), T_add: T.Buffer((512, 512), "float32"), ): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) C = T.sblock_alloc_buffer((512, 512)) C_local = T.sblock_alloc_buffer((512, 512), scope="local") A_shared = T.sblock_alloc_buffer((512, 512), scope="shared") @@ -742,7 +742,7 @@ def pool_blocked_cache_read_write( X: T.Buffer((1, 2, 8, 8, 8, 8, 32), "uint8"), pool: T.Buffer((1, 2, 4, 4, 8, 8, 32), "uint8"), ): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) pool_global = T.sblock_alloc_buffer((1, 2, 4, 4, 8, 8, 32), "uint8") X_global = T.sblock_alloc_buffer((1, 2, 8, 8, 8, 8, 32), "uint8") for b_0, c_o_0, h_o_0, w_o_0, h_i_0, w_i_0, c_i_0 in T.grid(1, 2, 4, 1, 8, 1, 4): diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_mlt_intrin.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_mlt_intrin.py index cc8e5d8f7dfc..6fb7c78dab90 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_mlt_intrin.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_mlt_intrin.py @@ -27,7 +27,7 @@ from tvm.s_tir.tensor_intrin.arm_cpu import DP4A_S8S8S32_INTRIN from tvm.s_tir.tensor_intrin.x86 import AVX512_DOT_16x4_INTRIN as AVX512_INTRIN from tvm.s_tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -40,7 +40,7 @@ def conv2d_nchwc( placeholder_1: T.Buffer((16, 4, 1, 1, 4, 16, 4), "int8"), conv2d_NCHWc_int8: T.Buffer((1, 16, 56, 56, 16), "int32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): with T.sblock("conv2d_NCHWc_int8"): ( @@ -74,7 +74,7 @@ def conv2d_nchwc( # fmt: off @T.prim_func def x86_conv2d_nchwc_0(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), placeholder_1: T.Buffer((16, 4, 1, 1, 4, 16, 4), "int8"), conv2d_NCHWc_int8: T.Buffer((1, 16, 56, 56, 16), "int32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): conv2d_NCHWc_int8_global = T.sblock_alloc_buffer((1, 16, 56, 56, 16), "int32") for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1 in T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1): @@ -120,7 +120,7 @@ def x86_conv2d_nchwc_0(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), place @T.prim_func def x86_conv2d_nchwc_1(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), placeholder_1: T.Buffer((16, 4, 1, 1, 4, 16, 4), "int8"), conv2d_NCHWc_int8: T.Buffer((1, 16, 56, 56, 16), "int32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): conv2d_NCHWc_int8_global = T.sblock_alloc_buffer((1, 16, 56, 56, 16), "int32") for i0_0, i1_0, i2_0, i3_0, i4_0_0 in T.grid(1, 8, 28, 56, 1): @@ -166,7 +166,7 @@ def x86_conv2d_nchwc_1(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), place @T.prim_func def x86_conv2d_nchwc_2(placeholder: T.Buffer((1, 4, 56, 56, 16), "uint8"), placeholder_1: T.Buffer((16, 4, 1, 1, 4, 16, 4), "int8"), conv2d_NCHWc_int8: T.Buffer((1, 16, 56, 56, 16), "int32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1): with T.sblock("conv2d_NCHWc_int8_o"): @@ -309,7 +309,7 @@ def dp4a_dense_0( W: T.Buffer((128, 128), "int8"), compute: T.Buffer((128, 128), "int32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): compute_local = T.sblock_alloc_buffer((128, 128), "int32", scope="local") X_shared = T.sblock_alloc_buffer((128, 128), "int8", scope="shared") diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py index 2b535502dac2..0fb711e4ece3 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py @@ -31,7 +31,7 @@ print_sketches, ) from tvm.s_tir.tensor_intrin.cuda import get_wmma_intrin_group -from tvm.script import tir as T +from tvm.script import tirx as T def multi_level_tiling_tensor_core( @@ -85,7 +85,7 @@ def test_matmul_relu(shared_scope): # fmt: off @T.prim_func def matmul_relu_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): C_reindex_shared = T.sblock_alloc_buffer((4, 8, 2, 1, 16, 16), scope=shared_scope) C_reindex_shared_wmma_accumulator = T.sblock_alloc_buffer((4, 8, 2, 1, 16, 16), scope="wmma.accumulator") @@ -236,7 +236,7 @@ def test_matmul_relu_with_fallback(): # fmt: off @T.prim_func def matmul_relu_fallback_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): C_reindex_shared = T.sblock_alloc_buffer((4, 2, 2, 4, 16, 16), scope="shared") C_reindex_shared_wmma_accumulator = T.sblock_alloc_buffer((4, 2, 2, 4, 16, 16), scope="wmma.accumulator") @@ -394,7 +394,7 @@ def test_conv2d(shared_scope): # fmt: off @T.prim_func def conv2d_0(inputs: T.Buffer((1, 16, 16, 32), "float16"), weight: T.Buffer((3, 3, 32, 32), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 32), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): PadInput = T.sblock_alloc_buffer((1, 18, 18, 32), "float16") conv2d_nhwc_reindex_shared_dyn = T.sblock_alloc_buffer((16, 2, 1, 1, 16, 16), scope=shared_scope) @@ -577,7 +577,7 @@ def test_matmul_relu_pipeline(shared_scope): @T.prim_func def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, 128), "float16"), compute: T.Buffer((128, 128), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") C = T.sblock_alloc_buffer((128, 128)) @@ -597,7 +597,7 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, v1 = T.axis.spatial(128, ax2_0_0 * 32 + ax0_ax1_fused % 32) T.reads(A[v0, v1]) T.writes(A_reindex_shared[v0, v1]) - T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 4, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 4, "tirx.manifest_shared_memory_local_stage": 1}) A_reindex_shared[v0, v1] = A[v0, v1] for ax0_ax1_fused in range(1024): with T.sblock("B_reindex_shared"): @@ -605,7 +605,7 @@ def matmul_relu_pipeline_0(A: T.Buffer((128, 128), "float16"), B: T.Buffer((128, v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax0_ax1_fused % 32) T.reads(B[v0, v1]) T.writes(B_reindex_shared[v0, v1]) - T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 2, "tir.manifest_shared_memory_local_stage": 1}) + T.sblock_attr({"buffer_dim_align": [[0, 0, 32, 8]], "double_buffer_scope": 0, "meta_schedule.cooperative_fetch": 2, "tirx.manifest_shared_memory_local_stage": 1}) B_reindex_shared[v0, v1] = B[v0, v1] for ax2_0_1 in T.serial(2, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): for ax0_0, ax1_0 in T.grid(2, 1): @@ -757,7 +757,7 @@ def test_padded_matmul_relu(): # fmt: off @T.prim_func def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 127), "float16"), compute: T.Buffer((127, 127), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) C_reindex_shared = T.sblock_alloc_buffer((4, 8, 2, 1, 16, 16), scope="shared") C_reindex_shared_wmma_accumulator = T.sblock_alloc_buffer((4, 8, 2, 1, 16, 16), scope="wmma.accumulator") A_reindex_shared = T.sblock_alloc_buffer((128, 128), "float16", scope="shared") @@ -905,7 +905,7 @@ def test_conv_1x1(): # fmt: off @T.prim_func def conv2d_1x1_0(inputs: T.Buffer((1, 16, 16, 64), "float16"), weight: T.Buffer((1, 1, 64, 64), "float16"), conv2d_nhwc: T.Buffer((1, 16, 16, 64), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): conv2d_nhwc_reindex_shared = T.sblock_alloc_buffer((2, 1, 8, 4, 16, 16), scope="shared") conv2d_nhwc_reindex_shared_wmma_accumulator = T.sblock_alloc_buffer((2, 1, 8, 4, 16, 16), scope="wmma.accumulator") @@ -1063,7 +1063,7 @@ def test_padded_conv(): # fmt: off @T.prim_func def padded_conv2d_0(inputs: T.Buffer((1, 224, 224, 3), "float16"), weight: T.Buffer((7, 7, 3, 64), "float16"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): conv2d_nhwc_reindex_shared = T.sblock_alloc_buffer((56, 2, 14, 2, 16, 16), scope="shared") conv2d_nhwc_reindex_shared_wmma_accumulator = T.sblock_alloc_buffer((56, 2, 14, 2, 16, 16), scope="wmma.accumulator") @@ -1215,7 +1215,7 @@ def test_padded_matmul_single_padded_input(): # fmt: off @T.prim_func def padded_matmul_single_padded_input_0(A: T.Buffer((1023, 4096), "float16"), B: T.Buffer((4096, 1024), "float16"), C: T.Buffer((1023, 1024), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): C_reindex_pad_shared = T.sblock_alloc_buffer((8, 32, 8, 2, 16, 16), scope="shared") C_reindex_pad_shared_wmma_accumulator = T.sblock_alloc_buffer((8, 32, 8, 2, 16, 16), scope="wmma.accumulator") @@ -1363,7 +1363,7 @@ def test_padded_matmul_no_padded_output(): # fmt: off @T.prim_func def padded_matmul_no_padded_output_0(A: T.Buffer((1024, 4095), "float16"), B: T.Buffer((4095, 1024), "float16"), C: T.Buffer((1024, 1024), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): C_reindex_shared = T.sblock_alloc_buffer((32, 16, 2, 4, 16, 16), scope="shared") C_reindex_shared_wmma_accumulator = T.sblock_alloc_buffer((32, 16, 2, 4, 16, 16), scope="wmma.accumulator") diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py index 2595705f6c76..56efaeeaf843 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_parallel_vectorize_unroll.py @@ -22,7 +22,7 @@ check_sketches, generate_design_space, ) -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target # fmt: off @@ -64,12 +64,12 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] -# from tvm.script import tir as T +# from tvm.script import tirx as T @tvm.script.ir_module class PureSpatial: @T.prim_func def main(placeholder: T.Buffer((1, 13, 13, 3, 85), "float32"), placeholder_1: T.Buffer((1, 26, 26, 3, 85), "float32"), placeholder_2: T.Buffer((1, 52, 52, 3, 85), "float32"), T_expand_dims: T.Buffer((1, 80, 10647), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) T_strided_slice_with_axes = T.sblock_alloc_buffer([1, 52, 52, 3, 1], dtype="float32") T_sigmoid = T.sblock_alloc_buffer([1, 52, 52, 3, 1], dtype="float32") T_strided_slice_with_axes_1 = T.sblock_alloc_buffer([1, 52, 52, 3, 80], dtype="float32") diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_random_compute_location.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_random_compute_location.py index 56001c7fc068..f3d8dbfd4dca 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_random_compute_location.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_random_compute_location.py @@ -21,7 +21,7 @@ check_sketches, generate_design_space, ) -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target # fmt: off diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py index 83e943609627..5df88ba7d511 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py @@ -27,7 +27,7 @@ from tvm.s_tir.meta_schedule.testing.dummy_object import DummyMutator from tvm.s_tir.meta_schedule.utils import derived_object from tvm.s_tir.schedule import Schedule, Trace -from tvm.script import tir as T +from tvm.script import tirx as T MATMUL_M = 32 diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cpu.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cpu.py index e614e8c649c6..d7e701e333d0 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cpu.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cpu.py @@ -24,7 +24,7 @@ print_sketches, ) from tvm.s_tir.meta_schedule.testing.te_workload import create_te_workload -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -45,7 +45,7 @@ def test_cpu_c1d(): # fmt: off @T.prim_func def c1d_0(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 128), "float32"), conv1d_nlc: T.Buffer((1, 128, 128), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -81,7 +81,7 @@ def c1d_0(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 12 conv1d_nlc[v0, v1, v2] = conv1d_nlc_global[v0, v1, v2] @T.prim_func def c1d_1(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 128), "float32"), conv1d_nlc: T.Buffer((1, 128, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -122,7 +122,7 @@ def c1d_1(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 12 @T.prim_func def c1d_2(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 128), "float32"), conv1d_nlc: T.Buffer((1, 128, 128), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -184,7 +184,7 @@ def test_cpu_c2d(): # fmt: off @T.prim_func def c2d_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -228,7 +228,7 @@ def c2d_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] @T.prim_func def c2d_1(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -268,7 +268,7 @@ def c2d_1(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] @T.prim_func def c2d_2(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -349,7 +349,7 @@ def test_cpu_c3d(): # fmt: off @T.prim_func def c3d_0(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 7, 3, 64), "float32"), conv3d_ndhwc: T.Buffer((1, 8, 112, 112, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -397,7 +397,7 @@ def c3d_0(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7 conv3d_ndhwc[v0, v1, v2, v3, v4] = conv3d_ndhwc_global[v0, v1, v2, v3, v4] @T.prim_func def c3d_1(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 7, 3, 64), "float32"), conv3d_ndhwc: T.Buffer((1, 8, 112, 112, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -445,7 +445,7 @@ def c3d_1(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7 conv3d_ndhwc[v0, v1, v2, v3, v4] = conv3d_ndhwc_global[v0, v1, v2, v3, v4] @T.prim_func def c3d_2(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 7, 3, 64), "float32"), conv3d_ndhwc: T.Buffer((1, 8, 112, 112, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -535,7 +535,7 @@ def test_cpu_cap(): # fmt: off @T.prim_func def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer((3, 3, 4, 4, 32, 32), "float32"), conv2d_capsule_nhwijc: T.Buffer((1, 8, 8, 4, 4, 32), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -584,7 +584,7 @@ def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5] = conv2d_capsule_nhwijc_global[v0, v1, v2, v3, v4, v5] @T.prim_func def cap_1(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer((3, 3, 4, 4, 32, 32), "float32"), conv2d_capsule_nhwijc: T.Buffer((1, 8, 8, 4, 4, 32), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -630,7 +630,7 @@ def cap_1(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( conv2d_capsule_nhwijc[v0, v1, v2, v3, v4, v5] = conv2d_capsule_nhwijc_global[v0, v1, v2, v3, v4, v5] @T.prim_func def cap_2(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer((3, 3, 4, 4, 32, 32), "float32"), conv2d_capsule_nhwijc: T.Buffer((1, 8, 8, 4, 4, 32), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -717,7 +717,7 @@ def test_cpu_dep(): # fmt: off @T.prim_func def dep_0(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T.Buffer((1, 3, 3, 32), "float32"), depth_conv2d_nhwc: T.Buffer((1, 112, 112, 32), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -756,7 +756,7 @@ def dep_0(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T. depth_conv2d_nhwc[v0, v1, v2, v3] = depth_conv2d_nhwc_global[v0, v1, v2, v3] @T.prim_func def dep_1(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T.Buffer((1, 3, 3, 32), "float32"), depth_conv2d_nhwc: T.Buffer((1, 112, 112, 32), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -792,7 +792,7 @@ def dep_1(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T. depth_conv2d_nhwc[v0, v1, v2, v3] = depth_conv2d_nhwc_global[v0, v1, v2, v3] @T.prim_func def dep_2(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T.Buffer((1, 3, 3, 32), "float32"), depth_conv2d_nhwc: T.Buffer((1, 112, 112, 32), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -866,7 +866,7 @@ def test_cpu_dil(): # fmt: off @T.prim_func def dil_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 109, 109, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -909,7 +909,7 @@ def dil_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] @T.prim_func def dil_1(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 109, 109, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -953,7 +953,7 @@ def dil_1(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] @T.prim_func def dil_2(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 109, 109, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1032,7 +1032,7 @@ def test_cpu_gmm(): # fmt: off @T.prim_func def gmm_0(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "float32"), Z: T.Buffer((1, 128, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1061,7 +1061,7 @@ def gmm_0(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "flo Z[v0, v1, v2] = Z_global[v0, v1, v2] @T.prim_func def gmm_1(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "float32"), Z: T.Buffer((1, 128, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1090,7 +1090,7 @@ def gmm_1(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "flo Z[v0, v1, v2] = Z_global[v0, v1, v2] @T.prim_func def gmm_2(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "float32"), Z: T.Buffer((1, 128, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1143,7 +1143,7 @@ def test_cpu_grp(): # fmt: off @T.prim_func def grp_0(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, 16, 128), "float32"), conv2d_nhwc: T.Buffer((1, 28, 28, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1187,7 +1187,7 @@ def grp_0(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] @T.prim_func def grp_1(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, 16, 128), "float32"), conv2d_nhwc: T.Buffer((1, 28, 28, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1227,7 +1227,7 @@ def grp_1(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, conv2d_nhwc[v0, v1, v2, v3] = conv2d_nhwc_global[v0, v1, v2, v3] @T.prim_func def grp_2(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, 16, 128), "float32"), conv2d_nhwc: T.Buffer((1, 28, 28, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1306,7 +1306,7 @@ def test_cpu_t2d(): # fmt: off @T.prim_func def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 512, 256), "float32"), conv2d_transpose_nhwc: T.Buffer((1, 8, 8, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1346,7 +1346,7 @@ def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5 conv2d_transpose_nhwc[v0, v1, v2, v3] = conv2d_transpose_nhwc_global[v0, v1, v2, v3] @T.prim_func def t2d_1(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 512, 256), "float32"), conv2d_transpose_nhwc: T.Buffer((1, 8, 8, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1387,7 +1387,7 @@ def t2d_1(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5 conv2d_transpose_nhwc[v0, v1, v2, v3] = conv2d_transpose_nhwc_global[v0, v1, v2, v3] @T.prim_func def t2d_2(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 512, 256), "float32"), conv2d_transpose_nhwc: T.Buffer((1, 8, 8, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1456,7 +1456,7 @@ def test_cpu_nrm(): # fmt: off @T.prim_func def nrm_0(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1487,7 +1487,7 @@ def nrm_0(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> N D[v_b] = T.sqrt(C[v_b]) @T.prim_func def nrm_1(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1518,7 +1518,7 @@ def nrm_1(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> N D[v_b] = T.sqrt(C[v_b]) @T.prim_func def nrm_2(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1569,7 +1569,7 @@ def test_cpu_sfm(): # fmt: off @T.prim_func def sfm_0(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1620,7 +1620,7 @@ def sfm_0(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_1(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1681,7 +1681,7 @@ def sfm_1(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_norm[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0] @T.prim_func def sfm_2(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1722,7 +1722,7 @@ def sfm_2(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_3(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1787,7 +1787,7 @@ def sfm_3(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_norm[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0] @T.prim_func def sfm_4(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1847,7 +1847,7 @@ def sfm_4(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_norm[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0] @T.prim_func def sfm_5(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1902,7 +1902,7 @@ def sfm_5(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_norm[v_i0, v_i1] = T_softmax_exp[v_i0, v_i1] / T_softmax_expsum[v_i0] @T.prim_func def sfm_6(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1946,7 +1946,7 @@ def sfm_6(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_7(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1988,7 +1988,7 @@ def sfm_7(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_8(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -2130,7 +2130,7 @@ def test_cpu_cbr(): # fmt: off @T.prim_func def cbr_0(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3, 64), "float32"), bias: T.Buffer(64, "float32"), bn_offset: T.Buffer(64, "float32"), bn_scale: T.Buffer(64, "float32"), compute: T.Buffer((1, 112, 112, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -2159,7 +2159,7 @@ def cbr_0(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3 compute[v_i0, v_i1, v_i2, v_i3] = T.max((Conv2dOutput[v_i0, v_i1, v_i2, v_i3] + bias[v_i3]) * bn_scale[v_i3] + bn_offset[v_i3], T.float32(0)) @T.prim_func def cbr_1(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3, 64), "float32"), bias: T.Buffer(64, "float32"), bn_offset: T.Buffer(64, "float32"), bn_scale: T.Buffer(64, "float32"), compute: T.Buffer((1, 112, 112, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -2203,7 +2203,7 @@ def cbr_1(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3 compute[v_i0, v_i1, v_i2, v_i3] = T.max((Conv2dOutput[v_i0, v_i1, v_i2, v_i3] + bias[v_i3]) * bn_scale[v_i3] + bn_offset[v_i3], T.float32(0)) @T.prim_func def cbr_2(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3, 64), "float32"), bias: T.Buffer(64, "float32"), bn_offset: T.Buffer(64, "float32"), bn_scale: T.Buffer(64, "float32"), compute: T.Buffer((1, 112, 112, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -2293,7 +2293,7 @@ def test_cpu_tbg(): # fmt: off @T.prim_func def tbg_0(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, 12, 64), "float32"), C: T.Buffer((1, 12, 128, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -2345,7 +2345,7 @@ def tbg_0(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, C[v0, v1, v2, v3] = C_global[v0, v1, v2, v3] @T.prim_func def tbg_1(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, 12, 64), "float32"), C: T.Buffer((1, 12, 128, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -2392,7 +2392,7 @@ def tbg_1(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, C[v0, v1, v2, v3] = C_global[v0, v1, v2, v3] @T.prim_func def tbg_2(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, 12, 64), "float32"), C: T.Buffer((1, 12, 128, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py index d3c93f0a498c..d99ce2fdcfce 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py @@ -24,7 +24,7 @@ print_sketches, ) from tvm.s_tir.meta_schedule.testing.te_workload import create_te_workload -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -45,7 +45,7 @@ def test_cuda_c1d(): # fmt: off @T.prim_func def c1d_0(inputs: T.Buffer((1, 256, 64), "float32"), weight: T.Buffer((3, 64, 128), "float32"), conv1d_nlc: T.Buffer((1, 128, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -123,7 +123,7 @@ def test_cuda_c2d(): # fmt: off @T.prim_func def c2d_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -207,7 +207,7 @@ def test_cuda_c3d(): # fmt: off @T.prim_func def c3d_0(inputs: T.Buffer((1, 16, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 7, 3, 64), "float32"), conv3d_ndhwc: T.Buffer((1, 8, 112, 112, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -297,7 +297,7 @@ def test_cuda_cap(): # fmt: off @T.prim_func def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer((3, 3, 4, 4, 32, 32), "float32"), conv2d_capsule_nhwijc: T.Buffer((1, 8, 8, 4, 4, 32), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -391,7 +391,7 @@ def test_cuda_dep(): # fmt: off @T.prim_func def dep_0(placeholder: T.Buffer((1, 112, 112, 32), "float32"), placeholder_1: T.Buffer((1, 3, 3, 32), "float32"), depth_conv2d_nhwc: T.Buffer((1, 112, 112, 32), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -472,7 +472,7 @@ def test_cuda_dil(): # fmt: off @T.prim_func def dil_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 109, 109, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -553,7 +553,7 @@ def test_cuda_gmm(): # fmt: off @T.prim_func def gmm_0(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "float32"), Z: T.Buffer((1, 128, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -627,7 +627,7 @@ def test_cuda_grp(): # fmt: off @T.prim_func def grp_0(inputs: T.Buffer((1, 56, 56, 64), "float32"), weight: T.Buffer((3, 3, 16, 128), "float32"), conv2d_nhwc: T.Buffer((1, 28, 28, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -709,7 +709,7 @@ def test_cuda_t2d(): # fmt: off @T.prim_func def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 512, 256), "float32"), conv2d_transpose_nhwc: T.Buffer((1, 8, 8, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -793,7 +793,7 @@ def test_cuda_nrm(): # fmt: off @T.prim_func def nrm_0(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -819,7 +819,7 @@ def nrm_0(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> N D[v_b] = T.sqrt(C[v_b]) @T.prim_func def nrm_1(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -866,7 +866,7 @@ def test_cuda_sfm(): # fmt: off @T.prim_func def sfm_0(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -906,7 +906,7 @@ def sfm_0(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_1(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -946,7 +946,7 @@ def sfm_1(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum[v_i0] @T.prim_func def sfm_2(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -988,7 +988,7 @@ def sfm_2(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 T_softmax_norm[v_i0, v_i1] = T.exp(A[v_i0, v_i1] - T_softmax_maxelem[v_i0]) / T_softmax_expsum_shared[v_i0] @T.prim_func def sfm_3(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1065,7 +1065,7 @@ def test_cuda_cbr(): # fmt: off @T.prim_func def cbr_0(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3, 64), "float32"), bias: T.Buffer(64, "float32"), bn_offset: T.Buffer(64, "float32"), bn_scale: T.Buffer(64, "float32"), compute: T.Buffer((1, 112, 112, 64), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1148,7 +1148,7 @@ def test_cuda_tbg(): # fmt: off @T.prim_func def tbg_0(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, 12, 64), "float32"), C: T.Buffer((1, 12, 128, 128), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda_async.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda_async.py index 8448efa59bfb..993058e605e7 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda_async.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda_async.py @@ -24,7 +24,7 @@ print_sketches, ) from tvm.s_tir.meta_schedule.testing.te_workload import create_te_workload -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target @@ -46,7 +46,7 @@ def get_c2d_prim_func(stage: int): # fmt: off @T.prim_func def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -107,7 +107,7 @@ def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3 # fmt: off @T.prim_func def c2d(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, 3, 64), "float32"), conv2d_nhwc: T.Buffer((1, 112, 112, 64), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -200,7 +200,7 @@ def get_gmm_prim_func(stage: int): # fmt: off @T.prim_func def gmm(X: T.Buffer((1, 1024, 1024), "float32"), Y: T.Buffer((1, 1024, 1024), "float32"), Z: T.Buffer((1, 1024, 1024), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -255,7 +255,7 @@ def gmm(X: T.Buffer((1, 1024, 1024), "float32"), Y: T.Buffer((1, 1024, 1024), "f # fmt: off @T.prim_func def gmm(X: T.Buffer((1, 1024, 1024), "float32"), Y: T.Buffer((1, 1024, 1024), "float32"), Z: T.Buffer((1, 1024, 1024), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_generator.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_generator.py index a32673dc7765..0f9a164b8305 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_generator.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_generator.py @@ -32,7 +32,7 @@ from tvm.s_tir.meta_schedule.tune_context import TuneContext from tvm.s_tir.meta_schedule.utils import derived_object from tvm.s_tir.schedule import Schedule -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument # fmt: off diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_post_opt.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_post_opt.py index 3425194c52f0..25618c533433 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_post_opt.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_post_opt.py @@ -26,7 +26,7 @@ import tvm.testing from tvm.s_tir import meta_schedule as ms from tvm.s_tir.meta_schedule.runner.config import EvaluatorConfig -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target logging.basicConfig() diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_task_scheduler.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_task_scheduler.py index ef2aed636e2a..2cb3aa5d3a31 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_task_scheduler.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_task_scheduler.py @@ -27,7 +27,7 @@ from tvm.s_tir import Schedule from tvm.s_tir import meta_schedule as ms from tvm.s_tir.meta_schedule.testing.dummy_object import DummyBuilder, DummyRunner -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @@ -40,7 +40,7 @@ def main( # type: ignore b: T.handle, c: T.handle, ) -> None: # pylint: disable=no-self-argument - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") @@ -60,7 +60,7 @@ def main( # type: ignore b: T.handle, d: T.handle, ) -> None: # pylint: disable=no-self-argument - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") D = T.match_buffer(d, (1024, 1024), "float32") @@ -85,7 +85,7 @@ def main( # type: ignore b: T.handle, c: T.handle, ) -> None: # pylint: disable=no-self-argument - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, [16, 128, 128]) B = T.match_buffer(b, [16, 128, 128]) C = T.match_buffer(c, [16, 128, 128]) diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_trace_apply.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_trace_apply.py index 9199f828e30f..9a57874fe07b 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_trace_apply.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_trace_apply.py @@ -23,10 +23,10 @@ from tvm.s_tir import Schedule from tvm.s_tir.tensor_intrin.cuda import * from tvm.s_tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target from tvm.target.codegen import llvm_lookup_intrinsic_id -from tvm.tir import floordiv, floormod +from tvm.tirx import floordiv, floormod # fmt: off @@ -39,7 +39,7 @@ def main( T_matmul_NT: T.Buffer((128, 128), "float32"), ) -> None: # function attr dict - T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"layout_free_buffers": [1], "tirx.noalias": True, "global_symbol": "main"}) # body # with T.sblock("root") for i0, i1, i2 in T.grid(128, 128, 128): @@ -62,7 +62,7 @@ def main( T_add: T.Buffer((128, 128), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True, "layout_free_buffers": [1]}) # body # with T.sblock("root") T_matmul_NT = T.sblock_alloc_buffer([128, 128], dtype="float32") @@ -98,7 +98,7 @@ def main( T_add: T.Buffer((128, 128), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True, "layout_free_buffers": [1]}) # body # with T.sblock("root") T_matmul_NT_global = T.sblock_alloc_buffer([128, 128], dtype="float32") @@ -177,7 +177,7 @@ class DenseAdd_cpu_no_write_cache: @T.prim_func def main(p0: T.Buffer((128, 128), "float32"), p1: T.Buffer((128, 128), "float32"), T_add: T.Buffer((128, 128), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True, "layout_free_buffers": [1]}) # body # with T.sblock("root") T_matmul_NT = T.sblock_alloc_buffer([128, 128], dtype="float32") @@ -227,7 +227,7 @@ def main( T_add: T.Buffer((128, 128), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True, "layout_free_buffers": [1]}) # body # with T.sblock("root") T_matmul_NT_local = T.sblock_alloc_buffer([128, 128], dtype="float32", scope="local") @@ -377,7 +377,7 @@ class Conv2dInt8: @T.prim_func def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer((1, 1, 1, 256), "int64"), p5: T.Buffer((1, 1, 1, 256), "int64"), p6: T.Buffer((1, 1, 1, 256), "int64"), p7: T.Buffer((), "int32"), p8: T.Buffer(1, "int32"), compute: T.Buffer((16, 56, 56, 256), "int32")) -> None: # function attr dict - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) # body # with T.sblock("root") pad_temp = T.sblock_alloc_buffer([16, 56, 56, 64], dtype="int8") @@ -493,7 +493,7 @@ class Conv2dInt8_target: @T.prim_func def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer((1, 1, 1, 256), "int64"), p5: T.Buffer((1, 1, 1, 256), "int64"), p6: T.Buffer((1, 1, 1, 256), "int64"), p7: T.Buffer((), "int32"), p8: T.Buffer(1, "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "uint8")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") pad_temp = T.sblock_alloc_buffer([16, 56, 56, 64], dtype="int8") @@ -636,7 +636,7 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " class Conv2dInt8_tensorcore_scheduled: @T.prim_func def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer((1, 1, 1, 256), "int64"), p5: T.Buffer((1, 1, 1, 256), "int64"), p6: T.Buffer((1, 1, 1, 256), "int64"), p7: T.Buffer((), "int32"), p8: T.Buffer((1,), "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "uint8")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): conv2d_nhwc_reindex_shared = T.sblock_alloc_buffer((50176, 256), "int32", scope="shared") conv2d_nhwc_reindex_shared_wmma_accumulator = T.sblock_alloc_buffer((50176, 256), "int32", scope="wmma.accumulator") @@ -738,7 +738,7 @@ class Conv2dInt8_NCHWc: @T.prim_func def main(p0: T.Buffer((1, 32, 7, 7, 16), "uint8"), p1: T.Buffer((128, 32, 1, 1, 4, 16, 4), "int8"), p2: T.Buffer((1, 128, 1, 1, 16), "int32"), p3: T.Buffer((1, 128, 1, 1, 16), "float32"), p4: T.Buffer(1, "float32"), p5: T.Buffer((1, 128, 7, 7, 16), "int32"), compute: T.Buffer((1, 128, 7, 7, 16), "uint8")) -> None: # function attr dict - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) # body # with T.sblock("root") compile_engine_const = T.sblock_alloc_buffer([], dtype="float32") @@ -901,7 +901,7 @@ class Conv2dInt8_NCHWc_target: @T.prim_func def main(p0: T.Buffer((1, 32, 7, 7, 16), "uint8"), p1: T.Buffer((128, 32, 1, 1, 4, 16, 4), "int8"), p2: T.Buffer((1, 128, 1, 1, 16), "int32"), p3: T.Buffer((1, 128, 1, 1, 16), "float32"), p4: T.Buffer(1, "float32"), p5: T.Buffer((1, 128, 7, 7, 16), "uint8"), T_cast: T.Buffer((1, 128, 7, 7, 16), "int32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") compile_engine_const = T.sblock_alloc_buffer([], dtype="float32") @@ -1119,7 +1119,7 @@ class Conv2dInt8_NCHWc_scheduled: @T.prim_func def main(p0: T.Buffer((1, 32, 7, 7, 16), "uint8"), p1: T.Buffer((128, 32, 1, 1, 4, 16, 4), "int8"), p2: T.Buffer((1, 128, 1, 1, 16), "int32"), p3: T.Buffer((1, 128, 1, 1, 16), "float32"), p4: T.Buffer(1, "float32"), p5: T.Buffer((1, 128, 7, 7, 16), "uint8"), T_cast: T.Buffer((1, 128, 7, 7, 16), "int32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") conv2d_NCHWc_int8 = T.sblock_alloc_buffer([1, 128, 7, 7, 16], dtype="int32") @@ -1182,7 +1182,7 @@ class Conv2dWinogradAddRelu: @T.prim_func def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((6, 6, 64, 64), "float32"), p2: T.Buffer((1, 1, 1, 64), "float32"), T_relu: T.Buffer((1, 56, 56, 64), "float32")) -> None: # function attr dict - T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"layout_free_buffers": [1], "tirx.noalias": True, "global_symbol": "main"}) # body # with T.sblock("root") data_pad = T.sblock_alloc_buffer([1, 58, 58, 64], dtype="float32") @@ -1274,7 +1274,7 @@ class Conv2dWinogradAddResidualRelu: @T.prim_func def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((6, 6, 64, 64), "float32"), p2: T.Buffer((1, 1, 1, 64), "float32"), p3: T.Buffer((1, 56, 56, 64), "float32"), T_relu: T.Buffer((1, 56, 56, 64), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True, "layout_free_buffers": [1]}) # body # with T.sblock("root") data_pad = T.sblock_alloc_buffer([1, 58, 58, 64], dtype="float32") @@ -1373,7 +1373,7 @@ class Conv2dWinogradAddResidualRelu_scheduled: @T.prim_func def main(p0: T.Buffer((1, 56, 56, 64), "float32"), p1: T.Buffer((6, 6, 64, 64), "float32"), p2: T.Buffer((1, 1, 1, 64), "float32"), p3: T.Buffer((1, 56, 56, 64), "float32"), T_relu: T.Buffer((1, 56, 56, 64), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True, "layout_free_buffers": [1]}) # body # with T.sblock("root") input_tile_local = T.sblock_alloc_buffer([6, 6, 196, 64], dtype="float32", scope="local") @@ -1513,7 +1513,7 @@ class Conv2dInt8_with_predicate: @T.prim_func def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer(256, "int32"), p5: T.Buffer(256, "int32"), p6: T.Buffer(256, "int32"), p7: T.Buffer((), "int32"), p8: T.Buffer(1, "int32"), compute: T.Buffer((16, 56, 56, 256), "int32")) -> None: # function attr dict - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) # body # with T.sblock("root") pad_temp = T.sblock_alloc_buffer([16, 56, 56, 64], dtype="int8") @@ -1587,7 +1587,7 @@ class Conv2dInt8_with_predicate_target: @T.prim_func def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer(256, "int32"), p5: T.Buffer(256, "int32"), p6: T.Buffer(256, "int32"), p7: T.Buffer((), "int32"), p8: T.Buffer(1, "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "int32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") pad_temp = T.sblock_alloc_buffer([16, 56, 56, 64], dtype="int8") @@ -1681,7 +1681,7 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " class Conv2dInt8_with_predicate_scheduled: @T.prim_func def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer((256,), "int32"), p5: T.Buffer((256,), "int32"), p6: T.Buffer((256,), "int32"), p7: T.Buffer((), "int32"), p8: T.Buffer((1,), "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1853,7 +1853,7 @@ def apply_anchor_trace(sch: Schedule) -> None: sch.transform_layout( block=b58, buffer=("read", 2), - index_map=tvm.tir.IndexMap.from_func( + index_map=tvm.tirx.IndexMap.from_func( lambda i0, i1: ( floordiv(i0, 64), i1, @@ -1920,7 +1920,7 @@ def apply_trace(sch): sch.transform_layout( block=b49, buffer=("read", 2), - index_map=tvm.tir.IndexMap.from_func( + index_map=tvm.tirx.IndexMap.from_func( lambda i0, i1: ( floordiv(i1, 16), floordiv(i0, 32), diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_context.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_context.py index 2187dac2a7e9..7590bee3cee9 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_context.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_context.py @@ -24,7 +24,7 @@ import tvm import tvm.testing from tvm.s_tir.meta_schedule import TuneContext -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,missing-docstring @@ -34,7 +34,7 @@ class Matmul: @T.prim_func def main(a: T.handle, b: T.handle, c: T.handle) -> None: # pylint: disable=no-self-argument - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_tir.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_tir.py index 877655ba6652..a5bd8e26597d 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_tir.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_tir.py @@ -27,7 +27,7 @@ from tvm.s_tir.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.s_tir.meta_schedule.testing.local_rpc import LocalRPC from tvm.s_tir.schedule import SBlockRV, Schedule -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target logging.basicConfig() diff --git a/tests/python/s_tir/schedule/test_tir_schedule_analysis.py b/tests/python/s_tir/schedule/test_tir_schedule_analysis.py index 19afccb0cfdd..140bc3f2af81 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_analysis.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_analysis.py @@ -35,9 +35,9 @@ WMMA_SYNC_16x16x16_f16f16f32_INTRIN, ) from tvm.s_tir.tensor_intrin.x86 import dot_product_16x4_u8i8i32_desc -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.te import create_prim_func -from tvm.tir import ( +from tvm.tirx import ( Evaluate, For, ForKind, @@ -47,9 +47,9 @@ floordiv, floormod, ) -from tvm.tir.analysis import expr_deep_equal -from tvm.tir.function import TensorIntrin -from tvm.tir.stmt_functor import pre_order_visit +from tvm.tirx.analysis import expr_deep_equal +from tvm.tirx.function import TensorIntrin +from tvm.tirx.stmt_functor import pre_order_visit def _make_vars(*args: str) -> list[Var]: @@ -164,7 +164,7 @@ def main( placeholder_1: T.Buffer((64, 256, 16, 4), "int8"), compute: T.Buffer((1024, 1024), "int32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -188,7 +188,7 @@ def main( placeholder_1: T.Buffer((16, 4, 1, 1, 4, 16, 4), "int8"), conv2d_NCHWc_int8: T.Buffer((1, 16, 56, 56, 16), "int32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): with T.sblock("conv2d_NCHWc_int8"): ( @@ -225,7 +225,7 @@ def collect_loops(prim_func): loops = [] def callback(node): - if isinstance(node, tvm.tir.For): + if isinstance(node, tvm.tirx.For): loops.append(node) return True diff --git a/tests/python/s_tir/schedule/test_tir_schedule_annotate_buffer_access.py b/tests/python/s_tir/schedule/test_tir_schedule_annotate_buffer_access.py index 678409cf95c1..53e033a5d5ba 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_annotate_buffer_access.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_annotate_buffer_access.py @@ -18,12 +18,12 @@ # ruff: noqa: E501, F401 import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T def test_annotate_read_buffer_access(): diff --git a/tests/python/s_tir/schedule/test_tir_schedule_block_scope.py b/tests/python/s_tir/schedule/test_tir_schedule_block_scope.py index e37e13d8a091..d9c11b1d1ca6 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_block_scope.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_block_scope.py @@ -22,10 +22,10 @@ import tvm import tvm.testing -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.s_tir.schedule import DepKind -from tvm.script import tir as T -from tvm.tir.stmt_functor import post_order_visit +from tvm.script import tirx as T +from tvm.tirx.stmt_functor import post_order_visit # pylint: disable=no-member,invalid-name,unused-variable @@ -85,12 +85,12 @@ def _get_sblock(s: s_tir.ScheduleState, name_hint: str) -> s_tir.StmtSRef: def f_visit(node): nonlocal result - if isinstance(node, tvm.tir.SBlock) and node.name_hint == name_hint: + if isinstance(node, tvm.tirx.SBlock) and node.name_hint == name_hint: result = node func = s.mod["main"] post_order_visit(func.body, f_visit) - assert result is not None and isinstance(result, tvm.tir.SBlock) + assert result is not None and isinstance(result, tvm.tirx.SBlock) return s.get_sref(result) diff --git a/tests/python/s_tir/schedule/test_tir_schedule_blockize.py b/tests/python/s_tir/schedule/test_tir_schedule_blockize.py index 4d7f094fb30d..ff915d817370 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_blockize.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_blockize.py @@ -20,9 +20,9 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import verify_trace_roundtrip -from tvm.script import tir as T +from tvm.script import tirx as T # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks diff --git a/tests/python/s_tir/schedule/test_tir_schedule_cache_index.py b/tests/python/s_tir/schedule/test_tir_schedule_cache_index.py index dbee3854356b..6eee610c0fbd 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_cache_index.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_cache_index.py @@ -22,9 +22,9 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import verify_trace_roundtrip -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable diff --git a/tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py b/tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py index 60b7b0ce6f7d..5da2c56535a6 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py @@ -22,12 +22,12 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable diff --git a/tests/python/s_tir/schedule/test_tir_schedule_compute_at.py b/tests/python/s_tir/schedule/test_tir_schedule_compute_at.py index c01d6fad0a03..48182fd77bb8 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_compute_at.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_compute_at.py @@ -20,12 +20,12 @@ import tvm import tvm.testing -from tvm import te, tir +from tvm import te, tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks @@ -1003,7 +1003,7 @@ def floordiv_and_floormod_indices_after_reverse_compute_at(a: T.handle, b: T.han @T.prim_func def recursive_floordiv_floormod(A: T.Buffer((16, 64, 1, 8, 8, 32), "float32"), C: T.Buffer((3, 512, 512), "float32")) -> None: - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): B = T.sblock_alloc_buffer((1, 128, 16, 8, 2, 32, 2), "float32") for axis1, axis2, axis3, axis4, axis5, axis6, axis7 in T.grid(1, 128, 16, 8, 2, 32, 2): @@ -1022,7 +1022,7 @@ def recursive_floordiv_floormod(A: T.Buffer((16, 64, 1, 8, 8, 32), "float32"), @T.prim_func def recursive_floordiv_floormod_after_reverse_compute_at(A: T.Buffer((16, 64, 1, 8, 8, 32), "float32"), C: T.Buffer((3, 512, 512), "float32")) -> None: - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) # with T.sblock("root"): B = T.sblock_alloc_buffer((1, 128, 16, 8, 2, 32, 2)) for axis1, axis2, axis3 in T.grid(1, 128, 16): diff --git a/tests/python/s_tir/schedule/test_tir_schedule_compute_inline.py b/tests/python/s_tir/schedule/test_tir_schedule_compute_inline.py index e77549d12ab3..64975a7467a4 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_compute_inline.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_compute_inline.py @@ -21,12 +21,12 @@ import tvm import tvm.s_tir.tensor_intrin import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable @@ -749,7 +749,7 @@ class Conv2dInt8_TensorCore_with_predicate_before: @T.prim_func def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer(256, "int32"), p5: T.Buffer(256, "int32"), p6: T.Buffer(256, "int32"), p7: T.Buffer((), "int32"), p8: T.Buffer(1, "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "int32")): # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body with T.sblock("root"): T.reads() @@ -869,7 +869,7 @@ def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), " class Conv2dInt8_TensorCore_with_predicate_after: @T.prim_func def main(p0: T.Buffer((16, 56, 56, 64), "int8"), p1: T.Buffer((256, 1, 1, 64), "int8"), p2: T.Buffer((1, 1, 1, 256), "int32"), p3: T.Buffer((1, 1, 1, 256), "int32"), p4: T.Buffer((256,), "int32"), p5: T.Buffer((256,), "int32"), p6: T.Buffer((256,), "int32"), p7: T.Buffer((), "int32"), p8: T.Buffer((1,), "int32"), p9: T.Buffer((16, 56, 56, 256), "int32"), compute: T.Buffer((16, 56, 56, 256), "int32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -1311,7 +1311,7 @@ def test_compute_inline_softmax(): # fmt: off @T.prim_func def before(p_lv44: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n, m = T.int64(), T.int64() lv44 = T.match_buffer(p_lv44, (T.int64(1), T.int64(32), n, m)) var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") @@ -1357,7 +1357,7 @@ def before(p_lv44: T.handle, p_output0: T.handle): @T.prim_func def after(p_lv44: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n, m = T.int64(), T.int64() lv44 = T.match_buffer(p_lv44, (T.int64(1), T.int64(32), n, m)) var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m), "float16") @@ -1405,7 +1405,7 @@ def test_reverse_compute_inline_layer_norm(): # fmt: off @T.prim_func def before(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) n = T.int64() lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") @@ -1446,7 +1446,7 @@ def before(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias @T.prim_func def after(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: T.Buffer((T.int64(2560),), "float32"), p_output0: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) n = T.int64() lv6 = T.match_buffer(p_lv6, (T.int64(1), n, T.int64(2560))) var_compute_intermediate = T.match_buffer(p_output0, (T.int64(1), n, T.int64(2560)), "float16") diff --git a/tests/python/s_tir/schedule/test_tir_schedule_decompose_padding.py b/tests/python/s_tir/schedule/test_tir_schedule_decompose_padding.py index 6487ff0b2ed8..24ed9b9bdb17 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_decompose_padding.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_decompose_padding.py @@ -20,9 +20,9 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import assert_structural_equal_ignore_global_symbol -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg diff --git a/tests/python/s_tir/schedule/test_tir_schedule_error.py b/tests/python/s_tir/schedule/test_tir_schedule_error.py index 9df7971eb77f..3ca9ce57f300 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_error.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_error.py @@ -20,8 +20,8 @@ import tvm import tvm.testing -from tvm import s_tir, tir -from tvm.script import tir as T +from tvm import s_tir, tirx +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable @@ -43,7 +43,7 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: @T.prim_func def two_kernels(var_A: T.handle, var_B: T.handle, seq_len: T.int32): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) A = T.match_buffer(var_A, (1, seq_len * 8), "int32") B = T.match_buffer(var_B, (1, seq_len * 8), "int32", align=8) with T.sblock("exclusive_scan"): diff --git a/tests/python/s_tir/schedule/test_tir_schedule_for_kind.py b/tests/python/s_tir/schedule/test_tir_schedule_for_kind.py index 06eaa08a326b..e391041102d0 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_for_kind.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_for_kind.py @@ -22,12 +22,12 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable diff --git a/tests/python/s_tir/schedule/test_tir_schedule_fuse_reduction_epilogue.py b/tests/python/s_tir/schedule/test_tir_schedule_fuse_reduction_epilogue.py index bfb728d90144..34556b92ff8f 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_fuse_reduction_epilogue.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_fuse_reduction_epilogue.py @@ -21,12 +21,12 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable diff --git a/tests/python/s_tir/schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py b/tests/python/s_tir/schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py index c119dc9ed5a7..a7a35a892e74 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py @@ -21,12 +21,12 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable diff --git a/tests/python/s_tir/schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py b/tests/python/s_tir/schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py index f1d6d2286e21..e957edc59ae8 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_fuse_reduction_epilogue_relu.py @@ -21,12 +21,12 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable diff --git a/tests/python/s_tir/schedule/test_tir_schedule_merge.py b/tests/python/s_tir/schedule/test_tir_schedule_merge.py index a2d3de732b38..e8df6c83ab0d 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_merge.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_merge.py @@ -20,12 +20,12 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable diff --git a/tests/python/s_tir/schedule/test_tir_schedule_pad_einsum.py b/tests/python/s_tir/schedule/test_tir_schedule_pad_einsum.py index e8fb12e2c649..6d130d808abc 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_pad_einsum.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_pad_einsum.py @@ -20,12 +20,12 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -167,7 +167,7 @@ def before( m: T.handle, d: T.handle, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int32() A = T.match_buffer(a, (1, n, 4096)) B = T.match_buffer(b, (11008, 4096)) @@ -189,7 +189,7 @@ def before( @T.prim_func def after(a: T.handle, b: T.handle, m: T.handle, d: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int32() A = T.match_buffer(a, (1, n, 4096)) B = T.match_buffer(b, (11008, 4096)) @@ -236,7 +236,7 @@ def before( w: T.handle, r: T.handle, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int32() A = T.match_buffer(a, (1, n, 4096)) W = T.match_buffer(w, (4096,), "float32") @@ -260,7 +260,7 @@ def before( @T.prim_func def after(a: T.handle, w: T.handle, r: T.handle): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) n = T.int32() A = T.match_buffer(a, (1, n, 4096)) W = T.match_buffer(w, (4096,), "float32") diff --git a/tests/python/s_tir/schedule/test_tir_schedule_partition.py b/tests/python/s_tir/schedule/test_tir_schedule_partition.py index dd0aa29c4997..33a94fd4692e 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_partition.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_partition.py @@ -20,13 +20,13 @@ import tvm import tvm.testing -from tvm import te, tir +from tvm import te, tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T -from tvm.tir.expr import IntImm +from tvm.script import tirx as T +from tvm.tirx.expr import IntImm # pylint: disable=no-member,invalid-name,unused-variable diff --git a/tests/python/s_tir/schedule/test_tir_schedule_read_write_at.py b/tests/python/s_tir/schedule/test_tir_schedule_read_write_at.py index 6ea8464f6c54..9c489611c1fc 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_read_write_at.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_read_write_at.py @@ -37,12 +37,12 @@ import pytest import tvm -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable diff --git a/tests/python/s_tir/schedule/test_tir_schedule_reduction.py b/tests/python/s_tir/schedule/test_tir_schedule_reduction.py index 009e79b2fa76..b290572349a2 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_reduction.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_reduction.py @@ -22,13 +22,13 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg diff --git a/tests/python/s_tir/schedule/test_tir_schedule_reindex.py b/tests/python/s_tir/schedule/test_tir_schedule_reindex.py index 6f6afb8b4015..387d075ec99f 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_reindex.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_reindex.py @@ -20,13 +20,13 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.schedule import ScheduleError from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func diff --git a/tests/python/s_tir/schedule/test_tir_schedule_reorder.py b/tests/python/s_tir/schedule/test_tir_schedule_reorder.py index 9924722de8dc..0ec7ef6c968b 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_reorder.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_reorder.py @@ -22,12 +22,12 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable diff --git a/tests/python/s_tir/schedule/test_tir_schedule_reorder_block_iter_var.py b/tests/python/s_tir/schedule/test_tir_schedule_reorder_block_iter_var.py index 082c8736cf60..4d44133e201b 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_reorder_block_iter_var.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_reorder_block_iter_var.py @@ -20,9 +20,9 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import verify_trace_roundtrip -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func diff --git a/tests/python/s_tir/schedule/test_tir_schedule_rfactor.py b/tests/python/s_tir/schedule/test_tir_schedule_rfactor.py index 195d361a9a5e..94d234f53621 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_rfactor.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_rfactor.py @@ -20,12 +20,12 @@ import tvm import tvm.testing -from tvm import te, tir, topi +from tvm import te, tirx, topi from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -1141,7 +1141,7 @@ def argmin_split_rfactor( def argmax_topi_rfactor( placeholder: T.Buffer((1, 32), "int32"), placeholder_red: T.Buffer(1, "int32") ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) placeholder_red_temp_v0 = T.sblock_alloc_buffer([1], dtype="int32") placeholder_red_temp_v1 = T.sblock_alloc_buffer([1], dtype="int32") placeholder_red_temp_v0_rf = T.sblock_alloc_buffer([1, 8], dtype="int32") @@ -1206,7 +1206,7 @@ def argmax_topi_rfactor( def argmin_topi_rfactor( placeholder: T.Buffer((1, 32), "int32"), placeholder_red: T.Buffer(1, "int32") ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) placeholder_red_temp_v0 = T.sblock_alloc_buffer([1], dtype="int32") placeholder_red_temp_v1 = T.sblock_alloc_buffer([1], dtype="int32") placeholder_red_temp_v0_rf = T.sblock_alloc_buffer([1, 8], dtype="int32") diff --git a/tests/python/s_tir/schedule/test_tir_schedule_rolling_buffer.py b/tests/python/s_tir/schedule/test_tir_schedule_rolling_buffer.py index d4ac8406e273..b52c33c58d25 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_rolling_buffer.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_rolling_buffer.py @@ -20,16 +20,16 @@ import tvm import tvm.testing -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T def check_rolling_buffer( - sch: s_tir.Schedule, origin: tir.PrimFunc, expected: tir.PrimFunc, check_run=False + sch: s_tir.Schedule, origin: tirx.PrimFunc, expected: tirx.PrimFunc, check_run=False ): scheduled = sch.mod["main"] assert_structural_equal_ignore_global_symbol(scheduled, expected) diff --git a/tests/python/s_tir/schedule/test_tir_schedule_sampling.py b/tests/python/s_tir/schedule/test_tir_schedule_sampling.py index b0cf34d3a97b..573f1b2cf269 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_sampling.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_sampling.py @@ -22,9 +22,9 @@ import pytest import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import verify_trace_roundtrip -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable diff --git a/tests/python/s_tir/schedule/test_tir_schedule_set_axis_separator.py b/tests/python/s_tir/schedule/test_tir_schedule_set_axis_separator.py index 8630197910ce..498462bf3fa7 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_set_axis_separator.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_set_axis_separator.py @@ -20,14 +20,14 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) from tvm.script import ir as I -from tvm.script import tir as T -from tvm.tir import IndexMap +from tvm.script import tirx as T +from tvm.tirx import IndexMap # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg diff --git a/tests/python/s_tir/schedule/test_tir_schedule_set_dtype.py b/tests/python/s_tir/schedule/test_tir_schedule_set_dtype.py index ee8b23b55103..5f76f2daed7d 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_set_dtype.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_set_dtype.py @@ -21,12 +21,12 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg diff --git a/tests/python/s_tir/schedule/test_tir_schedule_set_scope.py b/tests/python/s_tir/schedule/test_tir_schedule_set_scope.py index b37f2a7cab5f..414641b14572 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_set_scope.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_set_scope.py @@ -20,12 +20,12 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg diff --git a/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py b/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py index b83b6beeeda9..3fe374be3db9 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py @@ -20,13 +20,13 @@ import tvm import tvm.testing -from tvm import te, tir +from tvm import te, tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T -from tvm.tir.expr import IntImm +from tvm.script import tirx as T +from tvm.tirx.expr import IntImm # pylint: disable=no-member,invalid-name,unused-variable @@ -705,7 +705,7 @@ def test_sve_scalable_split_predicated(num_elements): @T.prim_func def before(a: T.handle): A = T.match_buffer(a, (num_elements,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + T.func_attr({"global_symbol": "my_module", "tirx.noalias": True}) for i in T.serial(num_elements): with T.sblock("A"): v_i = T.axis.remap("S", [i]) @@ -714,7 +714,7 @@ def before(a: T.handle): @T.prim_func def after(a: T.handle): A = T.match_buffer(a, (num_elements,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + T.func_attr({"global_symbol": "my_module", "tirx.noalias": True}) for i_0, i_1 in T.grid(outer_extent, T.vscale() * 4): with T.sblock("A"): v_i = T.axis.spatial(num_elements, i_0 * (T.vscale() * 4) + i_1) @@ -741,7 +741,7 @@ def test_sve_scalable_split_assume_exact_multiple(): @T.prim_func def before(a: T.handle): A = T.match_buffer(a, (128,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + T.func_attr({"global_symbol": "my_module", "tirx.noalias": True}) for i in T.serial(128): with T.sblock("A"): v_i = T.axis.remap("S", [i]) @@ -750,7 +750,7 @@ def before(a: T.handle): @T.prim_func def after(a: T.handle): A = T.match_buffer(a, (128,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + T.func_attr({"global_symbol": "my_module", "tirx.noalias": True}) for i_0, i_1 in T.grid(outer_extent, T.vscale() * 4): with T.sblock("A"): v_i = T.axis.spatial(128, i_0 * (T.vscale() * 4) + i_1) @@ -771,7 +771,7 @@ def test_sve_split_over_scalable_loop(): @T.prim_func def before(a: T.handle): A = T.match_buffer(a, (128,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + T.func_attr({"global_symbol": "my_module", "tirx.noalias": True}) for i in T.serial(4 * T.vscale()): with T.sblock("A"): v_i = T.axis.remap("S", [i]) @@ -780,7 +780,7 @@ def before(a: T.handle): @T.prim_func def after(a: T.handle): A = T.match_buffer(a, (128,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + T.func_attr({"global_symbol": "my_module", "tirx.noalias": True}) for i_0, i_1 in T.grid(T.vscale() * 2, T.vscale() * 2): with T.sblock("A"): v_i = T.axis.spatial(T.vscale() * 4, i_0 * (T.vscale() * 2) + i_1) @@ -802,7 +802,7 @@ def test_unsupported_target_scalable_split(capfd): @T.prim_func def before(a: T.handle): A = T.match_buffer(a, (128,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + T.func_attr({"global_symbol": "my_module", "tirx.noalias": True}) for i in T.serial(128): with T.sblock("A"): v_i = T.axis.remap("S", [i]) @@ -811,7 +811,7 @@ def before(a: T.handle): sch = tvm.s_tir.Schedule(before) (a,) = sch.get_loops("A") - err_msg = "The product of factors is not larger than or equal to the extent of loop tir.For#0" + err_msg = "The product of factors is not larger than or equal to the extent of loop tirx.For#0" with pytest.raises(tvm.s_tir.schedule.ScheduleError, match=err_msg): sch.split(a, factors=[T.ceildiv(128, 4 * T.vscale()), 4 * T.vscale()]) diff --git a/tests/python/s_tir/schedule/test_tir_schedule_state.py b/tests/python/s_tir/schedule/test_tir_schedule_state.py index f90d7332fd08..d28f3e85c6e7 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_state.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_state.py @@ -23,9 +23,9 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.ir import IRModule -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable @@ -94,7 +94,7 @@ def block_in_opaque_block(a: T.handle, b: T.handle) -> None: def replace_ir_builder(deep_copy=False, realize=False): new_func = tvm.script.from_source(elementwise.script()) s = tvm.s_tir.ScheduleState(new_func, debug_mask="all") - target = tvm.tir.SBlock( + target = tvm.tirx.SBlock( iter_vars=[], reads=[], writes=[], @@ -106,7 +106,7 @@ def replace_ir_builder(deep_copy=False, realize=False): annotations=None, ) if realize: - target = tvm.tir.SBlockRealize( + target = tvm.tirx.SBlockRealize( iter_values=[], predicate=True, block=target, @@ -122,7 +122,7 @@ def replace_ir_builder_module(deep_copy=False, realize=False): other_func = tvm.script.from_source(elementwise.script()) mod = IRModule(functions={"main": new_func, "other": other_func}) s = tvm.s_tir.ScheduleState(mod, debug_mask="all") - target = tvm.tir.SBlock( + target = tvm.tirx.SBlock( iter_vars=[], reads=[], writes=[], @@ -134,7 +134,7 @@ def replace_ir_builder_module(deep_copy=False, realize=False): annotations=None, ) if realize: - target = tvm.tir.SBlockRealize( + target = tvm.tirx.SBlockRealize( iter_values=[], predicate=True, block=target, @@ -323,12 +323,12 @@ def test_replace_block_in_opaque_block(): root_hash = s.mod["main"].__hash__() for_loop = s.mod["main"].body.block.body.body.block.body[1].then_case.block.body sref = s.get_sref(for_loop) - new_for_loop = tir.For( + new_for_loop = tirx.For( loop_var=for_loop.loop_var, min=0, extent=128, - kind=tir.ForKind.SERIAL, - body=tir.Evaluate(0), + kind=tirx.ForKind.SERIAL, + body=tirx.Evaluate(0), thread_binding=None, annotations=None, ) diff --git a/tests/python/s_tir/schedule/test_tir_schedule_state_cached_flags.py b/tests/python/s_tir/schedule/test_tir_schedule_state_cached_flags.py index eda7c087e965..f97a880ed472 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_state_cached_flags.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_state_cached_flags.py @@ -22,10 +22,10 @@ import tvm import tvm.testing -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.s_tir.schedule.state import CachedFlags -from tvm.script import tir as T -from tvm.tir.stmt_functor import post_order_visit +from tvm.script import tirx as T +from tvm.tirx.stmt_functor import post_order_visit # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg # fmt: off @@ -371,7 +371,7 @@ def uncovered_producer_region(A: T.Buffer((128,), "float32"), B: T.Buffer((128,) @T.prim_func def matmul_relu_padding(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 127), "float16"), compute: T.Buffer((127, 127), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") C = T.sblock_alloc_buffer([127, 127], dtype="float32") @@ -468,12 +468,12 @@ def _get_sblock(s: s_tir.ScheduleState, name_hint: str) -> s_tir.StmtSRef: def f_visit(node): nonlocal result - if isinstance(node, tvm.tir.SBlock) and node.name_hint == name_hint: + if isinstance(node, tvm.tirx.SBlock) and node.name_hint == name_hint: result = node func = s.mod["main"] post_order_visit(func.body, f_visit) - assert result is not None and isinstance(result, tvm.tir.SBlock) + assert result is not None and isinstance(result, tvm.tirx.SBlock) return s.get_sref(result) diff --git a/tests/python/s_tir/schedule/test_tir_schedule_storage_align.py b/tests/python/s_tir/schedule/test_tir_schedule_storage_align.py index 7878b52eab7c..344641dfef32 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_storage_align.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_storage_align.py @@ -19,12 +19,12 @@ import pytest import tvm -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name diff --git a/tests/python/s_tir/schedule/test_tir_schedule_tensorize.py b/tests/python/s_tir/schedule/test_tir_schedule_tensorize.py index 73c85e1cc43c..de8f8d0ad94f 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_tensorize.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_tensorize.py @@ -21,7 +21,7 @@ import tvm import tvm.testing -from tvm import te, tir +from tvm import te, tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, @@ -37,7 +37,7 @@ from tvm.s_tir.tensor_intrin.hexagon import VDMPY_i16i16i32_INTRIN, VRMPY_u8u8i32_INTRIN from tvm.s_tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN from tvm.s_tir.tensor_intrin.x86 import AVX512_DOT_16x4_INTRIN, VNNI_DOT_16x4_INTRIN -from tvm.script import tir as T +from tvm.script import tirx as T # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks @@ -496,11 +496,11 @@ def annotated_tensorized_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks -tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin) -tir.TensorIntrin.register("test_annotated_mma_intrin", annotated_mma_desc, mma_intrin) -tir.TensorIntrin.register("test_dot_product_intrin", dot_product_desc, dot_product_intrin) -tir.TensorIntrin.register("test_outer_product_intrin", outer_product_desc, outer_product_intrin) -tir.TensorIntrin.register("test_dot_product_intrin_annotated", dot_product_desc, dot_product_intrin_annotated) +tirx.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin) +tirx.TensorIntrin.register("test_annotated_mma_intrin", annotated_mma_desc, mma_intrin) +tirx.TensorIntrin.register("test_dot_product_intrin", dot_product_desc, dot_product_intrin) +tirx.TensorIntrin.register("test_outer_product_intrin", outer_product_desc, outer_product_intrin) +tirx.TensorIntrin.register("test_dot_product_intrin_annotated", dot_product_desc, dot_product_intrin_annotated) def test_tensorize_matmul(): @@ -749,9 +749,9 @@ def fetch_to_shared(block, idx): def test_tensor_intrin_look_up(): intrin_name = 'non_existent_intrin' - assert tir.TensorIntrin.get(intrin_name, allow_missing=True) is None + assert tirx.TensorIntrin.get(intrin_name, allow_missing=True) is None with pytest.raises(ValueError): - tir.TensorIntrin.get(intrin_name) + tirx.TensorIntrin.get(intrin_name) def test_tensorize_matmul_mixed_dtype(): @@ -841,11 +841,11 @@ def tensorized_matmul_int64_shape( def _tir_packed_int_to_int_to_float(storage_nbit: int): storage_dtype = "int" + str(storage_nbit) - def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + def f_convert(nbit: int, val: tirx.PrimExpr, pos: tirx.PrimExpr, dtype: str): assert val.dtype == storage_dtype - mask = tir.const((1 << nbit) - 1, "int32") - unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask - return tir.Cast(dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) + mask = tirx.const((1 << nbit) - 1, "int32") + unextended = (val >> (pos.astype("int32") * tirx.const(nbit, "int32"))) & mask + return tirx.Cast(dtype, (unextended << tirx.const(32 - nbit, "int32")) >> tirx.const(32 - nbit, "int32")) return f_convert @@ -911,7 +911,7 @@ def decode_i4s_to_f16_impl(compressed: T.handle, decompressed: T.handle) -> None 8, ) -tir.TensorIntrin.register("test_decode_i4s_to_f16_intrin", decode_i4s_to_f16_desc, decode_i4s_to_f16_impl) +tirx.TensorIntrin.register("test_decode_i4s_to_f16_intrin", decode_i4s_to_f16_desc, decode_i4s_to_f16_impl) def test_tensorize_arith_simplification(): # fmt: off diff --git a/tests/python/s_tir/schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py b/tests/python/s_tir/schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py index 99a15d3c7e63..2a081a2ab41d 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py @@ -68,7 +68,7 @@ def matmul(m, n, k, in_dtype, out_dtype, b_transposed): def maybe_cast(v): if in_dtype != out_dtype: - return tvm.tir.Cast(out_dtype, v) + return tvm.tirx.Cast(out_dtype, v) return v def maybe_swap(i, j): diff --git a/tests/python/s_tir/schedule/test_tir_schedule_tensorize_mfma_numeric.py b/tests/python/s_tir/schedule/test_tir_schedule_tensorize_mfma_numeric.py index c290a51eb0e1..ec330adab2e9 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_tensorize_mfma_numeric.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_tensorize_mfma_numeric.py @@ -58,7 +58,7 @@ def matmul(m, n, k, in_dtype, out_dtype, b_transposed): def maybe_cast(v): if in_dtype != out_dtype: - return tvm.tir.Cast(out_dtype, v) + return tvm.tirx.Cast(out_dtype, v) return v def maybe_swap(i, j): diff --git a/tests/python/s_tir/schedule/test_tir_schedule_trace.py b/tests/python/s_tir/schedule/test_tir_schedule_trace.py index f9ac5cf2ee91..a0c703599984 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_trace.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_trace.py @@ -23,10 +23,10 @@ import tvm import tvm.testing -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.s_tir.schedule import Instruction, InstructionKind, LoopRV, SBlockRV, Trace from tvm.s_tir.schedule.testing import assert_structural_equal_ignore_global_symbol -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable @@ -322,13 +322,13 @@ def test_apply_json_to_schedule_1(): def test_apply_json_to_schedule_sample_categorical(): - var = tir.Var("v", "int32") + var = tirx.Var("v", "int32") trace1 = Trace( insts=[ Instruction( kind=InstructionKind.get("SampleCategorical"), inputs=[], - attrs=[[tvm.tir.IntImm("int32", 3)], [tvm.tir.FloatImm("float32", 1.0)]], + attrs=[[tvm.tirx.IntImm("int32", 3)], [tvm.tirx.FloatImm("float32", 1.0)]], outputs=[var], ) ], diff --git a/tests/python/s_tir/schedule/test_tir_schedule_transform.py b/tests/python/s_tir/schedule/test_tir_schedule_transform.py index a466b55473af..d5d32cb1f114 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_transform.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_transform.py @@ -18,7 +18,7 @@ from tvm.s_tir import Schedule from tvm.s_tir.schedule.transform import tile_with_tensor_intrin from tvm.s_tir.tensor_intrin.x86 import AVX512_DOT_16x4_INTRIN, VNNI_DOT_16x4_INTRIN -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.script.ir_module @@ -29,7 +29,7 @@ def main( placeholder_1: T.Buffer((64, 256, 16, 4), "int8"), compute: T.Buffer((1024, 1024), "int32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -54,7 +54,7 @@ def main( compute: T.Buffer((1024, 1024), "int32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") for i0, i1_0, i2_0, i1_1, i2_1 in T.grid(1024, 64, 256, 16, 4): @@ -79,7 +79,7 @@ def main( placeholder_1: T.Buffer((16, 4, 1, 1, 4, 16, 4), "int8"), conv2d_NCHWc_int8: T.Buffer((1, 16, 56, 56, 16), "int32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): with T.sblock("conv2d_NCHWc_int8"): ( @@ -121,7 +121,7 @@ def main( conv2d_NCHWc_int8: T.Buffer((1, 16, 56, 56, 16), "int32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") for i0, i1, i2, i3, i4_0, i5, i6, i7, i8, i9_0, i4_1, i9_1 in T.grid( diff --git a/tests/python/s_tir/schedule/test_tir_schedule_transform_layout.py b/tests/python/s_tir/schedule/test_tir_schedule_transform_layout.py index 914e3b3c16c7..fb307cadf7f0 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_transform_layout.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_transform_layout.py @@ -22,13 +22,13 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks @@ -190,7 +190,7 @@ def main( p1: T.Buffer((T.int64(33), T.int64(128)), "float32"), T_add: T.Buffer((T.int64(33), T.int64(128)), "float32"), ): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): for ax0, ax1 in T.grid(T.int64(33), T.int64(128)): with T.sblock("T_add"): @@ -203,7 +203,7 @@ def main( class Expected: @T.prim_func def main(p0: T.Buffer((T.int64(33), T.int64(128)), "float32"), p1: T.Buffer((T.int64(33), T.int64(128)), "float32"), T_add: T.Buffer((T.int64(33), T.int64(128)), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): T_add_global = T.sblock_alloc_buffer((T.int64(2), T.int64(128), T.int64(32)), axis_separators=[2]) for axis0, axis1, axis2 in T.grid(T.int64(2), T.int64(128), T.int64(32)): @@ -683,7 +683,7 @@ def main(): "block", "A", lambda i: [i // 4, i % 4], - pad_value=tir.IntImm("int8", 0), + pad_value=tirx.IntImm("int8", 0), ) @@ -730,7 +730,7 @@ def before_func(A: T.Buffer(14, dtype)): vi = T.axis.remap("S", [i]) B[vi] = A[vi] - pad_value_imm = tir.IntImm(dtype, 0) + pad_value_imm = tirx.IntImm(dtype, 0) @T.prim_func(private=True) def expected_func(A: T.Buffer(14, dtype)): @@ -1071,7 +1071,7 @@ def main(A: T.Buffer(14, "int32")): sch = tvm.s_tir.Schedule(Before) A = sch.get(sch.get_sblock("block")).reads[0].buffer - other = tir.decl_buffer(1, A.dtype, name="other") + other = tirx.decl_buffer(1, A.dtype, name="other") with pytest.raises(tvm.s_tir.schedule.schedule.ScheduleError): sch.transform_layout( "block", @@ -1154,7 +1154,7 @@ def main(a: T.handle): sch.transform_layout( "block", "A", - lambda i: [i // 4, tvm.tir.IndexMap.AXIS_SEPARATOR, i % 4], + lambda i: [i // 4, tvm.tirx.IndexMap.AXIS_SEPARATOR, i % 4], pad_value=0, ) After = sch.mod @@ -1186,7 +1186,7 @@ def main(a: T.handle): sch.transform_layout( "block", "A", - lambda i: [i // 4, tvm.tir.IndexMap.AXIS_SEPARATOR, i % 4], + lambda i: [i // 4, tvm.tirx.IndexMap.AXIS_SEPARATOR, i % 4], pad_value=0, ) After = sch.mod @@ -1232,7 +1232,7 @@ def func(A: T.Buffer(T.int64(16), "int32")): # Triggering the error requires an IndexMap that introduces padding func = lambda i: [ # And a constant to be one of the output indices. - tir.const(0, i.dtype), + tirx.const(0, i.dtype), (i + 1) // 8, (i + 1) % 8, ] @@ -1255,7 +1255,7 @@ def test_transform_layout_with_symbolic_bound(): # pylint: disable=invalid-name,line-too-long,too-many-locals @T.prim_func def before(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) n = T.int64() A = T.match_buffer(a, (T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16") B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") @@ -1271,7 +1271,7 @@ def before(a: T.handle, b: T.handle, c: T.handle): @T.prim_func def after(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) n = T.int64() A = T.match_buffer(a, (T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16") B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") @@ -1305,7 +1305,7 @@ def test_transform_block_layout_with_symbolic_bound(): # pylint: disable=invalid-name,line-too-long,too-many-locals @T.prim_func def before(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) n = T.int64() A = T.match_buffer(a, (T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16") B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") @@ -1321,7 +1321,7 @@ def before(a: T.handle, b: T.handle, c: T.handle): @T.prim_func def after(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) n = T.int64() A = T.match_buffer(a, (T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float16") B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)), "float16") diff --git a/tests/python/s_tir/schedule/test_tir_schedule_utilities.py b/tests/python/s_tir/schedule/test_tir_schedule_utilities.py index ecc4f22635ce..5b948bd67524 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_utilities.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_utilities.py @@ -22,13 +22,13 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.ir import IRModule from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable @@ -132,7 +132,7 @@ def vector_add_2( @T.prim_func def tuple_reduction(data: T.Buffer((4, 32), "float32"), T_add: T.Buffer((4,), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body with T.sblock("root"): T.reads() diff --git a/tests/python/s_tir/test_s_tir_renew_defs.py b/tests/python/s_tir/test_s_tir_renew_defs.py index 6eb327753f48..e8fd00a3d1aa 100644 --- a/tests/python/s_tir/test_s_tir_renew_defs.py +++ b/tests/python/s_tir/test_s_tir_renew_defs.py @@ -17,10 +17,10 @@ import tvm import tvm.testing -from tvm.script import tir as T -from tvm.tir.buffer import Buffer -from tvm.tir.function import PrimFunc -from tvm.tir.stmt import SBlock +from tvm.script import tirx as T +from tvm.tirx.buffer import Buffer +from tvm.tirx.function import PrimFunc +from tvm.tirx.stmt import SBlock def _check_func_signature_remap(lhs: PrimFunc, rhs: PrimFunc): diff --git a/tests/python/s_tir/transform/test_s_tir_transform_annotate_irregular_loop.py b/tests/python/s_tir/transform/test_s_tir_transform_annotate_irregular_loop.py index e30b55fd9b0c..ad14904139db 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_annotate_irregular_loop.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_annotate_irregular_loop.py @@ -20,9 +20,9 @@ import tvm import tvm.testing -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_handle_irrgular_unit_loop(): diff --git a/tests/python/s_tir/transform/test_s_tir_transform_canonicalize_loop.py b/tests/python/s_tir/transform/test_s_tir_transform_canonicalize_loop.py index b2135683f892..82d2d0d71d53 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_canonicalize_loop.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_canonicalize_loop.py @@ -18,8 +18,8 @@ import pytest import tvm -from tvm import s_tir, tir -from tvm.script import tir as T +from tvm import s_tir, tirx +from tvm.script import tirx as T def test_canonicalize_loop(): diff --git a/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py b/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py index 1481b220f121..37f56b51ba7d 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py @@ -17,8 +17,8 @@ # ruff: noqa: E501 import tvm import tvm.testing -from tvm import s_tir, tir -from tvm.script import tir as T +from tvm import s_tir, tirx +from tvm.script import tirx as T class BaseCompactTest: @@ -36,7 +36,9 @@ def test_compact(self): before = tvm.IRModule.from_expr(self.before.with_attr("global_symbol", "main")) expected = tvm.IRModule.from_expr(self.expected.with_attr("global_symbol", "main")) - simplify = tvm.transform.Sequential([tir.transform.Simplify(), tir.transform.RemoveNoOp()]) + simplify = tvm.transform.Sequential( + [tirx.transform.Simplify(), tirx.transform.RemoveNoOp()] + ) after = simplify(s_tir.transform.CompactBufferAllocation(is_strict=is_strict)(before)) expected = simplify(expected) try: diff --git a/tests/python/s_tir/transform/test_s_tir_transform_convert_blocks_to_opaque.py b/tests/python/s_tir/transform/test_s_tir_transform_convert_blocks_to_opaque.py index a690d34e3670..84668a86ac6a 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_convert_blocks_to_opaque.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_convert_blocks_to_opaque.py @@ -19,16 +19,16 @@ import tvm import tvm.testing -from tvm import s_tir, te, tir +from tvm import s_tir, te, tirx from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def _check(original, transformed): func = original mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.ConvertBlocksToOpaque()(mod) - mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tirx.transform.Simplify()(mod) tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main")) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_decorate_device_scope.py b/tests/python/s_tir/transform/test_s_tir_transform_decorate_device_scope.py index bd75cc7b467c..6c5e39415ffd 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_decorate_device_scope.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_decorate_device_scope.py @@ -18,8 +18,8 @@ def test_decorate_device(): - x = tvm.tir.Var("x", "int32") - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x))) + x = tvm.tirx.Var("x", "int32") + mod = tvm.IRModule.from_expr(tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x))) stmt = tvm.s_tir.transform.DecorateDeviceScope()(mod)["main"].body assert stmt.attr_key == "device_scope" diff --git a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py index a2587111a2b3..0212585b534d 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py @@ -19,7 +19,7 @@ import tvm import tvm.testing from tvm.s_tir.transform import DefaultGPUSchedule -from tvm.script import tir as T +from tvm.script import tirx as T def test_broadcast_to_symbolic(): @@ -32,7 +32,7 @@ def broadcast_to( rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"), var_T_broadcast_to: T.handle, ): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) x_0 = T.int64() x_1 = T.int64() T_broadcast_to = T.match_buffer(var_T_broadcast_to, (x_0, x_1)) @@ -48,7 +48,7 @@ def broadcast_to( class Expected: @T.prim_func def broadcast_to(rxplaceholder: T.Buffer((T.int64(3), T.int64(1)), "float32"), var_T_broadcast_to: T.handle): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) x_0, x_1 = T.int64(), T.int64() T_broadcast_to = T.match_buffer(var_T_broadcast_to, (x_0, x_1)) for ax0_ax1_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"): @@ -78,7 +78,7 @@ def matmul( B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16"), ): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): for i, j, k in T.grid(32, 32, 32): with T.sblock("C"): @@ -102,7 +102,7 @@ def matmul_gpu( "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), - "tir.noalias": True}) + "tirx.noalias": True}) # with T.sblock("root"): for i, j, k in T.grid(32, 32, 32): with T.sblock("C"): @@ -121,7 +121,7 @@ def matmul_cpu( ): T.func_attr({"global_symbol": "main", "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), - "tir.noalias": True}) + "tirx.noalias": True}) # with T.sblock("root"): for i, j, k in T.grid(32, 32, 32): with T.sblock("C"): @@ -140,7 +140,7 @@ def matmul( B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16"), ): - T.func_attr({"tir.is_scheduled": True, "global_symbol": "main", "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): @@ -161,7 +161,7 @@ def matmul( @T.prim_func def matmul_cpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16")): - T.func_attr({"global_symbol": "main", "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), "tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "target": T.target({"keys": ["cpu"], "kind": "llvm", "tag": ""}), "tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): for i, j, k in T.grid(32, 32, 32): with T.sblock("C"): @@ -174,7 +174,7 @@ def matmul_cpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), "float16" @T.prim_func def matmul_gpu(A: T.Buffer((32, 32), "float16"), B: T.Buffer((32, 32), "float16"), C: T.Buffer((32, 32), "float16")): - T.func_attr({"global_symbol": "main", "target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "target": T.target({"arch": "sm_86", "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32}), "tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): for i_j_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for i_j_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): @@ -203,7 +203,7 @@ def test_add(): class Before: @T.prim_func def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -221,7 +221,7 @@ def add( ), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32"), ): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_i2_i3_fused_1 in T.thread_binding( @@ -275,7 +275,7 @@ def test_full(): class Before: @T.prim_func def full(rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -290,7 +290,7 @@ def full( rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"), ): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(6), thread="threadIdx.x"): @@ -326,7 +326,7 @@ def full( rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"), ): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(6), thread="threadIdx.x"): @@ -359,7 +359,7 @@ def test_multiple(): class Before: @T.prim_func def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -369,7 +369,7 @@ def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32") @T.prim_func def full(rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.sblock("T_full"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -387,7 +387,7 @@ def add( ), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32"), ): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_i2_i3_fused_1 in T.thread_binding( @@ -431,7 +431,7 @@ def full( rxplaceholder: T.Buffer((), "int32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32"), ): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(6), thread="threadIdx.x"): @@ -462,7 +462,7 @@ def test_add_on_metal(): class Before: @T.prim_func def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for i0, i1, i2, i3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): with T.sblock("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -474,7 +474,7 @@ def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32") class Expected: @T.prim_func def add(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_add: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) for i0_i1_i2_i3_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(72), thread="threadIdx.x"): with T.sblock("T_add"): @@ -500,7 +500,7 @@ def test_scalar_add(): class Before: @T.prim_func def add(rxplaceholder: T.Buffer((), "int64"), T_add: T.Buffer((), "int64")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) with T.sblock("T_add"): vi = T.axis.spatial(1, T.int64(0)) T.reads(rxplaceholder[()]) @@ -511,7 +511,7 @@ def add(rxplaceholder: T.Buffer((), "int64"), T_add: T.Buffer((), "int64")): class Expected: @T.prim_func def add(rxplaceholder: T.Buffer((), "int64"), T_add: T.Buffer((), "int64")): - T.func_attr({"tir.is_scheduled": True, "tir.noalias": True}) + T.func_attr({"tirx.is_scheduled": True, "tirx.noalias": True}) # with T.sblock("root"): for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"): @@ -547,7 +547,7 @@ def sum(A: T.Buffer((T.int64(2), T.int64(2)), "float64"), A_red: T.Buffer((), "f class Expected: @T.prim_func def sum(A: T.Buffer((T.int64(2), T.int64(2)), "float64"), A_red: T.Buffer((), "float64")): - T.func_attr({"tir.is_scheduled": True}) + T.func_attr({"tirx.is_scheduled": True}) # with T.sblock("root"): for u_fused_0 in T.thread_binding(1, thread="blockIdx.x"): for u_fused_1 in T.thread_binding(1, thread="threadIdx.x"): diff --git a/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py index 8bca1da57793..8c8dede155fb 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py @@ -19,7 +19,7 @@ import tvm.testing from tvm import s_tir from tvm.s_tir.transform import HoistedConditionals, HoistedLetBindings -from tvm.script import tir as T +from tvm.script import tirx as T def _run_transform(before, hoisted_conditionals, hoisted_let_bindings): diff --git a/tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py b/tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py index be404519465d..2c0ce74108b2 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_hoist_if.py @@ -21,7 +21,7 @@ import tvm from tvm import s_tir from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.testing import enabled_targets var_list = [] @@ -33,25 +33,25 @@ def verify_structure(stmt, expected_struct): def _extract_vars(op): global var_list - if isinstance(op, tvm.tir.Var): + if isinstance(op, tvm.tirx.Var): var_list.append(op.name) def _visit(op): key = op - if isinstance(op, tvm.tir.IfThenElse): + if isinstance(op, tvm.tirx.IfThenElse): global var_list - tvm.tir.stmt_functor.post_order_visit(op.condition, _extract_vars) - val = [(op.then_case, op.else_case), ("tir.IfThenElse", tuple(var_list))] + tvm.tirx.stmt_functor.post_order_visit(op.condition, _extract_vars) + val = [(op.then_case, op.else_case), ("tirx.IfThenElse", tuple(var_list))] var_list.clear() - elif isinstance(op, tvm.tir.For): - val = [(op.body,), ("tir.For", op.loop_var.name)] - elif isinstance(op, tvm.tir.AttrStmt): - val = [(op.body,), ("tir.AttrStmt", op.attr_key, int(op.value))] + elif isinstance(op, tvm.tirx.For): + val = [(op.body,), ("tirx.For", op.loop_var.name)] + elif isinstance(op, tvm.tirx.AttrStmt): + val = [(op.body,), ("tirx.AttrStmt", op.attr_key, int(op.value))] else: return node_dict[key] = val - tvm.tir.stmt_functor.post_order_visit(stmt, _visit) + tvm.tirx.stmt_functor.post_order_visit(stmt, _visit) for key, val in node_dict.items(): struct[val[1]] = tuple( node_dict[child][1] if child in node_dict else None for child in val[0] @@ -64,7 +64,7 @@ def _visit(op): def _opaque_eval(var): - return tvm.tir.Evaluate(tvm.tir.call_extern("int32", "dummy", var)) + return tvm.tirx.Evaluate(tvm.tirx.call_extern("int32", "dummy", var)) def test_hoist_top_for(): @@ -81,10 +81,10 @@ def func(l: T.int32, m: T.int32, n: T.int32): mod = tvm.IRModule.from_expr(func) new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body expected_struct = { - ("tir.For", "k"): (None,), - ("tir.For", "j"): (("tir.For", "k"),), - ("tir.IfThenElse", ("i",)): (("tir.For", "j"), ("tir.For", "j")), - ("tir.For", "i"): (("tir.IfThenElse", ("i",)),), + ("tirx.For", "k"): (None,), + ("tirx.For", "j"): (("tirx.For", "k"),), + ("tirx.IfThenElse", ("i",)): (("tirx.For", "j"), ("tirx.For", "j")), + ("tirx.For", "i"): (("tirx.IfThenElse", ("i",)),), } verify_structure(new_stmt, expected_struct) @@ -104,10 +104,10 @@ def func(l: T.int32, m: T.int32, n: T.int32): new_mod = tvm.s_tir.transform.HoistIfThenElse()(mod) new_stmt = new_mod["main"].body expected_struct = { - ("tir.For", "k"): (None,), - ("tir.IfThenElse", ("i", "j")): (("tir.For", "k"), ("tir.For", "k")), - ("tir.For", "j"): (("tir.IfThenElse", ("i", "j")),), - ("tir.For", "i"): (("tir.For", "j"),), + ("tirx.For", "k"): (None,), + ("tirx.IfThenElse", ("i", "j")): (("tirx.For", "k"), ("tirx.For", "k")), + ("tirx.For", "j"): (("tirx.IfThenElse", ("i", "j")),), + ("tirx.For", "i"): (("tirx.For", "j"),), } verify_structure(new_stmt, expected_struct) @@ -128,10 +128,10 @@ def func(data: T.handle("float32"), l: T.int32, m: T.int32, n: T.int32): mod = tvm.IRModule.from_expr(func) new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body expected_struct = { - ("tir.For", "k"): (None,), - ("tir.IfThenElse", ("i",)): (("tir.For", "k"), ("tir.For", "k")), - ("tir.For", "j"): (None,), - ("tir.For", "i"): (("tir.For", "j"),), + ("tirx.For", "k"): (None,), + ("tirx.IfThenElse", ("i",)): (("tirx.For", "k"), ("tirx.For", "k")), + ("tirx.For", "j"): (None,), + ("tirx.For", "i"): (("tirx.For", "j"),), } verify_structure(new_stmt, expected_struct) @@ -148,10 +148,10 @@ def func(l: T.int32, m: T.int32, n: T.int32): mod = tvm.IRModule.from_expr(func) new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body expected_struct = { - ("tir.For", "k"): (None,), - ("tir.For", "j"): (("tir.For", "k"),), - ("tir.IfThenElse", ("i",)): (("tir.For", "j"), None), - ("tir.For", "i"): (("tir.IfThenElse", ("i",)),), + ("tirx.For", "k"): (None,), + ("tirx.For", "j"): (("tirx.For", "k"),), + ("tirx.IfThenElse", ("i",)): (("tirx.For", "j"), None), + ("tirx.For", "i"): (("tirx.IfThenElse", ("i",)),), } verify_structure(new_stmt, expected_struct) @@ -179,12 +179,12 @@ def func(data: T.handle("float32"), l: T.int32, m: T.int32, n: T.int32): mod = tvm.IRModule.from_expr(func) new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body expected_struct = { - ("tir.For", "k"): (None,), - ("tir.IfThenElse", ("i", "j")): (("tir.For", "k"), ("tir.For", "k")), - ("tir.For", "j"): (("tir.IfThenElse", ("i", "j")),), - ("tir.For", "i"): (("tir.For", "j"),), - ("tir.AttrStmt", "thread_extent", 64): (("tir.For", "i"),), - ("tir.AttrStmt", "thread_extent", 32): (("tir.AttrStmt", "thread_extent", 64),), + ("tirx.For", "k"): (None,), + ("tirx.IfThenElse", ("i", "j")): (("tirx.For", "k"), ("tirx.For", "k")), + ("tirx.For", "j"): (("tirx.IfThenElse", ("i", "j")),), + ("tirx.For", "i"): (("tirx.For", "j"),), + ("tirx.AttrStmt", "thread_extent", 64): (("tirx.For", "i"),), + ("tirx.AttrStmt", "thread_extent", 32): (("tirx.AttrStmt", "thread_extent", 64),), } verify_structure(new_stmt, expected_struct) @@ -211,12 +211,12 @@ def func(data: T.handle("float32")): mod = tvm.IRModule.from_expr(func) new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body expected_struct = { - ("tir.For", "l"): (None,), - ("tir.For", "k"): (("tir.For", "l"),), - ("tir.IfThenElse", ("i", "j")): (("tir.For", "k"), ("tir.For", "k")), - ("tir.For", "j"): (None,), - ("tir.IfThenElse", ("i",)): (("tir.For", "j"), None), - ("tir.For", "i"): (("tir.IfThenElse", ("i",)),), + ("tirx.For", "l"): (None,), + ("tirx.For", "k"): (("tirx.For", "l"),), + ("tirx.IfThenElse", ("i", "j")): (("tirx.For", "k"), ("tirx.For", "k")), + ("tirx.For", "j"): (None,), + ("tirx.IfThenElse", ("i",)): (("tirx.For", "j"), None), + ("tirx.For", "i"): (("tirx.IfThenElse", ("i",)),), } verify_structure(new_stmt, expected_struct) @@ -253,17 +253,17 @@ def main(data: T.Buffer((1,), "float32"), n: T.int32): new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body # Updated expected_struct with renamed second nest variables expected_struct = { - ("tir.IfThenElse", ("i", "j")): (None, None), - ("tir.IfThenElse", ("j",)): (None, None), - ("tir.For", "l"): (None,), - ("tir.For", "k"): (("tir.For", "l"),), - ("tir.For", "j"): (None,), - ("tir.IfThenElse", ("i",)): (("tir.For", "j"), None), - ("tir.For", "i"): (("tir.IfThenElse", ("i",)),), - ("tir.For", "k2"): (None,), - ("tir.For", "j2"): (("tir.For", "k2"),), - ("tir.For", "i2"): (("tir.For", "j2"),), - ("tir.IfThenElse", ("n",)): (("tir.For", "i2"), None), + ("tirx.IfThenElse", ("i", "j")): (None, None), + ("tirx.IfThenElse", ("j",)): (None, None), + ("tirx.For", "l"): (None,), + ("tirx.For", "k"): (("tirx.For", "l"),), + ("tirx.For", "j"): (None,), + ("tirx.IfThenElse", ("i",)): (("tirx.For", "j"), None), + ("tirx.For", "i"): (("tirx.IfThenElse", ("i",)),), + ("tirx.For", "k2"): (None,), + ("tirx.For", "j2"): (("tirx.For", "k2"),), + ("tirx.For", "i2"): (("tirx.For", "j2"),), + ("tirx.IfThenElse", ("n",)): (("tirx.For", "i2"), None), } verify_structure(new_stmt, expected_struct) @@ -285,11 +285,11 @@ def func(data: T.handle("float32")): new_mod = tvm.s_tir.transform.HoistIfThenElse()(mod) new_stmt = new_mod["main"].body expected_struct = { - ("tir.For", "k"): (None,), - ("tir.IfThenElse", ("j",)): (("tir.For", "k"), None), - ("tir.For", "j"): (("tir.IfThenElse", ("j",)),), - ("tir.IfThenElse", ("i",)): (("tir.For", "j"), None), - ("tir.For", "i"): (("tir.IfThenElse", ("i",)),), + ("tirx.For", "k"): (None,), + ("tirx.IfThenElse", ("j",)): (("tirx.For", "k"), None), + ("tirx.For", "j"): (("tirx.IfThenElse", ("j",)),), + ("tirx.IfThenElse", ("i",)): (("tirx.For", "j"), None), + ("tirx.For", "i"): (("tirx.IfThenElse", ("i",)),), } verify_structure(new_stmt, expected_struct) @@ -348,9 +348,9 @@ def test_no_hoisting_4(): dshape_inner = (33, 63) # Create iter_var for tx (used inside loop with T.attr) - tx_var = tvm.tir.Var("threadIdx.x", "int32") - tx_iter = tvm.tir.IterVar( - tvm.ir.Range(0, dshape_inner[0]), tx_var, tvm.tir.IterVar.ThreadIndex, "threadIdx.x" + tx_var = tvm.tirx.Var("threadIdx.x", "int32") + tx_iter = tvm.tirx.IterVar( + tvm.ir.Range(0, dshape_inner[0]), tx_var, tvm.tirx.IterVar.ThreadIndex, "threadIdx.x" ) @I.ir_module @@ -443,9 +443,9 @@ def test_hoisting_block_scope_2(): dshape = (32, 64) # Create iter_var for bx (used inside loop with T.attr) - bx_var = tvm.tir.Var("blockIdx.x", "int32") - bx_iter = tvm.tir.IterVar( - tvm.ir.Range(0, dshape[1]), bx_var, tvm.tir.IterVar.ThreadIndex, "blockIdx.x" + bx_var = tvm.tirx.Var("blockIdx.x", "int32") + bx_iter = tvm.tirx.IterVar( + tvm.ir.Range(0, dshape[1]), bx_var, tvm.tirx.IterVar.ThreadIndex, "blockIdx.x" ) @I.ir_module @@ -467,8 +467,8 @@ def main(data: T.Buffer((1,), "float32"), l: T.int32, m: T.int32, n: T.int32): ] + T.float32(1.3) mod = Module - mod = tvm.tir.transform.Simplify()(mod) - mod = tvm.tir.transform.RemoveNoOp()(mod) + mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.RemoveNoOp()(mod) stmt = mod["main"].body new_stmt = tvm.s_tir.transform.HoistIfThenElse()(mod)["main"].body @@ -498,7 +498,7 @@ def main(data: T.Buffer((1,), "float32"), l: T.int32, m: T.int32, n: T.int32, g: new_stmt = tvm.s_tir.transform.HoistIfThenElse()(Module)["main"].body assert not tvm.ir.structural_equal(new_stmt, stmt) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], new_stmt)) + mod = tvm.IRModule.from_expr(tvm.tirx.PrimFunc([], new_stmt)) stmt = new_stmt with tvm.transform.PassContext( diff --git a/tests/python/s_tir/transform/test_s_tir_transform_inject_double_buffer.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_double_buffer.py index 97266310843d..62357a537c9a 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_inject_double_buffer.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_double_buffer.py @@ -19,7 +19,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_double_buffer(): @@ -44,7 +44,7 @@ def db(A: T.handle("float32"), C: T.handle("float32")): mod = Module opt = tvm.transform.Sequential( - [tvm.s_tir.transform.InjectDoubleBuffer(), tvm.tir.transform.Simplify()] + [tvm.s_tir.transform.InjectDoubleBuffer(), tvm.tirx.transform.Simplify()] ) with tvm.transform.PassContext(config={"s_tir.InjectDoubleBuffer": {"split_loop": 2}}): @@ -56,10 +56,10 @@ def db(A: T.handle("float32"), C: T.handle("float32")): def visitor(op): nonlocal allocate_node - if isinstance(op, tvm.tir.AllocBuffer) and "B" in str(op.buffer.data): + if isinstance(op, tvm.tirx.AllocBuffer) and "B" in str(op.buffer.data): allocate_node = op - tvm.tir.stmt_functor.post_order_visit(stmt, visitor) + tvm.tirx.stmt_functor.post_order_visit(stmt, visitor) assert allocate_node is not None assert list(allocate_node.buffer.shape) == [m * 2] @@ -67,10 +67,10 @@ def visitor(op): count = [0] def count_sync(op): - if isinstance(op, tvm.tir.Call) and op.op.same_as(tvm.ir.Op.get("tir.tvm_storage_sync")): + if isinstance(op, tvm.tirx.Call) and op.op.same_as(tvm.ir.Op.get("tirx.tvm_storage_sync")): count[0] += 1 - tvm.tir.stmt_functor.post_order_visit(f.body, count_sync) + tvm.tirx.stmt_functor.post_order_visit(f.body, count_sync) assert count[0] == 4 @@ -78,7 +78,7 @@ def test_double_buffer_transform(): transform = tvm.ir.transform.Sequential( [ tvm.s_tir.transform.InjectDoubleBuffer(), - tvm.tir.transform.Simplify(), + tvm.tirx.transform.Simplify(), ] ) @@ -104,10 +104,10 @@ def main(A: T.Buffer([16, 32], "float32"), B: T.Buffer(16, "float32")): def visitor(op): nonlocal allocate_node - if isinstance(op, tvm.tir.AllocBuffer): + if isinstance(op, tvm.tirx.AllocBuffer): allocate_node = op - tvm.tir.stmt_functor.post_order_visit(After["main"].body, visitor) + tvm.tirx.stmt_functor.post_order_visit(After["main"].body, visitor) assert allocate_node is not None assert list(allocate_node.buffer.shape) == [64] @@ -118,7 +118,7 @@ def test_double_buffer_with_decl_buffer(): transform = tvm.ir.transform.Sequential( [ tvm.s_tir.transform.InjectDoubleBuffer(), - tvm.tir.transform.Simplify(), + tvm.tirx.transform.Simplify(), ] ) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_inject_permuted_layout.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_permuted_layout.py index 49b0d5ca46b8..6a7cd9bb2c36 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_inject_permuted_layout.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_permuted_layout.py @@ -19,8 +19,8 @@ import tvm.s_tir import tvm.testing from tvm import IRModule -from tvm.script import tir as T -from tvm.tir import PrimFunc +from tvm.script import tirx as T +from tvm.tirx import PrimFunc def _check_primfunc_transform(before: PrimFunc, expected: PrimFunc): @@ -79,7 +79,7 @@ def expected(X: T.Buffer((4096, 4096), "float16")): with T.sblock(""): X_reindex_shared_dyn = T.sblock_alloc_buffer((128, 32), "float16", strides=(32, 1), scope="shared.dyn") with T.sblock("X_reindex_shared.dyn"): - # annotate the reads and writes because they cannot be inferred from tir.bitwise_xor + # annotate the reads and writes because they cannot be inferred from tirx.bitwise_xor T.reads(X[blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4:blockIdx_y // 8 * 128 + threadIdx_y * 8 + threadIdx_x // 4 + 97, ax2_0_0 * 32 + threadIdx_x % 4 * 8:ax2_0_0 * 32 + threadIdx_x % 4 * 8 + 8]) T.writes(X_reindex_shared_dyn[threadIdx_y * 8 + threadIdx_x // 4:threadIdx_y * 8 + threadIdx_x // 4 + 97, threadIdx_x % 4 * 8:threadIdx_x % 4 * 8 + 8]) for ax0_ax1_fused_0 in range(4): diff --git a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py index f266ee539aae..8b93c128b154 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_async_copy.py @@ -24,17 +24,17 @@ import tvm.testing from tvm import s_tir from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def count_cp_async(stmt): num_alloc = [0] def verify(n): - if isinstance(n, tvm.tir.Call) and n.op.name == "tir.ptx_cp_async": + if isinstance(n, tvm.tirx.Call) and n.op.name == "tirx.ptx_cp_async": num_alloc[0] += 1 - tvm.tir.stmt_functor.post_order_visit(stmt, verify) + tvm.tirx.stmt_functor.post_order_visit(stmt, verify) return num_alloc[0] @@ -46,7 +46,7 @@ def generate_global_to_shared_vectorized_copy(dtype, vector_size): def ptx_global_to_shared_copy( A: T.Buffer((32, 128), dtype), B: T.Buffer((32, 128), dtype) ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) @@ -74,7 +74,7 @@ def ptx_global_to_shared_copy( def ptx_global_to_shared_copy_fp32x1( A: T.Buffer((32, 128), "float32"), B: T.Buffer((32, 128), "float32") ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) @@ -101,7 +101,7 @@ def ptx_global_to_shared_dyn_copy_fp16x8( B: T.Buffer((32, 128), "float16"), C: T.Buffer((32, 128), "float16"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) @@ -135,9 +135,9 @@ def test_inject_async_copy(): mod = tvm.IRModule.from_expr(f) mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) - mod = tvm.tir.transform.FlattenBuffer()(mod) + mod = tvm.tirx.transform.FlattenBuffer()(mod) if vec_size > 1: - mod = tvm.tir.transform.VectorizeLoop()(mod) + mod = tvm.tirx.transform.VectorizeLoop()(mod) mod = tvm.s_tir.transform.InjectPTXAsyncCopy()(mod) assert count_cp_async(mod["main"].body) == 1 @@ -145,7 +145,7 @@ def test_inject_async_copy(): if not tvm.testing.is_ampere_or_newer(): continue - with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + with tvm.transform.PassContext(config={"tirx.use_async_copy": 1}): mod = tvm.compile(tvm.IRModule.from_expr(f), target="cuda") A_np = np.random.rand(32, 128).astype(dtype) @@ -163,8 +163,8 @@ def test_inject_async_copy_shared_dyn(): mod = tvm.IRModule.from_expr(f) mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) - mod = tvm.tir.transform.FlattenBuffer()(mod) - mod = tvm.tir.transform.VectorizeLoop()(mod) + mod = tvm.tirx.transform.FlattenBuffer()(mod) + mod = tvm.tirx.transform.VectorizeLoop()(mod) mod = tvm.s_tir.transform.MergeSharedMemoryAllocations()(mod) mod = tvm.s_tir.transform.InjectPTXAsyncCopy()(mod) @@ -173,7 +173,7 @@ def test_inject_async_copy_shared_dyn(): if not tvm.testing.is_ampere_or_newer(): return - with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + with tvm.transform.PassContext(config={"tirx.use_async_copy": 1}): mod = tvm.compile(tvm.IRModule.from_expr(f), target="cuda") A_np = np.random.rand(32, 128).astype("float16") @@ -191,7 +191,7 @@ def test_inject_async_copy_shared_dyn(): def ptx_global_to_shared_copy_fp32x1_barrier( A: T.Buffer((32, 128), "float32"), B: T.Buffer((32, 128), "float32") ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) @@ -225,13 +225,13 @@ def test_inject_async_copy_barrier(): mod = tvm.IRModule.from_expr(f) mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) - mod = tvm.tir.transform.FlattenBuffer()(mod) + mod = tvm.tirx.transform.FlattenBuffer()(mod) mod = tvm.s_tir.transform.InjectPTXAsyncCopy()(mod) assert count_cp_async(mod["main"].body) == 1 if tvm.testing.is_ampere_or_newer(): - with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + with tvm.transform.PassContext(config={"tirx.use_async_copy": 1}): mod = tvm.compile(tvm.IRModule.from_expr(f), target="cuda") A_np = np.random.rand(32, 128).astype(dtype) @@ -449,7 +449,7 @@ def simple_compute( B: T.Buffer((16, 14), "float32"), C: T.Buffer((16, 16), "float32"), ): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for tx in T.thread_binding(0, 16, thread="threadIdx.x"): for i in T.serial( 16, @@ -482,7 +482,7 @@ def simple_compute( C[tx, i] = A_shared[tx, 0] + B_shared[tx, 0] mod = tvm.IRModule.from_expr(simple_compute) - with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + with tvm.transform.PassContext(config={"tirx.use_async_copy": 1}): tvm.compile(mod, target="cuda") generated_code = postproc_if_missing_async_support() print(generated_code) @@ -503,7 +503,7 @@ def complex_compute( W: T.Buffer((1280, 3, 3, 1280), "float16"), Conv: T.Buffer((512, 1280), "float16"), ): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # with T.sblock("root"): data_im2col_reindex_shared_dyn = T.sblock_alloc_buffer( (512, 11520), "float16", scope="shared.dyn" @@ -946,7 +946,7 @@ def complex_compute( ) mod = tvm.IRModule.from_expr(complex_compute) - with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + with tvm.transform.PassContext(config={"tirx.use_async_copy": 1}): tvm.compile(mod, target="cuda") generated_code = postproc_if_missing_async_support() # generated_code must contain " setp.ne.b32 p, %0, 0;" diff --git a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_ldg32.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_ldg32.py index ad307bb647dc..e067c5125a3d 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_ldg32.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_ptx_ldg32.py @@ -19,17 +19,17 @@ import tvm import tvm.testing from tvm import s_tir -from tvm.script import tir as T +from tvm.script import tirx as T def _count_alloc(stmt): num_alloc = [0] def visit(n): - if isinstance(n, tvm.tir.AllocBuffer): + if isinstance(n, tvm.tirx.AllocBuffer): num_alloc[0] += 1 - tvm.tir.stmt_functor.post_order_visit(stmt, visit) + tvm.tirx.stmt_functor.post_order_visit(stmt, visit) return num_alloc[0] @@ -37,23 +37,23 @@ def _count_ptx_ldg32(stmt): num_call = [0] def visit(n): - if isinstance(n, tvm.tir.Call) and n.op.name == "tir.ptx_ldg32": + if isinstance(n, tvm.tirx.Call) and n.op.name == "tirx.ptx_ldg32": num_call[0] += 1 - tvm.tir.stmt_functor.post_order_visit(stmt, visit) + tvm.tirx.stmt_functor.post_order_visit(stmt, visit) return num_call[0] @T.prim_func def where_no_alloc(A: T.Buffer((4,), "float32"), C: T.Buffer((4,), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": T.target("cuda")}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True, "target": T.target("cuda")}) for i in range(4): C[i] = T.if_then_else(A[i] > T.float32(0), A[i], T.float32(0)) @T.prim_func def where_no_alloc_cpu(A: T.Buffer((4,), "float32"), C: T.Buffer((4,), "float32")) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": T.target("llvm")}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True, "target": T.target("llvm")}) for i in range(4): C[i] = T.if_then_else(A[i] > T.float32(0), A[i], T.float32(0)) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py index 70d4f666a389..c05c731495cb 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_software_pipeline.py @@ -23,7 +23,7 @@ import tvm import tvm.s_tir.tensor_intrin.cuda import tvm.testing -from tvm import TVMError, te, tir +from tvm import TVMError, te, tirx from tvm.s_tir.meta_schedule.testing import te_workload from tvm.s_tir.tensor_intrin.cuda import ( LDMATRIX_f16_A_DYN_INTRIN, @@ -33,7 +33,7 @@ MMA_store_16x16_f32_global_INTRIN, shared_16x16_to_ldmatrix_32x8_layout, ) -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.testing.tir import mma_schedule @@ -41,7 +41,7 @@ def _check(original, transformed): func = original mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.InjectSoftwarePipeline()(mod) - mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tirx.transform.Simplify()(mod) tvm.ir.assert_structural_equal( mod["main"], transformed.with_attr("global_symbol", "main"), True ) @@ -1533,7 +1533,7 @@ def index_map(i, j): def build_and_run(sch): if tvm.testing.is_ampere_or_newer(): - with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + with tvm.transform.PassContext(config={"tirx.use_async_copy": 1}): f = tvm.compile(sch.mod["main"], target="cuda") dev = tvm.device("cuda", 0) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_inject_virtual_thread.py b/tests/python/s_tir/transform/test_s_tir_transform_inject_virtual_thread.py index 897c61fc010d..2c251b15559b 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_inject_virtual_thread.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_inject_virtual_thread.py @@ -18,7 +18,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_vthread(): @@ -43,7 +43,7 @@ def main(A: T.handle("float32"), C: T.handle("float32")): "int32", "Run", B.access_ptr("r"), - T.call_intrin("int32", "tir.tvm_context_id"), + T.call_intrin("int32", "tirx.tvm_context_id"), ) ) C_buf[i * nthread + vt_x] = B[i] + T.float32(1) @@ -57,10 +57,10 @@ def main(A: T.handle("float32"), C: T.handle("float32")): allocates = [] def find_allocates(node): - if isinstance(node, tvm.tir.AllocBuffer): + if isinstance(node, tvm.tirx.AllocBuffer): allocates.append(node) - tvm.tir.stmt_functor.post_order_visit(stmt.body, find_allocates) + tvm.tirx.stmt_functor.post_order_visit(stmt.body, find_allocates) assert len(allocates) == 1 assert list(allocates[0].buffer.shape) == [B_expected_alloc] @@ -106,10 +106,10 @@ def main(): allocates = [] def find_allocates(node): - if isinstance(node, tvm.tir.AllocBuffer): + if isinstance(node, tvm.tirx.AllocBuffer): allocates.append(node) - tvm.tir.stmt_functor.post_order_visit(stmt.body, find_allocates) + tvm.tirx.stmt_functor.post_order_visit(stmt.body, find_allocates) assert len(allocates) == 3 # Check that we have the expected extents (order may vary) extents = sorted([int(a.buffer.shape[0]) for a in allocates]) @@ -142,10 +142,10 @@ def main(A: T.handle("float32")): if_nodes = [] def find_ifs(node): - if isinstance(node, tvm.tir.IfThenElse): + if isinstance(node, tvm.tirx.IfThenElse): if_nodes.append(node) - tvm.tir.stmt_functor.post_order_visit(stmt.body, find_ifs) + tvm.tirx.stmt_functor.post_order_visit(stmt.body, find_ifs) assert len(if_nodes) == 2 # First if has else_case, second does not @@ -197,7 +197,7 @@ def before_func(): before_mod = tvm.IRModule.from_expr(before_func.with_attr("global_symbol", "main")) intermediate_mod = tvm.s_tir.transform.InjectVirtualThread()(before_mod) - after_mod = tvm.tir.transform.StorageRewrite()(intermediate_mod) + after_mod = tvm.tirx.transform.StorageRewrite()(intermediate_mod) after_func = after_mod["main"] # Verify the vectorized allocation has the expected shape and dtype @@ -205,10 +205,10 @@ def before_func(): def visitor(op): nonlocal allocate_node - if isinstance(op, tvm.tir.AllocBuffer) and "shared" in str(op.buffer.data.type_annotation): + if isinstance(op, tvm.tirx.AllocBuffer) and "shared" in str(op.buffer.data.type_annotation): allocate_node = op - tvm.tir.stmt_functor.post_order_visit(after_func.body, visitor) + tvm.tirx.stmt_functor.post_order_visit(after_func.body, visitor) assert allocate_node is not None assert list(allocate_node.buffer.shape) == [4] assert allocate_node.buffer.dtype == "int32x4" diff --git a/tests/python/s_tir/transform/test_s_tir_transform_lift_thread_binding.py b/tests/python/s_tir/transform/test_s_tir_transform_lift_thread_binding.py index cdef794a3a45..40fcbb61886a 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_lift_thread_binding.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lift_thread_binding.py @@ -16,8 +16,8 @@ # under the License. # ruff: noqa: E501, F401 import tvm -from tvm import s_tir, tir -from tvm.script import tir as T +from tvm import s_tir, tirx +from tvm.script import tirx as T def test_lift_tx_beyond_local(): diff --git a/tests/python/s_tir/transform/test_s_tir_transform_loop_partition.py b/tests/python/s_tir/transform/test_s_tir_transform_loop_partition.py index 442d782ba0f0..505123d210b6 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_loop_partition.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_loop_partition.py @@ -21,12 +21,12 @@ import tvm import tvm.testing from tvm.ir.module import IRModule -from tvm.script import tir as T +from tvm.script import tirx as T def collect_visit(stmt, f): ret = [] - tvm.tir.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x))) + tvm.tirx.stmt_functor.post_order_visit(stmt, lambda x: ret.append(f(x))) return ret @@ -43,9 +43,9 @@ def func(n: T.int64, m: T.int64): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.Simplify()(mod)["main"].body - assert not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) + assert not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tirx.IfThenElse))) def test_multi_if(): @@ -65,9 +65,9 @@ def func(n: T.int64, m: T.int64): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.Simplify()(mod)["main"].body - assert not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) + assert not any(collect_visit(stmt.body[0], lambda x: isinstance(x, tvm.tirx.IfThenElse))) def test_condition(): @@ -79,9 +79,9 @@ def func(m: T.int64, n: T.int64): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.Simplify()(mod)["main"].body - assert not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))) + assert not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tirx.Select))) def test_condition_EQ(): @@ -93,9 +93,9 @@ def func(m: T.int64, n: T.int64): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) with tvm.transform.PassContext(config={"s_tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.Simplify()(mod)["main"].body - assert not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tir.Select))) + assert not any(collect_visit(stmt[0], lambda x: isinstance(x, tvm.tirx.Select))) def test_everything_during_deduction(): @@ -109,9 +109,9 @@ def func(m: T.int64, n: T.int64): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.Simplify()(mod)["main"].body - assert isinstance(stmt.body.body, tvm.tir.IfThenElse) + assert isinstance(stmt.body.body, tvm.tirx.IfThenElse) def test_oneD_pool(): @@ -139,9 +139,9 @@ def func(m: T.int64, data: T.handle("float32"), out: T.handle("float32")): with tvm.transform.PassContext(config={"s_tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.Simplify()(mod)["main"].body - assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) + assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tirx.IfThenElse))) def test_cce_loop_1(): @@ -160,9 +160,9 @@ def func(A: T.Buffer((n * m,), "float16"), B: T.Buffer((n * m,), "float16")): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) with tvm.transform.PassContext(config={"s_tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.Simplify()(mod)["main"].body - assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) + assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tirx.IfThenElse))) def test_cce_loop_2(): @@ -181,9 +181,9 @@ def func(): mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) with tvm.transform.PassContext(config={"s_tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.Simplify()(mod)["main"].body - assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) + assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tirx.IfThenElse))) def test_cce_loop_3(): @@ -202,16 +202,16 @@ def func(): with tvm.transform.PassContext(config={"s_tir.LoopPartition": {"partition_const_loop": True}}): mod = tvm.s_tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tirx.transform.Simplify()(mod)["main"].body - assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tir.IfThenElse))) + assert not any(collect_visit(stmt, lambda x: isinstance(x, tvm.tirx.IfThenElse))) @T.prim_func def partitioned_concat( A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32"), C: T.Buffer((32,), "float32") ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i in T.serial(0, 16): C[i] = A[i] for i in T.serial(0, 16): @@ -223,10 +223,10 @@ def partition_from_scheduled_tir(prim_func, pass_cfg, do_flatten=True): mod = IRModule.from_expr(prim_func.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) if do_flatten: - mod = tvm.tir.transform.FlattenBuffer()(mod) + mod = tvm.tirx.transform.FlattenBuffer()(mod) mod = tvm.s_tir.transform.LoopPartition()(mod) - mod = tvm.tir.transform.Simplify()(mod) - mod = tvm.tir.transform.RemoveNoOp()(mod) + mod = tvm.tirx.transform.Simplify()(mod) + mod = tvm.tirx.transform.RemoveNoOp()(mod) return mod @@ -327,9 +327,9 @@ def partitioned_main( } }, ) - mod = tvm.tir.transform.UnrollLoop()(mod) - mod = tvm.tir.transform.RemoveNoOp()(mod) - mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tirx.transform.UnrollLoop()(mod) + mod = tvm.tirx.transform.RemoveNoOp()(mod) + mod = tvm.tirx.transform.Simplify()(mod) tvm.ir.assert_structural_equal(mod["main"], partitioned_main.with_attr("global_symbol", "main")) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py b/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py index 43c57955cc79..d5ddc47dbcff 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py @@ -23,7 +23,7 @@ import tvm import tvm.testing from tvm import s_tir -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg diff --git a/tests/python/s_tir/transform/test_s_tir_transform_lower_init_block.py b/tests/python/s_tir/transform/test_s_tir_transform_lower_init_block.py index b48fbbebe02f..6ceb561687f8 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_lower_init_block.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_init_block.py @@ -17,7 +17,7 @@ # ruff: noqa: F401, F821 import tvm from tvm import s_tir -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=no-self-argument diff --git a/tests/python/s_tir/transform/test_s_tir_transform_lower_match_buffer.py b/tests/python/s_tir/transform/test_s_tir_transform_lower_match_buffer.py index 40a595db5b04..cc3c98c377c1 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_lower_match_buffer.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_match_buffer.py @@ -20,13 +20,13 @@ import tvm import tvm.s_tir import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T def _check(original, transformed): mod = tvm.IRModule.from_expr(original.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.LowerMatchBuffer()(mod) - mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tirx.transform.Simplify()(mod) tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main")) @@ -64,7 +64,7 @@ def transformed_buffer_load_store(a: T.handle, c: T.handle) -> None: A[i * 4 + ii, j, k * 2 + kk] += C[i * 4 + ii, k * 2 + kk] -@tvm.ir.register_op_attr("tir.intrin_test", "") +@tvm.ir.register_op_attr("tirx.intrin_test", "") def intrin_test(data, elem_offset, stride_0, stride_1, shape_0, shape_1): return 0 diff --git a/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py b/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py index f765874753d1..c4b212842be7 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_opaque_block.py @@ -17,14 +17,14 @@ import tvm import tvm.s_tir import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T def _check(original, transformed): func = original mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.LowerOpaqueBlock()(mod) - mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tirx.transform.Simplify()(mod) tvm.ir.assert_structural_equal( mod["main"], transformed.with_attr("global_symbol", "main"), True ) @@ -352,9 +352,9 @@ def test_annotated_loops(): attr3 = attr2.body assert attr1.attr_key == "pragma_1" and attr1.value == "str_value" assert attr2.attr_key == "pragma_2" - tvm.ir.assert_structural_equal(attr2.value, tvm.tir.IntImm("int32", 1)) + tvm.ir.assert_structural_equal(attr2.value, tvm.tirx.IntImm("int32", 1)) assert attr3.attr_key == "pragma_3" - tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0)) + tvm.ir.assert_structural_equal(attr3.value, tvm.tirx.FloatImm("float32", 0.0)) def test_annotated_block(): @@ -371,9 +371,9 @@ def annotated_block() -> None: attr3 = attr2.body assert attr1.attr_key == "pragma_1" and attr1.value == "str_value" assert attr2.attr_key == "pragma_2" - tvm.ir.assert_structural_equal(attr2.value, tvm.tir.IntImm("int32", 1)) + tvm.ir.assert_structural_equal(attr2.value, tvm.tirx.IntImm("int32", 1)) assert attr3.attr_key == "pragma_3" - tvm.ir.assert_structural_equal(attr3.value, tvm.tir.FloatImm("float32", 0.0)) + tvm.ir.assert_structural_equal(attr3.value, tvm.tirx.FloatImm("float32", 0.0)) def test_preserved_annotations(): diff --git a/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py index e5e17801fe1e..aff3376052bf 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py @@ -20,7 +20,7 @@ import tvm.testing from tvm import s_tir from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_basic(): diff --git a/tests/python/s_tir/transform/test_s_tir_transform_manifest_shared_memory_local_stage.py b/tests/python/s_tir/transform/test_s_tir_transform_manifest_shared_memory_local_stage.py index 10df113fb88c..cb82237cb259 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_manifest_shared_memory_local_stage.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_manifest_shared_memory_local_stage.py @@ -18,7 +18,7 @@ import tvm import tvm.testing from tvm import s_tir -from tvm.script import tir as T +from tvm.script import tirx as T # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks @@ -29,7 +29,7 @@ class MatmulBefore: @T.prim_func def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) # body # with T.sblock("root") for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"): @@ -47,14 +47,14 @@ def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float3 with T.sblock("A_shared"): T.reads(A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.writes(A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) - T.sblock_attr({"tir.manifest_shared_memory_local_stage":1}) + T.sblock_attr({"tirx.manifest_shared_memory_local_stage":1}) A_shared[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = A[blockIdx_y * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] for ax0_ax1_fused_0 in T.serial(64): for ax0_ax1_fused_3 in T.vectorized(4): with T.sblock("B_shared"): T.reads(B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) T.writes(B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32]) - T.sblock_attr({"tir.manifest_shared_memory_local_stage":1}) + T.sblock_attr({"tirx.manifest_shared_memory_local_stage":1}) B_shared[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] = B[k_0 * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) // 32, blockIdx_x * 32 + (ax0_ax1_fused_0 * 16 + threadIdx_y * 8 + threadIdx_x * 4 + ax0_ax1_fused_3) % 32] for k_1, i_2, j_2, k_2 in T.grid(2, 16, 16, 16): with T.sblock("C"): @@ -70,7 +70,7 @@ class MatmulAfter: @T.prim_func def main(A: T.Buffer((1024, 1024), "float32"), B: T.Buffer((1024, 1024), "float32"), C: T.Buffer((1024, 1024), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) # body # with T.sblock("root") for blockIdx_y in T.thread_binding(32, thread="blockIdx.y"): diff --git a/tests/python/s_tir/transform/test_s_tir_transform_memhammer_lower_auto_copy.py b/tests/python/s_tir/transform/test_s_tir_transform_memhammer_lower_auto_copy.py index d62cd62ed8a3..fdde44e00db7 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_memhammer_lower_auto_copy.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_memhammer_lower_auto_copy.py @@ -22,7 +22,7 @@ import tvm from tvm import s_tir -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.script.ir_module @@ -1138,7 +1138,7 @@ def verify_single_allocation(stmt, alloc_size=None): def verify(n): if ( - isinstance(n, tvm.tir.SBlock) + isinstance(n, tvm.tirx.SBlock) and n.alloc_buffers is not None and (True in ((buf.scope() == "shared.dyn") for buf in n.alloc_buffers)) ): @@ -1146,7 +1146,7 @@ def verify(n): for buf in n.alloc_buffers: alloc_extents.append(buf.shape) - tvm.tir.stmt_functor.post_order_visit(stmt, verify) + tvm.tirx.stmt_functor.post_order_visit(stmt, verify) assert num_alloc[0] == 1 if alloc_size: @@ -1162,7 +1162,7 @@ def prod(arr): def test_auto_padding(): mod = tvm.s_tir.transform.LowerAutoCopy()(Transpose) - mod = tvm.tir.transform.FlattenBuffer()(mod) + mod = tvm.tirx.transform.FlattenBuffer()(mod) verify_single_allocation(mod["main"].body, 16 * 130) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py index 9d5a05cd0449..83d71f078377 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -21,7 +21,7 @@ import tvm.testing from tvm import s_tir from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.topi.math import cast diff --git a/tests/python/s_tir/transform/test_s_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/s_tir/transform/test_s_tir_transform_plan_update_buffer_allocation_location.py index c92528e5a2da..88475eba8acf 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_plan_update_buffer_allocation_location.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_plan_update_buffer_allocation_location.py @@ -19,9 +19,9 @@ import tvm import tvm.testing -from tvm import s_tir, tir +from tvm import s_tir, tirx from tvm.s_tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN -from tvm.script import tir as T +from tvm.script import tirx as T def _check(original, transformed): @@ -356,14 +356,14 @@ def after(A: T.Buffer((4, 16), "int32"), C: T.Buffer((4, 8), "int32")): def test_buffer_conditional_lowering(): """Buffers passed as pointer arguments are unmodified - Confirm that the `tir.PlanAndUpdateBufferAllocationLocation` pass + Confirm that the `tirx.PlanAndUpdateBufferAllocationLocation` pass leaves (Buffer nodes corresponding to pointer-typed PrimFunc arguments) unchanged, rather than lowering them to `reads`, `writes`, and `alloc_buffer` nodes. """ @T.prim_func def before(A: T.handle("float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i in range(1): A_1 = T.decl_buffer((1,), data=A) A_1[i] = 0 @@ -375,7 +375,7 @@ def before(A: T.handle("float32")): def test_dltensor_buffer_is_unlowered(): """Buffers allocated with a Bind are unmodified - Confirm that the `tir.PlanAndUpdateBufferAllocationLocation` pass + Confirm that the `tirx.PlanAndUpdateBufferAllocationLocation` pass leaves (Buffer nodes corresponding to PrimFunc DLTensor arguments) unchanged, rather than lowering them to `reads`, `writes`, and `alloc_buffer` nodes. diff --git a/tests/python/s_tir/transform/test_s_tir_transform_profiling_instr.py b/tests/python/s_tir/transform/test_s_tir_transform_profiling_instr.py index 2581a7d7dfba..693111cdfe49 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_profiling_instr.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_profiling_instr.py @@ -22,10 +22,10 @@ import tvm.testing from tvm import s_tir from tvm.ir.module import IRModule -from tvm.script import tir as T +from tvm.script import tirx as T default_lwp_test_config = { - "tir.instrument_lwp": True, + "tirx.instrument_lwp": True, "s_tir.lwp_disable_func_prof": True, "s_tir.reset_start_id": True, } diff --git a/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py b/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py index 7d15fe82799d..529f09bdf663 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py @@ -21,7 +21,7 @@ import tvm.testing from tvm import TVMError from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_remove_store_undef(): diff --git a/tests/python/s_tir/transform/test_s_tir_transform_remove_weight_layout_rewrite_block.py b/tests/python/s_tir/transform/test_s_tir_transform_remove_weight_layout_rewrite_block.py index 3ea6d8e28f01..656d0f28996c 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_remove_weight_layout_rewrite_block.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_remove_weight_layout_rewrite_block.py @@ -19,8 +19,8 @@ import tvm from tvm.ir.module import IRModule -from tvm.script import tir as T -from tvm.tir.function import PrimFunc +from tvm.script import tirx as T +from tvm.tirx.function import PrimFunc def _check(before, expect): diff --git a/tests/python/s_tir/transform/test_s_tir_transform_renormalize_split_pattern.py b/tests/python/s_tir/transform/test_s_tir_transform_renormalize_split_pattern.py index 7ee25fede0cb..759aad2fa2b6 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_renormalize_split_pattern.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_renormalize_split_pattern.py @@ -19,7 +19,7 @@ import tvm import tvm.testing from tvm import s_tir -from tvm.script import tir as T +from tvm.script import tirx as T # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,redundant-keyword-arg @@ -29,7 +29,7 @@ class Before: @T.prim_func def main(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 512, 256), "float32"), conv2d_transpose_nhwc: T.Buffer((1, 8, 8, 256), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) inputs_flat = T.decl_buffer([8192], dtype="float32", data=inputs.data) weight_flat = T.decl_buffer([2097152], dtype="float32", data=weight.data) conv2d_transpose_nhwc_flat = T.decl_buffer([16384], dtype="float32", data=conv2d_transpose_nhwc.data) @@ -60,7 +60,7 @@ class After: @T.prim_func def main(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 512, 256), "float32"), conv2d_transpose_nhwc: T.Buffer((1, 8, 8, 256), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) inputs_flat = T.decl_buffer([8192], dtype="float32", data=inputs.data) weight_flat = T.decl_buffer([2097152], dtype="float32", data=weight.data) conv2d_transpose_nhwc_flat = T.decl_buffer([16384], dtype="float32", data=conv2d_transpose_nhwc.data) @@ -91,7 +91,7 @@ class After_simplified: @T.prim_func def main(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 512, 256), "float32"), conv2d_transpose_nhwc: T.Buffer((1, 8, 8, 256), "float32")) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # var definition threadIdx_x = T.env_thread("threadIdx.x") blockIdx_x = T.env_thread("blockIdx.x") @@ -123,7 +123,7 @@ def main(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 51 def test_renormalize_split_pattern(): after = tvm.s_tir.transform.RenormalizeSplitPattern()(Before) tvm.ir.assert_structural_equal(after, After) - after = tvm.tir.transform.Simplify()(after) + after = tvm.tirx.transform.Simplify()(after) tvm.ir.assert_structural_equal(after, After_simplified) @@ -166,7 +166,7 @@ def test_analyze_inside_integer_conditional(integer_condition): """ # Similar issue would occur in most transformations that subclass - # IRMutatorWithAnalyzer. tir.transform.Simplify() is an + # IRMutatorWithAnalyzer. tirx.transform.Simplify() is an # exception, as it rewrites the integer conditionals first. These # tests are written using RenormalizeSplitPattern as it is the # first case identified. diff --git a/tests/python/s_tir/transform/test_s_tir_transform_rewrite_unsafe_select.py b/tests/python/s_tir/transform/test_s_tir_transform_rewrite_unsafe_select.py index fa7ddc7c3cf1..e3f153c9afb6 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_rewrite_unsafe_select.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_rewrite_unsafe_select.py @@ -18,7 +18,7 @@ import tvm from tvm import s_tir from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_rewrite_Select(): @@ -63,11 +63,11 @@ def main(i: T.int32): ) aa = tvm.s_tir.transform.RewriteUnsafeSelect()(ModuleA)["main"].body.seq[-1].value - builtin_if_then_else = tvm.ir.Op.get("tir.if_then_else") + builtin_if_then_else = tvm.ir.Op.get("tirx.if_then_else") assert yy.op.same_as(builtin_if_then_else) assert yy.op.same_as(builtin_if_then_else) - assert isinstance(aa, tvm.tir.Select) + assert isinstance(aa, tvm.tirx.Select) if __name__ == "__main__": diff --git a/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py index 08a51d265556..37c67c83f1ef 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py @@ -18,20 +18,20 @@ import tvm import tvm.testing from tvm import s_tir -from tvm.script import tir as T +from tvm.script import tirx as T -def run_passes(func: tvm.tir.PrimFunc): +def run_passes(func: tvm.tirx.PrimFunc): mod = tvm.IRModule.from_expr(func) cuda_target = tvm.target.Target("cuda", host="llvm") - mod = tvm.tir.transform.Apply( + mod = tvm.tirx.transform.Apply( lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}) )(mod) - mod = tvm.tir.transform.AnnotateDeviceRegions()(mod) - mod = tvm.tir.transform.SplitHostDevice()(mod) + mod = tvm.tirx.transform.AnnotateDeviceRegions()(mod) + mod = tvm.tirx.transform.SplitHostDevice()(mod) return tvm.s_tir.transform.ThreadSync("shared")(mod) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_unify_thread_binding.py b/tests/python/s_tir/transform/test_s_tir_transform_unify_thread_binding.py index 959a55637173..bb6820d3bf6a 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_unify_thread_binding.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_unify_thread_binding.py @@ -22,13 +22,13 @@ import tvm import tvm.testing from tvm import s_tir -from tvm.script import tir as T +from tvm.script import tirx as T def _check(original, transformed): mod = tvm.IRModule.from_expr(original.with_attr("global_symbol", "main")) mod = tvm.s_tir.transform.UnifyThreadBinding()(mod) - mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tirx.transform.Simplify()(mod) tvm.ir.assert_structural_equal( mod["main"], transformed.with_attr("global_symbol", "main"), True ) diff --git a/tests/python/target/test_arm_target.py b/tests/python/target/test_arm_target.py index 6b6156d4ec91..96321bb6e449 100644 --- a/tests/python/target/test_arm_target.py +++ b/tests/python/target/test_arm_target.py @@ -24,7 +24,7 @@ import pytest import tvm -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import codegen llvm_version, arm_target, input_dtype, kernel_dtype, is_supported = tvm.testing.parameters( @@ -116,7 +116,7 @@ def test_scalable_div(sve_device_vector_length): @T.prim_func def my_func(a: T.handle): A = T.match_buffer(a, (1,), "int32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + T.func_attr({"global_symbol": "my_module", "tirx.noalias": True}) A[0] = T.Div(10000, 4 * T.vscale()) mod = tvm.compile(my_func, target=target) @@ -139,7 +139,7 @@ def test_scalable_buffer_load_store(sve_device_vector_length): def my_func(a: T.handle, b: T.handle): A = T.match_buffer(a, (num_elements,), "float32") B = T.match_buffer(b, (num_elements,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + T.func_attr({"global_symbol": "my_module", "tirx.noalias": True}) B[T.ramp(0, 1, 4 * T.vscale())] = A[T.ramp(0, 1, 4 * T.vscale())] mod = tvm.compile(my_func, target=target) @@ -166,7 +166,7 @@ def test_scalable_loop_bound(sve_device_vector_length): def my_func(a: T.handle, b: T.handle): A = T.match_buffer(a, (num_elements,), "float32") B = T.match_buffer(b, (num_elements,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + T.func_attr({"global_symbol": "my_module", "tirx.noalias": True}) for i in T.serial(0, 4 * T.vscale()): B[i] = A[i] @@ -190,7 +190,7 @@ def test_scalable_broadcast(sve_device_vector_length): @T.prim_func def my_func(a: T.handle): A = T.match_buffer(a, (num_elements,), "float32") - T.func_attr({"global_symbol": "my_module", "tir.noalias": True}) + T.func_attr({"global_symbol": "my_module", "tirx.noalias": True}) A[T.ramp(0, 1, 4 * T.vscale())] = T.broadcast(1, 4 * T.vscale()) mod = tvm.compile(my_func, target=target) diff --git a/tests/python/target/test_target_target.py b/tests/python/target/test_target_target.py index fe86982e5ed0..05c79abea7a2 100644 --- a/tests/python/target/test_target_target.py +++ b/tests/python/target/test_target_target.py @@ -384,7 +384,7 @@ def test_target_from_device_opencl(input_device): def test_module_dict_from_deserialized_targets(): target = Target("llvm") - from tvm.script import tir as T + from tvm.script import tirx as T @T.prim_func def func(): diff --git a/tests/python/te/test_te_create_primfunc.py b/tests/python/te/test_te_create_primfunc.py index 3a9bcd3957fd..e1fa7301b5da 100644 --- a/tests/python/te/test_te_create_primfunc.py +++ b/tests/python/te/test_te_create_primfunc.py @@ -21,8 +21,8 @@ import tvm import tvm.testing -from tvm import s_tir, te, tir, topi -from tvm.script import tir as T +from tvm import s_tir, te, tirx, topi +from tvm.script import tirx as T def test_unique_name_complete_block(): @@ -50,7 +50,7 @@ def test_unique_name_reduction_block(): def _check_workload(te_workload, tir_workload, index_dtype_override=None, do_simplify=False): func = te.create_prim_func(te_workload(), index_dtype_override) if do_simplify: - simplify = tir.transform.Simplify() + simplify = tirx.transform.Simplify() func = simplify(tvm.IRModule.from_expr(func))["main"] tir_workload = simplify(tvm.IRModule.from_expr(tir_workload))["main"] tvm.ir.assert_structural_equal(func, tir_workload) @@ -69,7 +69,7 @@ def te_matmul(): @T.prim_func def tir_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) @@ -88,7 +88,7 @@ def tir_matmul_int64( B: T.Buffer((T.int64(128), T.int64(128)), "float32"), C: T.Buffer((T.int64(128), T.int64(128)), "float32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i0, j0, k0 in T.grid(T.int64(128), T.int64(128), T.int64(128)): with T.sblock(): i, j, k = T.axis.remap("SSR", [i0, j0, k0]) @@ -114,7 +114,7 @@ def te_element_wise(): @T.prim_func def tir_element_wise(a: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, (128, 128)) C = T.match_buffer(c, (128, 128)) B = T.sblock_alloc_buffer((128, 128)) @@ -144,8 +144,8 @@ def te_conv2d(): W = te.placeholder((in_channel, kernel, kernel, out_channel), name="W") Apad = te.compute( (batch, in_channel, size + 2, size + 2), - lambda nn, cc, yy, xx: tvm.tir.if_then_else( - tvm.tir.all(yy >= 1, yy - 1 < size, xx >= 1, xx - 1 < size), + lambda nn, cc, yy, xx: tvm.tirx.if_then_else( + tvm.tirx.all(yy >= 1, yy - 1 < size, xx >= 1, xx - 1 < size), A[nn, cc, yy - 1, xx - 1], 0.0, ), @@ -166,7 +166,7 @@ def te_conv2d(): @T.prim_func def tir_conv2d(a: T.handle, w: T.handle, b: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, [16, 16, 14, 14]) W = T.match_buffer(w, [16, 3, 3, 32]) B = T.match_buffer(b, [16, 32, 14, 14]) @@ -204,7 +204,7 @@ def te_multi_output(): @T.prim_func def tir_multi_output(a0: T.handle, a1: T.handle, b0: T.handle, b1: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) m = T.int32() n = T.int32() A0 = T.match_buffer(a0, (m, n)) @@ -231,7 +231,7 @@ def te_extern(): C = te.extern( (128, 128), [A, B], - lambda ins, outs: tvm.tir.call_packed( + lambda ins, outs: tvm.tirx.call_packed( "tvm.contrib.cblas.matmul", ins[0], ins[1], outs[0], 0, 0 ), name="C", @@ -241,7 +241,7 @@ def te_extern(): @T.prim_func def tir_extern(a: T.handle, b: T.handle, c: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) off1 = te.var("elem_offset") off2 = te.var("elem_offset_1") off3 = te.var("elem_offset_2") @@ -303,7 +303,7 @@ def te_reordered_matmul(): @T.prim_func def tir_reordered_matmul(c: T.handle, a: T.handle, b: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) C = T.match_buffer(c, (128, 128)) @@ -348,7 +348,7 @@ def test_constant(): B = te.compute(tuple(), lambda: 2, name="B") # Manually craft ProducerLoad because `B[]` is not allowed. C = te.compute( - (M,), lambda x: A[x] + tvm.tir.expr.ProducerLoad(B, []), name="C", tag="broadcast" + (M,), lambda x: A[x] + tvm.tirx.expr.ProducerLoad(B, []), name="C", tag="broadcast" ) func = te.create_prim_func([C, A]) @@ -432,7 +432,7 @@ def expected_layout_attr( B: T.Buffer((128, 128), "float32"), D: T.Buffer((128, 128), "float32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True, "layout_free_buffers": [1]}) C = T.sblock_alloc_buffer([128, 128], dtype="float32") for i0, i1, i2 in T.grid(128, 128, 128): with T.sblock("C"): @@ -453,7 +453,7 @@ def expected_layout_attr_int64( B: T.Buffer((T.int64(128), T.int64(128)), "float32"), D: T.Buffer((T.int64(128), T.int64(128)), "float32"), ): - T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True, "layout_free_buffers": [1]}) C = T.sblock_alloc_buffer([T.int64(128), T.int64(128)], dtype="float32") for x, y, k in T.grid(T.int64(128), T.int64(128), T.int64(128)): with T.sblock("C"): @@ -498,12 +498,12 @@ def test_tensor_layout_attr(index_dtype_override, expected): def te_argmax_idx_val(): def f_combine(x, y): - lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) - rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1]) + lhs = tvm.tirx.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.tirx.Select((x[1] >= y[1]), x[1], y[1]) return lhs, rhs def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType): - return tvm.tir.const(-1, dtype0), tvm.te.min_value(dtype1) + return tvm.tirx.const(-1, dtype0), tvm.te.min_value(dtype1) argmax = te.comm_reducer(f_combine, f_identity, name="argmax") @@ -522,7 +522,7 @@ def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType): def tir_argmax_idx_val( var_idx: T.handle, var_val: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) m = T.int32() n = T.int32() idx = T.match_buffer(var_idx, [m, n], dtype="int32") @@ -545,12 +545,12 @@ def tir_argmax_idx_val( def te_argmax_val_idx(): def f_combine(x, y): - lhs = tvm.tir.Select((x[0] >= y[0]), x[0], y[0]) - rhs = tvm.tir.Select((x[0] >= y[0]), x[1], y[1]) + lhs = tvm.tirx.Select((x[0] >= y[0]), x[0], y[0]) + rhs = tvm.tirx.Select((x[0] >= y[0]), x[1], y[1]) return lhs, rhs def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType): - return tvm.te.min_value(dtype0), tvm.tir.const(-1, dtype1) + return tvm.te.min_value(dtype0), tvm.tirx.const(-1, dtype1) argmax = te.comm_reducer(f_combine, f_identity, name="argmax") @@ -569,7 +569,7 @@ def f_identity(dtype0: tvm.DataType, dtype1: tvm.DataType): def tir_argmax_val_idx( var_val: T.handle, var_idx: T.handle, var_argmax_v0: T.handle, var_argmax_v1: T.handle ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) m = T.int32() n = T.int32() val = T.match_buffer(var_val, [m, n], dtype="float32") @@ -622,7 +622,7 @@ def expected( b: T.Buffer((), "int32"), c: T.Buffer((), "int32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.sblock("root"): T.reads() T.writes() @@ -637,7 +637,7 @@ def expected( def te_reshape(): # The following is possible to be generated by TOPI. So we test this case. - A = te.placeholder((tvm.tir.IntImm("int64", 2), tvm.tir.IntImm("int64", 4)), name="A") + A = te.placeholder((tvm.tirx.IntImm("int64", 2), tvm.tirx.IntImm("int64", 4)), name="A") B = topi.reshape(A, (4, 2)) return [A, B] @@ -647,7 +647,7 @@ def tir_reshape( A: T.Buffer((T.int64(2), T.int64(4)), "float32"), T_reshape: T.Buffer((T.int64(4), T.int64(2)), "float32"), ): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i0, i1 in T.grid(T.int64(4), T.int64(2)): with T.sblock("T_reshape"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) @@ -669,8 +669,8 @@ def test_reshape(): def te_resize2d_symbolic(): - oh = tir.Var("oh", "int64") - ow = tir.Var("ow", "int64") + oh = tirx.Var("oh", "int64") + ow = tirx.Var("ow", "int64") roi = (0.0, 0.0, 0.0, 0.0) A = te.placeholder((2, 3, 128, 128), "float32", name="A") B = topi.image.resize2d( @@ -689,7 +689,7 @@ def tir_resize2d_symbolic( A: T.Buffer((T.int64(2), T.int64(3), T.int64(128), T.int64(128)), "float32"), var_resize: T.handle, ): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) oh = T.int64() ow = T.int64() resize = T.match_buffer(var_resize, [T.int64(2), T.int64(3), oh, ow], dtype="float32") @@ -742,7 +742,7 @@ def te_extern(): C = te.extern( (128, 128), [A, B, P], - lambda ins, outs: tvm.tir.call_extern( + lambda ins, outs: tvm.tirx.call_extern( "", "myfunc", ins[0].data, ins[1].data, outs[0].data, ins[2][0] ), name="C", @@ -751,7 +751,7 @@ def te_extern(): @T.prim_func def tir_extern(var_A: T.handle, var_B: T.handle, var_P: T.handle, var_C: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(var_A, [128, 128], dtype="float32", offset_factor=1) B = T.match_buffer(var_B, [128, 128], dtype="float32", offset_factor=1) P = T.match_buffer(var_P, [1], dtype="float32", offset_factor=1) @@ -775,7 +775,7 @@ def te_slice_with_var_input(): @T.prim_func def tir_slice_with_var_input(var_tensor: T.handle, idx: T.int64, var_slice: T.handle): - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) m, n = T.int64(), T.int64() tensor = T.match_buffer(var_tensor, (m, n)) slice = T.match_buffer(var_slice, (idx, n)) @@ -798,7 +798,7 @@ def test_loop_aware_initial_value(): @T.prim_func def tir_workload(var_a: T.handle, var_b: T.handle, var_sum_red: T.handle): - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) a = T.match_buffer(var_a, (5, 5)) b = T.match_buffer(var_b, (5,)) sum_red = T.match_buffer(var_sum_red, (5,)) @@ -833,7 +833,7 @@ def test_loop_aware_reducer_combiner(): @T.prim_func def tir_workload(var_a: T.handle, var_b: T.handle, var_sum_red: T.handle): - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) a = T.match_buffer(var_a, (5, 5)) b = T.match_buffer(var_b, (5,)) sum_red = T.match_buffer(var_sum_red, (5,)) @@ -872,7 +872,7 @@ def tir_workload( x: T.Buffer((1, 1024, 16, 40), "float32"), adaptive_pool_avg: T.Buffer((1, 1024, 12, 30), "float32"), ): - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) # fmt: off adaptive_pool_sum = T.sblock_alloc_buffer((1, 1024, 12, 30)) for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30): @@ -923,7 +923,7 @@ def test_nested_reduce_domain_dependency(): def tir_workload( x: T.Buffer((8, 8, 8, 8, 8), "float32"), compute: T.Buffer((8, 8, 8), "float32") ): - T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"tirx.noalias": True, "global_symbol": "main"}) for i0, i1, i2 in T.grid(8, 8, 8): with T.sblock("compute_2"): v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) diff --git a/tests/python/te/test_te_tensor.py b/tests/python/te/test_te_tensor.py index e6e2190f1803..4d0baafccb4c 100644 --- a/tests/python/te/test_te_tensor.py +++ b/tests/python/te/test_te_tensor.py @@ -85,14 +85,14 @@ def test_tensor_comm_reducer(): n = te.size_var("n") A = te.placeholder((m, n), name="A") k = te.reduce_axis((0, n), "k") - mysum = te.comm_reducer(lambda x, y: x + y, lambda t: tvm.tir.const(0, dtype=t)) + mysum = te.comm_reducer(lambda x, y: x + y, lambda t: tvm.tirx.const(0, dtype=t)) C = te.compute((m,), lambda i: mysum(A[i, k], axis=k)) def test_tensor_comm_reducer_overload(): m = te.size_var("m") n = te.size_var("n") - mysum = te.comm_reducer(lambda x, y: x + y, lambda t: tvm.tir.const(0, dtype=t)) + mysum = te.comm_reducer(lambda x, y: x + y, lambda t: tvm.tirx.const(0, dtype=t)) sum_res = mysum(m, n) @@ -117,7 +117,7 @@ def fcombine(x, y): return x[0] + y[0], x[1] + y[1] def fidentity(t0, t1): - return tvm.tir.const(0, t0), tvm.tir.const(1, t1) + return tvm.tirx.const(0, t0), tvm.tirx.const(1, t1) mysum = te.comm_reducer(fcombine, fidentity, name="mysum") @@ -168,8 +168,8 @@ def test_extern(): A = te.placeholder((m,), name="A") def extern_func(ins, outs): - assert isinstance(ins[0], tvm.tir.Buffer) - return tvm.tir.call_packed("myadd", ins[0].data, outs[0].data, m) + assert isinstance(ins[0], tvm.tirx.Buffer) + return tvm.tirx.call_packed("myadd", ins[0].data, outs[0].data, m) B = te.extern((m,), [A], extern_func) assert tuple(B.shape) == (m,) @@ -181,8 +181,8 @@ def test_extern_multi_out(): B = te.compute((m,), lambda i: A[i] * 10) def extern_func(ins, outs): - assert isinstance(ins[0], tvm.tir.Buffer) - return tvm.tir.call_packed("myadd", ins[0].data, outs[0].data, outs[1].data, m) + assert isinstance(ins[0], tvm.tirx.Buffer) + return tvm.tirx.call_packed("myadd", ins[0].data, outs[0].data, outs[1].data, m) res = te.extern([A.shape, A.shape], [A, B], extern_func) assert len(res) == 2 diff --git a/tests/python/testing/test_tvm_testing_before_after.py b/tests/python/testing/test_tvm_testing_before_after.py index a1f438c61d04..195d13808c38 100644 --- a/tests/python/testing/test_tvm_testing_before_after.py +++ b/tests/python/testing/test_tvm_testing_before_after.py @@ -19,7 +19,7 @@ import tvm import tvm.testing from tvm.script import ir_module -from tvm.script import tir as T +from tvm.script import tirx as T def test_before_after_prim_func(): diff --git a/tests/python/tir-base/test_tir_constructor.py b/tests/python/tir-base/test_tir_constructor.py deleted file mode 100644 index 654c22ab3515..000000000000 --- a/tests/python/tir-base/test_tir_constructor.py +++ /dev/null @@ -1,218 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E711 - -import pytest - -import tvm -from tvm import te, topi - - -def test_expr_constructor(): - x = tvm.tir.Var("xx", "float32") - assert isinstance(x, tvm.tir.Var) - assert x.name == "xx" - - x = tvm.tir.Reduce(None, [1], [tvm.tir.IterVar((0, 1), "x", 2)], None, 0) - assert isinstance(x, tvm.tir.Reduce) - assert x.combiner == None - assert x.value_index == 0 - - x = tvm.tir.FloatImm("float32", 1.0) - assert isinstance(x, tvm.tir.FloatImm) - assert x.value == 1.0 - assert x.dtype == "float32" - - x = tvm.tir.IntImm("int64", 2) - assert isinstance(x, tvm.tir.IntImm) - assert x.value == 2 - assert x.dtype == "int64" - - x = tvm.tir.StringImm("xyza") - assert isinstance(x, tvm.tir.StringImm) - assert x.value == "xyza" - - x = tvm.tir.Cast("float32", tvm.tir.IntImm("uint32", 1)) - assert isinstance(x, tvm.tir.Cast) - assert x.dtype == "float32" - assert x.value.value == 1 - - a = tvm.tir.const(1.0, dtype="float32") - b = tvm.tir.Var("x", "float32") - - for cls in [ - tvm.tir.Add, - tvm.tir.Sub, - tvm.tir.Mul, - tvm.tir.Div, - tvm.tir.Mod, - tvm.tir.Min, - tvm.tir.Max, - tvm.tir.LT, - tvm.tir.LE, - tvm.tir.GT, - tvm.tir.GE, - ]: - x = cls(a, b) - assert isinstance(x, cls) - assert x.a == a - assert x.b.same_as(b) - - a = tvm.runtime.convert(tvm.tir.Var("x", "int32") > 1) - b = tvm.runtime.convert(tvm.tir.Var("x", "int32") == 1) - - for cls in [tvm.tir.And, tvm.tir.Or]: - x = cls(a, b) - assert isinstance(x, cls) - assert x.a == a - assert x.b.same_as(b) - - x = tvm.tir.Not(a) - assert isinstance(x, tvm.tir.Not) - assert x.a == a - - x = tvm.tir.Select(a, a, b) - assert isinstance(x, tvm.tir.Select) - assert x.true_value == a - assert x.false_value == b - assert x.condition == a - - buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"))) - buffer = tvm.tir.decl_buffer([16], "float32", data=buffer_var) - x = tvm.tir.BufferLoad(buffer, [1]) - assert isinstance(x, tvm.tir.BufferLoad) - assert x.dtype == "float32" - assert x.buffer == buffer - assert x.buffer.data == buffer_var - assert list(x.indices) == [1] - - x = tvm.tir.Ramp(1, 2, 10) - assert isinstance(x, tvm.tir.Ramp) - assert x.base.value == 1 - assert x.stride.value == 2 - assert x.lanes == 10 - - x = tvm.tir.Broadcast(a, 10) - assert isinstance(x, tvm.tir.Broadcast) - assert x.value == a - assert x.lanes == 10 - - x = tvm.tir.Shuffle([a], [0]) - assert isinstance(x, tvm.tir.Shuffle) - assert x.vectors[0] == a - assert x.indices[0].value == 0 - - x = tvm.tir.Call("float32", "tir.call_extern", [tvm.tir.StringImm("xyz"), a]) - assert isinstance(x, tvm.tir.Call) - assert x.dtype == "float32" - assert x.op.name == "tir.call_extern" - assert x.args[1] == a - - v = tvm.tir.Var("aa", "int32") - x = tvm.tir.Let(v, 1, v) - assert x.var == v - assert x.value.value == 1 - assert x.body == v - - -def test_stmt_constructor(): - v = tvm.tir.Var("aa", "int32") - nop = tvm.tir.Evaluate(1) - x = tvm.tir.Bind(v, 1) - assert isinstance(x, tvm.tir.Bind) - assert x.var == v - assert x.value.value == 1 - - x = tvm.tir.AttrStmt(v == 1, "xx", 1, tvm.tir.Evaluate(1)) - assert isinstance(x, tvm.tir.AttrStmt) - assert x.value.value == 1 - - x = tvm.tir.AssertStmt( - tvm.tir.const(1, "bool"), - tvm.tir.StringImm("RuntimeError"), - [tvm.tir.StringImm("hellow")], - ) - assert isinstance(x, tvm.tir.AssertStmt) - assert x.error_kind.value == "RuntimeError" - assert len(x.message_parts) == 1 - assert x.message_parts[0].value == "hellow" - - x = tvm.tir.For(tvm.tir.Var("x", "int32"), 0, 10, tvm.tir.ForKind.SERIAL, nop) - assert isinstance(x, tvm.tir.For) - assert x.min.value == 0 - assert x.extent.value == 10 - assert x.body == nop - - buffer_var = tvm.tir.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("bool"))) - buffer = tvm.tir.decl_buffer([16], "bool", data=buffer_var) - x = tvm.tir.BufferStore(buffer, tvm.tir.IntImm("bool", 1), [10]) - assert isinstance(x, tvm.tir.BufferStore) - assert x.buffer == buffer - assert x.buffer.data == buffer_var - assert list(x.indices) == [10] - assert x.value.value == 1 - - buf = tvm.tir.decl_buffer([10], "float32") - x = tvm.tir.AllocBuffer(buf) - assert isinstance(x, tvm.tir.AllocBuffer) - assert x.buffer == buf - - x = tvm.tir.AttrStmt(buffer_var, "xyz", 1, nop) - assert isinstance(x, tvm.tir.AttrStmt) - assert x.node == buffer_var - assert x.attr_key == "xyz" - assert x.body == nop - - x = tvm.tir.IfThenElse(tvm.tir.const(1, "bool"), tvm.tir.Evaluate(11), nop) - assert isinstance(x, tvm.tir.IfThenElse) - assert x.then_case.value.value == 11 - assert x.else_case == nop - - -def test_float_constructor_requires_float_dtype(): - with pytest.raises(tvm.TVMError): - tvm.tir.FloatImm("int32", 1.0) - - -def test_math_unary_constructor_requires_float_dtype(): - x = tvm.tir.Var("x", "int32") - - with pytest.raises(TypeError, match=r"tir\.tan only supports floating-point inputs"): - tvm.tir.tan(x) - - with pytest.raises(TypeError, match=r"tir\.sin only supports floating-point inputs"): - tvm.tir.sin(x) - - y = tvm.tir.Var("y", "float32") - assert tvm.tir.tan(y).dtype == "float32" - - -def test_topi_tan_requires_float_dtype(): - x = te.placeholder((2, 2), dtype="int32", name="x") - - with pytest.raises(TypeError, match=r"tir\.tan only supports floating-point inputs"): - topi.tan(x) - - -def test_math_unary_constructor_preserves_bfloat16(): - x = tvm.tir.Var("x", "bfloat16") - y = tvm.tir.exp(x) - assert y.dtype == "bfloat16" - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/tir-base/test_tir_op_types.py b/tests/python/tir-base/test_tir_op_types.py deleted file mode 100644 index aefab62559c2..000000000000 --- a/tests/python/tir-base/test_tir_op_types.py +++ /dev/null @@ -1,352 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring -import tvm -import tvm.testing -from tvm import tir - - -def test_tir_op_tvm_tuple(): - x = tir.Var("x", dtype="float32") - y = tir.Var("y", dtype="float32") - z = tir.Var("z", dtype="float32") - expr = tir.tvm_tuple(x, y, z, 1, 2, 3) - assert expr.op.name == "tir.tvm_tuple" - - -def test_tir_op_tvm_struct_get(): - x = tir.Var("x", dtype="handle") - expr = tir.tvm_struct_get(x, 1, 2, dtype="int32") - assert expr.op.name == "tir.tvm_struct_get" - - -def test_tir_op_tvm_struct_set(): - x = tir.Var("x", dtype="handle") - expr = tir.tvm_struct_set(x, 1, 2, 3) - assert expr.op.name == "tir.tvm_struct_set" - - -def test_tir_op_address_of(): - buffer = tir.decl_buffer((128), "float32") - expr = tir.address_of(buffer[0]) - assert expr.op.name == "tir.address_of" - - -def test_tir_op_lookup_param(): - expr = tir.lookup_param("p0") - assert expr.op.name == "tir.lookup_param" - - -def test_tir_op_reinterpret(): - x = tir.Var("x", dtype="int32") - expr = tir.reinterpret("float32", x) - assert expr.op.name == "tir.reinterpret" - - -def test_tir_op_isnullptr(): - x = tir.Var("x", dtype="int32") - expr = tir.isnullptr(x) - assert expr.op.name == "tir.isnullptr" - - -def test_tir_op_call_assume(): - x = tir.Var("x", dtype="int32") - expr = tir.assume(cond=x) - assert expr.op.name == "tir.assume" - - -def test_tir_op_call_undef(): - expr = tir.undef() - assert expr.op.name == "tir.undef" - - -def test_tir_op_call_likely(): - x = tir.Var("x", dtype="int32") - expr = tir.likely(cond=x) - assert expr.op.name == "tir.likely" - - -def test_tir_op_tvm_thread_allreduce(): - x = tir.Var("x", "int32") - buffer = tir.decl_buffer((128), "float32") - y = tir.Var("y", "handle") - z = tir.Var("z", "int32") - expr = tir.tvm_thread_allreduce(x, buffer[0], True, y, z) - assert expr.op.name == "tir.tvm_thread_allreduce" - - -def test_tir_op_type_annotation(): - expr = tir.type_annotation("int32") - assert expr.op.name == "tir.type_annotation" - - -def test_tir_op_tvm_access_ptr(): - buffer = tir.decl_buffer((128), "float32") - expr = tir.tvm_access_ptr("float32", buffer.data, 0, 1, 2) - assert expr.op.name == "tir.tvm_access_ptr" - - -def test_tir_op_tvm_throw_last_error(): - expr = tir.tvm_throw_last_error() - assert expr.op.name == "tir.tvm_throw_last_error" - - -def test_tir_op_tvm_load_matrix_sync(): - buffer = tir.decl_buffer((16, 16), "float32") - x = tir.Var("x", "handle") - expr = tir.tvm_load_matrix_sync(buffer.data, 16, 16, 16, 0, x, 128, "row_major") - assert expr.op.name == "tir.tvm_load_matrix_sync" - - -def test_tir_op_tvm_store_matrix_sync(): - buffer = tir.decl_buffer((16, 16), "float32") - x = tir.Var("x", "handle") - expr = tir.tvm_store_matrix_sync(buffer.data, 16, 16, 16, 0, x, 128, "row_major") - assert expr.op.name == "tir.tvm_store_matrix_sync" - - -def test_tir_op_tvm_mma_sync(): - buffer_0 = tir.decl_buffer((16, 16), "float32") - buffer_1 = tir.decl_buffer((16, 16), "float32") - buffer_2 = tir.decl_buffer((16, 16), "float32") - buffer_3 = tir.decl_buffer((16, 16), "float32") - expr = tir.tvm_mma_sync(buffer_0.data, 0, buffer_1.data, 0, buffer_2.data, 0, buffer_3.data, 0) - assert expr.op.name == "tir.tvm_mma_sync" - - -def test_tir_op_tvm_bmma_sync(): - buffer_0 = tir.decl_buffer((16, 16), "float32") - buffer_1 = tir.decl_buffer((16, 16), "float32") - buffer_2 = tir.decl_buffer((16, 16), "float32") - buffer_3 = tir.decl_buffer((16, 16), "float32") - expr = tir.tvm_bmma_sync(buffer_0.data, 0, buffer_1.data, 0, buffer_2.data, 0, buffer_3.data, 0) - assert expr.op.name == "tir.tvm_bmma_sync" - - -def test_tir_op_tvm_fill_fragment(): - buffer = tir.decl_buffer((16, 16), "float32") - expr = tir.tvm_fill_fragment(buffer.data, 16, 16, 16, 0, 0) - assert expr.op.name == "tir.tvm_fill_fragment" - - -def test_tir_op_ptx_mma(): - buffer_a = tir.decl_buffer([32], "int4", scope="local") - buffer_b = tir.decl_buffer([16], "uint4", scope="local") - buffer_c = tir.decl_buffer([4], "int32", scope="local") - expr = tir.ptx_mma( - "int32", - "m8n8k32", - "row", - "col", - "int4", - "uint4", - "int32", - buffer_a.data, - 0, - buffer_b.data, - 0, - buffer_c.data, - 0, - False, - ) - assert expr.op.name == "tir.ptx_mma" - - -def test_tir_op_ptx_mma_sp(): - buffer_a = tir.decl_buffer([32], "int4", scope="local") - buffer_b = tir.decl_buffer([16], "uint4", scope="local") - buffer_c = tir.decl_buffer([4], "int32", scope="local") - buffer_d = tir.decl_buffer([1], "uint32", scope="local") - expr = tir.ptx_mma_sp( - "int32", - "m8n8k32", - "row", - "col", - "int4", - "uint4", - "int32", - buffer_a.data, - 0, - buffer_b.data, - 0, - buffer_c.data, - 0, - buffer_d.data, - 0, - 0, - False, - ) - assert expr.op.name == "tir.ptx_mma_sp" - - -def test_tir_op_mma_store(): - x = tir.Var("x", dtype="int32") - y = tir.Var("y", dtype="int32") - buffer_w = tir.decl_buffer([16, 8], dtype="int32", scope="warp", offset_factor=1) - buffer = tir.decl_buffer( - [16, 16], dtype="int32", scope="global", offset_factor=1, strides=[x, y] - ) - expr = tir.mma_store( - "int32", - 16, - 16, - buffer.access_ptr("w"), - buffer_w.data, - buffer_w.elem_offset, - x, - ) - assert expr.op.name == "tir.mma_store" - - -def test_tir_op_mma_fill(): - buffer_w = tir.decl_buffer([16, 8], dtype="int32", scope="warp", offset_factor=1) - expr = tir.mma_fill("int32", 8, buffer_w.data, buffer_w.elem_offset) - assert expr.op.name == "tir.mma_fill" - - -def test_op_ptx_ldmatrix(): - buffer_shared = tir.decl_buffer([16, 16], "float16", scope="shared") - buffer_local = tir.decl_buffer([8], "float16", scope="local") - expr = tir.ptx_ldmatrix( - "float16", False, 4, ".b16", buffer_local.data, 0, buffer_shared.data, 0 - ) - assert expr.op.name == "tir.ptx_ldmatrix" - - -def test_op_ptx_cp_async(): - buffer_shared = tir.decl_buffer([16, 16], "float16", scope="shared") - buffer_local = tir.decl_buffer([8], "float16", scope="local") - expr = tir.ptx_cp_async("float16", buffer_shared.data, 0, buffer_local.data, 0, 16) - assert expr.op.name == "tir.ptx_cp_async" - - -def test_op_ptx_cp_async_bulk(): - buffer_shared = tir.decl_buffer([16, 16], "float16", scope="shared") - buffer_local = tir.decl_buffer([8], "float16", scope="local") - expr = tir.ptx_cp_async_bulk("float16", buffer_shared.data, 0, buffer_local.data, 0, 16, 0) - assert expr.op.name == "tir.ptx_cp_async_bulk" - - -def test_op_ptx_commit_group(): - expr = tir.ptx_commit_group() - assert expr.op.name == "tir.ptx_commit_group" - - -def test_op_ptx_wait_group(): - expr = tir.ptx_wait_group(8) - assert expr.op.name == "tir.ptx_wait_group" - - -def test_op_ptx_cp_async_barrier(): - expr = tir.ptx_cp_async_barrier(0) - assert expr.op.name == "tir.ptx_cp_async_barrier" - - -def test_op_ptx_init_barrier_thread_count(): - expr = tir.ptx_init_barrier_thread_count(0, 32) - assert expr.op.name == "tir.ptx_init_barrier_thread_count" - - -def test_op_ptx_arrive_barrier(): - expr = tir.ptx_arrive_barrier(0) - assert expr.op.name == "tir.ptx_arrive_barrier" - - -def test_op_ptx_arrive_barrier_expect_tx(): - expr = tir.ptx_arrive_barrier_expect_tx(0, 32) - assert expr.op.name == "tir.ptx_arrive_barrier_expect_tx" - - -def test_op_ptx_wait_barrier(): - expr = tir.ptx_wait_barrier(0) - assert expr.op.name == "tir.ptx_wait_barrier" - - -def test_op_create_barriers(): - expr = tir.create_barriers(16) - assert expr.op.name == "tir.create_barriers" - - -def test_tir_op_vectorlow(): - buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1) - vec = buffer.vload([0, 0], dtype="int8x16") - expr = tir.vectorlow("int8x8", vec) - assert expr.op.name == "tir.vectorlow" - - -def test_tir_op_vectorhigh(): - buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1) - vec = buffer.vload([0, 0], dtype="int8x16") - expr = tir.vectorhigh("int8x8", vec) - assert expr.op.name == "tir.vectorhigh" - - -def test_tir_op_dp4a(): - vec1 = tir.Var("vec1", dtype="int8x4") - vec2 = tir.Var("vec2", dtype="int8x4") - acc = tir.Var("acc", dtype="int32") - expr = tir.dp4a(vec1, vec2, acc) - assert expr.op.name == "tir.dp4a" - - -def test_tir_op_vectorcombine(): - buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1) - vec = buffer.vload([0, 0], dtype="int8x16") - expr = tir.vectorcombine("int8x8", vec, vec) - assert expr.op.name == "tir.vectorcombine" - - -def test_tir_op_shift_left(): - x = tir.Var("x", dtype="int32") - y = tir.Var("x", dtype="int32") - expr = tir.shift_left(x, y) - assert expr.op.name == "tir.shift_left" - - -def test_tir_op_shift_right(): - x = tir.Var("x", dtype="int32") - y = tir.Var("x", dtype="int32") - expr = tir.shift_right(x, y) - assert expr.op.name == "tir.shift_right" - - -def test_tir_op_bitwise(): - x = tir.Var("x", dtype="int32") - y = tir.Var("y", dtype="int32") - expr = tir.bitwise_and(x, y) - assert expr.op.name == "tir.bitwise_and" - expr = tir.bitwise_or(x, y) - assert expr.op.name == "tir.bitwise_or" - expr = tir.bitwise_not(x) - assert expr.op.name == "tir.bitwise_not" - expr = tir.bitwise_xor(x, y) - assert expr.op.name == "tir.bitwise_xor" - - -def test_tir_op_TVMBackendAllocWorkspace(): - expr = tir.TVMBackendAllocWorkspace(0, 1, 2, 3, 4) - assert expr.op.name == "tir.TVMBackendAllocWorkspace" - - -def test_tir_op_TVMBackendFreeWorkspace(): - buffer = tir.decl_buffer((128), "float32") - expr = tir.TVMBackendFreeWorkspace(0, 1, buffer.data) - assert expr.op.name == "tir.TVMBackendFreeWorkspace" - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/tir-analysis/test_tir_analysis_expr_deep_equal.py b/tests/python/tirx-analysis/test_tir_analysis_expr_deep_equal.py similarity index 73% rename from tests/python/tir-analysis/test_tir_analysis_expr_deep_equal.py rename to tests/python/tirx-analysis/test_tir_analysis_expr_deep_equal.py index 5ed8314d1aaf..0ca88575a4d9 100644 --- a/tests/python/tir-analysis/test_tir_analysis_expr_deep_equal.py +++ b/tests/python/tirx-analysis/test_tir_analysis_expr_deep_equal.py @@ -18,18 +18,18 @@ def test_equal_expr(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") def func1(): return x + y + 1 def func2(): - return tvm.tir.exp(tvm.tir.truncdiv((x + y + 1) * y, 4)) + return tvm.tirx.exp(tvm.tirx.truncdiv((x + y + 1) * y, 4)) - assert tvm.tir.analysis.expr_deep_equal(func1(), func1()) - assert tvm.tir.analysis.expr_deep_equal(func2(), func2()) - assert not tvm.tir.analysis.expr_deep_equal(func2(), func1()) + assert tvm.tirx.analysis.expr_deep_equal(func1(), func1()) + assert tvm.tirx.analysis.expr_deep_equal(func2(), func2()) + assert not tvm.tirx.analysis.expr_deep_equal(func2(), func1()) if __name__ == "__main__": diff --git a/tests/python/tir-analysis/test_tir_analysis_undefined_vars.py b/tests/python/tirx-analysis/test_tir_analysis_undefined_vars.py similarity index 70% rename from tests/python/tir-analysis/test_tir_analysis_undefined_vars.py rename to tests/python/tirx-analysis/test_tir_analysis_undefined_vars.py index c975e6afb824..47a974791661 100644 --- a/tests/python/tir-analysis/test_tir_analysis_undefined_vars.py +++ b/tests/python/tirx-analysis/test_tir_analysis_undefined_vars.py @@ -14,11 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Tests for tir.analysis.undefined_vars (VarUseDefAnalyzer).""" +"""Tests for tirx.analysis.undefined_vars (VarUseDefAnalyzer).""" import tvm import tvm.testing -from tvm import tir +from tvm import tirx def test_decl_buffer_data_is_use(): @@ -28,17 +28,17 @@ def test_decl_buffer_data_is_use(): an existing variable from the enclosing scope. It must appear in the undefined list so that callers (e.g., CreateComputeScope) capture it. """ - n = tir.SizeVar("n", "int32") + n = tirx.SizeVar("n", "int32") from tvm.ir import PointerType, PrimType - data_ptr = tir.Var("buf_data", PointerType(PrimType("float32"))) - buf = tir.decl_buffer((n,), "float32", "buf", data=data_ptr) + data_ptr = tirx.Var("buf_data", PointerType(PrimType("float32"))) + buf = tirx.decl_buffer((n,), "float32", "buf", data=data_ptr) - body = tir.Evaluate(tir.BufferLoad(buf, [0])) - decl = tir.DeclBuffer(buf) - stmt = tir.SeqStmt([decl, body]) + body = tirx.Evaluate(tirx.BufferLoad(buf, [0])) + decl = tirx.DeclBuffer(buf) + stmt = tirx.SeqStmt([decl, body]) - undef = tvm.tir.analysis.undefined_vars(stmt, []) + undef = tvm.tirx.analysis.undefined_vars(stmt, []) undef_names = {v.name for v in undef} # data_ptr must be undefined (it comes from outside the DeclBuffer) assert "buf_data" in undef_names, f"Expected buf_data in undefined vars, got {undef_names}" @@ -52,16 +52,16 @@ def test_decl_buffer_elem_offset_is_use(): """ from tvm.ir import PointerType, PrimType - n = tir.SizeVar("n", "int32") - data_ptr = tir.Var("buf_data", PointerType(PrimType("float32"))) - elem_off = tir.Var("buf_elem_offset", "int32") - buf = tir.decl_buffer((n,), "float32", "buf", data=data_ptr, elem_offset=elem_off) + n = tirx.SizeVar("n", "int32") + data_ptr = tirx.Var("buf_data", PointerType(PrimType("float32"))) + elem_off = tirx.Var("buf_elem_offset", "int32") + buf = tirx.decl_buffer((n,), "float32", "buf", data=data_ptr, elem_offset=elem_off) - body = tir.Evaluate(tir.BufferLoad(buf, [0])) - decl = tir.DeclBuffer(buf) - stmt = tir.SeqStmt([decl, body]) + body = tirx.Evaluate(tirx.BufferLoad(buf, [0])) + decl = tirx.DeclBuffer(buf) + stmt = tirx.SeqStmt([decl, body]) - undef = tvm.tir.analysis.undefined_vars(stmt, []) + undef = tvm.tirx.analysis.undefined_vars(stmt, []) undef_names = {v.name for v in undef} assert "buf_data" in undef_names, f"Expected buf_data in undefined vars, got {undef_names}" assert "buf_elem_offset" in undef_names, ( @@ -75,14 +75,14 @@ def test_alloc_buffer_data_is_def(): AllocBuffer allocates new storage — the data pointer is a new definition, not a reference to an external variable. """ - n = tir.SizeVar("n", "int32") - buf = tir.decl_buffer((n,), "float32", "buf") + n = tirx.SizeVar("n", "int32") + buf = tirx.decl_buffer((n,), "float32", "buf") - body = tir.Evaluate(tir.BufferLoad(buf, [0])) - alloc = tir.AllocBuffer(buf) - stmt = tir.SeqStmt([alloc, body]) + body = tirx.Evaluate(tirx.BufferLoad(buf, [0])) + alloc = tirx.AllocBuffer(buf) + stmt = tirx.SeqStmt([alloc, body]) - undef = tvm.tir.analysis.undefined_vars(stmt, []) + undef = tvm.tirx.analysis.undefined_vars(stmt, []) undef_names = {v.name for v in undef} # data should NOT be undefined — AllocBuffer defines it assert buf.data.name not in undef_names, ( diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py b/tests/python/tirx-analysis/test_tir_analysis_verify_ssa.py similarity index 61% rename from tests/python/tir-analysis/test_tir_analysis_verify_ssa.py rename to tests/python/tirx-analysis/test_tir_analysis_verify_ssa.py index 07611e55ace7..d81582c9e3c6 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_ssa.py +++ b/tests/python/tirx-analysis/test_tir_analysis_verify_ssa.py @@ -18,23 +18,23 @@ def test_verify_ssa(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("tindex", "int32") - z = tvm.tir.Evaluate(x + y) - assert tvm.tir.analysis.verify_ssa(tvm.tir.PrimFunc([x, y], z)) + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("tindex", "int32") + z = tvm.tirx.Evaluate(x + y) + assert tvm.tirx.analysis.verify_ssa(tvm.tirx.PrimFunc([x, y], z)) - assert not tvm.tir.analysis.verify_ssa( - tvm.tir.PrimFunc([x, y], tvm.tir.SeqStmt([tvm.tir.Bind(x, 1), z])) + assert not tvm.tirx.analysis.verify_ssa( + tvm.tirx.PrimFunc([x, y], tvm.tirx.SeqStmt([tvm.tirx.Bind(x, 1), z])) ) def test_verify_weak_let_ssa(): - x = tvm.tir.Var("x", "int32") - z1 = tvm.tir.Let(x, 1, x + 1) - z2 = tvm.tir.Let(x, 2, x + 2) + x = tvm.tirx.Var("x", "int32") + z1 = tvm.tirx.Let(x, 1, x + 1) + z2 = tvm.tirx.Let(x, 2, x + 2) - assert tvm.tir.analysis.verify_ssa(tvm.tir.PrimFunc([], tvm.tir.Evaluate(z1 + z1))) - assert not tvm.tir.analysis.verify_ssa(tvm.tir.PrimFunc([], tvm.tir.Evaluate(z1 * z2))) + assert tvm.tirx.analysis.verify_ssa(tvm.tirx.PrimFunc([], tvm.tirx.Evaluate(z1 + z1))) + assert not tvm.tirx.analysis.verify_ssa(tvm.tirx.PrimFunc([], tvm.tirx.Evaluate(z1 * z2))) if __name__ == "__main__": diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tirx-analysis/test_tir_analysis_verify_well_formed.py similarity index 84% rename from tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py rename to tests/python/tirx-analysis/test_tir_analysis_verify_well_formed.py index d6c1dae3b64c..b0541eb8ef69 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py +++ b/tests/python/tirx-analysis/test_tir_analysis_verify_well_formed.py @@ -21,7 +21,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_pass_simple(): @@ -40,8 +40,8 @@ def element_wise( # It's a opaque block , so it can use outside variables C[i, j] = B[i, j] * 2.0 - assert tvm.tir.analysis.verify_well_formed(element_wise) - assert tvm.tir.analysis.verify_well_formed(tvm.IRModule.from_expr(element_wise)) + assert tvm.tirx.analysis.verify_well_formed(element_wise) + assert tvm.tirx.analysis.verify_well_formed(tvm.IRModule.from_expr(element_wise)) def test_fail_use_out_loop_var(): @@ -56,7 +56,7 @@ def element_wise( # we cannot use `i` since it's defined outside the block B[vi, vj] = A[i, vj] * 2.0 - assert not tvm.tir.analysis.verify_well_formed(element_wise, assert_mode=False) + assert not tvm.tirx.analysis.verify_well_formed(element_wise, assert_mode=False) def test_error_for_out_of_scope_usage(): @@ -67,23 +67,23 @@ def test_error_for_out_of_scope_usage(): the Bind is inside a child scope (e.g., ForNode body) and the variable is used outside that scope. """ - i = tvm.tir.Var("i", "int32") + i = tvm.tirx.Var("i", "int32") # Bind i inside a For loop body - for_stmt = tvm.tir.For( - tvm.tir.Var("j", "int32"), + for_stmt = tvm.tirx.For( + tvm.tirx.Var("j", "int32"), 0, 1, - tvm.tir.ForKind.SERIAL, - tvm.tir.SeqStmt([tvm.tir.Bind(i, 42), tvm.tir.Evaluate(i)]), + tvm.tirx.ForKind.SERIAL, + tvm.tirx.SeqStmt([tvm.tirx.Bind(i, 42), tvm.tirx.Evaluate(i)]), ) # Use i outside the For loop — this is out of scope - body = tvm.tir.SeqStmt([for_stmt, tvm.tir.Evaluate(i)]) - func = tvm.tir.PrimFunc([], body) + body = tvm.tirx.SeqStmt([for_stmt, tvm.tirx.Evaluate(i)]) + func = tvm.tirx.PrimFunc([], body) with pytest.raises( ValueError, match="Invalid use of undefined variable i at .* no longer in-scope." ): - tvm.tir.analysis.verify_well_formed(func) + tvm.tirx.analysis.verify_well_formed(func) def test_error_for_nested_rebind_usage(): @@ -99,7 +99,7 @@ def func(): with pytest.raises( ValueError, match="ill-formed, due to multiple nested definitions of variable i" ): - tvm.tir.analysis.verify_well_formed(func) + tvm.tirx.analysis.verify_well_formed(func) def test_error_for_repeated_binding(): @@ -119,13 +119,13 @@ def func(): T.evaluate(i) with pytest.raises(ValueError, match="multiple nested definitions of variable i"): - tvm.tir.analysis.verify_well_formed(func) + tvm.tirx.analysis.verify_well_formed(func) def test_error_for_cross_function_reuse(): """A variable may not be re-defined in another function""" - i = tvm.tir.Var("i", "int32") + i = tvm.tirx.Var("i", "int32") @I.ir_module(check_well_formed=False) class mod: @@ -140,7 +140,7 @@ def func2(): T.evaluate(i) with pytest.raises(ValueError, match="multiple definitions of variable i"): - tvm.tir.analysis.verify_well_formed(mod) + tvm.tirx.analysis.verify_well_formed(mod) def test_reuse_of_env_thread_in_function_is_well_formed(): @@ -159,7 +159,7 @@ def func(A: T.Buffer([256], "float32")): with T.launch_thread(threadIdx_x, 256): A[threadIdx_x] = A[threadIdx_x] + 2.0 - tvm.tir.analysis.verify_well_formed(func) + tvm.tirx.analysis.verify_well_formed(func) def test_reuse_of_env_thread_in_function_is_mandatory(): @@ -168,7 +168,7 @@ def test_reuse_of_env_thread_in_function_is_mandatory(): Not only are environment threads allowed to have multiple definition sites, it is mandatory for them to have multiple definition sites. If a PrimFunc contains more than one - `"thread_extent"` with the same name, but with different `tir.Var` + `"thread_extent"` with the same name, but with different `tirx.Var` instances, it is ill-formed. """ @@ -180,18 +180,18 @@ def func(A: T.Buffer([256], "float32")): with T.launch_thread("threadIdx.x", 256) as threadIdx_x: A[threadIdx_x] = A[threadIdx_x] + 2.0 - tvm.tir.analysis.verify_well_formed(func) + tvm.tirx.analysis.verify_well_formed(func) def test_reuse_of_env_thread_across_functions_is_ill_formed(): """An env thread may not be reused across PrimFunc - However, each function must have its own `tir.Var` representing + However, each function must have its own `tirx.Var` representing the environment thread, and may not share these variables across PrimFuncs. """ - threadIdx_x = tvm.tir.Var("threadIdx_x", "int32") + threadIdx_x = tvm.tirx.Var("threadIdx_x", "int32") @I.ir_module(check_well_formed=False) class mod: @@ -214,7 +214,7 @@ def kernel_2(A: T.Buffer([256], "float32")): A[threadIdx_x] = A[threadIdx_x] + T.float32(1) with pytest.raises(ValueError, match="multiple definitions of variable threadIdx_x"): - tvm.tir.analysis.verify_well_formed(mod) + tvm.tirx.analysis.verify_well_formed(mod) def test_multiple_buffer_arguments_may_share_allocation(): @@ -234,7 +234,7 @@ def func(A_handle: T.handle, B_handle: T.handle): pass - tvm.tir.analysis.verify_well_formed(mod) + tvm.tirx.analysis.verify_well_formed(mod) def test_block_match_buffer_defines_buffer_obj(): @@ -253,7 +253,7 @@ def func(A: T.Buffer([256, 256], "float32")): ) B[i, j] = 0.0 - tvm.tir.analysis.verify_well_formed(mod) + tvm.tirx.analysis.verify_well_formed(mod) def test_block_match_buffer_defines_symbolic_variables(): @@ -276,7 +276,7 @@ def func(A: T.Buffer([256, 256], "int32")): B[i, j] = elem_offset - tvm.tir.analysis.verify_well_formed(mod) + tvm.tirx.analysis.verify_well_formed(mod) def test_error_message_without_previous_definition_location(): @@ -302,7 +302,7 @@ def func(): T.evaluate(x) with pytest.raises(ValueError) as exc_info: - tvm.tir.analysis.verify_well_formed(func, assert_mode=True) + tvm.tirx.analysis.verify_well_formed(func, assert_mode=True) error_msg = str(exc_info.value) @@ -327,7 +327,7 @@ def func(): T.evaluate(x) with pytest.raises(ValueError) as exc_info: - tvm.tir.analysis.verify_well_formed(func, assert_mode=True) + tvm.tirx.analysis.verify_well_formed(func, assert_mode=True) error_msg = str(exc_info.value) @@ -358,7 +358,7 @@ def func(): T.evaluate(x) with pytest.raises(ValueError) as exc_info: - tvm.tir.analysis.verify_well_formed(func, assert_mode=True) + tvm.tirx.analysis.verify_well_formed(func, assert_mode=True) error_msg = str(exc_info.value) @@ -376,7 +376,7 @@ def func(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): for i in T.grid(128): B[i] = A[i] * 2.0 - tvm.tir.analysis.verify_well_formed(func) + tvm.tirx.analysis.verify_well_formed(func) def test_decl_buffer_is_well_formed(): @@ -388,7 +388,7 @@ def func(A: T.Buffer((128,), "float32")): for i in T.grid(128): B[i] = A[i] * 2.0 - tvm.tir.analysis.verify_well_formed(func) + tvm.tirx.analysis.verify_well_formed(func) def test_alloc_buffer_in_block_is_well_formed(): @@ -405,7 +405,7 @@ def func(A: T.Buffer((128,), "float32")): vi = T.axis.remap("S", [i]) B[vi] = A[vi] * 2.0 - tvm.tir.analysis.verify_well_formed(mod) + tvm.tirx.analysis.verify_well_formed(mod) def test_match_buffer_in_block_is_well_formed(): @@ -424,40 +424,40 @@ def func(A: T.Buffer((128, 128), "float32")): ) A_tile[i, j] = A_tile[i, j] * 2.0 - tvm.tir.analysis.verify_well_formed(mod) + tvm.tirx.analysis.verify_well_formed(mod) def test_error_undeclared_buffer_in_schedulable_tir(): """In schedule-level TIR (with SBlock nodes), all buffers must be declared.""" # Manually construct a BufferStore that uses a buffer without any declaration # inside a block context. - n = tvm.tir.SizeVar("n", "int32") - A = tvm.tir.decl_buffer([n], "float32", name="A") - i = tvm.tir.Var("i", "int32") + n = tvm.tirx.SizeVar("n", "int32") + A = tvm.tirx.decl_buffer([n], "float32", name="A") + i = tvm.tirx.Var("i", "int32") # Create an undeclared buffer using an explicit data pointer that is NOT # in the buffer_map and NOT wrapped with DeclBuffer. - B_data = tvm.tir.Var("B_data", tvm.ir.PointerType(tvm.ir.PrimType("float32"))) - B = tvm.tir.decl_buffer([n], "float32", name="B", data=B_data) + B_data = tvm.tirx.Var("B_data", tvm.ir.PointerType(tvm.ir.PrimType("float32"))) + B = tvm.tirx.decl_buffer([n], "float32", name="B", data=B_data) # Build a block that writes to B without any declaration of B. - bi = tvm.tir.SizeVar("bi", "int32") - block = tvm.tir.SBlock( - iter_vars=[tvm.tir.IterVar(tvm.ir.Range(0, n), bi, 0)], # 0 = kDataPar - reads=[tvm.tir.BufferRegion(A, [tvm.ir.Range(bi, bi + 1)])], - writes=[tvm.tir.BufferRegion(B, [tvm.ir.Range(bi, bi + 1)])], - body=tvm.tir.BufferStore(B, tvm.tir.BufferLoad(A, [bi]), [bi]), + bi = tvm.tirx.SizeVar("bi", "int32") + block = tvm.tirx.SBlock( + iter_vars=[tvm.tirx.IterVar(tvm.ir.Range(0, n), bi, 0)], # 0 = kDataPar + reads=[tvm.tirx.BufferRegion(A, [tvm.ir.Range(bi, bi + 1)])], + writes=[tvm.tirx.BufferRegion(B, [tvm.ir.Range(bi, bi + 1)])], + body=tvm.tirx.BufferStore(B, tvm.tirx.BufferLoad(A, [bi]), [bi]), name_hint="write_B", ) - block_realize = tvm.tir.SBlockRealize( + block_realize = tvm.tirx.SBlockRealize( iter_values=[i], - predicate=tvm.tir.const(True), + predicate=tvm.tirx.const(True), block=block, ) - prim_func = tvm.tir.PrimFunc( + prim_func = tvm.tirx.PrimFunc( params=[A.data, B_data], - body=tvm.tir.For(i, 0, n, tvm.tir.ForKind.SERIAL, block_realize), + body=tvm.tirx.For(i, 0, n, tvm.tirx.ForKind.SERIAL, block_realize), buffer_map={A.data: A}, # Note: B is NOT in buffer_map, so its declaration scope is only # within a DeclBuffer node (which we intentionally omit here). @@ -465,7 +465,7 @@ def test_error_undeclared_buffer_in_schedulable_tir(): # B is used in the block but was never declared — should fail. with pytest.raises(ValueError, match="buffer B.*without a prior DeclBuffer"): - tvm.tir.analysis.verify_well_formed(prim_func) + tvm.tirx.analysis.verify_well_formed(prim_func) if __name__ == "__main__": diff --git a/tests/python/tir-base/test_tir_base.py b/tests/python/tirx-base/test_tir_base.py similarity index 86% rename from tests/python/tir-base/test_tir_base.py rename to tests/python/tirx-base/test_tir_base.py index 53d2a60f7abd..7607e9639994 100644 --- a/tests/python/tir-base/test_tir_base.py +++ b/tests/python/tirx-base/test_tir_base.py @@ -21,17 +21,17 @@ import pytest import tvm -from tvm import tir +from tvm import tirx from tvm.base import TVMError from tvm.ir.transform import PassContext -from tvm.script import tir as T +from tvm.script import tirx as T def build_tir_func(func): func = func.with_attr("global_symbol", "main") pass_ctx = PassContext.current() - if pass_ctx.config.get("tir.noalias", True): - func = func.with_attr("tir.noalias", True) + if pass_ctx.config.get("tirx.noalias", True): + func = func.with_attr("tirx.noalias", True) mod = tvm.IRModule({"main": func}) func = tvm.compile(mod) return func @@ -46,23 +46,23 @@ def test_scalar_add(): rhs_types = ["float32", "float16"] for lhs_type, rhs_type in itertools.product(lhs_types, rhs_types): # Input vars should be float32, we will cast to test for upcasting between them - lhs_input = tir.Var("lhs", "float32") - rhs_input = tir.Var("rhs", "float32") - lhs = tir.Cast(lhs_type, lhs_input) - rhs = tir.Cast(rhs_type, rhs_input) + lhs_input = tirx.Var("lhs", "float32") + rhs_input = tirx.Var("rhs", "float32") + lhs = tirx.Cast(lhs_type, lhs_input) + rhs = tirx.Cast(rhs_type, rhs_input) output = lhs + rhs - output = tir.ret(output) - output = tir.Evaluate(output) - func = tir.PrimFunc([lhs_input, rhs_input], output) + output = tirx.ret(output) + output = tirx.Evaluate(output) + func = tirx.PrimFunc([lhs_input, rhs_input], output) func = build_tir_func(func) out = func(1.0, 2.0) assert out == 3.0 def assignment_helper(store_dtype, value_dtype): - store = tir.Var("store", dtype=store_dtype) - value = tir.Var("value", dtype=value_dtype) - tir.Let(store, value, body=store) + store = tirx.Var("store", dtype=store_dtype) + value = tirx.Var("value", dtype=value_dtype) + tirx.Let(store, value, body=store) def test_fail_implicit_downcasts_same_type(): @@ -94,10 +94,10 @@ def test_cast_between_types(): def test_ret_const(): - a = tir.const(0) - b = tir.ret(a) - b = tir.Evaluate(b) - func = tir.PrimFunc([], b) + a = tirx.const(0) + b = tirx.ret(a) + b = tirx.Evaluate(b) + func = tirx.PrimFunc([], b) func = build_tir_func(func) out = func() assert out == 0 @@ -172,16 +172,16 @@ def func(Out: T.Buffer[(2,), "int32"]): def test_exception(): with pytest.raises(TypeError): - x = tir.Var(name=1, dtype="int") + x = tirx.Var(name=1, dtype="int") def test_eq_ops(): - a = tir.IntImm("int8", 1) + a = tirx.IntImm("int8", 1) with pytest.raises(ValueError): assert a != None with pytest.raises(ValueError): assert not a == None - b = tir.StringImm("abc") + b = tirx.StringImm("abc") assert b != None assert not b == None diff --git a/tests/python/tir-base/test_tir_buffer.py b/tests/python/tirx-base/test_tir_buffer.py similarity index 74% rename from tests/python/tir-base/test_tir_buffer.py rename to tests/python/tirx-base/test_tir_buffer.py index 49bdb243cad1..bcdd0830a7f3 100644 --- a/tests/python/tir-base/test_tir_buffer.py +++ b/tests/python/tirx-base/test_tir_buffer.py @@ -21,26 +21,26 @@ import tvm import tvm.testing -from tvm.script import tir as T -from tvm.tir import Buffer +from tvm.script import tirx as T +from tvm.tirx import Buffer def test_buffer(): - m = tvm.tir.SizeVar("m", "int32") - n = tvm.tir.SizeVar("n", "int32") - l = tvm.tir.SizeVar("l", "int32") - Ab = tvm.tir.decl_buffer((m, n), "float32") - Bb = tvm.tir.decl_buffer((n, l), "float32") + m = tvm.tirx.SizeVar("m", "int32") + n = tvm.tirx.SizeVar("n", "int32") + l = tvm.tirx.SizeVar("l", "int32") + Ab = tvm.tirx.decl_buffer((m, n), "float32") + Bb = tvm.tirx.decl_buffer((n, l), "float32") - assert isinstance(Ab, tvm.tir.Buffer) + assert isinstance(Ab, tvm.tirx.Buffer) assert Ab.dtype == "float32" assert tuple(Ab.shape) == (m, n) def test_buffer_access_ptr(): - m = tvm.tir.SizeVar("m", "int32") - n = tvm.tir.SizeVar("n", "int32") - Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1, 1]) + m = tvm.tirx.SizeVar("m", "int32") + n = tvm.tirx.SizeVar("n", "int32") + Ab = tvm.tirx.decl_buffer((m, n), "float32", strides=[n + 1, 1]) aptr = Ab.access_ptr("rw") tvm.ir.assert_structural_equal(aptr.args[3], Ab.strides[0] * m) assert aptr.args[0].dtype == Ab.dtype @@ -50,32 +50,32 @@ def test_buffer_access_ptr(): def test_buffer_access_ptr_offset(): - m = tvm.tir.SizeVar("m", "int32") - n = tvm.tir.SizeVar("n", "int32") - Ab = tvm.tir.decl_buffer((m, n), "float32") + m = tvm.tirx.SizeVar("m", "int32") + n = tvm.tirx.SizeVar("n", "int32") + Ab = tvm.tirx.decl_buffer((m, n), "float32") aptr = Ab.access_ptr("rw", offset=100) tvm.testing.assert_prim_expr_equal(aptr.args[2], 100) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE - v = tvm.tir.SizeVar("int32", "int32") + v = tvm.tirx.SizeVar("int32", "int32") aptr = Ab.access_ptr("rw", offset=100 + 100 + v) tvm.testing.assert_prim_expr_equal(aptr.args[2], 200 + v) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE - aptr = Ab.access_ptr("rw", offset=tvm.tir.call_extern("int32", "test_call", 100 + 100 + v)) + aptr = Ab.access_ptr("rw", offset=tvm.tirx.call_extern("int32", "test_call", 100 + 100 + v)) tvm.testing.assert_prim_expr_equal( - aptr.args[2], tvm.tir.call_extern("int32", "test_call", 200 + v) + aptr.args[2], tvm.tirx.call_extern("int32", "test_call", 200 + v) ) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE def test_buffer_access_ptr_extent(): - m = tvm.tir.SizeVar("m", "int32") - n = tvm.tir.SizeVar("n", "int32") - Ab = tvm.tir.decl_buffer((m, n), "float32") + m = tvm.tirx.SizeVar("m", "int32") + n = tvm.tirx.SizeVar("n", "int32") + Ab = tvm.tirx.decl_buffer((m, n), "float32") aptr = Ab.access_ptr("rw") tvm.ir.assert_structural_equal(aptr.args[3], m * n) aptr = Ab.access_ptr("rw", offset=100) tvm.ir.assert_structural_equal(aptr.args[3], m * n - 100) - Ab = tvm.tir.decl_buffer((m, n), "float32", strides=[n + 1, 1]) + Ab = tvm.tirx.decl_buffer((m, n), "float32", strides=[n + 1, 1]) aptr = Ab.access_ptr("rw", offset=100) tvm.ir.assert_structural_equal(aptr.args[3], Ab.strides[0] * m - 100) @@ -87,29 +87,29 @@ def test_buffer_access_ptr_extent(): def test_buffer_vload(): - m = tvm.tir.SizeVar("m", "int32") - n = tvm.tir.SizeVar("n", "int32") - Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100) + m = tvm.tirx.SizeVar("m", "int32") + n = tvm.tirx.SizeVar("n", "int32") + Ab = tvm.tirx.decl_buffer((m, n), "float32", elem_offset=100) load = Ab.vload([2, 3]) tvm.ir.assert_structural_equal(load.indices, [T.int32(2), T.int32(3)]) def test_buffer_offset_of(): - m = tvm.tir.SizeVar("m", "int32") - n = tvm.tir.SizeVar("n", "int32") - Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100) + m = tvm.tirx.SizeVar("m", "int32") + n = tvm.tirx.SizeVar("n", "int32") + Ab = tvm.tirx.decl_buffer((m, n), "float32", elem_offset=100) offset = Ab.offset_of([2, 3]) tvm.ir.assert_structural_equal(offset, [n * 2 + 103]) def test_buffer_index_merge_mult_mod(): - m = tvm.tir.SizeVar("m", "int32") - n = tvm.tir.SizeVar("n", "int32") - s = tvm.tir.SizeVar("s", "int32") - k0 = tvm.tir.SizeVar("k0", "int32") - k1 = tvm.tir.SizeVar("k1", "int32") - A = tvm.tir.decl_buffer((m, n), "float32") - A_stride = tvm.tir.decl_buffer((m, n), "float32", strides=(s, 1)) + m = tvm.tirx.SizeVar("m", "int32") + n = tvm.tirx.SizeVar("n", "int32") + s = tvm.tirx.SizeVar("s", "int32") + k0 = tvm.tirx.SizeVar("k0", "int32") + k1 = tvm.tirx.SizeVar("k1", "int32") + A = tvm.tirx.decl_buffer((m, n), "float32") + A_stride = tvm.tirx.decl_buffer((m, n), "float32", strides=(s, 1)) def assert_simplified_equal(index_simplified, index_direct): ( @@ -117,8 +117,8 @@ def assert_simplified_equal(index_simplified, index_direct): f"index_simplified={index_simplified}, index_direct={index_direct}", ) - idxd = tvm.tir.indexdiv - idxm = tvm.tir.indexmod + idxd = tvm.tirx.indexdiv + idxm = tvm.tirx.indexmod # Test Case1 index_simplified = A_stride.offset_of( @@ -152,10 +152,10 @@ def assert_simplified_equal(index_simplified, index_direct): assert_simplified_equal(index_simplified, index_direct) # Test Case5 - B = tvm.tir.decl_buffer((1, 14, 14, 1024)) - i = tvm.tir.SizeVar("i", "int32") - j = tvm.tir.SizeVar("j", "int32") - k = tvm.tir.SizeVar("k", "int32") + B = tvm.tirx.decl_buffer((1, 14, 14, 1024)) + i = tvm.tirx.SizeVar("i", "int32") + j = tvm.tirx.SizeVar("j", "int32") + k = tvm.tirx.SizeVar("k", "int32") index_simplified1 = B.offset_of( ( @@ -180,7 +180,7 @@ def assert_simplified_equal(index_simplified, index_direct): def test_buffer_flatten(): """A buffer should flatten to a 1-d shape""" - buf = tvm.tir.decl_buffer([16, 32]) + buf = tvm.tirx.decl_buffer([16, 32]) flat = buf.get_flattened_buffer() assert buf.data.same_as(flat.data) tvm.ir.assert_structural_equal(flat.shape, [T.int32(16 * 32)]) @@ -188,14 +188,14 @@ def test_buffer_flatten(): def test_buffer_flatten_preserves_identity(): """Flattening a 1-d buffer should return the original""" - buf = tvm.tir.decl_buffer([16]) + buf = tvm.tirx.decl_buffer([16]) flat = buf.get_flattened_buffer() assert buf.same_as(flat) def test_buffer_flatten_uses_axis_separators(): """Flattening to N-d physical buffers uses the axis separators""" - buf = tvm.tir.decl_buffer([4, 16, 32], axis_separators=[2]) + buf = tvm.tirx.decl_buffer([4, 16, 32], axis_separators=[2]) flat = buf.get_flattened_buffer() tvm.ir.assert_structural_equal(flat.axis_separators, [T.int32(1)]) tvm.ir.assert_structural_equal(flat.shape, [T.int32(4 * 16), T.int32(32)]) @@ -203,7 +203,7 @@ def test_buffer_flatten_uses_axis_separators(): def test_invalid_axis_separators_raises_exception(): with pytest.raises(ValueError): - tvm.tir.decl_buffer([1], axis_separators=[1, 2]) + tvm.tirx.decl_buffer([1], axis_separators=[1, 2]) if __name__ == "__main__": diff --git a/tests/python/tirx-base/test_tir_constructor.py b/tests/python/tirx-base/test_tir_constructor.py new file mode 100644 index 000000000000..f8a30c75893f --- /dev/null +++ b/tests/python/tirx-base/test_tir_constructor.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ruff: noqa: E711 + +import pytest + +import tvm +from tvm import te, topi + + +def test_expr_constructor(): + x = tvm.tirx.Var("xx", "float32") + assert isinstance(x, tvm.tirx.Var) + assert x.name == "xx" + + x = tvm.tirx.Reduce(None, [1], [tvm.tirx.IterVar((0, 1), "x", 2)], None, 0) + assert isinstance(x, tvm.tirx.Reduce) + assert x.combiner == None + assert x.value_index == 0 + + x = tvm.tirx.FloatImm("float32", 1.0) + assert isinstance(x, tvm.tirx.FloatImm) + assert x.value == 1.0 + assert x.dtype == "float32" + + x = tvm.tirx.IntImm("int64", 2) + assert isinstance(x, tvm.tirx.IntImm) + assert x.value == 2 + assert x.dtype == "int64" + + x = tvm.tirx.StringImm("xyza") + assert isinstance(x, tvm.tirx.StringImm) + assert x.value == "xyza" + + x = tvm.tirx.Cast("float32", tvm.tirx.IntImm("uint32", 1)) + assert isinstance(x, tvm.tirx.Cast) + assert x.dtype == "float32" + assert x.value.value == 1 + + a = tvm.tirx.const(1.0, dtype="float32") + b = tvm.tirx.Var("x", "float32") + + for cls in [ + tvm.tirx.Add, + tvm.tirx.Sub, + tvm.tirx.Mul, + tvm.tirx.Div, + tvm.tirx.Mod, + tvm.tirx.Min, + tvm.tirx.Max, + tvm.tirx.LT, + tvm.tirx.LE, + tvm.tirx.GT, + tvm.tirx.GE, + ]: + x = cls(a, b) + assert isinstance(x, cls) + assert x.a == a + assert x.b.same_as(b) + + a = tvm.runtime.convert(tvm.tirx.Var("x", "int32") > 1) + b = tvm.runtime.convert(tvm.tirx.Var("x", "int32") == 1) + + for cls in [tvm.tirx.And, tvm.tirx.Or]: + x = cls(a, b) + assert isinstance(x, cls) + assert x.a == a + assert x.b.same_as(b) + + x = tvm.tirx.Not(a) + assert isinstance(x, tvm.tirx.Not) + assert x.a == a + + x = tvm.tirx.Select(a, a, b) + assert isinstance(x, tvm.tirx.Select) + assert x.true_value == a + assert x.false_value == b + assert x.condition == a + + buffer_var = tvm.tirx.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("float32"))) + buffer = tvm.tirx.decl_buffer([16], "float32", data=buffer_var) + x = tvm.tirx.BufferLoad(buffer, [1]) + assert isinstance(x, tvm.tirx.BufferLoad) + assert x.dtype == "float32" + assert x.buffer == buffer + assert x.buffer.data == buffer_var + assert list(x.indices) == [1] + + x = tvm.tirx.Ramp(1, 2, 10) + assert isinstance(x, tvm.tirx.Ramp) + assert x.base.value == 1 + assert x.stride.value == 2 + assert x.lanes == 10 + + x = tvm.tirx.Broadcast(a, 10) + assert isinstance(x, tvm.tirx.Broadcast) + assert x.value == a + assert x.lanes == 10 + + x = tvm.tirx.Shuffle([a], [0]) + assert isinstance(x, tvm.tirx.Shuffle) + assert x.vectors[0] == a + assert x.indices[0].value == 0 + + x = tvm.tirx.Call("float32", "tirx.call_extern", [tvm.tirx.StringImm("xyz"), a]) + assert isinstance(x, tvm.tirx.Call) + assert x.dtype == "float32" + assert x.op.name == "tirx.call_extern" + assert x.args[1] == a + + v = tvm.tirx.Var("aa", "int32") + x = tvm.tirx.Let(v, 1, v) + assert x.var == v + assert x.value.value == 1 + assert x.body == v + + +def test_stmt_constructor(): + v = tvm.tirx.Var("aa", "int32") + nop = tvm.tirx.Evaluate(1) + x = tvm.tirx.Bind(v, 1) + assert isinstance(x, tvm.tirx.Bind) + assert x.var == v + assert x.value.value == 1 + + x = tvm.tirx.AttrStmt(v == 1, "xx", 1, tvm.tirx.Evaluate(1)) + assert isinstance(x, tvm.tirx.AttrStmt) + assert x.value.value == 1 + + x = tvm.tirx.AssertStmt( + tvm.tirx.const(1, "bool"), + tvm.tirx.StringImm("RuntimeError"), + [tvm.tirx.StringImm("hellow")], + ) + assert isinstance(x, tvm.tirx.AssertStmt) + assert x.error_kind.value == "RuntimeError" + assert len(x.message_parts) == 1 + assert x.message_parts[0].value == "hellow" + + x = tvm.tirx.For(tvm.tirx.Var("x", "int32"), 0, 10, tvm.tirx.ForKind.SERIAL, nop) + assert isinstance(x, tvm.tirx.For) + assert x.min.value == 0 + assert x.extent.value == 10 + assert x.body == nop + + buffer_var = tvm.tirx.Var("buf", tvm.ir.PointerType(tvm.ir.PrimType("bool"))) + buffer = tvm.tirx.decl_buffer([16], "bool", data=buffer_var) + x = tvm.tirx.BufferStore(buffer, tvm.tirx.IntImm("bool", 1), [10]) + assert isinstance(x, tvm.tirx.BufferStore) + assert x.buffer == buffer + assert x.buffer.data == buffer_var + assert list(x.indices) == [10] + assert x.value.value == 1 + + buf = tvm.tirx.decl_buffer([10], "float32") + x = tvm.tirx.AllocBuffer(buf) + assert isinstance(x, tvm.tirx.AllocBuffer) + assert x.buffer == buf + + x = tvm.tirx.AttrStmt(buffer_var, "xyz", 1, nop) + assert isinstance(x, tvm.tirx.AttrStmt) + assert x.node == buffer_var + assert x.attr_key == "xyz" + assert x.body == nop + + x = tvm.tirx.IfThenElse(tvm.tirx.const(1, "bool"), tvm.tirx.Evaluate(11), nop) + assert isinstance(x, tvm.tirx.IfThenElse) + assert x.then_case.value.value == 11 + assert x.else_case == nop + + +def test_float_constructor_requires_float_dtype(): + with pytest.raises(tvm.TVMError): + tvm.tirx.FloatImm("int32", 1.0) + + +def test_math_unary_constructor_requires_float_dtype(): + x = tvm.tirx.Var("x", "int32") + + with pytest.raises(TypeError, match=r"tirx\.tan only supports floating-point inputs"): + tvm.tirx.tan(x) + + with pytest.raises(TypeError, match=r"tirx\.sin only supports floating-point inputs"): + tvm.tirx.sin(x) + + y = tvm.tirx.Var("y", "float32") + assert tvm.tirx.tan(y).dtype == "float32" + + +def test_topi_tan_requires_float_dtype(): + x = te.placeholder((2, 2), dtype="int32", name="x") + + with pytest.raises(TypeError, match=r"tirx\.tan only supports floating-point inputs"): + topi.tan(x) + + +def test_math_unary_constructor_preserves_bfloat16(): + x = tvm.tirx.Var("x", "bfloat16") + y = tvm.tirx.exp(x) + assert y.dtype == "bfloat16" + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tir-base/test_tir_host_func.py b/tests/python/tirx-base/test_tir_host_func.py similarity index 90% rename from tests/python/tir-base/test_tir_host_func.py rename to tests/python/tirx-base/test_tir_host_func.py index 8f9a51942d0b..023517d8f56c 100644 --- a/tests/python/tir-base/test_tir_host_func.py +++ b/tests/python/tirx-base/test_tir_host_func.py @@ -17,7 +17,7 @@ import tvm from tvm.s_tir.meta_schedule.testing import te_workload from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,missing-class-docstring,missing-function-docstring # fmt: off @@ -35,7 +35,7 @@ def main( { "global_symbol": "test", "target": tvm.target.Target("llvm", host="llvm"), - "tir.noalias": True, + "tirx.noalias": True, } ) # with T.sblock("root"): @@ -61,17 +61,17 @@ def test_host_func(): ) mod = tvm.ir.IRModule({"main": func}) target = tvm.target.Target("cuda") - mod = tvm.tir.transform.Apply( + mod = tvm.tirx.transform.Apply( lambda f: f.with_attr( { "global_symbol": "test", - "tir.is_host_func": True, + "tirx.is_host_func": True, } ) )(mod) - mod = tvm.tir.transform.BindTarget(target)(mod) + mod = tvm.tirx.transform.BindTarget(target)(mod) tvm.ir.assert_structural_equal(mod, Module) - assert "tir.is_host_func" not in mod["main"].attrs, ( + assert "tirx.is_host_func" not in mod["main"].attrs, ( """Target and is_host_func attributes should be mutually exclusive""" ) diff --git a/tests/python/tir-base/test_tir_imm_values.py b/tests/python/tirx-base/test_tir_imm_values.py similarity index 83% rename from tests/python/tir-base/test_tir_imm_values.py rename to tests/python/tirx-base/test_tir_imm_values.py index de5798a31687..2e940c0964e6 100644 --- a/tests/python/tir-base/test_tir_imm_values.py +++ b/tests/python/tirx-base/test_tir_imm_values.py @@ -23,8 +23,8 @@ import tvm import tvm.testing -from tvm import tir -from tvm.script import tir as T +from tvm import tirx +from tvm.script import tirx as T @pytest.mark.parametrize( @@ -40,7 +40,7 @@ ) def test_tir_make_intimm(dtype, literals): for l in literals: - imm = tir.const(l, dtype) + imm = tirx.const(l, dtype) assert imm.value == l, imm @@ -57,7 +57,7 @@ def test_tir_make_intimm(dtype, literals): def test_tir_invalid_intimm(dtype, literals): for l in literals: with pytest.raises(tvm.TVMError): - tir.const(l, dtype) + tirx.const(l, dtype) @pytest.mark.parametrize( @@ -77,8 +77,8 @@ def test_tir_large_py_int_literals(dtype, literals): For large uint value, use LargeUIntImm intrin, """ for l in literals: - x = tir.const(l, dtype) - if isinstance(x, tir.IntImm | tir.FloatImm): + x = tirx.const(l, dtype) + if isinstance(x, tirx.IntImm | tirx.FloatImm): assert x.value == literals[l] else: # LargeUIntImm(low32, hi32) @@ -86,14 +86,14 @@ def test_tir_large_py_int_literals(dtype, literals): def test_tir_intimm_overflow(): - assert int(tir.const(255, "uint8") + tir.const(1, "uint8")) == 0 - assert int(tir.const(2**31 - 1, "int32") + tir.const(1, "int32")) == -(2**31) - assert int(tir.const(2**32 - 1, "uint32") + tir.const(1, "uint32")) == 0 - assert int(tir.const(2**63 - 1, "int64") + tir.const(1, "int64")) == -(2**63) - assert int(tir.const(2**32, "uint64") * tir.const(2**32, "uint64")) == 0 + assert int(tirx.const(255, "uint8") + tirx.const(1, "uint8")) == 0 + assert int(tirx.const(2**31 - 1, "int32") + tirx.const(1, "int32")) == -(2**31) + assert int(tirx.const(2**32 - 1, "uint32") + tirx.const(1, "uint32")) == 0 + assert int(tirx.const(2**63 - 1, "int64") + tirx.const(1, "int64")) == -(2**63) + assert int(tirx.const(2**32, "uint64") * tirx.const(2**32, "uint64")) == 0 # customized int types - assert int(tir.const(7, "int4") + tir.const(1, "int4")) == -8 - assert int(tir.const(2**39 - 1, "int40") + tir.const(1, "int40")) == -(2**39) + assert int(tirx.const(7, "int4") + tirx.const(1, "int4")) == -8 + assert int(tirx.const(2**39 - 1, "int40") + tirx.const(1, "int40")) == -(2**39) def compare_float_value(value, expect, msg): @@ -116,7 +116,7 @@ def compare_float_value(value, expect, msg): ) def test_tir_make_floatimm(dtype, literals): for l in literals: - imm = tir.const(l, dtype) + imm = tirx.const(l, dtype) compare_float_value(imm.value, l, "imm value should match feed value") @@ -131,13 +131,13 @@ def test_tir_invalid_floatimm(dtype, literals): """Currently only fp16 and fp32 have range check.""" for l in literals: with pytest.raises(tvm.TVMError): - tir.const(l, dtype) + tirx.const(l, dtype) @pytest.mark.parametrize("dtype", ["float16", "float32", "float64"]) @pytest.mark.parametrize("literal", [3.14, np.nan, np.inf]) def test_tir_special_floatimms(dtype, literal): - x = tir.const(literal, dtype) + x = tirx.const(literal, dtype) compare_float_value(x.value, literal, "imm value should match feed value") @@ -169,7 +169,7 @@ def imm_overflow_fp64() -> T.float64: ], ) def test_tir_const_auto_dtype(literal, expect_dtype): - x = tir.const(literal, dtype=None) + x = tirx.const(literal, dtype=None) assert x.dtype == expect_dtype assert x.value == literal @@ -226,14 +226,14 @@ def check_tir_const_fold( if skip_overflow: py_res = foldf(x, y) - if isinstance(py_res, tir.IntImm | tir.FloatImm): + if isinstance(py_res, tirx.IntImm | tirx.FloatImm): py_res = py_res.value if not (ninfo.min <= py_res <= ninfo.max): # If the result overflow, certain arithmetics is non-defined # thus we intentionally do not make the test failed. return - fold_res = foldf(tir.const(x, dtype), tir.const(y, dtype)) + fold_res = foldf(tirx.const(x, dtype), tirx.const(y, dtype)) calc_res = calcf(x, y) flaky_msg = ( @@ -345,23 +345,23 @@ def imm_floordiv(x: T.int8, y: T.int8) -> T.int8: # divide by zero with pytest.raises(tvm.TVMError): - check_tir_const_fold("int8", lambda x, y: tir.floordiv(x, y), ffloordiv, 1, 0) + check_tir_const_fold("int8", lambda x, y: tirx.floordiv(x, y), ffloordiv, 1, 0) with pytest.raises(tvm.TVMError): - check_tir_const_fold("int8", lambda x, y: tir.truncdiv(x, y), ftruncdiv, 1, 0) + check_tir_const_fold("int8", lambda x, y: tirx.truncdiv(x, y), ftruncdiv, 1, 0) # i8 mod folding is not implemented - assert not isinstance(tir.floormod(tir.const(7, "int8"), tir.const(3, "int8")), tir.IntImm) - assert not isinstance(tir.truncmod(tir.const(7, "int8"), tir.const(3, "int8")), tir.IntImm) + assert not isinstance(tirx.floormod(tirx.const(7, "int8"), tirx.const(3, "int8")), tirx.IntImm) + assert not isinstance(tirx.truncmod(tirx.const(7, "int8"), tirx.const(3, "int8")), tirx.IntImm) # randomized check check_tir_const_fold("int8", lambda x, y: x * y, fmul) check_tir_const_fold("int8", lambda x, y: x + y, fadd) check_tir_const_fold("int8", lambda x, y: x - y, fsub) check_tir_const_fold( - "int8", lambda x, y: tir.floordiv(x, y), ffloordiv, y_range=(1, np.iinfo("int8").max) + "int8", lambda x, y: tirx.floordiv(x, y), ffloordiv, y_range=(1, np.iinfo("int8").max) ) check_tir_const_fold( - "int8", lambda x, y: tir.truncdiv(x, y), ftruncdiv, y_range=(1, np.iinfo("int8").max) + "int8", lambda x, y: tirx.truncdiv(x, y), ftruncdiv, y_range=(1, np.iinfo("int8").max) ) @@ -404,23 +404,27 @@ def imm_floordiv(x: T.uint8, y: T.uint8) -> T.uint8: # divide by zero with pytest.raises(tvm.TVMError): - check_tir_const_fold("uint8", lambda x, y: tir.floordiv(x, y), ffloordiv, 1, 0) + check_tir_const_fold("uint8", lambda x, y: tirx.floordiv(x, y), ffloordiv, 1, 0) with pytest.raises(tvm.TVMError): - check_tir_const_fold("uint8", lambda x, y: tir.truncdiv(x, y), ftruncdiv, 1, 0) + check_tir_const_fold("uint8", lambda x, y: tirx.truncdiv(x, y), ftruncdiv, 1, 0) # u8 mod folding is not implemented - assert not isinstance(tir.floormod(tir.const(7, "uint8"), tir.const(3, "uint8")), tir.IntImm) - assert not isinstance(tir.truncmod(tir.const(7, "uint8"), tir.const(3, "uint8")), tir.IntImm) + assert not isinstance( + tirx.floormod(tirx.const(7, "uint8"), tirx.const(3, "uint8")), tirx.IntImm + ) + assert not isinstance( + tirx.truncmod(tirx.const(7, "uint8"), tirx.const(3, "uint8")), tirx.IntImm + ) # randomized check check_tir_const_fold("uint8", lambda x, y: x * y, fmul) check_tir_const_fold("uint8", lambda x, y: x + y, fadd) check_tir_const_fold("uint8", lambda x, y: x - y, fsub) check_tir_const_fold( - "uint8", lambda x, y: tir.floordiv(x, y), ffloordiv, y_range=(1, np.iinfo("uint8").max) + "uint8", lambda x, y: tirx.floordiv(x, y), ffloordiv, y_range=(1, np.iinfo("uint8").max) ) check_tir_const_fold( - "uint8", lambda x, y: tir.truncdiv(x, y), ftruncdiv, y_range=(1, np.iinfo("uint8").max) + "uint8", lambda x, y: tirx.truncdiv(x, y), ftruncdiv, y_range=(1, np.iinfo("uint8").max) ) @@ -465,18 +469,18 @@ def imm_floormod(x: T.int32, y: T.int32) -> T.int32: ftruncmod = tvm.compile(imm_truncmod, target="llvm") # i32 overflow is not specified, only check for range - assert -(2**31) <= int(tir.const(2**31 - 1, "int32") + tir.const(1, "int32")) < 2**31 - assert -(2**31) <= int(tir.const(-(2**31), "int32") - tir.const(1, "int32")) < 2**31 + assert -(2**31) <= int(tirx.const(2**31 - 1, "int32") + tirx.const(1, "int32")) < 2**31 + assert -(2**31) <= int(tirx.const(-(2**31), "int32") - tirx.const(1, "int32")) < 2**31 # divide by zero with pytest.raises(tvm.TVMError): - check_tir_const_fold("int32", lambda x, y: tir.floordiv(x, y), ffloordiv, 1, 0) + check_tir_const_fold("int32", lambda x, y: tirx.floordiv(x, y), ffloordiv, 1, 0) with pytest.raises(tvm.TVMError): - check_tir_const_fold("int32", lambda x, y: tir.floormod(x, y), ffloormod, 1, 0) + check_tir_const_fold("int32", lambda x, y: tirx.floormod(x, y), ffloormod, 1, 0) with pytest.raises(tvm.TVMError): - check_tir_const_fold("int32", lambda x, y: tir.truncdiv(x, y), ftruncdiv, 1, 0) + check_tir_const_fold("int32", lambda x, y: tirx.truncdiv(x, y), ftruncdiv, 1, 0) with pytest.raises(tvm.TVMError): - check_tir_const_fold("int32", lambda x, y: tir.truncmod(x, y), ftruncmod, 1, 0) + check_tir_const_fold("int32", lambda x, y: tirx.truncmod(x, y), ftruncmod, 1, 0) # randomized check check_tir_const_fold("int32", lambda x, y: x * y, fmul, skip_overflow=True) @@ -484,28 +488,28 @@ def imm_floormod(x: T.int32, y: T.int32) -> T.int32: check_tir_const_fold("int32", lambda x, y: x - y, fsub, skip_overflow=True) check_tir_const_fold( "int32", - lambda x, y: tir.floordiv(x, y), + lambda x, y: tirx.floordiv(x, y), ffloordiv, y_range=(1, np.iinfo("int32").max), skip_overflow=True, ) check_tir_const_fold( "int32", - lambda x, y: tir.truncdiv(x, y), + lambda x, y: tirx.truncdiv(x, y), ftruncdiv, y_range=(1, np.iinfo("int32").max), skip_overflow=True, ) check_tir_const_fold( "int32", - lambda x, y: tir.floormod(x, y), + lambda x, y: tirx.floormod(x, y), ffloormod, y_range=(1, np.iinfo("int32").max), skip_overflow=False, ) check_tir_const_fold( "int32", - lambda x, y: tir.truncmod(x, y), + lambda x, y: tirx.truncmod(x, y), ftruncmod, y_range=(1, np.iinfo("int32").max), skip_overflow=False, @@ -543,17 +547,21 @@ def imm_floordiv(x: T.uint32, y: T.uint32) -> T.uint32: ftruncdiv = tvm.compile(imm_truncdiv, target="llvm") # u32 overflow is not specified, only check for range - assert 0 <= int(tir.const(2**32 - 1, "uint32") + tir.const(1, "uint32")) < 2**32 + assert 0 <= int(tirx.const(2**32 - 1, "uint32") + tirx.const(1, "uint32")) < 2**32 # divide by zero with pytest.raises(tvm.TVMError): - check_tir_const_fold("uint32", lambda x, y: tir.floordiv(x, y), ffloordiv, 1, 0) + check_tir_const_fold("uint32", lambda x, y: tirx.floordiv(x, y), ffloordiv, 1, 0) with pytest.raises(tvm.TVMError): - check_tir_const_fold("uint32", lambda x, y: tir.truncdiv(x, y), ftruncdiv, 1, 0) + check_tir_const_fold("uint32", lambda x, y: tirx.truncdiv(x, y), ftruncdiv, 1, 0) # u8 mod folding is not implemented - assert not isinstance(tir.floormod(tir.const(7, "uint32"), tir.const(3, "uint32")), tir.IntImm) - assert not isinstance(tir.truncmod(tir.const(7, "uint32"), tir.const(3, "uint32")), tir.IntImm) + assert not isinstance( + tirx.floormod(tirx.const(7, "uint32"), tirx.const(3, "uint32")), tirx.IntImm + ) + assert not isinstance( + tirx.truncmod(tirx.const(7, "uint32"), tirx.const(3, "uint32")), tirx.IntImm + ) # randomized check check_tir_const_fold("uint32", lambda x, y: x * y, fmul, skip_overflow=True) @@ -561,14 +569,14 @@ def imm_floordiv(x: T.uint32, y: T.uint32) -> T.uint32: check_tir_const_fold("uint32", lambda x, y: x - y, fsub, skip_overflow=True) check_tir_const_fold( "uint32", - lambda x, y: tir.floordiv(x, y), + lambda x, y: tirx.floordiv(x, y), ffloordiv, y_range=(1, np.iinfo("uint32").max), skip_overflow=False, ) check_tir_const_fold( "uint32", - lambda x, y: tir.truncdiv(x, y), + lambda x, y: tirx.truncdiv(x, y), ftruncdiv, y_range=(1, np.iinfo("uint32").max), skip_overflow=False, diff --git a/tests/python/tir-base/test_tir_index_map.py b/tests/python/tirx-base/test_tir_index_map.py similarity index 88% rename from tests/python/tir-base/test_tir_index_map.py rename to tests/python/tirx-base/test_tir_index_map.py index bc063a3b44a9..28b75d8f62c2 100644 --- a/tests/python/tir-base/test_tir_index_map.py +++ b/tests/python/tirx-base/test_tir_index_map.py @@ -22,8 +22,8 @@ import tvm.testing from tvm.ir import assert_structural_equal from tvm.runtime import const -from tvm.script import tir as T -from tvm.tir import IndexMap, IntImm, floordiv, floormod +from tvm.script import tirx as T +from tvm.tirx import IndexMap, IntImm, floordiv, floormod def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: @@ -71,7 +71,7 @@ def test_nonbijective_inverse_gives_error(): index_map.inverse([14]) -dynamic_N = tvm.tir.Var("N", "int32") +dynamic_N = tvm.tirx.Var("N", "int32") padding_test_case = tvm.testing.parameter( by_dict={ "no_padding": dict( @@ -86,23 +86,23 @@ def test_nonbijective_inverse_gives_error(): inverse=lambda i, j: [4 * i + j], pre_shape=[15], post_shape=[T.int32(4), T.int32(4)], - padding=lambda i, j: tvm.tir.And(i == 3, tvm.runtime.convert(3) == j), + padding=lambda i, j: tvm.tirx.And(i == 3, tvm.runtime.convert(3) == j), ), "left_padding": dict( forward=lambda i: [(i + 1) // 4, (i + 1) % 4], inverse=lambda i, j: [4 * i + j - 1], pre_shape=[15], post_shape=[T.int32(4), T.int32(4)], - padding=lambda i, j: tvm.tir.And(i == 0, j < 1), + padding=lambda i, j: tvm.tirx.And(i == 0, j < 1), ), "left_and_right_padding": dict( forward=lambda i: [(i + 1) // 4, (i + 1) % 4], inverse=lambda i, j: [4 * i + j - 1], pre_shape=[14], post_shape=[T.int32(4), T.int32(4)], - padding=lambda i, j: tvm.tir.Or( - tvm.tir.And(i == 0, j < 1), - tvm.tir.And(i == 3, tvm.runtime.convert(3) == j), + padding=lambda i, j: tvm.tirx.Or( + tvm.tirx.And(i == 0, j < 1), + tvm.tirx.And(i == 3, tvm.runtime.convert(3) == j), ), ), "dynamic_size": dict( @@ -110,9 +110,9 @@ def test_nonbijective_inverse_gives_error(): inverse=lambda i, j: [4 * i + j], pre_shape=[dynamic_N], post_shape=[(dynamic_N - dynamic_N % (-4)) // 4, T.int32(4)], - padding=lambda i, j: tvm.tir.And( + padding=lambda i, j: tvm.tirx.And( dynamic_N % (-4) != 0, - tvm.tir.And(i == dynamic_N // 4, j >= dynamic_N % 4), + tvm.tirx.And(i == dynamic_N // 4, j >= dynamic_N % 4), ), ), "2d_padding": dict( @@ -128,14 +128,14 @@ def test_nonbijective_inverse_gives_error(): T.int32(4), # Range of iter%4 T.int32(8), # Range of iter%8 ], - padding=lambda i_outer, j_outer, i_inner, j_inner: tvm.tir.Or( - tvm.tir.Or( - tvm.tir.And(i_outer == 0, i_inner < 1), - tvm.tir.And(i_outer == 3, tvm.runtime.convert(3) == i_inner), + padding=lambda i_outer, j_outer, i_inner, j_inner: tvm.tirx.Or( + tvm.tirx.Or( + tvm.tirx.And(i_outer == 0, i_inner < 1), + tvm.tirx.And(i_outer == 3, tvm.runtime.convert(3) == i_inner), ), - tvm.tir.Or( - tvm.tir.And(j_outer == 0, j_inner < 5), - tvm.tir.And(j_outer == 4, j_inner >= 4), + tvm.tirx.Or( + tvm.tirx.And(j_outer == 0, j_inner < 5), + tvm.tirx.And(j_outer == 4, j_inner >= 4), ), ), ), @@ -144,28 +144,28 @@ def test_nonbijective_inverse_gives_error(): inverse=lambda i, j, k: [32 * i + 4 * j + k], pre_shape=[116], post_shape=[T.int32(4), T.int32(8), T.int32(4)], - padding=lambda i, j, k: tvm.tir.And(i == 3, 4 * j + k >= 20), + padding=lambda i, j, k: tvm.tirx.And(i == 3, 4 * j + k >= 20), ), "multiple_right_padding_transpose": dict( forward=lambda i: [(i // 4) % 8, i // 32, i % 4], inverse=lambda j, i, k: [32 * i + 4 * j + k], pre_shape=[116], post_shape=[T.int32(8), T.int32(4), T.int32(4)], - padding=lambda j, i, k: tvm.tir.And(i == 3, 4 * j + k >= 20), + padding=lambda j, i, k: tvm.tirx.And(i == 3, 4 * j + k >= 20), ), "multiple_left_padding": dict( forward=lambda i: [(i + 5) // 32, ((i + 5) // 4) % 8, (i + 5) % 4], inverse=lambda i, j, k: [32 * i + 4 * j + k - 5], pre_shape=[123], post_shape=[T.int32(4), T.int32(8), T.int32(4)], - padding=lambda i, j, k: tvm.tir.And(i == 0, j * 4 + k < 5), + padding=lambda i, j, k: tvm.tirx.And(i == 0, j * 4 + k < 5), ), "multiple_left_padding_with_transpose": dict( forward=lambda i: [((i + 5) // 4) % 8, (i + 5) // 32, (i + 5) % 4], inverse=lambda j, i, k: [32 * i + 4 * j + k - 5], pre_shape=[123], post_shape=[T.int32(8), T.int32(4), T.int32(4)], - padding=lambda j, i, k: tvm.tir.And(i == 0, j * 4 + k < 5), + padding=lambda j, i, k: tvm.tirx.And(i == 0, j * 4 + k < 5), ), "outer_loop_extent_one": dict( forward=lambda i: [i // 4, i % 4], diff --git a/tests/python/tir-base/test_tir_intrin.py b/tests/python/tirx-base/test_tir_intrin.py similarity index 84% rename from tests/python/tir-base/test_tir_intrin.py rename to tests/python/tirx-base/test_tir_intrin.py index 24b0bc90cc2a..0dd06dee934a 100644 --- a/tests/python/tir-base/test_tir_intrin.py +++ b/tests/python/tirx-base/test_tir_intrin.py @@ -23,9 +23,9 @@ import tvm import tvm.testing -from tvm import te, tir, topi +from tvm import te, tirx, topi from tvm.contrib import clang, utils -from tvm.script import tir as T +from tvm.script import tirx as T def test_nearbyint(): @@ -33,7 +33,7 @@ def test_nearbyint(): "m", ) A = te.placeholder((m,), name="A") - A_rounded = te.compute((m,), lambda *i: tvm.tir.nearbyint(A(*i)), name="A") + A_rounded = te.compute((m,), lambda *i: tvm.tirx.nearbyint(A(*i)), name="A") # Convert to TIR and create schedule mod = te.create_prim_func([A, A_rounded]) @@ -57,31 +57,31 @@ def test_nearbyint(): def test_round_intrinsics_on_int(): - i = tvm.tir.Var("i", "int32") - for op in [tvm.tir.round, tvm.tir.trunc, tvm.tir.ceil, tvm.tir.floor, tvm.tir.nearbyint]: - assert op(tvm.tir.const(10, "int32")).value == 10 - assert op(tvm.tir.const(True, "bool")).value == True + i = tvm.tirx.Var("i", "int32") + for op in [tvm.tirx.round, tvm.tirx.trunc, tvm.tirx.ceil, tvm.tirx.floor, tvm.tirx.nearbyint]: + assert op(tvm.tirx.const(10, "int32")).value == 10 + assert op(tvm.tirx.const(True, "bool")).value == True assert op(i).same_as(i) - assert tvm.tir.isnan(tvm.tir.const(10, "int32")).value == False + assert tvm.tirx.isnan(tvm.tirx.const(10, "int32")).value == False def test_unary_intrin(): test_funcs = [ - (tvm.tir.exp, lambda x: np.exp(x)), - (tvm.tir.exp10, lambda x: np.power(10, x)), - (tvm.tir.log2, lambda x: np.log2(x)), - (tvm.tir.log10, lambda x: np.log10(x)), - (tvm.tir.sinh, lambda x: np.sinh(x)), - (tvm.tir.cosh, lambda x: np.cosh(x)), - (tvm.tir.log1p, lambda x: np.log1p(x)), - (tvm.tir.asin, lambda x: np.arcsin(x)), - (tvm.tir.acos, lambda x: np.arccos(x)), - (tvm.tir.atan, lambda x: np.arctan(x)), - (tvm.tir.asinh, lambda x: np.arcsinh(x)), - (tvm.tir.acosh, lambda x: np.arccosh(x)), - (tvm.tir.atanh, lambda x: np.arctanh(x)), - (tvm.tir.erf, lambda x: scipy.special.erf(x)), + (tvm.tirx.exp, lambda x: np.exp(x)), + (tvm.tirx.exp10, lambda x: np.power(10, x)), + (tvm.tirx.log2, lambda x: np.log2(x)), + (tvm.tirx.log10, lambda x: np.log10(x)), + (tvm.tirx.sinh, lambda x: np.sinh(x)), + (tvm.tirx.cosh, lambda x: np.cosh(x)), + (tvm.tirx.log1p, lambda x: np.log1p(x)), + (tvm.tirx.asin, lambda x: np.arcsin(x)), + (tvm.tirx.acos, lambda x: np.arccos(x)), + (tvm.tirx.atan, lambda x: np.arctan(x)), + (tvm.tirx.asinh, lambda x: np.arcsinh(x)), + (tvm.tirx.acosh, lambda x: np.arccosh(x)), + (tvm.tirx.atanh, lambda x: np.arctanh(x)), + (tvm.tirx.erf, lambda x: scipy.special.erf(x)), ] def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5): @@ -140,8 +140,8 @@ def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5): def test_asin_acos_boundary_values(): """Test asin and acos with boundary values and threshold switching.""" test_funcs = [ - (tvm.tir.asin, lambda x: np.arcsin(x)), - (tvm.tir.acos, lambda x: np.arccos(x)), + (tvm.tirx.asin, lambda x: np.arcsin(x)), + (tvm.tirx.acos, lambda x: np.arccos(x)), ] def run_test(tvm_intrin, np_func): @@ -191,10 +191,10 @@ def run_test(tvm_intrin, np_func): def test_binary_intrin(): test_funcs = [ - (tvm.tir.atan2, lambda x1, x2: np.arctan2(x1, x2)), - (tvm.tir.nextafter, lambda x1, x2: np.nextafter(x1, x2)), - (tvm.tir.copysign, lambda x1, x2: np.copysign(x1, x2)), - (tvm.tir.hypot, lambda x1, x2: np.hypot(x1, x2)), + (tvm.tirx.atan2, lambda x1, x2: np.arctan2(x1, x2)), + (tvm.tirx.nextafter, lambda x1, x2: np.nextafter(x1, x2)), + (tvm.tirx.copysign, lambda x1, x2: np.copysign(x1, x2)), + (tvm.tirx.hypot, lambda x1, x2: np.hypot(x1, x2)), ] def run_test(tvm_intrin, np_func): @@ -230,7 +230,7 @@ def test_ldexp(): ) A = te.placeholder((m,), name="A") B = te.placeholder((m,), name="B", dtype="int32") - C = te.compute((m,), lambda *i: tvm.tir.ldexp(A(*i), B(*i)), name="C") + C = te.compute((m,), lambda *i: tvm.tirx.ldexp(A(*i), B(*i)), name="C") # Convert to TIR and create schedule mod = te.create_prim_func([A, B, C]) @@ -270,7 +270,7 @@ def clz_np(x, dtype): m = te.var("m") A = te.placeholder((m,), name="A", dtype=dtype) - B = te.compute((m,), lambda *i: tvm.tir.clz(A(*i)), name="B") + B = te.compute((m,), lambda *i: tvm.tirx.clz(A(*i)), name="B") # Convert to TIR and create schedule mod = te.create_prim_func([A, B]) @@ -307,7 +307,7 @@ class Module: @T.prim_func def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None: # function attr dict - T.func_attr({"global_symbol": "test_fma", "tir.noalias": True}) + T.func_attr({"global_symbol": "test_fma", "tirx.noalias": True}) n = T.int32() stride = T.int32() stride_1 = T.int32() @@ -357,12 +357,12 @@ def test_tir_fma(A: T.handle, B: T.handle, C: T.handle, d: T.handle) -> None: def test_fma(): opt = tvm.transform.Sequential( [ - tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm"))), - tvm.tir.transform.LowerIntrin(), + tvm.tirx.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm"))), + tvm.tirx.transform.LowerIntrin(), ] ) mod = opt(Module) - assert mod["test_tir_fma"].body.body.value.op.name == "tir.call_llvm_pure_intrin" + assert mod["test_tir_fma"].body.body.value.op.name == "tirx.call_llvm_pure_intrin" if __name__ == "__main__": diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tirx-base/test_tir_nodes.py similarity index 52% rename from tests/python/tir-base/test_tir_nodes.py rename to tests/python/tirx-base/test_tir_nodes.py index a20e96b5668f..4a5ee3ae3569 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tirx-base/test_tir_nodes.py @@ -23,15 +23,15 @@ def test_const(): - x = tvm.tir.const(1, "int32") + x = tvm.tirx.const(1, "int32") assert x.dtype == "int32" - assert isinstance(x, tvm.tir.IntImm) + assert isinstance(x, tvm.tirx.IntImm) def test_te_const(): - x = tvm.tir.const(1, "int32") + x = tvm.tirx.const(1, "int32") assert x.dtype == "int32" - assert isinstance(x, tvm.tir.IntImm) + assert isinstance(x, tvm.tirx.IntImm) def test_tir_const_dtype_inference(): @@ -50,59 +50,59 @@ def test_tir_const_dtype_inference(): np.float32(1), np.float64(1), ]: - assert tvm.tir.const(data).dtype == str(np.array(data).dtype) + assert tvm.tirx.const(data).dtype == str(np.array(data).dtype) - assert tvm.tir.const(True).dtype == "bool" - assert tvm.tir.const(1).dtype == "int32" - assert tvm.tir.const(1.0).dtype == "float32" + assert tvm.tirx.const(True).dtype == "bool" + assert tvm.tirx.const(1).dtype == "int32" + assert tvm.tirx.const(1.0).dtype == "float32" def test_make(): - x = tvm.tir.const(1, "int32") - y = tvm.tir.Var("x", "int32") + x = tvm.tirx.const(1, "int32") + y = tvm.tirx.Var("x", "int32") z = x + y - assert isinstance(tvm.tir.max(x, y), tvm.tir.Max) - assert isinstance(tvm.tir.min(x, y), tvm.tir.Min) + assert isinstance(tvm.tirx.max(x, y), tvm.tirx.Max) + assert isinstance(tvm.tirx.min(x, y), tvm.tirx.Min) def test_ir(): - x = tvm.tir.const(1, "int32") - y = tvm.tir.IntImm("int32", 1) + x = tvm.tirx.const(1, "int32") + y = tvm.tirx.IntImm("int32", 1) z = x + y - stmt = tvm.tir.Evaluate(z) - assert isinstance(stmt, tvm.tir.Evaluate) + stmt = tvm.tirx.Evaluate(z) + assert isinstance(stmt, tvm.tirx.Evaluate) def test_ir2(): - buf_size = tvm.tir.Var("size", "int32") - x = tvm.tir.Var("n", "int32") + buf_size = tvm.tirx.Var("size", "int32") + x = tvm.tirx.Var("n", "int32") storage_type = ir.PrimType("int32") handle_type = ir.PointerType(storage_type) - array = tvm.tir.Var("array", handle_type) - buf = tvm.tir.decl_buffer([buf_size], "int32", data=array) + array = tvm.tirx.Var("array", handle_type) + buf = tvm.tirx.decl_buffer([buf_size], "int32", data=array) - st = tvm.tir.BufferStore(buf, x + 1, [1]) - assert isinstance(st, tvm.tir.BufferStore) + st = tvm.tirx.BufferStore(buf, x + 1, [1]) + assert isinstance(st, tvm.tirx.BufferStore) assert st.buffer == buf assert st.buffer.data == array def test_let(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - stmt = tvm.tir.Bind(x, 10) + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + stmt = tvm.tirx.Bind(x, 10) def test_cast(): - x = tvm.tir.Var("x", "float32") + x = tvm.tirx.Var("x", "float32") y = x.astype("int32") z = x.astype("float32x4") - assert isinstance(y, tvm.tir.Cast) - assert isinstance(z, tvm.tir.Broadcast) + assert isinstance(y, tvm.tirx.Cast) + assert isinstance(z, tvm.tirx.Broadcast) assert z.lanes == 4 - s = tvm.tir.StringImm("s") + s = tvm.tirx.StringImm("s") with pytest.raises(tvm.error.TVMError): try: s.astype("int") @@ -112,9 +112,9 @@ def test_cast(): def test_attr(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - stmt = tvm.tir.AttrStmt(y, "stride", 10, tvm.tir.Evaluate(x + 1)) + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + stmt = tvm.tirx.AttrStmt(y, "stride", 10, tvm.tirx.Evaluate(x + 1)) assert stmt.node == y a = tvm.runtime.convert(1) @@ -127,77 +127,77 @@ def test_attr(): def test_basic(): - a = tvm.tir.Var("a", "int32") - b = tvm.tir.Var("b", "int32") + a = tvm.tirx.Var("a", "int32") + b = tvm.tirx.Var("b", "int32") c = a + b assert str(c) == f"{a.name} + {b.name}" def test_stmt(): - x = tvm.tir.Evaluate(0) - tvm.tir.For(tvm.tir.Var("i", "int32"), 0, 1, tvm.tir.ForKind.SERIAL, x) - tvm.tir.For(tvm.tir.Var("i", "int32"), 0, 1, tvm.tir.ForKind.UNROLLED, x, step=2) + x = tvm.tirx.Evaluate(0) + tvm.tirx.For(tvm.tirx.Var("i", "int32"), 0, 1, tvm.tirx.ForKind.SERIAL, x) + tvm.tirx.For(tvm.tirx.Var("i", "int32"), 0, 1, tvm.tirx.ForKind.UNROLLED, x, step=2) def test_dir(): - x = tvm.tir.Var("x", "int32") + x = tvm.tirx.Var("x", "int32") dir(x) def test_dtype(): - x = tvm.tir.Var("x", "int32") + x = tvm.tirx.Var("x", "int32") assert x.dtype == "int32" - y = tvm.tir.Var("y", "int32") + y = tvm.tirx.Var("y", "int32") assert (x > y).dtype == "bool" def test_any(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - z = tvm.tir.Var("z", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + z = tvm.tirx.Var("z", "int32") try: t = x or x assert False except ValueError: pass try: - tvm.tir.any() + tvm.tirx.any() assert False except ValueError: pass - assert str(tvm.tir.any(x < y)) == f"{x.name} < {y.name}" - assert str(tvm.tir.any(x < y, x > z)) == f"{x.name} < {y.name} or {x.name} > {z.name}" + assert str(tvm.tirx.any(x < y)) == f"{x.name} < {y.name}" + assert str(tvm.tirx.any(x < y, x > z)) == f"{x.name} < {y.name} or {x.name} > {z.name}" assert ( - str(tvm.tir.any(x < y, y > z + 1, x < z * 2)) + str(tvm.tirx.any(x < y, y > z + 1, x < z * 2)) == f"{x.name} < {y.name} or {y.name} > {z.name} + 1 or {x.name} < {z.name} * 2" ) def test_all(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - z = tvm.tir.Var("z", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + z = tvm.tirx.Var("z", "int32") try: t = x and x assert False except ValueError: pass try: - tvm.tir.all() + tvm.tirx.all() assert False except ValueError: pass - assert str(tvm.tir.all(x < y)) == f"{x.name} < {y.name}" - assert str(tvm.tir.all(x < y, x > z)) == f"{x.name} < {y.name} and {x.name} > {z.name}" + assert str(tvm.tirx.all(x < y)) == f"{x.name} < {y.name}" + assert str(tvm.tirx.all(x < y, x > z)) == f"{x.name} < {y.name} and {x.name} > {z.name}" assert ( - str(tvm.tir.all(x < y, y > z + 1, x < z * 2)) + str(tvm.tirx.all(x < y, y > z + 1, x < z * 2)) == f"{x.name} < {y.name} and {y.name} > {z.name} + 1 and {x.name} < {z.name} * 2" ) def test_bitwise(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") assert str(x << y) == "T.shift_left(x, y)" assert str(x >> y) == "T.shift_right(x, y)" assert str(x & y) == "T.bitwise_and(x, y)" @@ -211,13 +211,13 @@ def test_bitwise(): assert str(10 % x) == "10 % x" assert str(~x) == "T.bitwise_not(x)" - assert (tvm.tir.const(1, "int8x2") >> 1).dtype == "int8x2" - assert (x >> tvm.tir.const(1, "int32x2")).dtype == "int32x2" - assert (tvm.tir.Var("z", "int8x2") << tvm.tir.const(1, "int8x2")).dtype == "int8x2" + assert (tvm.tirx.const(1, "int8x2") >> 1).dtype == "int8x2" + assert (x >> tvm.tirx.const(1, "int32x2")).dtype == "int32x2" + assert (tvm.tirx.Var("z", "int8x2") << tvm.tirx.const(1, "int8x2")).dtype == "int8x2" def test_float_bitwise(): - t = tvm.tir.const(1.5, dtype="float32") + t = tvm.tirx.const(1.5, dtype="float32") for test in [ lambda lhs, rhs: lhs << rhs, lambda lhs, rhs: lhs >> rhs, @@ -238,7 +238,7 @@ def test_float_bitwise(): def test_shift_bounds(): - x = tvm.tir.Var("x", "int32") + x = tvm.tirx.Var("x", "int32") for test in [lambda lhs, rhs: lhs << rhs, lambda lhs, rhs: lhs >> rhs]: # negative case for testcase in [(x, -1), (x, 32)]: @@ -255,40 +255,40 @@ def test_shift_bounds(): def test_divide_by_zero(): for test in [ - lambda lhs, rhs: tvm.tir.floormod(lhs, rhs), - lambda lhs, rhs: tvm.tir.floordiv(lhs, rhs), - lambda lhs, rhs: tvm.tir.truncmod(lhs, rhs), - lambda lhs, rhs: tvm.tir.truncdiv(lhs, rhs), - lambda lhs, rhs: tvm.tir.div(lhs, rhs), + lambda lhs, rhs: tvm.tirx.floormod(lhs, rhs), + lambda lhs, rhs: tvm.tirx.floordiv(lhs, rhs), + lambda lhs, rhs: tvm.tirx.truncmod(lhs, rhs), + lambda lhs, rhs: tvm.tirx.truncdiv(lhs, rhs), + lambda lhs, rhs: tvm.tirx.div(lhs, rhs), ]: try: - test(tvm.tir.const(5, "int32"), tvm.tir.const(0, "int32")) + test(tvm.tirx.const(5, "int32"), tvm.tirx.const(0, "int32")) assert False except tvm.TVMError: pass def test_infinity(): - assert str(tvm.tir.infinity("float16")) == 'T.float16("inf")' - assert str(tvm.tir.infinity("float32")) == 'T.float32("inf")' - assert str(tvm.tir.infinity("float64")) == 'T.float64("inf")' + assert str(tvm.tirx.infinity("float16")) == 'T.float16("inf")' + assert str(tvm.tirx.infinity("float32")) == 'T.float32("inf")' + assert str(tvm.tirx.infinity("float64")) == 'T.float64("inf")' def test_isnan(): - x = tvm.tir.Var("x", "float32") - assert str(tvm.tir.isnan(x)) == "T.isnan(x)" - assert str(tvm.tir.isnan(x).dtype) == "bool" - y = tvm.tir.Var("y", "float16") - assert str(tvm.tir.isnan(y)) == 'T.isnan(T.Cast("float32", y))' - z = tvm.tir.Var("z", "int32") - assert str(tvm.tir.isnan(z)) == "T.bool(False)" - k = tvm.tir.Var("k", "int8x2") - assert str(tvm.tir.isnan(k).dtype) == "boolx2" + x = tvm.tirx.Var("x", "float32") + assert str(tvm.tirx.isnan(x)) == "T.isnan(x)" + assert str(tvm.tirx.isnan(x).dtype) == "bool" + y = tvm.tirx.Var("y", "float16") + assert str(tvm.tirx.isnan(y)) == 'T.isnan(T.Cast("float32", y))' + z = tvm.tirx.Var("z", "int32") + assert str(tvm.tirx.isnan(z)) == "T.bool(False)" + k = tvm.tirx.Var("k", "int8x2") + assert str(tvm.tirx.isnan(k).dtype) == "boolx2" def test_equality(): - a = tvm.tir.Var("a", "int32") - b = tvm.tir.Var("b", "int32") + a = tvm.tirx.Var("a", "int32") + b = tvm.tirx.Var("b", "int32") c = a == b assert not c d = c != c @@ -297,32 +297,32 @@ def test_equality(): def test_equality_string_imm(): x = "a" - y = tvm.tir.StringImm(x) + y = tvm.tirx.StringImm(x) x == y.value x == y def test_prim_func(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - b = tvm.tir.decl_buffer((x,), "float32") - stmt = tvm.tir.SeqStmt([tvm.tir.Bind(x, 10), tvm.tir.Evaluate(x + 1)]) + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + b = tvm.tirx.decl_buffer((x,), "float32") + stmt = tvm.tirx.SeqStmt([tvm.tirx.Bind(x, 10), tvm.tirx.Evaluate(x + 1)]) - func = tvm.tir.PrimFunc([x, y, b], stmt) + func = tvm.tirx.PrimFunc([x, y, b], stmt) # make sure we can print assert func.buffer_map[func.params[2]].same_as(b) assert len(func.buffer_map) == 1 - f2 = func.with_attr({"calling_conv": 1, "tir.noalias": True}) + f2 = func.with_attr({"calling_conv": 1, "tirx.noalias": True}) assert f2.attrs["calling_conv"] == 1 assert not func.attrs def test_vars(): - x = tvm.tir.Var("xyz", "int8") + x = tvm.tirx.Var("xyz", "int8") assert x.dtype == "int8" ptype = tvm.ir.PointerType(tvm.ir.PrimType("float")) - x = tvm.tir.Var("xyz", ptype) + x = tvm.tirx.Var("xyz", ptype) assert x.dtype == "handle" assert x.type_annotation == ptype assert isinstance(ptype.element_type, tvm.ir.PrimType) @@ -332,7 +332,7 @@ def test_scoped_storage_vars(): dtype = "float" storage_scope = "global.texture" ptype = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) - x = tvm.tir.Var("xyz", ptype) + x = tvm.tirx.Var("xyz", ptype) assert x.dtype == "handle" assert x.type_annotation == ptype assert x.type_annotation.storage_scope == storage_scope @@ -340,13 +340,13 @@ def test_scoped_storage_vars(): def test_buffer_load_store(): - b = tvm.tir.decl_buffer((10,), "float32") - x = tvm.tir.BufferLoad(b, [0]) - assert isinstance(x, tvm.tir.BufferLoad) + b = tvm.tirx.decl_buffer((10,), "float32") + x = tvm.tirx.BufferLoad(b, [0]) + assert isinstance(x, tvm.tirx.BufferLoad) assert x.dtype == "float32" assert x.buffer == b - s = tvm.tir.BufferStore(b, 0.1, [0]) - assert isinstance(s, tvm.tir.BufferStore) + s = tvm.tirx.BufferStore(b, 0.1, [0]) + assert isinstance(s, tvm.tirx.BufferStore) def test_intimm_cond(): @@ -363,14 +363,14 @@ def test_intimm_cond(): def _create_ramp(lanes): - return tvm.tir.Ramp(0, 1, lanes) + return tvm.tirx.Ramp(0, 1, lanes) def _create_broadcast(lanes): - return tvm.tir.Broadcast(0, lanes) + return tvm.tirx.Broadcast(0, lanes) -@pytest.mark.parametrize("lanes", [tvm.tir.IntImm(dtype="int64", value=11)]) +@pytest.mark.parametrize("lanes", [tvm.tirx.IntImm(dtype="int64", value=11)]) @pytest.mark.parametrize("node_func", [_create_ramp, _create_broadcast]) def test_lane_types(lanes, node_func): def _check_dtype(node): @@ -380,18 +380,18 @@ def _check_dtype(node): _check_dtype(node_func(lanes)) -@pytest.mark.parametrize("lanes", [(11 * tvm.tir.vscale()), (tvm.tir.vscale() * 11)]) +@pytest.mark.parametrize("lanes", [(11 * tvm.tirx.vscale()), (tvm.tirx.vscale() * 11)]) @pytest.mark.parametrize("node_func", [_create_ramp, _create_broadcast]) def test_scalable_vec(lanes, node_func): def _check_dtype(node): - assert node.lanes.a.equal(tvm.tir.vscale()) + assert node.lanes.a.equal(tvm.tirx.vscale()) assert node.lanes.b == 11 _check_dtype(node_func(lanes)) @pytest.mark.parametrize( - "lanes", [(tvm.tir.vscale()), (tvm.tir.vscale() + 3), (tvm.tir.vscale() * 2 + 5)] + "lanes", [(tvm.tirx.vscale()), (tvm.tirx.vscale() + 3), (tvm.tirx.vscale() * 2 + 5)] ) @pytest.mark.parametrize("node_func", [_create_ramp, _create_broadcast]) def test_scalable_vec_error(lanes, node_func): @@ -400,111 +400,111 @@ def test_scalable_vec_error(lanes, node_func): def test_broadcast_to_scalable_vec(): - vec = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) + 3 + vec = tvm.tirx.expr.Ramp(0, 1, 4 * tvm.tirx.vscale()) + 3 broadcast = vec.b - assert isinstance(broadcast, tvm.tir.expr.Broadcast) + assert isinstance(broadcast, tvm.tirx.expr.Broadcast) assert broadcast.value == 3 - assert broadcast.lanes.a.equal(tvm.tir.vscale()) + assert broadcast.lanes.a.equal(tvm.tirx.vscale()) assert broadcast.lanes.b == 4 def test_buffer_load_scalable_vec(): - buf = tvm.tir.decl_buffer((24,), "float32") - index = tvm.tir.expr.Ramp(1, 1, 8 * tvm.tir.vscale()) - load = tvm.tir.BufferLoad(buf, [index]) + buf = tvm.tirx.decl_buffer((24,), "float32") + index = tvm.tirx.expr.Ramp(1, 1, 8 * tvm.tirx.vscale()) + load = tvm.tirx.BufferLoad(buf, [index]) - assert isinstance(load, tvm.tir.BufferLoad) + assert isinstance(load, tvm.tirx.BufferLoad) assert load.dtype == "float32xvscalex8" def test_buffer_store_scalable_vec(): - b = tvm.tir.decl_buffer((24,), "int32") - value = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) - index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) - store = tvm.tir.BufferStore(b, value, [index]) + b = tvm.tirx.decl_buffer((24,), "int32") + value = tvm.tirx.expr.Broadcast(1, 4 * tvm.tirx.vscale()) + index = tvm.tirx.expr.Ramp(0, 1, 4 * tvm.tirx.vscale()) + store = tvm.tirx.BufferStore(b, value, [index]) - assert isinstance(store, tvm.tir.BufferStore) + assert isinstance(store, tvm.tirx.BufferStore) assert store.value.dtype == "int32xvscalex4" def test_buffer_store_predicate_invalid_scalability(): - b = tvm.tir.decl_buffer((24,), "int32") - value = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) - index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) - predicate = tvm.tir.expr.Broadcast(tvm.tir.IntImm("int1", 1), 4) + b = tvm.tirx.decl_buffer((24,), "int32") + value = tvm.tirx.expr.Broadcast(1, 4 * tvm.tirx.vscale()) + index = tvm.tirx.expr.Ramp(0, 1, 4 * tvm.tirx.vscale()) + predicate = tvm.tirx.expr.Broadcast(tvm.tirx.IntImm("int1", 1), 4) err_msg = "Predicate mask dtype and value dtype must both be scalable." with pytest.raises(tvm.TVMError, match=err_msg): - tvm.tir.BufferStore(b, value, [index], predicate) + tvm.tirx.BufferStore(b, value, [index], predicate) def test_buffer_store_predicate_invalid_lanes(): - b = tvm.tir.decl_buffer((24,), "int32") - value = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) - index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) - predicate = tvm.tir.expr.Broadcast(tvm.tir.IntImm("int1", 1), 8 * tvm.tir.vscale()) + b = tvm.tirx.decl_buffer((24,), "int32") + value = tvm.tirx.expr.Broadcast(1, 4 * tvm.tirx.vscale()) + index = tvm.tirx.expr.Ramp(0, 1, 4 * tvm.tirx.vscale()) + predicate = tvm.tirx.expr.Broadcast(tvm.tirx.IntImm("int1", 1), 8 * tvm.tirx.vscale()) err_msg = ( "Got a predicate mask with 8 lanes, but trying to store a " "value with 4 lanes. The number of lanes must match." ) with pytest.raises(tvm.TVMError, match=err_msg): - tvm.tir.BufferStore(b, value, [index], predicate) + tvm.tirx.BufferStore(b, value, [index], predicate) def test_buffer_store_predicate_elements_invalid_type(): - b = tvm.tir.decl_buffer((24,), "int32") - value = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) - index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) - predicate = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) + b = tvm.tirx.decl_buffer((24,), "int32") + value = tvm.tirx.expr.Broadcast(1, 4 * tvm.tirx.vscale()) + index = tvm.tirx.expr.Ramp(0, 1, 4 * tvm.tirx.vscale()) + predicate = tvm.tirx.expr.Broadcast(1, 4 * tvm.tirx.vscale()) err_msg = "Predicate mask elements must be boolean values, but got int32." with pytest.raises(tvm.TVMError, match=err_msg): - tvm.tir.BufferStore(b, value, [index], predicate) + tvm.tirx.BufferStore(b, value, [index], predicate) def test_buffer_load_predicate_elements_invalid_type(): - b = tvm.tir.decl_buffer((24,), "int32") - index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) - predicate = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale()) + b = tvm.tirx.decl_buffer((24,), "int32") + index = tvm.tirx.expr.Ramp(0, 1, 4 * tvm.tirx.vscale()) + predicate = tvm.tirx.expr.Broadcast(1, 4 * tvm.tirx.vscale()) err_msg = "Predicate mask elements must be boolean values, but got int32." with pytest.raises(tvm.TVMError, match=err_msg): - tvm.tir.BufferLoad(b, [index], predicate) + tvm.tirx.BufferLoad(b, [index], predicate) def test_buffer_store_predicate_invalid_scalability(): - b = tvm.tir.decl_buffer((24,), "int32") - index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) - predicate = tvm.tir.expr.Broadcast(tvm.tir.IntImm("int1", 1), 4) + b = tvm.tirx.decl_buffer((24,), "int32") + index = tvm.tirx.expr.Ramp(0, 1, 4 * tvm.tirx.vscale()) + predicate = tvm.tirx.expr.Broadcast(tvm.tirx.IntImm("int1", 1), 4) err_msg = "Predicate mask dtype and load indices must both be scalable." with pytest.raises(tvm.TVMError, match=err_msg): - tvm.tir.BufferLoad(b, [index], predicate) + tvm.tirx.BufferLoad(b, [index], predicate) def test_buffer_store_predicate_invalid_lanes(): - b = tvm.tir.decl_buffer((24,), "int32") - index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) - predicate = tvm.tir.expr.Broadcast(tvm.tir.IntImm("int1", 1), 8 * tvm.tir.vscale()) + b = tvm.tirx.decl_buffer((24,), "int32") + index = tvm.tirx.expr.Ramp(0, 1, 4 * tvm.tirx.vscale()) + predicate = tvm.tirx.expr.Broadcast(tvm.tirx.IntImm("int1", 1), 8 * tvm.tirx.vscale()) err_msg = ( "Got a predicate mask with 8 lanes, but trying to load a " "vector with 4 lanes. The number of lanes must match." ) with pytest.raises(tvm.TVMError, match=err_msg): - tvm.tir.BufferLoad(b, [index], predicate) + tvm.tirx.BufferLoad(b, [index], predicate) def test_scalable_vec_cast(): - b = tvm.tir.decl_buffer((24,), "float32") - value = tvm.tir.expr.Broadcast(1, 12 * tvm.tir.vscale()).astype("float32xvscalex12") - index = tvm.tir.expr.Ramp(0, 1, 12 * tvm.tir.vscale()) + b = tvm.tirx.decl_buffer((24,), "float32") + value = tvm.tirx.expr.Broadcast(1, 12 * tvm.tirx.vscale()).astype("float32xvscalex12") + index = tvm.tirx.expr.Ramp(0, 1, 12 * tvm.tirx.vscale()) - store = tvm.tir.BufferStore(b, value, [index]) + store = tvm.tirx.BufferStore(b, value, [index]) - assert isinstance(store.value.value, tvm.tir.expr.FloatImm) + assert isinstance(store.value.value, tvm.tirx.expr.FloatImm) if __name__ == "__main__": diff --git a/tests/python/tirx-base/test_tir_op_types.py b/tests/python/tirx-base/test_tir_op_types.py new file mode 100644 index 000000000000..bf2c75a1e0e4 --- /dev/null +++ b/tests/python/tirx-base/test_tir_op_types.py @@ -0,0 +1,354 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import tvm +import tvm.testing +from tvm import tirx + + +def test_tir_op_tvm_tuple(): + x = tirx.Var("x", dtype="float32") + y = tirx.Var("y", dtype="float32") + z = tirx.Var("z", dtype="float32") + expr = tirx.tvm_tuple(x, y, z, 1, 2, 3) + assert expr.op.name == "tirx.tvm_tuple" + + +def test_tir_op_tvm_struct_get(): + x = tirx.Var("x", dtype="handle") + expr = tirx.tvm_struct_get(x, 1, 2, dtype="int32") + assert expr.op.name == "tirx.tvm_struct_get" + + +def test_tir_op_tvm_struct_set(): + x = tirx.Var("x", dtype="handle") + expr = tirx.tvm_struct_set(x, 1, 2, 3) + assert expr.op.name == "tirx.tvm_struct_set" + + +def test_tir_op_address_of(): + buffer = tirx.decl_buffer((128), "float32") + expr = tirx.address_of(buffer[0]) + assert expr.op.name == "tirx.address_of" + + +def test_tir_op_lookup_param(): + expr = tirx.lookup_param("p0") + assert expr.op.name == "tirx.lookup_param" + + +def test_tir_op_reinterpret(): + x = tirx.Var("x", dtype="int32") + expr = tirx.reinterpret("float32", x) + assert expr.op.name == "tirx.reinterpret" + + +def test_tir_op_isnullptr(): + x = tirx.Var("x", dtype="int32") + expr = tirx.isnullptr(x) + assert expr.op.name == "tirx.isnullptr" + + +def test_tir_op_call_assume(): + x = tirx.Var("x", dtype="int32") + expr = tirx.assume(cond=x) + assert expr.op.name == "tirx.assume" + + +def test_tir_op_call_undef(): + expr = tirx.undef() + assert expr.op.name == "tirx.undef" + + +def test_tir_op_call_likely(): + x = tirx.Var("x", dtype="int32") + expr = tirx.likely(cond=x) + assert expr.op.name == "tirx.likely" + + +def test_tir_op_tvm_thread_allreduce(): + x = tirx.Var("x", "int32") + buffer = tirx.decl_buffer((128), "float32") + y = tirx.Var("y", "handle") + z = tirx.Var("z", "int32") + expr = tirx.tvm_thread_allreduce(x, buffer[0], True, y, z) + assert expr.op.name == "tirx.tvm_thread_allreduce" + + +def test_tir_op_type_annotation(): + expr = tirx.type_annotation("int32") + assert expr.op.name == "tirx.type_annotation" + + +def test_tir_op_tvm_access_ptr(): + buffer = tirx.decl_buffer((128), "float32") + expr = tirx.tvm_access_ptr("float32", buffer.data, 0, 1, 2) + assert expr.op.name == "tirx.tvm_access_ptr" + + +def test_tir_op_tvm_throw_last_error(): + expr = tirx.tvm_throw_last_error() + assert expr.op.name == "tirx.tvm_throw_last_error" + + +def test_tir_op_tvm_load_matrix_sync(): + buffer = tirx.decl_buffer((16, 16), "float32") + x = tirx.Var("x", "handle") + expr = tirx.tvm_load_matrix_sync(buffer.data, 16, 16, 16, 0, x, 128, "row_major") + assert expr.op.name == "tirx.tvm_load_matrix_sync" + + +def test_tir_op_tvm_store_matrix_sync(): + buffer = tirx.decl_buffer((16, 16), "float32") + x = tirx.Var("x", "handle") + expr = tirx.tvm_store_matrix_sync(buffer.data, 16, 16, 16, 0, x, 128, "row_major") + assert expr.op.name == "tirx.tvm_store_matrix_sync" + + +def test_tir_op_tvm_mma_sync(): + buffer_0 = tirx.decl_buffer((16, 16), "float32") + buffer_1 = tirx.decl_buffer((16, 16), "float32") + buffer_2 = tirx.decl_buffer((16, 16), "float32") + buffer_3 = tirx.decl_buffer((16, 16), "float32") + expr = tirx.tvm_mma_sync(buffer_0.data, 0, buffer_1.data, 0, buffer_2.data, 0, buffer_3.data, 0) + assert expr.op.name == "tirx.tvm_mma_sync" + + +def test_tir_op_tvm_bmma_sync(): + buffer_0 = tirx.decl_buffer((16, 16), "float32") + buffer_1 = tirx.decl_buffer((16, 16), "float32") + buffer_2 = tirx.decl_buffer((16, 16), "float32") + buffer_3 = tirx.decl_buffer((16, 16), "float32") + expr = tirx.tvm_bmma_sync( + buffer_0.data, 0, buffer_1.data, 0, buffer_2.data, 0, buffer_3.data, 0 + ) + assert expr.op.name == "tirx.tvm_bmma_sync" + + +def test_tir_op_tvm_fill_fragment(): + buffer = tirx.decl_buffer((16, 16), "float32") + expr = tirx.tvm_fill_fragment(buffer.data, 16, 16, 16, 0, 0) + assert expr.op.name == "tirx.tvm_fill_fragment" + + +def test_tir_op_ptx_mma(): + buffer_a = tirx.decl_buffer([32], "int4", scope="local") + buffer_b = tirx.decl_buffer([16], "uint4", scope="local") + buffer_c = tirx.decl_buffer([4], "int32", scope="local") + expr = tirx.ptx_mma( + "int32", + "m8n8k32", + "row", + "col", + "int4", + "uint4", + "int32", + buffer_a.data, + 0, + buffer_b.data, + 0, + buffer_c.data, + 0, + False, + ) + assert expr.op.name == "tirx.ptx_mma" + + +def test_tir_op_ptx_mma_sp(): + buffer_a = tirx.decl_buffer([32], "int4", scope="local") + buffer_b = tirx.decl_buffer([16], "uint4", scope="local") + buffer_c = tirx.decl_buffer([4], "int32", scope="local") + buffer_d = tirx.decl_buffer([1], "uint32", scope="local") + expr = tirx.ptx_mma_sp( + "int32", + "m8n8k32", + "row", + "col", + "int4", + "uint4", + "int32", + buffer_a.data, + 0, + buffer_b.data, + 0, + buffer_c.data, + 0, + buffer_d.data, + 0, + 0, + False, + ) + assert expr.op.name == "tirx.ptx_mma_sp" + + +def test_tir_op_mma_store(): + x = tirx.Var("x", dtype="int32") + y = tirx.Var("y", dtype="int32") + buffer_w = tirx.decl_buffer([16, 8], dtype="int32", scope="warp", offset_factor=1) + buffer = tirx.decl_buffer( + [16, 16], dtype="int32", scope="global", offset_factor=1, strides=[x, y] + ) + expr = tirx.mma_store( + "int32", + 16, + 16, + buffer.access_ptr("w"), + buffer_w.data, + buffer_w.elem_offset, + x, + ) + assert expr.op.name == "tirx.mma_store" + + +def test_tir_op_mma_fill(): + buffer_w = tirx.decl_buffer([16, 8], dtype="int32", scope="warp", offset_factor=1) + expr = tirx.mma_fill("int32", 8, buffer_w.data, buffer_w.elem_offset) + assert expr.op.name == "tirx.mma_fill" + + +def test_op_ptx_ldmatrix(): + buffer_shared = tirx.decl_buffer([16, 16], "float16", scope="shared") + buffer_local = tirx.decl_buffer([8], "float16", scope="local") + expr = tirx.ptx_ldmatrix( + "float16", False, 4, ".b16", buffer_local.data, 0, buffer_shared.data, 0 + ) + assert expr.op.name == "tirx.ptx_ldmatrix" + + +def test_op_ptx_cp_async(): + buffer_shared = tirx.decl_buffer([16, 16], "float16", scope="shared") + buffer_local = tirx.decl_buffer([8], "float16", scope="local") + expr = tirx.ptx_cp_async("float16", buffer_shared.data, 0, buffer_local.data, 0, 16) + assert expr.op.name == "tirx.ptx_cp_async" + + +def test_op_ptx_cp_async_bulk(): + buffer_shared = tirx.decl_buffer([16, 16], "float16", scope="shared") + buffer_local = tirx.decl_buffer([8], "float16", scope="local") + expr = tirx.ptx_cp_async_bulk("float16", buffer_shared.data, 0, buffer_local.data, 0, 16, 0) + assert expr.op.name == "tirx.ptx_cp_async_bulk" + + +def test_op_ptx_commit_group(): + expr = tirx.ptx_commit_group() + assert expr.op.name == "tirx.ptx_commit_group" + + +def test_op_ptx_wait_group(): + expr = tirx.ptx_wait_group(8) + assert expr.op.name == "tirx.ptx_wait_group" + + +def test_op_ptx_cp_async_barrier(): + expr = tirx.ptx_cp_async_barrier(0) + assert expr.op.name == "tirx.ptx_cp_async_barrier" + + +def test_op_ptx_init_barrier_thread_count(): + expr = tirx.ptx_init_barrier_thread_count(0, 32) + assert expr.op.name == "tirx.ptx_init_barrier_thread_count" + + +def test_op_ptx_arrive_barrier(): + expr = tirx.ptx_arrive_barrier(0) + assert expr.op.name == "tirx.ptx_arrive_barrier" + + +def test_op_ptx_arrive_barrier_expect_tx(): + expr = tirx.ptx_arrive_barrier_expect_tx(0, 32) + assert expr.op.name == "tirx.ptx_arrive_barrier_expect_tx" + + +def test_op_ptx_wait_barrier(): + expr = tirx.ptx_wait_barrier(0) + assert expr.op.name == "tirx.ptx_wait_barrier" + + +def test_op_create_barriers(): + expr = tirx.create_barriers(16) + assert expr.op.name == "tirx.create_barriers" + + +def test_tir_op_vectorlow(): + buffer = tirx.decl_buffer((4, 4), "int8", offset_factor=1) + vec = buffer.vload([0, 0], dtype="int8x16") + expr = tirx.vectorlow("int8x8", vec) + assert expr.op.name == "tirx.vectorlow" + + +def test_tir_op_vectorhigh(): + buffer = tirx.decl_buffer((4, 4), "int8", offset_factor=1) + vec = buffer.vload([0, 0], dtype="int8x16") + expr = tirx.vectorhigh("int8x8", vec) + assert expr.op.name == "tirx.vectorhigh" + + +def test_tir_op_dp4a(): + vec1 = tirx.Var("vec1", dtype="int8x4") + vec2 = tirx.Var("vec2", dtype="int8x4") + acc = tirx.Var("acc", dtype="int32") + expr = tirx.dp4a(vec1, vec2, acc) + assert expr.op.name == "tirx.dp4a" + + +def test_tir_op_vectorcombine(): + buffer = tirx.decl_buffer((4, 4), "int8", offset_factor=1) + vec = buffer.vload([0, 0], dtype="int8x16") + expr = tirx.vectorcombine("int8x8", vec, vec) + assert expr.op.name == "tirx.vectorcombine" + + +def test_tir_op_shift_left(): + x = tirx.Var("x", dtype="int32") + y = tirx.Var("x", dtype="int32") + expr = tirx.shift_left(x, y) + assert expr.op.name == "tirx.shift_left" + + +def test_tir_op_shift_right(): + x = tirx.Var("x", dtype="int32") + y = tirx.Var("x", dtype="int32") + expr = tirx.shift_right(x, y) + assert expr.op.name == "tirx.shift_right" + + +def test_tir_op_bitwise(): + x = tirx.Var("x", dtype="int32") + y = tirx.Var("y", dtype="int32") + expr = tirx.bitwise_and(x, y) + assert expr.op.name == "tirx.bitwise_and" + expr = tirx.bitwise_or(x, y) + assert expr.op.name == "tirx.bitwise_or" + expr = tirx.bitwise_not(x) + assert expr.op.name == "tirx.bitwise_not" + expr = tirx.bitwise_xor(x, y) + assert expr.op.name == "tirx.bitwise_xor" + + +def test_tir_op_TVMBackendAllocWorkspace(): + expr = tirx.TVMBackendAllocWorkspace(0, 1, 2, 3, 4) + assert expr.op.name == "tirx.TVMBackendAllocWorkspace" + + +def test_tir_op_TVMBackendFreeWorkspace(): + buffer = tirx.decl_buffer((128), "float32") + expr = tirx.TVMBackendFreeWorkspace(0, 1, buffer.data) + assert expr.op.name == "tirx.TVMBackendFreeWorkspace" + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tir-base/test_tir_ops.py b/tests/python/tirx-base/test_tir_ops.py similarity index 67% rename from tests/python/tir-base/test_tir_ops.py rename to tests/python/tirx-base/test_tir_ops.py index 9eeac7030164..34b62e0fd1ed 100644 --- a/tests/python/tir-base/test_tir_ops.py +++ b/tests/python/tirx-base/test_tir_ops.py @@ -32,12 +32,12 @@ def check_throws(f): def test_const_fold(): def check(f, *args): - x = f(*[tvm.tir.const(x, "int32") for x in args]) + x = f(*[tvm.tirx.const(x, "int32") for x in args]) y = f(*args) - if not isinstance(x, tvm.tir.IntImm) or x.value != int(y): + if not isinstance(x, tvm.tirx.IntImm) or x.value != int(y): raise ValueError(f"check error: {x} vs {y} ") - tmod = tvm.tir.truncmod + tmod = tvm.tirx.truncmod check(lambda x, y: x + y, 3, 4) check(lambda x, y: x * y, 3, 12) check(lambda x, y: x * y - 10, 3, 12) @@ -52,68 +52,68 @@ def check(f, *args): def test_const_fold2(): - x = tvm.tir.Var("x", "int32") - tmod = tvm.tir.truncmod - tdiv = tvm.tir.truncdiv + x = tvm.tirx.Var("x", "int32") + tmod = tvm.tirx.truncmod + tdiv = tvm.tirx.truncdiv assert (x + 0).same_as(x) assert (0 + x).same_as(x) assert (x - 0).same_as(x) assert tmod(x, 1).value == 0 assert (x * 1).same_as(x) assert (1 * x).same_as(x) - assert isinstance(tdiv(1, x), tvm.tir.Div) + assert isinstance(tdiv(1, x), tvm.tirx.Div) def test_const_fold3(): # Test that using ints with logic operations is forbidden - x = tvm.tir.Var("x", "int32") + x = tvm.tirx.Var("x", "int32") for val in [0, 1]: - for func in [tvm.tir.all, tvm.tir.any]: - check_throws(lambda: func(tvm.tir.const(val, "bool"), x)) - check_throws(lambda: func(x, tvm.tir.const(val, "bool"))) + for func in [tvm.tirx.all, tvm.tirx.any]: + check_throws(lambda: func(tvm.tirx.const(val, "bool"), x)) + check_throws(lambda: func(x, tvm.tirx.const(val, "bool"))) # Test const folding when both arguments are const for tvm_func, py_func in [ - (tvm.tir.all, lambda a, b: a and b), - (tvm.tir.any, lambda a, b: a or b), + (tvm.tirx.all, lambda a, b: a and b), + (tvm.tirx.any, lambda a, b: a or b), ]: for v1 in [0, 1]: for v2 in [0, 1]: tvm.ir.assert_structural_equal( - tvm_func(tvm.tir.const(v1, "bool"), tvm.tir.const(v2, "bool")), - tvm.tir.const(py_func(v1, v2), "bool"), + tvm_func(tvm.tirx.const(v1, "bool"), tvm.tirx.const(v2, "bool")), + tvm.tirx.const(py_func(v1, v2), "bool"), ) - x = tvm.tir.Var("x", "bool") - true = tvm.tir.const(1, "bool") - false = tvm.tir.const(0, "bool") + x = tvm.tirx.Var("x", "bool") + true = tvm.tirx.const(1, "bool") + false = tvm.tirx.const(0, "bool") - assert tvm.tir.all(x, true).same_as(x) - assert tvm.tir.all(true, x).same_as(x) - assert tvm.tir.any(x, false).same_as(x) - assert tvm.tir.any(false, x).same_as(x) + assert tvm.tirx.all(x, true).same_as(x) + assert tvm.tirx.all(true, x).same_as(x) + assert tvm.tirx.any(x, false).same_as(x) + assert tvm.tirx.any(false, x).same_as(x) - assert tvm.tir.all(x, false).same_as(false) - assert tvm.tir.all(false, x).same_as(false) - assert tvm.tir.any(x, true).same_as(true) - assert tvm.tir.any(true, x).same_as(true) + assert tvm.tirx.all(x, false).same_as(false) + assert tvm.tirx.all(false, x).same_as(false) + assert tvm.tirx.any(x, true).same_as(true) + assert tvm.tirx.any(true, x).same_as(true) def test_const_fold4(): - x1 = tvm.tir.const(4, "int32") + x1 = tvm.tirx.const(4, "int32") x2 = x1 + 5 - tdiv = tvm.tir.truncdiv - assert isinstance(x2, tvm.tir.IntImm) and x2.value == 9 + tdiv = tvm.tirx.truncdiv + assert isinstance(x2, tvm.tirx.IntImm) and x2.value == 9 x3 = tdiv(x2, 3) - assert isinstance(x3, tvm.tir.IntImm) and x3.value == 3 + assert isinstance(x3, tvm.tirx.IntImm) and x3.value == 3 x4 = x3 + 0.55 - assert isinstance(x4, tvm.tir.FloatImm) and abs(x4.value - 3.55) < 1e-6 - x5 = tvm.tir.ceil(x4) - assert isinstance(x5, tvm.tir.FloatImm) and x5.value == 4 + assert isinstance(x4, tvm.tirx.FloatImm) and abs(x4.value - 3.55) < 1e-6 + x5 = tvm.tirx.ceil(x4) + assert isinstance(x5, tvm.tirx.FloatImm) and x5.value == 4 x6 = x5.astype("int") - assert isinstance(x6, tvm.tir.IntImm) and x6.value == 4, f"x6={x6}" - y = (tvm.tir.round((tvm.tir.const(6.5, "float32") - 1) / 1.5) + 2).astype("int") - assert isinstance(y, tvm.tir.IntImm) and y.value == 6 + assert isinstance(x6, tvm.tirx.IntImm) and x6.value == 4, f"x6={x6}" + y = (tvm.tirx.round((tvm.tirx.const(6.5, "float32") - 1) / 1.5) + 2).astype("int") + assert isinstance(y, tvm.tirx.IntImm) and y.value == 6 def test_binary_dtype_match(): @@ -126,8 +126,8 @@ def verify_general_dtype_support(f, is_conditional=False): [("uint32", "int32"), "uint32"], ] for (lhs_dtype, rhs_dtype), out_dtype in rules: - lhs = tvm.tir.Var("lhs", lhs_dtype) - rhs = tvm.tir.Var("rhs", rhs_dtype) + lhs = tvm.tirx.Var("lhs", lhs_dtype) + rhs = tvm.tirx.Var("rhs", rhs_dtype) out = f(lhs, rhs) if not is_conditional: assert out.dtype == out_dtype @@ -146,8 +146,8 @@ def verify_general_dtype_support(f, is_conditional=False): def verify_callop_float_only(f): for lhs_dtype in ["int32", "float32", "float64"]: for rhs_dtype in ["int32", "float32", "float64"]: - lhs = tvm.tir.Var("lhs", lhs_dtype) - rhs = tvm.tir.Var("rhs", rhs_dtype) + lhs = tvm.tirx.Var("lhs", lhs_dtype) + rhs = tvm.tirx.Var("rhs", rhs_dtype) if "float" not in lhs_dtype and "float" not in rhs_dtype: check_throws(lambda: f(lhs, rhs)) elif "float" in lhs_dtype: @@ -176,28 +176,28 @@ def verify_callop_float_only(f): verify_general_dtype_support(lambda a, b: a * b) verify_general_dtype_support(lambda a, b: a >= b, is_conditional=True) verify_general_dtype_support(lambda a, b: a <= b, is_conditional=True) - verify_callop_float_only(lambda a, b: tvm.tir.power(a, b)) + verify_callop_float_only(lambda a, b: tvm.tirx.power(a, b)) # verify bool & int32 constant folding - assert tvm.tir.const(1) == tvm.tir.const(True) - assert tvm.tir.const(2) != tvm.tir.const(True) + assert tvm.tirx.const(1) == tvm.tirx.const(True) + assert tvm.tirx.const(2) != tvm.tirx.const(True) def test_if_then_else(): cases = [ - [(tvm.tir.Var("cond", "bool"), "bool", "int32"), "int32"], + [(tvm.tirx.Var("cond", "bool"), "bool", "int32"), "int32"], [(True, "int32", "float32"), "float32"], [(False, "int32", "int64"), "int64"], - [(tvm.tir.Var("cond", "bool"), "uint32", "int32"), "uint32"], - [(tvm.tir.Var("cond", "int32"), "uint32", "int32"), "uint32"], + [(tvm.tirx.Var("cond", "bool"), "uint32", "int32"), "uint32"], + [(tvm.tirx.Var("cond", "int32"), "uint32", "int32"), "uint32"], ] for (cond, lhs_dtype, rhs_dtype), out_dtype in cases: - lhs = tvm.tir.Var("lhs", lhs_dtype) - rhs = tvm.tir.Var("rhs", rhs_dtype) + lhs = tvm.tirx.Var("lhs", lhs_dtype) + rhs = tvm.tirx.Var("rhs", rhs_dtype) if cond is True or cond is False: - out = tvm.tir.if_then_else(cond, lhs, rhs) - out2 = tvm.tir.if_then_else(not cond, rhs, lhs) - out3 = tvm.tir.if_then_else(not cond, lhs, rhs) + out = tvm.tirx.if_then_else(cond, lhs, rhs) + out2 = tvm.tirx.if_then_else(not cond, rhs, lhs) + out3 = tvm.tirx.if_then_else(not cond, lhs, rhs) tvm.ir.assert_structural_equal(out, out2) == 1 if cond: tvm.ir.assert_structural_equal(out, lhs.astype(out_dtype)) == 1 @@ -206,39 +206,39 @@ def test_if_then_else(): tvm.ir.assert_structural_equal(out, rhs.astype(out_dtype)) == 1 tvm.ir.assert_structural_equal(out3, lhs.astype(out_dtype)) == 1 elif cond.dtype == "bool": - out = tvm.tir.if_then_else(cond, lhs, rhs) + out = tvm.tirx.if_then_else(cond, lhs, rhs) assert out.dtype == out_dtype assert out.args[1].dtype == out_dtype assert out.args[2].dtype == out_dtype elif cond.dtype != "bool": - check_throws(lambda: tvm.tir.if_then_else(cond, lhs, rhs)) + check_throws(lambda: tvm.tirx.if_then_else(cond, lhs, rhs)) else: raise ValueError("Unknown combinations") @pytest.mark.parametrize("num_args", list(range(2, 10))) def test_comm_reducer(num_args): - """Handle all arguments in tir comm_reducer + """Handle all arguments in tirx comm_reducer - The `tir.comm_reducer` API has two distinct usages. It can reduce + The `tirx.comm_reducer` API has two distinct usages. It can reduce a tensor along a specified axis, similar to numpy.max, or it can reduce several arguments together, simililar to Python's built-in max(). This choice is based on the type of the second argument. - If the `tir.comm_reducer` is reducing all arguments, then all + If the `tirx.comm_reducer` is reducing all arguments, then all arguments should be used. In the past, the introduction of new arguments intended for use when reducing along a tensor axis has failed to forward these arguments when reducing along a list of items. """ - assert tvm.tir.max(*range(num_args)) == num_args - 1 + assert tvm.tirx.max(*range(num_args)) == num_args - 1 def test_llvm_intrin(): with pytest.raises(ValueError, match=r"Unknown llvm intrinsic function llvm.dummy"): - a = tvm.tir.call_llvm_intrin("int32x4", "llvm.dummy") + a = tvm.tirx.call_llvm_intrin("int32x4", "llvm.dummy") with pytest.raises(ValueError, match=r"Unknown llvm intrinsic function llvm.dummy"): - a = tvm.tir.call_llvm_pure_intrin("int32x4", "llvm.dummy") + a = tvm.tirx.call_llvm_pure_intrin("int32x4", "llvm.dummy") if __name__ == "__main__": diff --git a/tests/python/tir-base/test_tir_ptx_cp_async.py b/tests/python/tirx-base/test_tir_ptx_cp_async.py similarity index 94% rename from tests/python/tir-base/test_tir_ptx_cp_async.py rename to tests/python/tirx-base/test_tir_ptx_cp_async.py index bba17b6315cb..dd47446b68e0 100644 --- a/tests/python/tir-base/test_tir_ptx_cp_async.py +++ b/tests/python/tirx-base/test_tir_ptx_cp_async.py @@ -20,12 +20,12 @@ import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func def ptx_cp_async(A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), "float16")) -> None: - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) @@ -68,7 +68,7 @@ def test_ptx_cp_async(): def ptx_cp_async_barrier( A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), "float16") ) -> None: - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) @@ -113,7 +113,7 @@ def test_ptx_cp_async_barrier(): @T.prim_func def ptx_cp_async_bulk(A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), "float16")) -> None: - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) diff --git a/tests/python/tir-base/test_tir_ptx_ldmatrix.py b/tests/python/tirx-base/test_tir_ptx_ldmatrix.py similarity index 96% rename from tests/python/tir-base/test_tir_ptx_ldmatrix.py rename to tests/python/tirx-base/test_tir_ptx_ldmatrix.py index 1acf7d3f0caf..afab98f8282c 100644 --- a/tests/python/tir-base/test_tir_ptx_ldmatrix.py +++ b/tests/python/tirx-base/test_tir_ptx_ldmatrix.py @@ -19,14 +19,14 @@ import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func def ptx_ldmatrix( A: T.Buffer((16, 16), "float16"), B: T.Buffer((16, 16), "float16"), num: T.int32, trans: T.uint8 ) -> None: - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") T.launch_thread(bx, 1) diff --git a/tests/python/tir-base/test_tir_ptx_mma.py b/tests/python/tirx-base/test_tir_ptx_mma.py similarity index 96% rename from tests/python/tir-base/test_tir_ptx_mma.py rename to tests/python/tirx-base/test_tir_ptx_mma.py index 8dbf1a1f22f3..e1816125f28d 100644 --- a/tests/python/tir-base/test_tir_ptx_mma.py +++ b/tests/python/tirx-base/test_tir_ptx_mma.py @@ -19,12 +19,12 @@ import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func def gemm_mma_m8n8k4_row_col_fp64pf64fp64(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [8, 4], dtype="float64") B = T.match_buffer(b, [8, 4], dtype="float64") C = T.match_buffer(c, [8, 8], dtype="float64") @@ -89,7 +89,7 @@ def test_gemm_mma_m8n8k4_row_col_fp64pf64fp64(): @T.prim_func def gemm_mma_m8n8k4_row_row_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [16, 4], dtype="float16") B = T.match_buffer(b, [4, 16], dtype="float16") C = T.match_buffer(c, [16, 16], dtype="float16") @@ -165,7 +165,7 @@ def test_gemm_mma_m8n8k4_row_row_fp16fp16fp16(): @T.prim_func def gemm_mma_m8n8k4_row_row_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [16, 4], dtype="float16") B = T.match_buffer(b, [4, 16], dtype="float16") C = T.match_buffer(c, [16, 16], dtype="float32") @@ -248,7 +248,7 @@ def test_gemm_mma_m8n8k4_row_row_fp16fp16fp32(): @T.prim_func def gemm_mma_m8n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [8, 16], dtype="int8") B = T.match_buffer(b, [8, 16], dtype="int8") C = T.match_buffer(c, [8, 8], dtype="int32") @@ -319,7 +319,7 @@ def test_gemm_mma_m8n8k16_row_col_s8s8s32(): @T.prim_func def gemm_mma_m8n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [8, 16], dtype="int8") B = T.match_buffer(b, [8, 16], dtype="uint8") C = T.match_buffer(c, [8, 8], dtype="int32") @@ -390,7 +390,7 @@ def test_gemm_mma_m8n8k16_row_col_s8u8s32(): @T.prim_func def gemm_mma_m8n8k32_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [8, 32], dtype="int4") B = T.match_buffer(b, [8, 32], dtype="int4") C = T.match_buffer(c, [8, 8], dtype="int32") @@ -453,7 +453,7 @@ def test_gemm_mma_m8n8k32_row_col_s4s4s32(): @T.prim_func def gemm_mma_m8n8k32_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [8, 32], dtype="int4") B = T.match_buffer(b, [8, 32], dtype="uint4") C = T.match_buffer(c, [8, 8], dtype="int32") @@ -516,7 +516,7 @@ def test_gemm_mma_m8n8k32_row_col_s4u4s32(): @T.prim_func def gemm_mma_m16n8k8_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [16, 8], dtype="float16") B = T.match_buffer(b, [8, 8], dtype="float16") C = T.match_buffer(c, [16, 8], dtype="float32") @@ -589,7 +589,7 @@ def test_gemm_mma_m16n8k8_row_col_fp16fp16fp32(): @T.prim_func def gemm_mma_m16n8k16_row_col_fp16fp16fp16(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [16, 16], dtype="float16") B = T.match_buffer(b, [8, 16], dtype="float16") C = T.match_buffer(c, [16, 8], dtype="float16") @@ -665,7 +665,7 @@ def test_gemm_mma_m16n8k16_row_col_fp16fp16fp16(): @T.prim_func def gemm_mma_m16n8k16_row_col_fp16fp16fp32(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [16, 16], dtype="float16") B = T.match_buffer(b, [8, 16], dtype="float16") C = T.match_buffer(c, [16, 8], dtype="float32") @@ -741,7 +741,7 @@ def test_gemm_mma_m16n8k16_row_col_fp16fp16fp32(): @T.prim_func def gemm_mma_m16n8k16_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [16, 16], dtype="int8") B = T.match_buffer(b, [8, 16], dtype="int8") C = T.match_buffer(c, [16, 8], dtype="int32") @@ -817,7 +817,7 @@ def test_gemm_mma_m16n8k16_row_col_s8s8s32(): @T.prim_func def gemm_mma_m16n8k16_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [16, 16], dtype="int8") B = T.match_buffer(b, [8, 16], dtype="uint8") C = T.match_buffer(c, [16, 8], dtype="int32") @@ -893,7 +893,7 @@ def test_gemm_mma_m16n8k16_row_col_s8u8s32(): @T.prim_func def gemm_mma_m16n8k32_row_col_s8s8s32(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [16, 32], dtype="int8") B = T.match_buffer(b, [8, 32], dtype="int8") C = T.match_buffer(c, [16, 8], dtype="int32") @@ -969,7 +969,7 @@ def test_gemm_mma_m16n8k32_row_col_s8s8s32(): @T.prim_func def gemm_mma_m16n8k32_row_col_s8u8s32(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [16, 32], dtype="int8") B = T.match_buffer(b, [8, 32], dtype="uint8") C = T.match_buffer(c, [16, 8], dtype="int32") @@ -1045,7 +1045,7 @@ def test_gemm_mma_m16n8k32_row_col_s8u8s32(): @T.prim_func def gemm_mma_m16n8k64_row_col_s4s4s32(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [16, 64], dtype="int4") B = T.match_buffer(b, [8, 64], dtype="int4") C = T.match_buffer(c, [16, 8], dtype="int32") @@ -1113,7 +1113,7 @@ def test_gemm_mma_m16n8k64_row_col_s4s4s32(): @T.prim_func def gemm_mma_m16n8k64_row_col_s4u4s32(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [16, 64], dtype="int4") B = T.match_buffer(b, [8, 64], dtype="uint4") C = T.match_buffer(c, [16, 8], dtype="int32") @@ -1181,7 +1181,7 @@ def test_gemm_mma_m16n8k64_row_col_s4u4s32(): @T.prim_func def gemm_mma_m16n8k256_row_col_b1b1s32(a: T.handle, b: T.handle, c: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [16, 256], dtype="int1") B = T.match_buffer(b, [8, 256], dtype="int1") C = T.match_buffer(c, [16, 8], dtype="int32") diff --git a/tests/python/tir-base/test_tir_ptx_mma_sp.py b/tests/python/tirx-base/test_tir_ptx_mma_sp.py similarity index 96% rename from tests/python/tir-base/test_tir_ptx_mma_sp.py rename to tests/python/tirx-base/test_tir_ptx_mma_sp.py index 6111c930479e..1f8322d7affc 100644 --- a/tests/python/tir-base/test_tir_ptx_mma_sp.py +++ b/tests/python/tirx-base/test_tir_ptx_mma_sp.py @@ -19,7 +19,7 @@ import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T def gen_2in4_mask(m: int, n: int): @@ -42,7 +42,7 @@ def get_dense_mat_by_mask(val, mask): @T.prim_func def mma_sp_m16n8k16_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [16, 8], dtype="float16") B = T.match_buffer(b, [16, 8], dtype="float16") C = T.match_buffer(c, [16, 8], dtype="float16") @@ -96,7 +96,7 @@ def mma_sp_m16n8k16_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata: @T.prim_func def mma_sp_m16n8k16_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [16, 8], dtype="float16") B = T.match_buffer(b, [16, 8], dtype="float16") C = T.match_buffer(c, [16, 8], dtype="float32") @@ -150,7 +150,7 @@ def mma_sp_m16n8k16_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata: @T.prim_func def mma_sp_m16n8k32_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [16, 16], dtype="float16") B = T.match_buffer(b, [32, 8], dtype="float16") C = T.match_buffer(c, [16, 8], dtype="float16") @@ -204,7 +204,7 @@ def mma_sp_m16n8k32_f16f16f16(a: T.handle, b: T.handle, c: T.handle, _metadata: @T.prim_func def mma_sp_m16n8k32_f16f16f32(a: T.handle, b: T.handle, c: T.handle, _metadata: T.handle): - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) A = T.match_buffer(a, [16, 16], dtype="float16") B = T.match_buffer(b, [32, 8], dtype="float16") C = T.match_buffer(c, [16, 8], dtype="float32") diff --git a/tests/python/tir-base/test_tir_scalable_datatype.py b/tests/python/tirx-base/test_tir_scalable_datatype.py similarity index 93% rename from tests/python/tir-base/test_tir_scalable_datatype.py rename to tests/python/tirx-base/test_tir_scalable_datatype.py index 140ed047079b..f05110e2e83e 100644 --- a/tests/python/tir-base/test_tir_scalable_datatype.py +++ b/tests/python/tirx-base/test_tir_scalable_datatype.py @@ -18,8 +18,8 @@ import pytest import tvm -from tvm import tir -from tvm.script import tir as T +from tvm import tirx +from tvm.script import tirx as T from tvm.target.codegen import llvm_version_major """ @@ -34,7 +34,7 @@ def test_create_scalable_data_type_python_api(): @pytest.mark.skipif(llvm_version_major() < 13, reason="Stepvector intrinsic was added in LLVM 13.") def test_create_scalable_tir_intrin(): - intrin = tir.call_llvm_intrin("int32xvscalex4", "llvm.experimental.stepvector") + intrin = tirx.call_llvm_intrin("int32xvscalex4", "llvm.experimental.stepvector") assert intrin.dtype == "int32xvscalex4" assert str(intrin) == 'T.call_llvm_intrin("int32xvscalex4", "llvm.experimental.stepvector")' diff --git a/tests/python/tir-base/test_tir_specialize.py b/tests/python/tirx-base/test_tir_specialize.py similarity index 94% rename from tests/python/tir-base/test_tir_specialize.py rename to tests/python/tirx-base/test_tir_specialize.py index fb0515c9a2e0..471d99ef0a75 100644 --- a/tests/python/tir-base/test_tir_specialize.py +++ b/tests/python/tirx-base/test_tir_specialize.py @@ -21,7 +21,7 @@ import tvm from tvm.s_tir.schedule.testing import assert_structural_equal_ignore_global_symbol -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func @@ -182,13 +182,13 @@ def test_specialize_nothing(): def test_specialize_matmul(): a, _, _, n = matmul.params # fully specialized - func = matmul.specialize({a: tvm.tir.decl_buffer((128, 128))}) + func = matmul.specialize({a: tvm.tirx.decl_buffer((128, 128))}) assert_structural_equal_ignore_global_symbol(func, matmul_128) # partially specialized func = matmul.specialize({n: 128}) assert_structural_equal_ignore_global_symbol(func, matmul_m_128) # symbolic specialized - func = matmul.specialize({n: tvm.tir.Var("x", "int32") * 8}) + func = matmul.specialize({n: tvm.tirx.Var("x", "int32") * 8}) assert_structural_equal_ignore_global_symbol(func, matmul_m_8x) @@ -196,17 +196,17 @@ def test_specialize_elemwise(): a, c = element_wise.params C = element_wise.buffer_map[c] # fully specialized - func = element_wise.specialize({a: tvm.tir.decl_buffer((128, 64))}) + func = element_wise.specialize({a: tvm.tirx.decl_buffer((128, 64))}) assert_structural_equal_ignore_global_symbol(func, element_wise_128_64) # partially specialized - func = element_wise.specialize({c: tvm.tir.decl_buffer((128, C.shape[1]))}) + func = element_wise.specialize({c: tvm.tirx.decl_buffer((128, C.shape[1]))}) assert_structural_equal_ignore_global_symbol(func, element_wise_128_n) def test_specialize_mem_copy(): a, _, m, n, p, q = mem_copy.params # fully specialized - func = mem_copy.specialize({a: tvm.tir.decl_buffer((16, 16), strides=[8, 1], elem_offset=4)}) + func = mem_copy.specialize({a: tvm.tirx.decl_buffer((16, 16), strides=[8, 1], elem_offset=4)}) assert_structural_equal_ignore_global_symbol(func, mem_copy_16_16_8_4) func = mem_copy.specialize({n: 16, m: 16, p: 8, q: 4}) assert_structural_equal_ignore_global_symbol(func, mem_copy_16_16_8_4) @@ -241,7 +241,7 @@ def expected(a: T.handle, b: T.handle): B[vi] = A[vi // 8, vi % 8] + 714 b = before.params[1] - after = before.specialize({b: tvm.tir.decl_buffer([16], dtype="int32")}) + after = before.specialize({b: tvm.tirx.decl_buffer([16], dtype="int32")}) assert_structural_equal_ignore_global_symbol(expected, after) @@ -301,10 +301,10 @@ def expected(A: T.Buffer([16, 16], "float32"), B_handle: T.handle): def test_specialize_buffer_var_to_expr(): """Handle specialization of buffer var - The `tir::Buffer::data` field must be an explicit `tir::Var`, and - cannot be replaced with a `tir::PrimExpr` of type + The `tirx::Buffer::data` field must be an explicit `tirx::Var`, and + cannot be replaced with a `tirx::PrimExpr` of type `DataType::Handle()`. However, these substitutions are useful - when lowering. If these occur, a binding of the `tir::Var` is + when lowering. If these occur, a binding of the `tirx::Var` is included in the specialized function. """ @@ -326,7 +326,7 @@ def expected(A_data: T.handle("float32")): B_data = before.params[1] # body is a SeqStmt; the first statement is DeclBuffer for A_buf A_buf = before.body[0].buffer - param_map = {B_data: tvm.tir.address_of(A_buf[16])} + param_map = {B_data: tvm.tirx.address_of(A_buf[16])} after = before.specialize(param_map) tvm.ir.assert_structural_equal(expected, after) diff --git a/tests/python/tir-base/test_tir_stmt_functor_ir_transform.py b/tests/python/tirx-base/test_tir_stmt_functor_ir_transform.py similarity index 83% rename from tests/python/tir-base/test_tir_stmt_functor_ir_transform.py rename to tests/python/tirx-base/test_tir_stmt_functor_ir_transform.py index 8affe28825ff..83e676fccb79 100644 --- a/tests/python/tir-base/test_tir_stmt_functor_ir_transform.py +++ b/tests/python/tirx-base/test_tir_stmt_functor_ir_transform.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_ir_transform(): @@ -39,21 +39,21 @@ def main(n: T.int32): ) body = Module["main"].body - builtin_call_extern = tvm.ir.Op.get("tir.call_extern") + builtin_call_extern = tvm.ir.Op.get("tirx.call_extern") def preorder(op): if op.op.same_as(builtin_call_extern) and op.args[0].value == "TestC": - return tvm.tir.const(42, "int32") + return tvm.tirx.const(42, "int32") return None def postorder(op): - assert isinstance(op, tvm.tir.Call) + assert isinstance(op, tvm.tirx.Call) if op.op.same_as(builtin_call_extern) and op.args[0].value == "TestA": - return tvm.tir.call_extern("int32", "TestB", op.args[1] + 1) + return tvm.tirx.call_extern("int32", "TestB", op.args[1] + 1) return op - body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder, ["tir.Call"]) - stmt_list = tvm.tir.stmt_list(body.body.body) + body = tvm.tirx.stmt_functor.ir_transform(body, preorder, postorder, ["tirx.Call"]) + stmt_list = tvm.tirx.stmt_list(body.body.body) assert stmt_list[0].value.args[1].args[0].value == "TestB" assert stmt_list[1].value.value == 42 diff --git a/tests/python/tir-base/test_tir_stmt_functor_substitute.py b/tests/python/tirx-base/test_tir_stmt_functor_substitute.py similarity index 94% rename from tests/python/tir-base/test_tir_stmt_functor_substitute.py rename to tests/python/tirx-base/test_tir_stmt_functor_substitute.py index 7a26844f1690..11db657a3ebb 100644 --- a/tests/python/tir-base/test_tir_stmt_functor_substitute.py +++ b/tests/python/tirx-base/test_tir_stmt_functor_substitute.py @@ -18,15 +18,15 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T -from tvm.tir.stmt_functor import substitute +from tvm.script import tirx as T +from tvm.tirx.stmt_functor import substitute def _apply_substitute(mod): """Apply substitute transform to replace the first parameter with 16.""" func = mod["main"] vmap = {func.params[0]: 16} - new_func = tvm.tir.PrimFunc(params=[], body=substitute(func.body, vmap)).with_attr( + new_func = tvm.tirx.PrimFunc(params=[], body=substitute(func.body, vmap)).with_attr( "global_symbol", func.attrs["global_symbol"] ) return tvm.IRModule.from_expr(new_func) diff --git a/tests/python/tir-base/test_tir_structural_equal_hash.py b/tests/python/tirx-base/test_tir_structural_equal_hash.py similarity index 69% rename from tests/python/tir-base/test_tir_structural_equal_hash.py rename to tests/python/tirx-base/test_tir_structural_equal_hash.py index 6f44138300b8..86fccad482b1 100644 --- a/tests/python/tir-base/test_tir_structural_equal_hash.py +++ b/tests/python/tirx-base/test_tir_structural_equal_hash.py @@ -20,7 +20,7 @@ import tvm from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def consistent_equal(x, y, map_free_vars=False): @@ -66,11 +66,11 @@ def get_sequal_mismatch(x, y, map_free_vars=False): def test_exprs(): # save load json - x = tvm.tir.const(1, "int32") - y = tvm.tir.const(10, "int32") - vx = tvm.tir.Var("x", "int32") - vy = tvm.tir.Var("y", "int32") - vz = tvm.tir.Var("z", "int32") + x = tvm.tirx.const(1, "int32") + y = tvm.tirx.const(10, "int32") + vx = tvm.tirx.Var("x", "int32") + vy = tvm.tirx.Var("y", "int32") + vz = tvm.tirx.Var("z", "int32") zx = vx + vx zy = vy + vy @@ -90,9 +90,9 @@ def test_exprs(): assert consistent_equal(vx + vy + vz, vy + vz + vx, map_free_vars=True) assert not consistent_equal(vx + 1, vy + 1, map_free_vars=False) # Defintition remap - assert consistent_equal(tvm.tir.Let(vx, 1, vx - 1), tvm.tir.Let(vy, 1, vy - 1)) + assert consistent_equal(tvm.tirx.Let(vx, 1, vx - 1), tvm.tirx.Let(vy, 1, vy - 1)) # Default same address free var remap - assert consistent_equal(tvm.tir.Let(vx, 1, vx // vz), tvm.tir.Let(vy, 1, vy // vz)) + assert consistent_equal(tvm.tirx.Let(vx, 1, vx // vz), tvm.tirx.Let(vy, 1, vy // vz)) assert consistent_equal(zx * zx, zx * zx) assert consistent_equal(zx * zx, zy * zy, map_free_vars=True) @@ -100,17 +100,17 @@ def test_exprs(): def test_prim_func(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") # counter example of same equality - func0 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(x + y)) - func1 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(y + x)) + func0 = tvm.tirx.PrimFunc([x, y], tvm.tirx.Evaluate(x + y)) + func1 = tvm.tirx.PrimFunc([x, y], tvm.tirx.Evaluate(y + x)) assert not consistent_equal(func0, func1) # new cases - b = tvm.tir.decl_buffer((x,), "float32") - stmt = tvm.tir.SeqStmt([tvm.tir.Bind(x, 10), tvm.tir.Evaluate(x + 1)]) - func0 = tvm.tir.PrimFunc([x, y, b], stmt) + b = tvm.tirx.decl_buffer((x,), "float32") + stmt = tvm.tirx.SeqStmt([tvm.tirx.Bind(x, 10), tvm.tirx.Evaluate(x + 1)]) + func0 = tvm.tirx.PrimFunc([x, y, b], stmt) # easiest way to deep copy is via save/load func1 = tvm.ir.load_json(tvm.ir.save_json(func0)) tvm.ir.assert_structural_equal(func0, func1) @@ -127,12 +127,12 @@ def test_prim_func(): def test_prim_func_param_count_mismatch(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - z = tvm.tir.Var("z", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + z = tvm.tirx.Var("z", "int32") # counter example of same equality - func0 = tvm.tir.PrimFunc([x, y], tvm.tir.Evaluate(x)) - func1 = tvm.tir.PrimFunc([x, y, z], tvm.tir.Evaluate(x)) + func0 = tvm.tirx.PrimFunc([x, y], tvm.tirx.Evaluate(x)) + func1 = tvm.tirx.PrimFunc([x, y, z], tvm.tirx.Evaluate(x)) lhs_path, rhs_path = get_sequal_mismatch(func0, func1) expected_lhs_path = AccessPath.root().attr("params").array_item_missing(2) expected_rhs_path = AccessPath.root().attr("params").array_item(2) @@ -141,12 +141,12 @@ def test_prim_func_param_count_mismatch(): def test_prim_func_param_dtype_mismatch(): - x = tvm.tir.Var("x", "int32") - y_0 = tvm.tir.Var("y", "int32") - y_1 = tvm.tir.Var("z", "float32") + x = tvm.tirx.Var("x", "int32") + y_0 = tvm.tirx.Var("y", "int32") + y_1 = tvm.tirx.Var("z", "float32") # counter example of same equality - func0 = tvm.tir.PrimFunc([x, y_0], tvm.tir.Evaluate(x)) - func1 = tvm.tir.PrimFunc([x, y_1], tvm.tir.Evaluate(x)) + func0 = tvm.tirx.PrimFunc([x, y_0], tvm.tirx.Evaluate(x)) + func1 = tvm.tirx.PrimFunc([x, y_1], tvm.tirx.Evaluate(x)) lhs_path, rhs_path = get_sequal_mismatch(func0, func1) expected_path = AccessPath.root().attr("params").array_item(1).attr("dtype") assert lhs_path == expected_path @@ -154,13 +154,13 @@ def test_prim_func_param_dtype_mismatch(): def test_prim_func_body_mismatch(): - x_0 = tvm.tir.Var("x", "int32") - y_0 = tvm.tir.Var("y", "int32") - x_1 = tvm.tir.Var("x", "int32") - y_1 = tvm.tir.Var("y", "int32") + x_0 = tvm.tirx.Var("x", "int32") + y_0 = tvm.tirx.Var("y", "int32") + x_1 = tvm.tirx.Var("x", "int32") + y_1 = tvm.tirx.Var("y", "int32") # counter example of same equality - func0 = tvm.tir.PrimFunc([x_0, y_0], tvm.tir.Evaluate(x_0 + x_0)) - func1 = tvm.tir.PrimFunc([x_1, y_1], tvm.tir.Evaluate(x_1 + y_1)) + func0 = tvm.tirx.PrimFunc([x_0, y_0], tvm.tirx.Evaluate(x_0 + x_0)) + func1 = tvm.tirx.PrimFunc([x_1, y_1], tvm.tirx.Evaluate(x_1 + y_1)) lhs_path, rhs_path = get_sequal_mismatch(func0, func1) expected_path = AccessPath.root().attr("body").attr("value").attr("b") assert lhs_path == expected_path @@ -215,17 +215,17 @@ def func2(A: T.handle, n_param: T.int32): def test_buffer_storage_scope(): - x = tvm.tir.Var("x", "handle") + x = tvm.tirx.Var("x", "handle") - buffer_local_0 = tvm.tir.decl_buffer((10, 10), "float32", scope="local") - buffer_local_1 = tvm.tir.decl_buffer((10, 10), "float32", scope="local") - buffer_global = tvm.tir.decl_buffer((10, 10), "float32") - buffer_empty = tvm.tir.decl_buffer((10, 10), "float32", scope="") + buffer_local_0 = tvm.tirx.decl_buffer((10, 10), "float32", scope="local") + buffer_local_1 = tvm.tirx.decl_buffer((10, 10), "float32", scope="local") + buffer_global = tvm.tirx.decl_buffer((10, 10), "float32") + buffer_empty = tvm.tirx.decl_buffer((10, 10), "float32", scope="") - func0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_local_0}) - func1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_local_1}) - func2 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_global}) - func3 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_empty}) + func0 = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x: buffer_local_0}) + func1 = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x: buffer_local_1}) + func2 = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x: buffer_global}) + func3 = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x: buffer_empty}) assert consistent_equal(func0, func1) assert consistent_equal(func2, func3) @@ -233,14 +233,14 @@ def test_buffer_storage_scope(): def test_buffer_map_mismatch(): - x = tvm.tir.Var("x", "int32") - buffer_0 = tvm.tir.decl_buffer((10, 10)) - buffer_0_clone = tvm.tir.decl_buffer((10, 10)) - buffer_1 = tvm.tir.decl_buffer((10, 20)) + x = tvm.tirx.Var("x", "int32") + buffer_0 = tvm.tirx.decl_buffer((10, 10)) + buffer_0_clone = tvm.tirx.decl_buffer((10, 10)) + buffer_1 = tvm.tirx.decl_buffer((10, 20)) - func_0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0}) - func_0_clone = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0_clone}) - func_1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_1}) + func_0 = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x: buffer_0}) + func_0_clone = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x: buffer_0_clone}) + func_1 = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x: buffer_1}) lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1) expected_path = ( @@ -253,14 +253,14 @@ def test_buffer_map_mismatch(): def test_buffer_map_length_mismatch(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("x", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("x", "int32") - buffer_0 = tvm.tir.decl_buffer((10, 10)) - buffer_1 = tvm.tir.decl_buffer((10, 20)) + buffer_0 = tvm.tirx.decl_buffer((10, 10)) + buffer_1 = tvm.tirx.decl_buffer((10, 20)) - func_0 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0}) - func_1 = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x), buffer_map={x: buffer_0, y: buffer_1}) + func_0 = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x: buffer_0}) + func_1 = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x), buffer_map={x: buffer_0, y: buffer_1}) lhs_path, rhs_path = get_sequal_mismatch(func_0, func_1) @@ -271,34 +271,34 @@ def test_buffer_map_length_mismatch(): def test_buffer_load_store(): - b = tvm.tir.decl_buffer((10, 10), "float32") - x = tvm.tir.BufferLoad(b, [0, 1]) - y = tvm.tir.BufferLoad(b, [0, 1]) - z = tvm.tir.BufferLoad(b, [1, 2]) + b = tvm.tirx.decl_buffer((10, 10), "float32") + x = tvm.tirx.BufferLoad(b, [0, 1]) + y = tvm.tirx.BufferLoad(b, [0, 1]) + z = tvm.tirx.BufferLoad(b, [1, 2]) assert consistent_equal(y, x) assert not consistent_equal(y, z) - i = tvm.tir.Var("x", "int32") - sx = tvm.tir.BufferStore(b, 0.1, [0, i]) - sy = tvm.tir.BufferStore(b, 0.1, [0, i]) - sz = tvm.tir.BufferStore(b, 0.1, [1, i]) + i = tvm.tirx.Var("x", "int32") + sx = tvm.tirx.BufferStore(b, 0.1, [0, i]) + sy = tvm.tirx.BufferStore(b, 0.1, [0, i]) + sz = tvm.tirx.BufferStore(b, 0.1, [1, i]) assert consistent_equal(sy, sx) assert not consistent_equal(sy, sz) def test_while(): - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - wx = tvm.tir.While(x > 0, tvm.tir.Evaluate(x)) - wy = tvm.tir.While(y > 0, tvm.tir.Evaluate(y)) + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + wx = tvm.tirx.While(x > 0, tvm.tirx.Evaluate(x)) + wy = tvm.tirx.While(y > 0, tvm.tirx.Evaluate(y)) assert not consistent_equal(wx, wy) assert consistent_equal(wx, wy, map_free_vars=True) def test_while_condition_mismatch(): - x = tvm.tir.Var("x", "int32") - w_0 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x)) - w_1 = tvm.tir.While(x < 0, tvm.tir.Evaluate(x)) + x = tvm.tirx.Var("x", "int32") + w_0 = tvm.tirx.While(x > 0, tvm.tirx.Evaluate(x)) + w_1 = tvm.tirx.While(x < 0, tvm.tirx.Evaluate(x)) lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1) expected_path = AccessPath.root().attr("condition") assert lhs_path == expected_path @@ -306,9 +306,9 @@ def test_while_condition_mismatch(): def test_while_body_mismatch(): - x = tvm.tir.Var("x", "int32") - w_0 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x)) - w_1 = tvm.tir.While(x > 0, tvm.tir.Evaluate(x + 1)) + x = tvm.tirx.Var("x", "int32") + w_0 = tvm.tirx.While(x > 0, tvm.tirx.Evaluate(x)) + w_1 = tvm.tirx.While(x > 0, tvm.tirx.Evaluate(x + 1)) lhs_path, rhs_path = get_sequal_mismatch(w_0, w_1) expected_path = AccessPath.root().attr("body").attr("value") assert lhs_path == expected_path @@ -316,21 +316,21 @@ def test_while_body_mismatch(): def test_seq_mismatch(): - x = tvm.tir.Var("x", "int32") - seq_0 = tvm.tir.SeqStmt( + x = tvm.tirx.Var("x", "int32") + seq_0 = tvm.tirx.SeqStmt( [ - tvm.tir.Evaluate(x), - tvm.tir.Evaluate(x + 1), - tvm.tir.Evaluate(x + 2), - tvm.tir.Evaluate(x + 3), + tvm.tirx.Evaluate(x), + tvm.tirx.Evaluate(x + 1), + tvm.tirx.Evaluate(x + 2), + tvm.tirx.Evaluate(x + 3), ] ) - seq_1 = tvm.tir.SeqStmt( + seq_1 = tvm.tirx.SeqStmt( [ - tvm.tir.Evaluate(x), - tvm.tir.Evaluate(x + 1), - tvm.tir.Evaluate(x + 99), - tvm.tir.Evaluate(x + 3), + tvm.tirx.Evaluate(x), + tvm.tirx.Evaluate(x + 1), + tvm.tirx.Evaluate(x + 99), + tvm.tirx.Evaluate(x + 3), ] ) lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1) @@ -343,16 +343,18 @@ def test_seq_mismatch(): def test_seq_mismatch_different_lengths(): # Make sure we report a difference inside the array first, rather than the difference in length - x = tvm.tir.Var("x", "int32") - seq_0 = tvm.tir.SeqStmt( + x = tvm.tirx.Var("x", "int32") + seq_0 = tvm.tirx.SeqStmt( [ - tvm.tir.Evaluate(x), - tvm.tir.Evaluate(x + 1), - tvm.tir.Evaluate(x + 2), - tvm.tir.Evaluate(x + 3), + tvm.tirx.Evaluate(x), + tvm.tirx.Evaluate(x + 1), + tvm.tirx.Evaluate(x + 2), + tvm.tirx.Evaluate(x + 3), ] ) - seq_1 = tvm.tir.SeqStmt([tvm.tir.Evaluate(x), tvm.tir.Evaluate(x + 1), tvm.tir.Evaluate(x + 3)]) + seq_1 = tvm.tirx.SeqStmt( + [tvm.tirx.Evaluate(x), tvm.tirx.Evaluate(x + 1), tvm.tirx.Evaluate(x + 3)] + ) lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1) expected_path = ( AccessPath.root().attr("seq").array_item(2).attr("value").attr("b").attr("value") @@ -362,16 +364,18 @@ def test_seq_mismatch_different_lengths(): def test_seq_length_mismatch(): - x = tvm.tir.Var("x", "int32") - seq_0 = tvm.tir.SeqStmt( + x = tvm.tirx.Var("x", "int32") + seq_0 = tvm.tirx.SeqStmt( [ - tvm.tir.Evaluate(x), - tvm.tir.Evaluate(x + 1), - tvm.tir.Evaluate(x + 2), - tvm.tir.Evaluate(x + 3), + tvm.tirx.Evaluate(x), + tvm.tirx.Evaluate(x + 1), + tvm.tirx.Evaluate(x + 2), + tvm.tirx.Evaluate(x + 3), ] ) - seq_1 = tvm.tir.SeqStmt([tvm.tir.Evaluate(x), tvm.tir.Evaluate(x + 1), tvm.tir.Evaluate(x + 2)]) + seq_1 = tvm.tirx.SeqStmt( + [tvm.tirx.Evaluate(x), tvm.tirx.Evaluate(x + 1), tvm.tirx.Evaluate(x + 2)] + ) lhs_path, rhs_path = get_sequal_mismatch(seq_0, seq_1) expected_lhs_path = AccessPath.root().attr("seq").array_item(3) expected_rhs_path = AccessPath.root().attr("seq").array_item_missing(3) diff --git a/tests/python/tir-base/test_tir_texture_scope.py b/tests/python/tirx-base/test_tir_texture_scope.py similarity index 96% rename from tests/python/tir-base/test_tir_texture_scope.py rename to tests/python/tirx-base/test_tir_texture_scope.py index e9699c6b2269..ce1b717b5de6 100644 --- a/tests/python/tir-base/test_tir_texture_scope.py +++ b/tests/python/tirx-base/test_tir_texture_scope.py @@ -20,9 +20,9 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.ir.module import IRModule -from tvm.script import tir as T +from tvm.script import tirx as T def test_texture_scope(): @@ -30,7 +30,7 @@ def test_texture_scope(): class PlusOneMultTwo: @T.prim_func def main(a: T.handle, b: T.handle) -> None: - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) A = T.match_buffer(a, (128, 128, 4), dtype="float32", scope="global.texture") B = T.sblock_alloc_buffer((128, 128, 4), dtype="float32", scope="global.texture") C = T.match_buffer(b, (128, 128, 4), dtype="float32", scope="global.texture") diff --git a/tests/python/tir-base/test_tir_unsafe_hide_buffer_access.py b/tests/python/tirx-base/test_tir_unsafe_hide_buffer_access.py similarity index 98% rename from tests/python/tir-base/test_tir_unsafe_hide_buffer_access.py rename to tests/python/tirx-base/test_tir_unsafe_hide_buffer_access.py index 46a1c26b92c5..081ba4993316 100644 --- a/tests/python/tir-base/test_tir_unsafe_hide_buffer_access.py +++ b/tests/python/tirx-base/test_tir_unsafe_hide_buffer_access.py @@ -20,12 +20,12 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.s_tir.schedule.testing import ( assert_structural_equal_ignore_global_symbol, verify_trace_roundtrip, ) -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func diff --git a/tests/python/tir-transform/test_tir_functor.py b/tests/python/tirx-transform/test_tir_functor.py similarity index 97% rename from tests/python/tir-transform/test_tir_functor.py rename to tests/python/tirx-transform/test_tir_functor.py index 0683fd453155..021acd8fb60b 100644 --- a/tests/python/tir-transform/test_tir_functor.py +++ b/tests/python/tirx-transform/test_tir_functor.py @@ -18,8 +18,8 @@ import tvm import tvm.testing -from tvm import tir -from tvm.tir import ( +from tvm import tirx +from tvm.tirx import ( EQ, LT, Add, @@ -61,9 +61,9 @@ def __str__(self) -> str: return "\n".join(self.log) -@tir.functor.visitor +@tirx.functor.visitor class ASTPrinter(PyStmtExprVisitor): - """Print tir AST in structured format. The shape of Node is ignored.""" + """Print tirx AST in structured format. The shape of Node is ignored.""" def __init__(self) -> None: super().__init__() @@ -78,7 +78,7 @@ def visit_add_(self, op: Add) -> None: super().visit_add_(op) -@tir.functor.visitor +@tirx.functor.visitor class SimpleExprCounter(PyStmtExprVisitor): """Count expressions without recursion""" @@ -103,7 +103,7 @@ def visit_mul_(self, op: Mul): super().visit_mul_(op) -@tir.functor.mutator +@tirx.functor.mutator class VariableReplacer(PyStmtExprMutator): """Replace variables with constants""" @@ -117,7 +117,7 @@ def visit_var_(self, op: Var): return op -@tir.functor.mutator +@tirx.functor.mutator class AddToSubMutator(PyStmtExprMutator): """Convert Add operations to Sub operations""" @@ -129,7 +129,7 @@ def visit_add_(self, op: Add): return Sub(a, b) -@tir.functor.visitor +@tirx.functor.visitor class SimpleStmtCounter(PyStmtExprVisitor): """Count statements without recursion""" @@ -152,7 +152,7 @@ def visit_evaluate_(self, op: Evaluate): super().visit_evaluate_(op) -@tir.functor.mutator +@tirx.functor.mutator class ForLoopUnroller(PyStmtExprMutator): """Simple loop unroller for demonstration""" @@ -166,7 +166,7 @@ def visit_for_(self, op: For): return super().visit_for_(op) -@tir.functor.visitor +@tirx.functor.visitor class SimpleStmtExprVisitor(PyStmtExprVisitor): """Visitor that handles both statements and expressions""" @@ -186,7 +186,7 @@ def visit_evaluate_(self, op: Evaluate): self.visit_expr(op.value) -@tir.functor.mutator +@tirx.functor.mutator class ComplexMutator(PyStmtExprMutator): """Mutator that handles both statements and expressions""" @@ -265,7 +265,7 @@ def test_simple_stmt_counter(): # Create a simple for loop loop_body = Evaluate(IntImm("int32", 0)) - for_stmt = For(i, IntImm("int32", 0), IntImm("int32", 10), tir.ForKind.SERIAL, loop_body) + for_stmt = For(i, IntImm("int32", 0), IntImm("int32", 10), tirx.ForKind.SERIAL, loop_body) counter = SimpleStmtCounter() counter.visit_stmt(for_stmt) diff --git a/tests/python/tir-transform/test_tir_inline_private_functions.py b/tests/python/tirx-transform/test_tir_inline_private_functions.py similarity index 95% rename from tests/python/tir-transform/test_tir_inline_private_functions.py rename to tests/python/tirx-transform/test_tir_inline_private_functions.py index e2f41fda16a1..54669c9977b7 100644 --- a/tests/python/tir-transform/test_tir_inline_private_functions.py +++ b/tests/python/tirx-transform/test_tir_inline_private_functions.py @@ -20,16 +20,16 @@ import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T class BaseTestCase: def test_well_formed(self): - After = tvm.tir.transform.InlinePrivateFunctions()(self.Before) - tvm.tir.analysis.verify_well_formed(After) + After = tvm.tirx.transform.InlinePrivateFunctions()(self.Before) + tvm.tirx.analysis.verify_well_formed(After) def test_produces_expected(self): - After = tvm.tir.transform.InlinePrivateFunctions()(self.Before) + After = tvm.tirx.transform.InlinePrivateFunctions()(self.Before) tvm.ir.assert_structural_equal(self.Expected, After) @@ -170,8 +170,8 @@ def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): class TestInlineCallOccurringInExpression(BaseTestCase): """Inline a Call node that is used in a function - The current implementation only replaces `tir.Call` instances that - occur in a `tir.Evaluate` context. This is the primary use case, + The current implementation only replaces `tirx.Call` instances that + occur in a `tirx.Evaluate` context. This is the primary use case, used in destination-passing style. This unit test is marked as xfail. If/when the implementation @@ -179,7 +179,7 @@ class TestInlineCallOccurringInExpression(BaseTestCase): expression, the annotation can be removed. """ - @pytest.mark.xfail(reason="Inlining of PrimFuncs outside of tir.Evaluate is not yet supported") + @pytest.mark.xfail(reason="Inlining of PrimFuncs outside of tirx.Evaluate is not yet supported") def test_produces_expected(self): super().test_produces_expected(self) diff --git a/tests/python/tir-transform/test_tir_transform_annotate_device_regions.py b/tests/python/tirx-transform/test_tir_transform_annotate_device_regions.py similarity index 93% rename from tests/python/tir-transform/test_tir_transform_annotate_device_regions.py rename to tests/python/tirx-transform/test_tir_transform_annotate_device_regions.py index bf6ba4e3c831..6d0b91015ec5 100644 --- a/tests/python/tir-transform/test_tir_transform_annotate_device_regions.py +++ b/tests/python/tirx-transform/test_tir_transform_annotate_device_regions.py @@ -18,7 +18,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_annotate_thread_extent(): @@ -41,7 +41,7 @@ def main(A: T.Buffer(16, "float32")): i = T.launch_thread("threadIdx.x", 16) A[i] = 0.0 - After = tvm.tir.transform.AnnotateDeviceRegions()(Before) + After = tvm.tirx.transform.AnnotateDeviceRegions()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -65,7 +65,7 @@ def main(A: T.Buffer(1, "float32")): T.attr(0, "device_scope", 0) A[0] = 0.0 - After = tvm.tir.transform.AnnotateDeviceRegions()(Before) + After = tvm.tirx.transform.AnnotateDeviceRegions()(Before) tvm.ir.assert_structural_equal(After, Expected) diff --git a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py b/tests/python/tirx-transform/test_tir_transform_bf16_legalize.py similarity index 93% rename from tests/python/tir-transform/test_tir_transform_bf16_legalize.py rename to tests/python/tirx-transform/test_tir_transform_bf16_legalize.py index 62f9dd332047..fdaa51622b6b 100644 --- a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py +++ b/tests/python/tirx-transform/test_tir_transform_bf16_legalize.py @@ -16,14 +16,14 @@ # under the License. import tvm import tvm.script -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target -from tvm.tir.transform.transform import BindTarget +from tvm.tirx.transform.transform import BindTarget def u16tof32(v): uint32_v = v.astype("uint32") - uint32_v = uint32_v << tvm.tir.const(16, "uint32") + uint32_v = uint32_v << tvm.tirx.const(16, "uint32") return T.reinterpret("float32", uint32_v) @@ -33,10 +33,10 @@ def bf16tof32(v): def f32tou16(v): uint32_v = T.reinterpret("uint32", v) - rounding_bias = (uint32_v >> tvm.tir.const(16, "uint32")) & tvm.tir.const(1, "uint32") - rounding_bias += tvm.tir.const(0x7FFF, "uint32") + rounding_bias = (uint32_v >> tvm.tirx.const(16, "uint32")) & tvm.tirx.const(1, "uint32") + rounding_bias += tvm.tirx.const(0x7FFF, "uint32") uint32_v = uint32_v + rounding_bias - return (uint32_v >> tvm.tir.const(16, "uint32")).astype("uint16") + return (uint32_v >> tvm.tirx.const(16, "uint32")).astype("uint16") def f32tobf16(v): @@ -100,8 +100,8 @@ def main( target = Target("nvidia/geforce-rtx-2080-ti") before = BindTarget(target)(get_before()) - after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) - after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) + after_compute = tvm.tirx.transform.BF16ComputeLegalize()(before) + after_storage = tvm.tirx.transform.BF16StorageLegalize()(after_compute) tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) @@ -169,8 +169,8 @@ def main( target = Target("nvidia/geforce-rtx-2080-ti") before = BindTarget(target)(get_before()) - after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) - after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) + after_compute = tvm.tirx.transform.BF16ComputeLegalize()(before) + after_storage = tvm.tirx.transform.BF16StorageLegalize()(after_compute) tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) @@ -238,8 +238,8 @@ def main( target = Target("nvidia/geforce-rtx-3090-ti") before = BindTarget(target)(get_before()) - after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) - after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) + after_compute = tvm.tirx.transform.BF16ComputeLegalize()(before) + after_storage = tvm.tirx.transform.BF16StorageLegalize()(after_compute) tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) @@ -346,8 +346,8 @@ def main( target = Target("nvidia/geforce-rtx-2080-ti") before = BindTarget(target)(get_before()) - after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) - after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) + after_compute = tvm.tirx.transform.BF16ComputeLegalize()(before) + after_storage = tvm.tirx.transform.BF16StorageLegalize()(after_compute) tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) @@ -442,8 +442,8 @@ def main( target = Target("nvidia/geforce-rtx-3090-ti") before = BindTarget(target)(get_before()) - after_compute = tvm.tir.transform.BF16ComputeLegalize()(before) - after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute) + after_compute = tvm.tirx.transform.BF16ComputeLegalize()(before) + after_storage = tvm.tirx.transform.BF16StorageLegalize()(after_compute) tvm.ir.assert_structural_equal(after_compute, BindTarget(target)(after_compute_legalize())) tvm.ir.assert_structural_equal(after_storage, BindTarget(target)(after_storage_legalize())) diff --git a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py b/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py similarity index 83% rename from tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py rename to tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py index 237ee8ba2492..8786720a2522 100644 --- a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py @@ -18,7 +18,7 @@ import tvm from tvm.ir.base import save_json -from tvm.script import tir as T +from tvm.script import tirx as T # ===================================================================== @@ -54,7 +54,7 @@ def main(B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, z3: T.int32): b = T.bind(cse_v2 + z3) B[i2] = a + b - after = tvm.tir.transform.CommonSubexprElim()(Before) + after = tvm.tirx.transform.CommonSubexprElim()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -100,7 +100,7 @@ def main( else: B[i3] = y - after = tvm.tir.transform.CommonSubexprElim()(Before) + after = tvm.tirx.transform.CommonSubexprElim()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -146,7 +146,7 @@ def main( else: B[i3] = cse_v1 - after = tvm.tir.transform.CommonSubexprElim()(Before) + after = tvm.tirx.transform.CommonSubexprElim()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -189,7 +189,7 @@ def main( B[i2] = cse_v1 B[i3] = cse_v2 - after = tvm.tir.transform.CommonSubexprElim()(Before) + after = tvm.tirx.transform.CommonSubexprElim()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -212,7 +212,7 @@ def main(x: T.int32, y: T.int32, z: T.int32): a = T.bind(x + (y + z)) T.evaluate(a) - after = tvm.tir.transform.CommonSubexprElim()(Before) + after = tvm.tirx.transform.CommonSubexprElim()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -225,8 +225,8 @@ def test_deterministic(): NUM_TERMS = 10 REPEATS = 10 - x = tvm.tir.Var("x", "int32") - result = tvm.tir.Var("result", "int32") + x = tvm.tirx.Var("x", "int32") + result = tvm.tirx.Var("result", "int32") offsets = sorted([i + 1 for i in range(NUM_TERMS)]) inc1 = [(x + offsets[i]) for i in range(NUM_TERMS)] @@ -235,12 +235,12 @@ def test_deterministic(): expression = x for add in inc1 + inc2: expression = expression + add - body = tvm.tir.SeqStmt([tvm.tir.Bind(result, expression), tvm.tir.Evaluate(result)]) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([x], body)) + body = tvm.tirx.SeqStmt([tvm.tirx.Bind(result, expression), tvm.tirx.Evaluate(result)]) + mod = tvm.IRModule.from_expr(tvm.tirx.PrimFunc([x], body)) initial_hash = None for _ in range(REPEATS): - out = tvm.tir.transform.CommonSubexprElim()(mod) + out = tvm.tirx.transform.CommonSubexprElim()(mod) func = out["main"] json_val = save_json(func) json_hash = hashlib.sha256(json_val.encode()).hexdigest() @@ -271,7 +271,7 @@ def main(B: T.Buffer((50,), "int32"), y: T.int32, z: T.int32): B[i] = cse_v1 B[i + 10] = cse_v1 - after = tvm.tir.transform.CommonSubexprElim()(Before) + after = tvm.tirx.transform.CommonSubexprElim()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -298,7 +298,7 @@ def main(B: T.Buffer((50,), "int32"), y: T.int32, z: T.int32): for i in range(10): B[i + 1] = cse_v1 - after = tvm.tir.transform.CommonSubexprElim()(Before) + after = tvm.tirx.transform.CommonSubexprElim()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -322,7 +322,7 @@ def main(A: T.Buffer((50,), "int32"), B: T.Buffer((50,), "int32")): B[0] = A[0] + A[0] B[1] = A[0] + A[0] - after = tvm.tir.transform.CommonSubexprElim()(Before) + after = tvm.tirx.transform.CommonSubexprElim()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -369,7 +369,7 @@ def main( else: B[2] = y - after = tvm.tir.transform.CommonSubexprElim()(Before) + after = tvm.tirx.transform.CommonSubexprElim()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -410,7 +410,7 @@ def main( B[2] = cse_v1 B[3] = cse_v2 - after = tvm.tir.transform.CommonSubexprElim()(Before) + after = tvm.tirx.transform.CommonSubexprElim()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -435,7 +435,7 @@ def main(B: T.Buffer((50,), "int32"), y: T.int32, z: T.int32): if cse_v1 > 0: B[0] = cse_v1 - after = tvm.tir.transform.CommonSubexprElim()(Before) + after = tvm.tirx.transform.CommonSubexprElim()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -458,7 +458,7 @@ def main(B: T.Buffer((50,), "int32"), x: T.int32): B[0] = T.call_extern("my_func", x, dtype="int32") + 1 B[1] = T.call_extern("my_func", x, dtype="int32") + 1 - after = tvm.tir.transform.CommonSubexprElim()(Before) + after = tvm.tirx.transform.CommonSubexprElim()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -494,7 +494,7 @@ def main( B[0] = cse_v1 B[1] = cse_v1 - after = tvm.tir.transform.CommonSubexprElim()(Before) + after = tvm.tirx.transform.CommonSubexprElim()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -519,7 +519,7 @@ def main(B: T.Buffer((50,), "int32"), y: T.int32, z: T.int32): for i in range(cse_v1): B[i] = cse_v1 - after = tvm.tir.transform.CommonSubexprElim()(Before) + after = tvm.tirx.transform.CommonSubexprElim()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -550,7 +550,7 @@ def main( cse_v1 = T.bind(i * 4) A[cse_v1] = B[cse_v1] - after = tvm.tir.transform.CommonSubexprElim()(Before) + after = tvm.tirx.transform.CommonSubexprElim()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -573,7 +573,7 @@ def main(x: T.int32, y: T.int32, z: T.int32): a = T.bind(x + (y + z)) T.evaluate(a) - after = tvm.tir.transform.CommonSubexprElim()(Before) + after = tvm.tirx.transform.CommonSubexprElim()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -585,23 +585,23 @@ def main(x: T.int32, y: T.int32, z: T.int32): # ===================================================================== def test_let_body_no_extraction(): """CSE must not extract expressions from Let bodies that use Let-bound vars.""" - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") # Let(x, 1, (x+y) + (x+y)) -- x+y appears twice but x is Let-bound - let_expr = tvm.tir.Let(x, tvm.tir.IntImm("int32", 1), (x + y) + (x + y)) - buf = tvm.tir.decl_buffer((10,), "int32", name="B") - i = tvm.tir.Var("i", "int32") - store = tvm.tir.BufferStore(buf, let_expr, [i]) - loop = tvm.tir.For( + let_expr = tvm.tirx.Let(x, tvm.tirx.IntImm("int32", 1), (x + y) + (x + y)) + buf = tvm.tirx.decl_buffer((10,), "int32", name="B") + i = tvm.tirx.Var("i", "int32") + store = tvm.tirx.BufferStore(buf, let_expr, [i]) + loop = tvm.tirx.For( i, - tvm.tir.const(0, "int32"), - tvm.tir.const(10, "int32"), - tvm.tir.ForKind.SERIAL, + tvm.tirx.const(0, "int32"), + tvm.tirx.const(10, "int32"), + tvm.tirx.ForKind.SERIAL, store, ) - func = tvm.tir.PrimFunc([buf, y], loop) + func = tvm.tirx.PrimFunc([buf, y], loop) mod = tvm.IRModule({"main": func}) - mod_after = tvm.tir.transform.CommonSubexprElim()(mod) + mod_after = tvm.tirx.transform.CommonSubexprElim()(mod) # No CSE variables should be introduced script = mod_after["main"].script() assert "cse_v" not in script, f"CSE incorrectly extracted from Let body:\n{script}" @@ -614,24 +614,24 @@ def test_let_body_no_extraction(): # ===================================================================== def test_let_value_cse(): """CSE can extract from Let values (computed before binding).""" - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - z = tvm.tir.Var("z", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + z = tvm.tirx.Var("z", "int32") # Let(x, y+z, x+1) with y+z also appearing outside the Let - let_expr = tvm.tir.Let(x, y + z, x + 1) - buf = tvm.tir.decl_buffer((10,), "int32", name="B") - i = tvm.tir.Var("i", "int32") - store = tvm.tir.BufferStore(buf, (y + z) + let_expr, [i]) - loop = tvm.tir.For( + let_expr = tvm.tirx.Let(x, y + z, x + 1) + buf = tvm.tirx.decl_buffer((10,), "int32", name="B") + i = tvm.tirx.Var("i", "int32") + store = tvm.tirx.BufferStore(buf, (y + z) + let_expr, [i]) + loop = tvm.tirx.For( i, - tvm.tir.const(0, "int32"), - tvm.tir.const(10, "int32"), - tvm.tir.ForKind.SERIAL, + tvm.tirx.const(0, "int32"), + tvm.tirx.const(10, "int32"), + tvm.tirx.ForKind.SERIAL, store, ) - func = tvm.tir.PrimFunc([buf, y, z], loop) + func = tvm.tirx.PrimFunc([buf, y, z], loop) mod = tvm.IRModule({"main": func}) - mod_after = tvm.tir.transform.CommonSubexprElim()(mod) + mod_after = tvm.tirx.transform.CommonSubexprElim()(mod) # y+z should be extracted (appears in Let value AND outside) script = mod_after["main"].script() assert "cse_v" in script, f"CSE should extract y+z from Let value:\n{script}" @@ -644,27 +644,27 @@ def test_let_value_cse(): # ===================================================================== def test_nested_let_no_extraction(): """CSE must not extract from nested Let bodies.""" - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - z = tvm.tir.Var("z", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + z = tvm.tirx.Var("z", "int32") # Let(x, 1, Let(y, 2, (x+y+z) + (x+y+z))) inner = (x + y + z) + (x + y + z) - nested_let = tvm.tir.Let( - x, tvm.tir.IntImm("int32", 1), tvm.tir.Let(y, tvm.tir.IntImm("int32", 2), inner) + nested_let = tvm.tirx.Let( + x, tvm.tirx.IntImm("int32", 1), tvm.tirx.Let(y, tvm.tirx.IntImm("int32", 2), inner) ) - buf = tvm.tir.decl_buffer((10,), "int32", name="B") - i = tvm.tir.Var("i", "int32") - store = tvm.tir.BufferStore(buf, nested_let, [i]) - loop = tvm.tir.For( + buf = tvm.tirx.decl_buffer((10,), "int32", name="B") + i = tvm.tirx.Var("i", "int32") + store = tvm.tirx.BufferStore(buf, nested_let, [i]) + loop = tvm.tirx.For( i, - tvm.tir.const(0, "int32"), - tvm.tir.const(10, "int32"), - tvm.tir.ForKind.SERIAL, + tvm.tirx.const(0, "int32"), + tvm.tirx.const(10, "int32"), + tvm.tirx.ForKind.SERIAL, store, ) - func = tvm.tir.PrimFunc([buf, z], loop) + func = tvm.tirx.PrimFunc([buf, z], loop) mod = tvm.IRModule({"main": func}) - mod_after = tvm.tir.transform.CommonSubexprElim()(mod) + mod_after = tvm.tirx.transform.CommonSubexprElim()(mod) script = mod_after["main"].script() assert "cse_v" not in script, f"CSE incorrectly extracted from nested Let body:\n{script}" @@ -678,37 +678,37 @@ def test_nested_let_no_extraction(): # ===================================================================== def test_let_floordiv_pattern(): """CSE must handle the Let pattern from LowerIntrin's floordiv lowering.""" - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - rmod = tvm.tir.Var("rmod", "int32") - rdiv = tvm.tir.Var("rdiv", "int32") + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + rmod = tvm.tirx.Var("rmod", "int32") + rdiv = tvm.tirx.Var("rdiv", "int32") # Simulate lowered floordiv: Let(rmod, x%y, Let(rdiv, x/y, Select(...))) - select_cond = tvm.tir.And(y >= 0, rmod >= 0) | tvm.tir.And(y < 0, rmod <= 0) - select_expr = tvm.tir.Select(select_cond, rdiv, rdiv - 1) - inner_let = tvm.tir.Let(rdiv, tvm.tir.Div(x, y), select_expr) - outer_let = tvm.tir.Let(rmod, tvm.tir.Mod(x, y), inner_let) + select_cond = tvm.tirx.And(y >= 0, rmod >= 0) | tvm.tirx.And(y < 0, rmod <= 0) + select_expr = tvm.tirx.Select(select_cond, rdiv, rdiv - 1) + inner_let = tvm.tirx.Let(rdiv, tvm.tirx.Div(x, y), select_expr) + outer_let = tvm.tirx.Let(rmod, tvm.tirx.Mod(x, y), inner_let) # Wrap in Let(x, load, Let(y, load, ...)) - buf_a = tvm.tir.decl_buffer((10,), "int32", name="A") - buf_b = tvm.tir.decl_buffer((10,), "int32", name="B") - buf_c = tvm.tir.decl_buffer((10,), "int32", name="C") - i = tvm.tir.Var("i", "int32") - full_expr = tvm.tir.Let( + buf_a = tvm.tirx.decl_buffer((10,), "int32", name="A") + buf_b = tvm.tirx.decl_buffer((10,), "int32", name="B") + buf_c = tvm.tirx.decl_buffer((10,), "int32", name="C") + i = tvm.tirx.Var("i", "int32") + full_expr = tvm.tirx.Let( x, - tvm.tir.BufferLoad(buf_a, [i]), - tvm.tir.Let(y, tvm.tir.BufferLoad(buf_b, [i]), outer_let), + tvm.tirx.BufferLoad(buf_a, [i]), + tvm.tirx.Let(y, tvm.tirx.BufferLoad(buf_b, [i]), outer_let), ) - store = tvm.tir.BufferStore(buf_c, full_expr, [i]) - loop = tvm.tir.For( + store = tvm.tirx.BufferStore(buf_c, full_expr, [i]) + loop = tvm.tirx.For( i, - tvm.tir.const(0, "int32"), - tvm.tir.const(10, "int32"), - tvm.tir.ForKind.SERIAL, + tvm.tirx.const(0, "int32"), + tvm.tirx.const(10, "int32"), + tvm.tirx.ForKind.SERIAL, store, ) - func = tvm.tir.PrimFunc([buf_a, buf_b, buf_c], loop) + func = tvm.tirx.PrimFunc([buf_a, buf_b, buf_c], loop) mod = tvm.IRModule({"main": func}) # Should not crash and should not extract Let-bound vars - mod_after = tvm.tir.transform.CommonSubexprElim()(mod) + mod_after = tvm.tirx.transform.CommonSubexprElim()(mod) script = mod_after["main"].script() assert "cse_v" not in script, f"CSE incorrectly extracted from Let body:\n{script}" diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py b/tests/python/tirx-transform/test_tir_transform_convert_ssa.py similarity index 79% rename from tests/python/tir-transform/test_tir_transform_convert_ssa.py rename to tests/python/tirx-transform/test_tir_transform_convert_ssa.py index 625001bf9f8f..fd92753d6bfa 100644 --- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py +++ b/tests/python/tirx-transform/test_tir_transform_convert_ssa.py @@ -18,9 +18,9 @@ import tvm import tvm.testing -from tvm import ir, tir +from tvm import ir, tirx from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_reuse_in_sequential_bind(): @@ -29,16 +29,16 @@ def test_reuse_in_sequential_bind(): # Manually construct the PrimFunc body, as SSA violations are # not valid TIR, and may not be expressible in future versions # of TVMSCript. - var = tir.Var("var", "int32") - sequential_bindings = tir.SeqStmt( + var = tirx.Var("var", "int32") + sequential_bindings = tirx.SeqStmt( [ - tir.Bind(var, 16), - tir.Evaluate(var), - tir.Bind(var, 32), - tir.Evaluate(var), + tirx.Bind(var, 16), + tirx.Evaluate(var), + tirx.Bind(var, 32), + tirx.Evaluate(var), ] ) - before = tir.PrimFunc([], sequential_bindings) + before = tirx.PrimFunc([], sequential_bindings) @T.prim_func(private=True) def expected(): @@ -48,7 +48,7 @@ def expected(): T.evaluate(var2) mod = tvm.IRModule.from_expr(before) - mod = tvm.tir.transform.ConvertSSA()(mod) + mod = tvm.tirx.transform.ConvertSSA()(mod) tvm.ir.assert_structural_equal(mod["main"], expected) @@ -64,42 +64,42 @@ def test_reuse_in_nested_bind(): # Manually construct the PrimFunc body, as SSA violations are # not valid TIR, and may not be expressible in future versions # of TVMScript. - var = tir.Var("var", "int32") + var = tirx.Var("var", "int32") # Note: nested SeqStmt is flattened by the IR builder, so the input # is actually a flat SeqStmt with 5 elements. - inner_seq = tir.SeqStmt( + inner_seq = tirx.SeqStmt( [ - tir.Bind(var, 16), - tir.Evaluate(var), + tirx.Bind(var, 16), + tirx.Evaluate(var), ] ) - outer_seq = tir.SeqStmt( + outer_seq = tirx.SeqStmt( [ - tir.Bind(var, 32), - tir.Evaluate(var), + tirx.Bind(var, 32), + tirx.Evaluate(var), inner_seq, - tir.Evaluate(var), + tirx.Evaluate(var), ] ) - before = tir.PrimFunc([], outer_seq) + before = tirx.PrimFunc([], outer_seq) # In the flat model, the second Bind(var, 16) redefines var for # ALL subsequent siblings including the last Evaluate. - var1 = tir.Var("var", "int32") - var2 = tir.Var("var", "int32") - expected_body = tir.SeqStmt( + var1 = tirx.Var("var", "int32") + var2 = tirx.Var("var", "int32") + expected_body = tirx.SeqStmt( [ - tir.Bind(var1, 32), - tir.Evaluate(var1), - tir.Bind(var2, 16), - tir.Evaluate(var2), - tir.Evaluate(var2), + tirx.Bind(var1, 32), + tirx.Evaluate(var1), + tirx.Bind(var2, 16), + tirx.Evaluate(var2), + tirx.Evaluate(var2), ] ) - expected = tir.PrimFunc([], expected_body) + expected = tirx.PrimFunc([], expected_body) mod = tvm.IRModule.from_expr(before) - mod = tvm.tir.transform.ConvertSSA()(mod) + mod = tvm.tirx.transform.ConvertSSA()(mod) tvm.ir.assert_structural_equal(mod["main"], expected) @@ -130,14 +130,14 @@ def func_b(): var = T.int32(10) T.evaluate(var) - after = tvm.tir.transform.ConvertSSA()(before) + after = tvm.tirx.transform.ConvertSSA()(before) tvm.ir.assert_structural_equal(after, expected) def test_reused_parameter(): """De-duplicate Var usage in parameters - In this test, the same `tir.Var` instance is used for the + In this test, the same `tirx.Var` instance is used for the parameter `n` in both functions. """ @@ -162,7 +162,7 @@ def func_a(n: T.int32): def func_b(n: T.int32): T.evaluate(n) - after = tvm.tir.transform.ConvertSSA()(before) + after = tvm.tirx.transform.ConvertSSA()(before) tvm.ir.assert_structural_equal(after, expected) @@ -193,7 +193,7 @@ def func_b(a: T.handle("float32")): A = T.decl_buffer(shape=1, dtype="float32", data=a) T.evaluate(A[0]) - after = tvm.tir.transform.ConvertSSA()(before) + after = tvm.tirx.transform.ConvertSSA()(before) tvm.ir.assert_structural_equal(after, expected) @@ -221,7 +221,7 @@ def func_a(A: T.Buffer(1, "float32")): def func_b(A: T.Buffer(1, "float32")): T.evaluate(A[0]) - after = tvm.tir.transform.ConvertSSA()(before) + after = tvm.tirx.transform.ConvertSSA()(before) tvm.ir.assert_structural_equal(after, expected) @@ -234,7 +234,7 @@ class before: def func(A: T.Buffer(1, "float32")): T.evaluate(A[0]) - after = tvm.tir.transform.ConvertSSA()(before) + after = tvm.tirx.transform.ConvertSSA()(before) tvm.ir.assert_structural_equal(before, after) assert before.same_as(after) @@ -243,11 +243,11 @@ def test_keep_duplicate_thread_idx_in_same_function(): """Environment threads are treated as being at function scope The `"thread_extent"` attribute has some unique semantics. It - serves as the definition of the `tir::Var` representing the + serves as the definition of the `tirx::Var` representing the environment thread (e.g. `threadIdx.x` in CUDA). However, multiple `"thread_extent"` attributes may co-exist in the same PrimFunc. For the purpose of variable scope, use of the - `tir::Var` is only allowed within the body of the `AttrStmt`. + `tirx::Var` is only allowed within the body of the `AttrStmt`. However, for the purpose of well-formed-ness, all `"thread_extent"` attributes must use the same IterVar instance (e.g. `WarpIndexFinder` in `lower_warp_memory.cc` may throw an @@ -270,7 +270,7 @@ def main(A: T.Buffer([256], "float32")): with T.launch_thread(threadIdx_x, 256): A[threadIdx_x] = A[threadIdx_x] + 2.0 - after = tvm.tir.transform.ConvertSSA()(before) + after = tvm.tirx.transform.ConvertSSA()(before) tvm.ir.assert_structural_equal(after, before) @@ -292,7 +292,7 @@ def test_de_duplicate_thread_idx_across_multiple_functions(): Var/IterVar usage across the two PrimFuncs. """ - threadIdx_x = tvm.tir.Var("threadIdx_x", "int32") + threadIdx_x = tvm.tirx.Var("threadIdx_x", "int32") # threadIdx_x is defined outside @I.ir_module(check_well_formed=False) @@ -337,7 +337,7 @@ def kernel_2(A: T.Buffer([256], "float32")): ) A[threadIdx_x] = A[threadIdx_x] + T.float32(1) - after = tvm.tir.transform.ConvertSSA()(before) + after = tvm.tirx.transform.ConvertSSA()(before) tvm.ir.assert_structural_equal(after, expected) @@ -346,12 +346,12 @@ def test_de_duplicate_thread_idx_iter_var_across_multiple_functions(): Like `test_de_duplicate_thread_idx_across_multiple_functions`, except the `IterVar` for the environment thread is duplicated across multiple - PrimFuncs, not just the `tir.Var` inside the `IterVar`. + PrimFuncs, not just the `tirx.Var` inside the `IterVar`. """ - threadIdx_x = tvm.tir.Var("threadIdx_x", "int32") - iter_var = tvm.tir.IterVar( - tvm.ir.Range(0, 256), threadIdx_x, tvm.tir.IterVar.ThreadIndex, "threadIdx.x" + threadIdx_x = tvm.tirx.Var("threadIdx_x", "int32") + iter_var = tvm.tirx.IterVar( + tvm.ir.Range(0, 256), threadIdx_x, tvm.tirx.IterVar.ThreadIndex, "threadIdx.x" ) # complaints of multiple definitions for threadIdx_x @@ -389,7 +389,7 @@ def kernel_2(A: T.Buffer([256], "float32")): ) A[threadIdx_x] = A[threadIdx_x] + T.float32(1) - after = tvm.tir.transform.ConvertSSA()(before) + after = tvm.tirx.transform.ConvertSSA()(before) tvm.ir.assert_structural_equal(after, expected) @@ -403,9 +403,9 @@ def test_thread_idx_reused_within_and_across_functions(): de-duplicated. """ - threadIdx_x = tvm.tir.Var("threadIdx_x", "int32") - iter_var = tvm.tir.IterVar( - tvm.ir.Range(0, 256), threadIdx_x, tvm.tir.IterVar.ThreadIndex, "threadIdx.x" + threadIdx_x = tvm.tirx.Var("threadIdx_x", "int32") + iter_var = tvm.tirx.IterVar( + tvm.ir.Range(0, 256), threadIdx_x, tvm.tirx.IterVar.ThreadIndex, "threadIdx.x" ) # complaints of multiple definitions of threadIdx_x @@ -443,64 +443,64 @@ def kernel_2(A: T.Buffer([256], "float32")): with T.launch_thread(threadIdx_x, 256): A[threadIdx_x] = A[threadIdx_x] + 2.0 - after = tvm.tir.transform.ConvertSSA()(before) + after = tvm.tirx.transform.ConvertSSA()(before) tvm.ir.assert_structural_equal(after, expected) def test_track_forward_declarations_in_attr_stmt(): - """T.attr statements may refer to a about-to-be-defined tir.Var""" + """T.attr statements may refer to a about-to-be-defined tirx.Var""" # Generate the PrimFunc, which is already SSA # # This is constructed directly, rather than using TVMScript. - # This test case requires a `tir.AttrStmt` that references a - # variable, followed by the `tir.For` defining that variable. + # This test case requires a `tirx.AttrStmt` that references a + # variable, followed by the `tirx.For` defining that variable. # This is not expressible in TVMScript, as it only provides the # loop iterator within the body of the loop. - i0_outer_outer = tir.Var("i0_outer_outer", "int32") - i0_outer_inner = tir.Var("i0_outer_inner", "int32") - i0_inner = tir.Var("i0_inner", "int32") + i0_outer_outer = tirx.Var("i0_outer_outer", "int32") + i0_outer_inner = tirx.Var("i0_outer_inner", "int32") + i0_inner = tirx.Var("i0_inner", "int32") - A = tir.decl_buffer(1024, "float32", "A") - B = tir.decl_buffer(1024, "float32", "B") + A = tirx.decl_buffer(1024, "float32", "A") + B = tirx.decl_buffer(1024, "float32", "B") index = i0_outer_outer * 52 + i0_outer_inner * 4 + i0_inner - stmt = tir.BufferStore(B, tir.BufferLoad(A, [index]), [index]) - stmt = tir.IfThenElse(i0_outer_outer * 13 + i0_outer_inner < 256, stmt, None) - stmt = tir.For(i0_inner, 0, 4, tir.ForKind.VECTORIZED, stmt) - stmt = tir.For(i0_outer_inner, 0, 13, tir.ForKind.PARALLEL, stmt) - stmt = tir.AttrStmt( + stmt = tirx.BufferStore(B, tirx.BufferLoad(A, [index]), [index]) + stmt = tirx.IfThenElse(i0_outer_outer * 13 + i0_outer_inner < 256, stmt, None) + stmt = tirx.For(i0_inner, 0, 4, tirx.ForKind.VECTORIZED, stmt) + stmt = tirx.For(i0_outer_inner, 0, 13, tirx.ForKind.PARALLEL, stmt) + stmt = tirx.AttrStmt( T.iter_var(i0_outer_inner, None, "DataPar", ""), "pragma_parallal_barrier_when_finish", 1, stmt, ) - stmt = tir.AttrStmt( + stmt = tirx.AttrStmt( T.iter_var(i0_outer_inner, None, "DataPar", ""), "pragma_parallal_stride_pattern", 1, stmt, ) - stmt = tir.For(i0_outer_outer, 0, 20, tir.ForKind.SERIAL, stmt) - stmt = tir.AttrStmt( + stmt = tirx.For(i0_outer_outer, 0, 20, tirx.ForKind.SERIAL, stmt) + stmt = tirx.AttrStmt( T.iter_var(i0_outer_outer, None, "DataPar", ""), "pragma_parallal_launch_point", 1, stmt, ) - A_handle = tir.Var("A_handle", "handle") - B_handle = tir.Var("B_handle", "handle") + A_handle = tirx.Var("A_handle", "handle") + B_handle = tirx.Var("B_handle", "handle") - before = tir.PrimFunc( + before = tirx.PrimFunc( [A_handle, B_handle], stmt, buffer_map={A_handle: A, B_handle: B}, ) mod = tvm.IRModule.from_expr(before) - after = tvm.tir.transform.ConvertSSA()(mod) + after = tvm.tirx.transform.ConvertSSA()(mod) tvm.ir.assert_structural_equal(after["main"], before) @@ -513,24 +513,24 @@ def test_shared_shape_var_in_buffer_map_and_alloc_buffer(): function body (including AllocBuffer shapes) must remain the same Var object so that MakePackedAPI can bind it from the DLTensor shape. """ - n = tir.SizeVar("n", "int32") - A_handle = tir.Var("A_handle", "handle") - B_handle = tir.Var("B_handle", "handle") - A = tir.decl_buffer((n,), "float32", "A") - B = tir.decl_buffer((n,), "float32", "B") + n = tirx.SizeVar("n", "int32") + A_handle = tirx.Var("A_handle", "handle") + B_handle = tirx.Var("B_handle", "handle") + A = tirx.decl_buffer((n,), "float32", "A") + B = tirx.decl_buffer((n,), "float32", "B") # AllocBuffer with shape [n] in the body (flat, no body) - C = tir.decl_buffer((n,), "float32", "C") - body = tir.SeqStmt([tir.AllocBuffer(C), tir.Evaluate(1)]) + C = tirx.decl_buffer((n,), "float32", "C") + body = tirx.SeqStmt([tirx.AllocBuffer(C), tirx.Evaluate(1)]) - before = tir.PrimFunc( + before = tirx.PrimFunc( [A_handle, B_handle], body, buffer_map={A_handle: A, B_handle: B}, ) mod = tvm.IRModule.from_expr(before) - after = tvm.tir.transform.ConvertSSA()(mod) + after = tvm.tirx.transform.ConvertSSA()(mod) # The function is already SSA — ConvertSSA should not change it. tvm.ir.assert_structural_equal(after["main"], before) diff --git a/tests/python/tir-transform/test_tir_transform_device_kernel_launch.py b/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py similarity index 88% rename from tests/python/tir-transform/test_tir_transform_device_kernel_launch.py rename to tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py index e9c5d9776409..6d77d7e87164 100644 --- a/tests/python/tir-transform/test_tir_transform_device_kernel_launch.py +++ b/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py @@ -18,17 +18,17 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_lower_device_kernel_launch(): """Kernel launch parameters are added at the call site - The "tir.kernel_launch_params" determines which parameters belong + The "tirx.kernel_launch_params" determines which parameters belong to the runtime, and which below to the device-side PrimFunc. Parameters that are required prior to launching a kernel (e.g. the number of CUDA threads to use) are stored in the - `"tir.kernel_launch_params"` attribute, and are used by the + `"tirx.kernel_launch_params"` attribute, and are used by the runtime prior in order to launch the generated kernel. """ @@ -58,15 +58,15 @@ def kernel(A_data: T.handle("float32")): { "target": T.target("cuda"), "calling_conv": 2, - "tir.kernel_launch_params": [], + "tirx.kernel_launch_params": [], "global_symbol": "kernel", - "tir.is_global_func": True, + "tirx.is_global_func": True, } ) A = T.decl_buffer(1, dtype="float32", data=A_data) A[0] = 0.0 - After = tvm.tir.transform.LowerDeviceKernelLaunch()(Before) + After = tvm.tirx.transform.LowerDeviceKernelLaunch()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -109,26 +109,26 @@ def kernel(A_data: T.handle("float32")): { "target": T.target("cuda"), "calling_conv": 2, - "tir.kernel_launch_params": [], + "tirx.kernel_launch_params": [], "global_symbol": "kernel_by_another_name", - "tir.is_global_func": True, + "tirx.is_global_func": True, } ) A = T.decl_buffer(1, dtype="float32", data=A_data) A[0] = 0.0 - After = tvm.tir.transform.LowerDeviceKernelLaunch()(Before) + After = tvm.tirx.transform.LowerDeviceKernelLaunch()(Before) tvm.ir.assert_structural_equal(After, Expected) def test_collect_launch_parameter(): """Kernel launch parameters are added at the call site - The "tir.kernel_launch_params" determines which parameters belong + The "tirx.kernel_launch_params" determines which parameters belong to the runtime, and which below to the device-side PrimFunc. Parameters that are required prior to launching a kernel (e.g. the number of CUDA threads to use) are stored in the - `"tir.kernel_launch_params"` attribute, and are used by the + `"tirx.kernel_launch_params"` attribute, and are used by the runtime prior in order to launch the generated kernel. """ @@ -164,16 +164,16 @@ def kernel(A_data: T.handle("float32")): { "target": T.target("cuda"), "calling_conv": 2, - "tir.kernel_launch_params": ["threadIdx.x"], + "tirx.kernel_launch_params": ["threadIdx.x"], "global_symbol": "kernel", - "tir.is_global_func": True, + "tirx.is_global_func": True, } ) A = T.decl_buffer(16, dtype="float32", data=A_data) i = T.launch_thread("threadIdx.x", 16) A[i] = 0.0 - After = tvm.tir.transform.LowerDeviceKernelLaunch()(Before) + After = tvm.tirx.transform.LowerDeviceKernelLaunch()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -213,13 +213,13 @@ def kernel(A_data: T.handle("float32")): { "target": T.target("c"), "global_symbol": "kernel", - "tir.is_global_func": True, + "tirx.is_global_func": True, } ) A = T.decl_buffer(16, dtype="float32", data=A_data) A[0] = 0.0 - After = tvm.tir.transform.LowerDeviceKernelLaunch()(Before) + After = tvm.tirx.transform.LowerDeviceKernelLaunch()(Before) tvm.ir.assert_structural_equal(After, Expected) diff --git a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py b/tests/python/tirx-transform/test_tir_transform_flatten_buffer.py similarity index 99% rename from tests/python/tir-transform/test_tir_transform_flatten_buffer.py rename to tests/python/tirx-transform/test_tir_transform_flatten_buffer.py index 36ce63d0da41..06f041ce25e8 100644 --- a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py +++ b/tests/python/tirx-transform/test_tir_transform_flatten_buffer.py @@ -17,14 +17,14 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def _transform(): return tvm.transform.Sequential( [ - tvm.tir.transform.FlattenBuffer(), - tvm.tir.transform.Simplify(), + tvm.tirx.transform.FlattenBuffer(), + tvm.tirx.transform.Simplify(), ] ) diff --git a/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py b/tests/python/tirx-transform/test_tir_transform_force_narrow_index_to_i32.py similarity index 93% rename from tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py rename to tests/python/tirx-transform/test_tir_transform_force_narrow_index_to_i32.py index 0fdbac91473d..41232a8694bc 100644 --- a/tests/python/tir-transform/test_tir_transform_force_narrow_index_to_i32.py +++ b/tests/python/tirx-transform/test_tir_transform_force_narrow_index_to_i32.py @@ -20,7 +20,7 @@ import tvm import tvm.testing from tvm import TVMError -from tvm.script import tir as T +from tvm.script import tirx as T def test_thread_axis1(): @@ -43,7 +43,7 @@ def expected(A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32")): B[blockIdx_x * 32 + threadIdx_x] = A[blockIdx_x * 32 + threadIdx_x] + T.float32(1) mod = tvm.IRModule.from_expr(before) - func = tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"] + func = tvm.tirx.transform.ForceNarrowIndexToInt32()(mod)["main"] tvm.ir.assert_structural_equal(func, expected) @@ -54,7 +54,7 @@ def before( placeholder_1: T.Buffer((T.int64(1), T.int64(12), T.int64(384), 384), "bool"), T_where: T.Buffer((T.int64(1), T.int64(12), T.int64(384), 384), "float32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"): for i0_i1_i2_i3_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): for i0_i1_i2_i3_fused_0 in T.serial(T.int64(7)): @@ -112,7 +112,7 @@ def expected( placeholder_1: T.Buffer((1, 12, 384, 384), "bool"), T_where: T.Buffer((1, 12, 384, 384), "float32"), ): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i0_i1_i2_i3_fused_1 in T.thread_binding(256, thread="blockIdx.x"): for i0_i1_i2_i3_fused_2 in T.thread_binding(1024, thread="threadIdx.x"): for i0_i1_i2_i3_fused_0 in range(7): @@ -158,7 +158,7 @@ def expected( ) mod = tvm.IRModule.from_expr(before) - func = tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"] + func = tvm.tirx.transform.ForceNarrowIndexToInt32()(mod)["main"] tvm.ir.assert_structural_equal(func, expected) @@ -180,7 +180,7 @@ def expected(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): B[vi] = A[vi] + T.float32(1) mod = tvm.IRModule.from_expr(before) - func = tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"] + func = tvm.tirx.transform.ForceNarrowIndexToInt32()(mod)["main"] tvm.ir.assert_structural_equal(func, expected) @@ -202,7 +202,7 @@ def expected(A: T.Buffer((128,), "int16"), B: T.Buffer((128,), "int16")): B[vi] = A[vi] + T.int16(1) mod = tvm.IRModule.from_expr(before) - after = tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"] + after = tvm.tirx.transform.ForceNarrowIndexToInt32()(mod)["main"] tvm.ir.assert_structural_equal(after, expected) @@ -217,7 +217,7 @@ def func(A: T.Buffer((128,), "int64"), B: T.Buffer((128,), "int64")): mod = tvm.IRModule.from_expr(func) with pytest.raises(TVMError): - tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"] + tvm.tirx.transform.ForceNarrowIndexToInt32()(mod)["main"] def test_fail_on_buffer_map(): @@ -237,7 +237,7 @@ def func(A: T.Buffer((128,), "int32"), B: T.Buffer((128,), "int32")): mod = tvm.IRModule.from_expr(func) with pytest.raises(TVMError): - tvm.tir.transform.ForceNarrowIndexToInt32()(mod)["main"] + tvm.tirx.transform.ForceNarrowIndexToInt32()(mod)["main"] def test_pod_params_and_select(): @@ -257,7 +257,7 @@ def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32"), n: T.int32) for i in range(4): B[i] = T.Select(1 <= i, A[i + n], T.Cast("float32", i)) - after = tvm.tir.transform.ForceNarrowIndexToInt32()(Before) + after = tvm.tirx.transform.ForceNarrowIndexToInt32()(Before) tvm.ir.assert_structural_equal(Expected, after) @@ -276,7 +276,7 @@ def main(B: T.Buffer((4,), "int32")): for i in range(4): B[i] = T.clz(i) - 32 + 64 - after = tvm.tir.transform.ForceNarrowIndexToInt32()(Before) + after = tvm.tirx.transform.ForceNarrowIndexToInt32()(Before) tvm.ir.assert_structural_equal(Expected, after) @@ -301,7 +301,7 @@ def main(buf: T.handle): for i in range(ceil_log2): T.evaluate(0) - after = tvm.tir.transform.ForceNarrowIndexToInt32()(Before) + after = tvm.tirx.transform.ForceNarrowIndexToInt32()(Before) tvm.ir.assert_structural_equal(Expected, after) diff --git a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py b/tests/python/tirx-transform/test_tir_transform_fp8_legalize.py similarity index 97% rename from tests/python/tir-transform/test_tir_transform_fp8_legalize.py rename to tests/python/tirx-transform/test_tir_transform_fp8_legalize.py index 0b10fe5c2199..39a149a2e0b7 100644 --- a/tests/python/tir-transform/test_tir_transform_fp8_legalize.py +++ b/tests/python/tirx-transform/test_tir_transform_fp8_legalize.py @@ -17,9 +17,9 @@ import tvm import tvm.script import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.target import Target -from tvm.tir.transform.transform import BindTarget +from tvm.tirx.transform.transform import BindTarget # pylint: disable=no-member,invalid-name,unused-variable @@ -211,15 +211,15 @@ def test_fp8_compute_legalize(dtype, promote_dtype): expected = BindTarget(target)(get_after_compute_legalize(dtype, promote_dtype)) # run the transform twice to ensure we can afford to deal # with this repeative optimizations - after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(before) - after = tvm.tir.transform.FP8ComputeLegalize(promote_dtype)(after) + after = tvm.tirx.transform.FP8ComputeLegalize(promote_dtype)(before) + after = tvm.tirx.transform.FP8ComputeLegalize(promote_dtype)(after) tvm.ir.assert_structural_equal(after, expected) def test_fp8_storage_legalize(dtype, promote_dtype): target = Target("nvidia/nvidia-a100") before = BindTarget(target)(get_after_compute_legalize(dtype, promote_dtype)) - after = tvm.tir.transform.FP8StorageLegalize()(before) + after = tvm.tirx.transform.FP8StorageLegalize()(before) expected = BindTarget(target)(get_after_storage_legalize(dtype, promote_dtype)) tvm.ir.assert_structural_equal(after, expected) diff --git a/tests/python/tir-transform/test_tir_transform_helpers.py b/tests/python/tirx-transform/test_tir_transform_helpers.py similarity index 88% rename from tests/python/tir-transform/test_tir_transform_helpers.py rename to tests/python/tirx-transform/test_tir_transform_helpers.py index 949bfc2179a3..cef440ea80d9 100644 --- a/tests/python/tir-transform/test_tir_transform_helpers.py +++ b/tests/python/tirx-transform/test_tir_transform_helpers.py @@ -20,7 +20,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_annotate_entry_func_single_primfunc(): @@ -36,11 +36,11 @@ def func1(A: T.Buffer((16,), "float32")): mod = MockModule assert mod assert not mod["func1"].attrs - after = tvm.tir.transform.AnnotateEntryFunc()(mod) + after = tvm.tirx.transform.AnnotateEntryFunc()(mod) assert ( after["func1"].attrs - and "tir.is_entry_func" in after["func1"].attrs - and after["func1"].attrs["tir.is_entry_func"] + and "tirx.is_entry_func" in after["func1"].attrs + and after["func1"].attrs["tirx.is_entry_func"] ) @@ -69,7 +69,7 @@ def test_annotate_entry_func_multiple_primfunc(): assert not mod["func1"].attrs assert not mod["func2"].attrs # This should fail - after = tvm.tir.transform.AnnotateEntryFunc()(mod) + after = tvm.tirx.transform.AnnotateEntryFunc()(mod) def test_bind_target(): @@ -79,7 +79,7 @@ def test_bind_target(): target = tvm.target.Target("cuda") assert not mod["func1"].attrs assert not mod["func2"].attrs - after = tvm.tir.transform.BindTarget(target)(mod) + after = tvm.tirx.transform.BindTarget(target)(mod) assert "target" in after["func1"].attrs assert after["func1"].attrs["target"] == target @@ -103,7 +103,7 @@ def main(): T.func_attr({"target": T.target("cuda")}) T.evaluate(0) - After = tvm.tir.transform.BindTarget(tvm.target.Target("cuda"))(Before) + After = tvm.tirx.transform.BindTarget(tvm.target.Target("cuda"))(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -124,7 +124,7 @@ def main(): T.func_attr({"global_symbol": "main", "target": T.target("cuda", host="llvm")}) T.evaluate(0) - After = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm"))(Before) + After = tvm.tirx.transform.BindTarget(tvm.target.Target("cuda", host="llvm"))(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -151,7 +151,7 @@ def main(): T.func_attr({"target": T.target("cuda")}) T.evaluate(0) - After = tvm.tir.transform.BindTarget(tvm.target.Target("cuda", host="llvm"))(Before) + After = tvm.tirx.transform.BindTarget(tvm.target.Target("cuda", host="llvm"))(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -167,7 +167,7 @@ def main(): Expected = Before - After = tvm.tir.transform.BindTarget(tvm.target.Target("cuda"))(Before) + After = tvm.tirx.transform.BindTarget(tvm.target.Target("cuda"))(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -193,7 +193,7 @@ def main(): ) T.evaluate(0) - After = tvm.tir.transform.BindTarget( + After = tvm.tirx.transform.BindTarget( tvm.target.Target("cuda", host={"kind": "llvm", "opt-level": 0}) )(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -224,7 +224,7 @@ def func2(): T.func_attr({"target": T.target("cuda")}) T.evaluate(0) - After = tvm.tir.transform.BindTarget(tvm.target.Target("cuda"))(Before) + After = tvm.tirx.transform.BindTarget(tvm.target.Target("cuda"))(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -278,7 +278,7 @@ def main( for tx in T.thread_binding(length, "threadIdx.x"): C[bx, tx] = Expected.add(A[bx, tx], B[bx, tx]) # Call from device - After = tvm.tir.transform.BindTarget( + After = tvm.tirx.transform.BindTarget( tvm.target.Target("cuda", host={"kind": "llvm", "opt-level": 0}) )(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -292,29 +292,29 @@ def test_filter_primfunc(): mod["func2"] = mod["func2"].with_attr("temp", "test2") # Test condition that does not filter out anything - def checker_filter_out_none(func: tvm.tir.PrimFunc): + def checker_filter_out_none(func: tvm.tirx.PrimFunc): return "temp" in func.attrs - after = tvm.tir.transform.Filter(checker_filter_out_none)(mod) + after = tvm.tirx.transform.Filter(checker_filter_out_none)(mod) assert len(after.functions) == 2 # Filtered functions should satisfy the given condition. assert checker_filter_out_none(after["func1"]) assert checker_filter_out_none(after["func2"]) # Test condition that selectively filters out primfuncs - def checker_filter_out_one(func: tvm.tir.PrimFunc): + def checker_filter_out_one(func: tvm.tirx.PrimFunc): return ("temp" in func.attrs) and func.attrs["temp"] == "test1" - after = tvm.tir.transform.Filter(checker_filter_out_one)(mod) + after = tvm.tirx.transform.Filter(checker_filter_out_one)(mod) assert len(after.functions) == 1 # Filtered functions should satisfy the given condition. assert checker_filter_out_one(after["func1"]) # Test condition that filters out everything - def checker_filter_out_both(func: tvm.tir.PrimFunc): + def checker_filter_out_both(func: tvm.tirx.PrimFunc): return "invalid_attr" in func.attrs - after = tvm.tir.transform.Filter(checker_filter_out_both)(mod) + after = tvm.tirx.transform.Filter(checker_filter_out_both)(mod) assert len(after.functions) == 0 @@ -337,7 +337,7 @@ def func(): class Expected: pass - After = tvm.tir.transform.Filter(lambda prim_func: False)(Before) + After = tvm.tirx.transform.Filter(lambda prim_func: False)(Before) tvm.ir.assert_structural_equal(After, Expected) diff --git a/tests/python/tir-transform/test_tir_transform_lower_intrin.py b/tests/python/tirx-transform/test_tir_transform_lower_intrin.py similarity index 66% rename from tests/python/tir-transform/test_tir_transform_lower_intrin.py rename to tests/python/tirx-transform/test_tir_transform_lower_intrin.py index 90f6a3a47d96..75e801dfd3e5 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_intrin.py +++ b/tests/python/tirx-transform/test_tir_transform_lower_intrin.py @@ -23,14 +23,14 @@ def lower_intrin(params, stmt): """wrapper to call transformation in stmt""" - lower_expr = isinstance(stmt, tvm.tir.PrimExpr) - stmt = tvm.tir.Evaluate(stmt) if lower_expr else stmt + lower_expr = isinstance(stmt, tvm.tirx.PrimExpr) + stmt = tvm.tirx.Evaluate(stmt) if lower_expr else stmt mod = tvm.IRModule.from_expr( - tvm.tir.PrimFunc(params, stmt).with_attr("target", tvm.target.Target("llvm")) - ) - mod = tvm.transform.Sequential([tvm.tir.transform.Simplify(), tvm.tir.transform.LowerIntrin()])( - mod + tvm.tirx.PrimFunc(params, stmt).with_attr("target", tvm.target.Target("llvm")) ) + mod = tvm.transform.Sequential( + [tvm.tirx.transform.Simplify(), tvm.tirx.transform.LowerIntrin()] + )(mod) func = mod["main"] stmt = func.body return stmt.value if lower_expr else stmt.body @@ -48,30 +48,30 @@ def check_value(expr, variables, data, fref): # Build input and output buffers input_bufs = [ - tvm.tir.decl_buffer((n,), dtype=variables[i].dtype, name=f"v{i}") for i in range(num_vars) + tvm.tirx.decl_buffer((n,), dtype=variables[i].dtype, name=f"v{i}") for i in range(num_vars) ] - out_buf = tvm.tir.decl_buffer((n,), dtype=expr.dtype, name="C") + out_buf = tvm.tirx.decl_buffer((n,), dtype=expr.dtype, name="C") # Build loop body: for each i, bind variables[j] = input_bufs[j][i], then store expr to out - loop_var = tvm.tir.Var("i", "int32") + loop_var = tvm.tirx.Var("i", "int32") def make_store(i_var): # Build the expression with each variable bound to the corresponding buffer load result = expr for j in range(num_vars - 1, -1, -1): - result = tvm.tir.Let(variables[j], tvm.tir.BufferLoad(input_bufs[j], [i_var]), result) - return tvm.tir.BufferStore(out_buf, result, [i_var]) + result = tvm.tirx.Let(variables[j], tvm.tirx.BufferLoad(input_bufs[j], [i_var]), result) + return tvm.tirx.BufferStore(out_buf, result, [i_var]) - loop = tvm.tir.For( + loop = tvm.tirx.For( loop_var, - tvm.tir.const(0, "int32"), - tvm.tir.const(n, "int32"), - tvm.tir.ForKind.SERIAL, + tvm.tirx.const(0, "int32"), + tvm.tirx.const(n, "int32"), + tvm.tirx.ForKind.SERIAL, make_store(loop_var), ) - prim_func = tvm.tir.PrimFunc(input_bufs + [out_buf], loop) - prim_func = prim_func.with_attr({"tir.noalias": True, "global_symbol": "main"}) + prim_func = tvm.tirx.PrimFunc(input_bufs + [out_buf], loop) + prim_func = prim_func.with_attr({"tirx.noalias": True, "global_symbol": "main"}) f = tvm.compile(prim_func, "llvm") arrays = [ @@ -98,32 +98,32 @@ def get_ref_data(): def test_lower_floordiv(): data = get_ref_data() for dtype in ["int32", "int64", "int16"]: - x = tvm.tir.Var("x", dtype) - y = tvm.tir.Var("y", dtype) - zero = tvm.tir.const(0, dtype) + x = tvm.tirx.Var("x", dtype) + y = tvm.tirx.Var("y", dtype) + zero = tvm.tirx.const(0, dtype) # no constraints - res = lower_intrin([x, y], tvm.tir.floordiv(x, y)) + res = lower_intrin([x, y], tvm.tirx.floordiv(x, y)) check_value(res, [x, y], data, lambda a, b: a // b) # rhs >= 0 - res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.tir.floordiv(x, y), zero)) + res = lower_intrin([x, y], tvm.tirx.Select(y >= 0, tvm.tirx.floordiv(x, y), zero)) check_value(res, [x, y], data, lambda a, b: a // b if b > 0 else 0) # involves max res = lower_intrin( - [x, y], tvm.tir.Select(y >= 0, tvm.tir.max(tvm.tir.floordiv(x, y), zero), zero) + [x, y], tvm.tirx.Select(y >= 0, tvm.tirx.max(tvm.tirx.floordiv(x, y), zero), zero) ) check_value(res, [x, y], data, lambda a, b: max(a // b, 0) if b > 0 else 0) # lhs >= 0 res = lower_intrin( - [x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.tir.floordiv(x, y), zero) + [x, y], tvm.tirx.Select(tvm.tirx.all(y >= 0, x >= 0), tvm.tirx.floordiv(x, y), zero) ) check_value(res, [x, y], data, lambda a, b: a // b if b > 0 and a >= 0 else 0) # const power of two - res = lower_intrin([x, y], tvm.tir.floordiv(x, tvm.tir.const(8, dtype=dtype))) + res = lower_intrin([x, y], tvm.tirx.floordiv(x, tvm.tirx.const(8, dtype=dtype))) check_value(res, [x, y], [(a, b) for a, b in data if b == 8], lambda a, b: a // b) # floordiv(x + m, k), m and k are positive constant. 2 <= m <= k-1. res = lower_intrin( [x, y], - tvm.tir.floordiv(x + tvm.tir.const(4, dtype=dtype), tvm.tir.const(5, dtype=dtype)), + tvm.tirx.floordiv(x + tvm.tirx.const(4, dtype=dtype), tvm.tirx.const(5, dtype=dtype)), ) check_value(res, [x, y], [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) // b) @@ -132,27 +132,27 @@ def test_lower_floordiv(): def test_lower_floormod(): data = get_ref_data() for dtype in ["int32", "int64", "int16"]: - x = tvm.tir.Var("x", dtype) - y = tvm.tir.Var("y", dtype) - zero = tvm.tir.const(0, dtype) + x = tvm.tirx.Var("x", dtype) + y = tvm.tirx.Var("y", dtype) + zero = tvm.tirx.const(0, dtype) # no constraints - res = lower_intrin([x, y], tvm.tir.floormod(x, y)) + res = lower_intrin([x, y], tvm.tirx.floormod(x, y)) check_value(res, [x, y], data, lambda a, b: a % b) # rhs >= 0 - res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.tir.floormod(x, y), zero)) + res = lower_intrin([x, y], tvm.tirx.Select(y >= 0, tvm.tirx.floormod(x, y), zero)) check_value(res, [x, y], data, lambda a, b: a % b if b > 0 else 0) # lhs >= 0 res = lower_intrin( - [x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0), tvm.tir.floormod(x, y), zero) + [x, y], tvm.tirx.Select(tvm.tirx.all(y >= 0, x >= 0), tvm.tirx.floormod(x, y), zero) ) check_value(res, [x, y], data, lambda a, b: a % b if b > 0 and a >= 0 else 0) # const power of two - res = lower_intrin([x, y], tvm.tir.floormod(x, tvm.tir.const(8, dtype=dtype))) + res = lower_intrin([x, y], tvm.tirx.floormod(x, tvm.tirx.const(8, dtype=dtype))) check_value(res, [x, y], [(a, b) for a, b in data if b == 8], lambda a, b: a % b) # floormod(x + m, k), m and k are positive constant. 2 <= m <= k-1. res = lower_intrin( [x, y], - tvm.tir.floormod(x + tvm.tir.const(4, dtype=dtype), tvm.tir.const(5, dtype=dtype)), + tvm.tirx.floormod(x + tvm.tirx.const(4, dtype=dtype), tvm.tirx.const(5, dtype=dtype)), ) check_value(res, [x, y], [(a, b) for a, b in data if b == 5], lambda a, b: (a + 4) % b) @@ -166,26 +166,26 @@ def test_lower_floordiv_overflow_checks(): """ # Check 3: (b-1) - a_min must not overflow (numerator and C++ int64). # x (int64) full range -> min_value = -2^63. With b = 3: numerator = 2 - (-2^63) > LLONG_MAX. - x = tvm.tir.Var("x", "int64") - res = lower_intrin([x], tvm.tir.floordiv(x, tvm.tir.const(3, "int64"))) + x = tvm.tirx.Var("x", "int64") + res = lower_intrin([x], tvm.tirx.floordiv(x, tvm.tirx.const(3, "int64"))) data_check3 = [(-(2**63),), (0,), (100,)] check_value(res, [x], data_check3, lambda a: a // 3) # Check 4: c_value * b_value must not overflow dtype. # x (int16) full range -> min_value = -32768, c = ceil(32770/3) = 10923; 10923*3 > 32767. - x = tvm.tir.Var("x", "int16") - res = lower_intrin([x], tvm.tir.floordiv(x, tvm.tir.const(3, "int16"))) + x = tvm.tirx.Var("x", "int16") + res = lower_intrin([x], tvm.tirx.floordiv(x, tvm.tirx.const(3, "int16"))) data_check4 = [(-32768,), (0,), (100,)] check_value(res, [x], data_check4, lambda a: a // 3) # Check 5: a_max + b*c must not overflow (offset numerator). - # tir.min(tir.max(x, -10), 32758) can give bounds [-10, 32758]; b=3, c=4; a_max + 12 > 32767. + # tirx.min(tirx.max(x, -10), 32758) can give bounds [-10, 32758]; b=3, c=4; a_max + 12 > 32767. # In practice this path may not be triggered. This test still validates correct lowering. - x = tvm.tir.Var("x", "int16") - clamped = tvm.tir.min( - tvm.tir.max(x, tvm.tir.const(-10, "int16")), tvm.tir.const(32758, "int16") + x = tvm.tirx.Var("x", "int16") + clamped = tvm.tirx.min( + tvm.tirx.max(x, tvm.tirx.const(-10, "int16")), tvm.tirx.const(32758, "int16") ) - res = lower_intrin([x], tvm.tir.floordiv(clamped, tvm.tir.const(3, "int16"))) + res = lower_intrin([x], tvm.tirx.floordiv(clamped, tvm.tirx.const(3, "int16"))) data_check5 = [(-10,), (0,), (32758,), (32757,)] check_value(res, [x], data_check5, lambda a: (min(max(a, -10), 32758)) // 3) diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py b/tests/python/tirx-transform/test_tir_transform_lower_tvm_builtin.py similarity index 89% rename from tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py rename to tests/python/tirx-transform/test_tir_transform_lower_tvm_builtin.py index 14d54a623423..d3eded149358 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/tirx-transform/test_tir_transform_lower_tvm_builtin.py @@ -21,7 +21,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T @tvm.register_global_func("tvm.test_matmul") @@ -108,7 +108,7 @@ def main( T.tvm_struct_set(stack_ffi_any, 3, 15, T.int64(0)) T.call_packed_lowered("tvm.test_matmul", stack_ffi_any, 0, 3) - After = tvm.tir.transform.LowerTVMBuiltin()(Before) + After = tvm.tirx.transform.LowerTVMBuiltin()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -118,29 +118,29 @@ def test_call_packed_return_non_i32(): expected_value = np.array([1.2, 1.4], dtype="float32") def packed_echo(value): - return tvm.tir.call_intrin( - value.dtype, tvm.ir.Op.get("tir.tvm_call_packed"), "testing.echo", value + return tvm.tirx.call_intrin( + value.dtype, tvm.ir.Op.get("tirx.tvm_call_packed"), "testing.echo", value ) def build_tir(): - Ab = tvm.tir.decl_buffer((2,), "float32") + Ab = tvm.tirx.decl_buffer((2,), "float32") # Build statements using direct TIR construction (no ir_builder) # 1. Store packed_echo(const) result into Ab[0] - store0 = tvm.tir.BufferStore( - Ab, packed_echo(tvm.tir.const(expected_value[0], "float32")), [0] + store0 = tvm.tirx.BufferStore( + Ab, packed_echo(tvm.tirx.const(expected_value[0], "float32")), [0] ) # 2. Let binding: Aptr_dup = packed_echo(Ab.data), then store const into Ab[1] - Aptr_dup = tvm.tir.Var("Aptr_dup", "handle") - store1 = tvm.tir.BufferStore(Ab, tvm.tir.const(expected_value[1], "float32"), [1]) - bind_stmt = tvm.tir.Bind(Aptr_dup, packed_echo(Ab.data)) + Aptr_dup = tvm.tirx.Var("Aptr_dup", "handle") + store1 = tvm.tirx.BufferStore(Ab, tvm.tirx.const(expected_value[1], "float32"), [1]) + bind_stmt = tvm.tirx.Bind(Aptr_dup, packed_echo(Ab.data)) # Combine into sequence - stmt = tvm.tir.SeqStmt([store0, bind_stmt, store1]) + stmt = tvm.tirx.SeqStmt([store0, bind_stmt, store1]) return tvm.IRModule.from_expr( - tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "packed_test") + tvm.tirx.PrimFunc([Ab], stmt).with_attr("global_symbol", "packed_test") ) mod = build_tir() @@ -153,7 +153,7 @@ def build_tir(): def test_lower_overflow_int32(): @T.prim_func(check_well_formed=False) def variance4(rxplaceholder: T.Buffer((T.int64(1), T.int64(32), T.int64(25690112)), "float32")): - T.func_attr({"global_symbol": "variance4", "tir.noalias": True}) + T.func_attr({"global_symbol": "variance4", "tirx.noalias": True}) rxplaceholder_red = T.alloc_buffer((32,), "float32") T_subtract = T.alloc_buffer((822083584,), "float32") rxplaceholder_red_1 = T.Buffer((T.int64(32),), data=rxplaceholder_red.data) @@ -189,7 +189,7 @@ def main(): buf = T.decl_buffer(16, "float32", data=ptr.data) buf[0] = 0.0 - After = tvm.tir.transform.LowerTVMBuiltin()(Before) + After = tvm.tirx.transform.LowerTVMBuiltin()(Before) # Verify the lowered module can be printed (no crash) script_output = After.script() # Should contain TVMBackendAllocWorkspace and TVMBackendFreeWorkspace @@ -222,7 +222,7 @@ def main(): buf = T.decl_buffer(16, "float32", data=ptr.data) buf[0] = 0.0 - After = tvm.tir.transform.LowerTVMBuiltin()(Before) + After = tvm.tirx.transform.LowerTVMBuiltin()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -240,7 +240,7 @@ def main(): buf[0] = 0.0 with pytest.raises(tvm.TVMError): - tvm.tir.transform.LowerTVMBuiltin()(Before) + tvm.tirx.transform.LowerTVMBuiltin()(Before) def test_lower_allocate_requires_device_type(): @@ -248,7 +248,7 @@ def test_lower_allocate_requires_device_type(): The device type can be inferred either from the `"device_type"` statement attribute, or from the `"target"` function attribute. - Here, we provide neither. The `"tir.is_host_func"` attribute is + Here, we provide neither. The `"tirx.is_host_func"` attribute is provided as otherwise the function would be skipped altogether by LowerTVMBuiltin. """ @@ -257,14 +257,14 @@ def test_lower_allocate_requires_device_type(): class Before: @T.prim_func def main(): - T.func_attr({"tir.is_host_func": True}) + T.func_attr({"tirx.is_host_func": True}) T.attr("dummy", "device_id", 0) ptr = T.alloc_buffer((1024 * 1024,), "float32") buf = T.decl_buffer(1024 * 1024, "float32", data=ptr.data) buf[0] = 0.0 with pytest.raises(tvm.TVMError): - tvm.tir.transform.LowerTVMBuiltin()(Before) + tvm.tirx.transform.LowerTVMBuiltin()(Before) def test_lower_cpu_alloc_with_function_attr(): @@ -288,7 +288,7 @@ def main(): # Expected is same as before for this transform Expected = Before - After = tvm.tir.transform.LowerTVMBuiltin()(Before) + After = tvm.tirx.transform.LowerTVMBuiltin()(Before) tvm.ir.assert_structural_equal(After, Expected) diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py b/tests/python/tirx-transform/test_tir_transform_make_packed_api.py similarity index 94% rename from tests/python/tir-transform/test_tir_transform_make_packed_api.py rename to tests/python/tirx-transform/test_tir_transform_make_packed_api.py index 0d39da0d9ed6..90d7e25bbf40 100644 --- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py +++ b/tests/python/tirx-transform/test_tir_transform_make_packed_api.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Tests for tir.transform.MakePackedAPI TIR transform. +"""Tests for tirx.transform.MakePackedAPI TIR transform. Tests verify the transform output using TVMScript before/after patterns. Runtime error tests are in tests/python/codegen/test_codegen_error_handling.py. @@ -24,20 +24,20 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def _find_compute_scope(func): result = None def _visitor(stmt): - if isinstance(stmt, tir.AttrStmt) and stmt.attr_key == "compute_scope": + if isinstance(stmt, tirx.AttrStmt) and stmt.attr_key == "compute_scope": nonlocal result result = stmt - tir.stmt_functor.post_order_visit(func.body, _visitor) + tirx.stmt_functor.post_order_visit(func.body, _visitor) return result @@ -54,7 +54,7 @@ def before(): if use_global_symbol: before = before.with_attr("global_symbol", "main") - after = tvm.tir.transform.MakePackedAPI()(tvm.IRModule.from_expr(before))["main"] + after = tvm.tirx.transform.MakePackedAPI()(tvm.IRModule.from_expr(before))["main"] if use_global_symbol: assert len(after.params) == 4 else: @@ -78,7 +78,7 @@ def main(A: T.Buffer(1, "float32")): T.func_attr({"global_symbol": "main", "target": T.target("cuda", host=host)}) T.evaluate(0) - after = tvm.tir.transform.MakePackedAPI()(before) + after = tvm.tirx.transform.MakePackedAPI()(before) target_attr = after["main"].attrs["target"] assert str(host) == str(target_attr) @@ -105,7 +105,7 @@ def subroutine(A_data: T.handle("float32")): T.func_attr({"target": T.target("llvm")}) T.evaluate(A_data) - after = tvm.tir.transform.MakePackedAPI()(before) + after = tvm.tirx.transform.MakePackedAPI()(before) tvm.ir.assert_structural_equal(before["subroutine"], after["subroutine"]) compute_scope = _find_compute_scope(after["main"]) @@ -137,7 +137,7 @@ def subroutine(A_data: T.handle("float32")): T.func_attr({"global_symbol": "subroutine", "target": T.target("llvm", host="llvm")}) T.evaluate(A_data) - after = tvm.tir.transform.MakePackedAPI()(before) + after = tvm.tirx.transform.MakePackedAPI()(before) main_compute_scope = _find_compute_scope(after["main"]) assert main_compute_scope is not None @@ -147,9 +147,9 @@ def subroutine(A_data: T.handle("float32")): subroutine_call_op = main_compute_scope.body.value.op assert ( isinstance(subroutine_call_op, tvm.ir.Op) - and subroutine_call_op.name == "tir.tvm_call_cpacked" + and subroutine_call_op.name == "tirx.tvm_call_cpacked" ), ( - f"The main function's CallNode should be lowered to the builtin 'tir.tvm_call_cpacked', " + f"The main function's CallNode should be lowered to the builtin 'tirx.tvm_call_cpacked', " f"but instead has an operation of type {subroutine_call_op}" ) @@ -191,7 +191,7 @@ def func_without_arg( return 0 return 0 - After = tvm.tir.transform.MakePackedAPI()(Before) + After = tvm.tirx.transform.MakePackedAPI()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -258,7 +258,7 @@ def main( return 0 return 0 - After = tvm.tir.transform.MakePackedAPI()(Before) + After = tvm.tirx.transform.MakePackedAPI()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -325,7 +325,7 @@ def main( return 0 return 0 - After = tvm.tir.transform.MakePackedAPI()(Before) + After = tvm.tirx.transform.MakePackedAPI()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -396,7 +396,7 @@ def main( return 0 return 0 - After = tvm.tir.transform.MakePackedAPI()(Before) + After = tvm.tirx.transform.MakePackedAPI()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -421,7 +421,7 @@ def main(a: T.handle, b: T.handle): B[i] = A[i] + A[i + 1] # Should not raise "variable batch_size has been used before definition" - After = tvm.tir.transform.MakePackedAPI()(Before) + After = tvm.tirx.transform.MakePackedAPI()(Before) assert len(After["main"].params) == 4 diff --git a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py b/tests/python/tirx-transform/test_tir_transform_narrow_datatype.py similarity index 94% rename from tests/python/tir-transform/test_tir_transform_narrow_datatype.py rename to tests/python/tirx-transform/test_tir_transform_narrow_datatype.py index a22efb83f290..dbd31e25ed17 100644 --- a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py +++ b/tests/python/tirx-transform/test_tir_transform_narrow_datatype.py @@ -16,13 +16,13 @@ # under the License. import tvm import tvm.testing -from tvm.script import tir as T -from tvm.tir import const +from tvm.script import tirx as T +from tvm.tirx import const def lower_stmt(params, stmt, target_bits): - func = tvm.tir.PrimFunc(params, stmt) - func = tvm.tir.transform.NarrowDataType(target_bits)(tvm.IRModule.from_expr(func))["main"] + func = tvm.tirx.PrimFunc(params, stmt) + func = tvm.tirx.transform.NarrowDataType(target_bits)(tvm.IRModule.from_expr(func))["main"] stmt = func.body return stmt @@ -31,14 +31,14 @@ def lower_func_body(func, target_bits): """Lower a TVMScript function and return the first For loop in the body.""" mod = tvm.IRModule.from_expr(func) gvar = next(iter(mod.functions.keys())) - func = tvm.tir.transform.NarrowDataType(target_bits)(mod)[gvar] + func = tvm.tirx.transform.NarrowDataType(target_bits)(mod)[gvar] body = func.body # With flat buffer semantics, navigate to the first For node - if isinstance(body, tvm.tir.SeqStmt): + if isinstance(body, tvm.tirx.SeqStmt): for stmt in body: - if isinstance(stmt, tvm.tir.For): + if isinstance(stmt, tvm.tirx.For): return stmt - while hasattr(body, "body") and not isinstance(body, tvm.tir.For): + while hasattr(body, "body") and not isinstance(body, tvm.tirx.For): body = body.body return body @@ -110,7 +110,7 @@ def func(A: T.Buffer((m * n,), "float32"), B: T.Buffer((m * n,), "float32")): mod = tvm.IRModule.from_expr(func) gvar = next(iter(mod.functions.keys())) - func_narrowed = tvm.tir.transform.NarrowDataType(target_bits)(mod)[gvar] + func_narrowed = tvm.tirx.transform.NarrowDataType(target_bits)(mod)[gvar] stmt = func_narrowed.body assert stmt.node.var.dtype == target_dtype assert stmt.body.node.var.dtype == target_dtype @@ -148,7 +148,7 @@ def func( mod = tvm.IRModule.from_expr(func) gvar = next(iter(mod.functions.keys())) - func_narrowed = tvm.tir.transform.NarrowDataType(target_bits)(mod)[gvar] + func_narrowed = tvm.tirx.transform.NarrowDataType(target_bits)(mod)[gvar] stmt = func_narrowed.body assert stmt.seq[0].loop_var.dtype == target_dtype @@ -209,7 +209,7 @@ def expected_after(A: T.Buffer(128, "float32"), B: T.Buffer(130, "float32")): i * 65 + j >= 0 and i * 65 + j < 128, A[i * 65 + j], T.float32(0), dtype="float32" ) - after = tvm.tir.transform.NarrowDataType(32)( + after = tvm.tirx.transform.NarrowDataType(32)( tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) )["main"] tvm.ir.assert_structural_equal(after, expected_after.with_attr("global_symbol", "main")) @@ -232,7 +232,7 @@ def expected_after(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32" vi = T.axis.spatial(T.int32(128), i * T.int32(8) + j) B[vi] = A[vi] + T.float32(1) - after = tvm.tir.transform.NarrowDataType(32)( + after = tvm.tirx.transform.NarrowDataType(32)( tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) )["main"] tvm.ir.assert_structural_equal(after, expected_after.with_attr("global_symbol", "main")) @@ -294,10 +294,10 @@ def expected_after(PSUM: T.Buffer((313600,), "int32"), PAVG: T.Buffer((313600,), ), ) - after = tvm.tir.transform.NarrowDataType(32)( + after = tvm.tirx.transform.NarrowDataType(32)( tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) ) - after = tvm.tir.transform.Simplify()(after) + after = tvm.tirx.transform.Simplify()(after) tvm.ir.assert_structural_equal(after["main"], expected_after.with_attr("global_symbol", "main")) @@ -312,7 +312,7 @@ def expect(A: T.Buffer((16,), "int64")): for i in range(15): A[i + 1] = A[i] + T.int64(1) - after = tvm.tir.transform.NarrowDataType(32)( + after = tvm.tirx.transform.NarrowDataType(32)( tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) )["main"] tvm.ir.assert_structural_equal(after, expect.with_attr("global_symbol", "main")) diff --git a/tests/python/tir-transform/test_tir_transform_pointer_value_type_rewrite.py b/tests/python/tirx-transform/test_tir_transform_pointer_value_type_rewrite.py similarity index 94% rename from tests/python/tir-transform/test_tir_transform_pointer_value_type_rewrite.py rename to tests/python/tirx-transform/test_tir_transform_pointer_value_type_rewrite.py index d7b41a1ed382..1fa78faf48f4 100644 --- a/tests/python/tir-transform/test_tir_transform_pointer_value_type_rewrite.py +++ b/tests/python/tirx-transform/test_tir_transform_pointer_value_type_rewrite.py @@ -19,11 +19,11 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_rewrite_to_shuffle_0(): - transform = tvm.tir.transform.PointerValueTypeRewrite() + transform = tvm.tirx.transform.PointerValueTypeRewrite() @I.ir_module class Before: @@ -55,7 +55,7 @@ def main(A: T.Buffer((4,), "float32x4"), B: T.Buffer((4,), "float32")): def test_rewrite_to_shuffle_1(): - transform = tvm.tir.transform.PointerValueTypeRewrite() + transform = tvm.tirx.transform.PointerValueTypeRewrite() @I.ir_module class Before: @@ -98,7 +98,7 @@ def main(A: T.Buffer((2,), "float32x4"), B: T.Buffer((1,), "float32")): def test_address_of(): - transform = tvm.tir.transform.PointerValueTypeRewrite() + transform = tvm.tirx.transform.PointerValueTypeRewrite() @I.ir_module class Before: @@ -121,7 +121,7 @@ def main(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32x4")): def test_scalar_read_without_write(): - transform = tvm.tir.transform.PointerValueTypeRewrite() + transform = tvm.tirx.transform.PointerValueTypeRewrite() @I.ir_module class Before: diff --git a/tests/python/tir-transform/test_tir_transform_prim_func_pass.py b/tests/python/tirx-transform/test_tir_transform_prim_func_pass.py similarity index 72% rename from tests/python/tir-transform/test_tir_transform_prim_func_pass.py rename to tests/python/tirx-transform/test_tir_transform_prim_func_pass.py index 72b168414263..db4d7cd3e372 100644 --- a/tests/python/tir-transform/test_tir_transform_prim_func_pass.py +++ b/tests/python/tirx-transform/test_tir_transform_prim_func_pass.py @@ -19,7 +19,7 @@ def test_prim_func_pass(): - @tvm.tir.transform.prim_func_pass(opt_level=1) + @tvm.tirx.transform.prim_func_pass(opt_level=1) class TestReplaceFunc: """Simple test function to replace one argument to another.""" @@ -29,14 +29,14 @@ def __init__(self, new_func): def transform_function(self, func, mod, ctx): return self.new_func - x = tvm.tir.Var("x", "int32") - y = tvm.tir.Var("y", "int32") - b = tvm.tir.decl_buffer((x,), "float32") - stmt = tvm.tir.SeqStmt([tvm.tir.Bind(x, 10), tvm.tir.Evaluate(x + 1)]) + x = tvm.tirx.Var("x", "int32") + y = tvm.tirx.Var("y", "int32") + b = tvm.tirx.decl_buffer((x,), "float32") + stmt = tvm.tirx.SeqStmt([tvm.tirx.Bind(x, 10), tvm.tirx.Evaluate(x + 1)]) - func = tvm.tir.PrimFunc([x, y, b], stmt) + func = tvm.tirx.PrimFunc([x, y, b], stmt) - new_func = tvm.tir.PrimFunc([x, y, b], tvm.tir.Evaluate(0)) + new_func = tvm.tirx.PrimFunc([x, y, b], tvm.tirx.Evaluate(0)) mod = tvm.IRModule({"main": func}) mod = TestReplaceFunc(new_func)(mod) @@ -49,15 +49,15 @@ def fapply(f): assert tvm.testing.object_use_count(f) == 1 return f - pidentity = tvm.tir.transform.Apply(fapply) - x = tvm.tir.Var("x", "int32") - func = tvm.tir.PrimFunc([x], tvm.tir.Evaluate(x)).with_attr("target_bits", 32) + pidentity = tvm.tirx.transform.Apply(fapply) + x = tvm.tirx.Var("x", "int32") + func = tvm.tirx.PrimFunc([x], tvm.tirx.Evaluate(x)).with_attr("target_bits", 32) func_hash = func.__hash__() mod = tvm.IRModule({"main": func}) del func # copy on write mod_hash = mod.__hash__() - mod = tvm.transform.Sequential([pidentity, tvm.tir.transform.NarrowDataType(32)])(mod._move()) + mod = tvm.transform.Sequential([pidentity, tvm.tirx.transform.NarrowDataType(32)])(mod._move()) assert mod_hash == mod.__hash__() assert func_hash == mod["main"].__hash__() diff --git a/tests/python/tir-transform/test_tir_transform_remove_assume.py b/tests/python/tirx-transform/test_tir_transform_remove_assume.py similarity index 92% rename from tests/python/tir-transform/test_tir_transform_remove_assume.py rename to tests/python/tirx-transform/test_tir_transform_remove_assume.py index 26846ebf30d5..3e92b7c5e8b1 100644 --- a/tests/python/tir-transform/test_tir_transform_remove_assume.py +++ b/tests/python/tirx-transform/test_tir_transform_remove_assume.py @@ -18,7 +18,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_remove_assume(): @@ -37,7 +37,7 @@ class Expected: def main(A: T.Buffer(1, "int32")): A[0] = 10 - After = tvm.tir.transform.RemoveAssume()(Before) + After = tvm.tirx.transform.RemoveAssume()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -61,7 +61,7 @@ def main(A: T.Buffer(16, "int32")): for i in T.serial(16): A[i] = 10 - After = tvm.tir.transform.RemoveAssume()(Before) + After = tvm.tirx.transform.RemoveAssume()(Before) tvm.ir.assert_structural_equal(After, Expected) diff --git a/tests/python/tir-transform/test_tir_transform_remove_no_op.py b/tests/python/tirx-transform/test_tir_transform_remove_no_op.py similarity index 93% rename from tests/python/tir-transform/test_tir_transform_remove_no_op.py rename to tests/python/tirx-transform/test_tir_transform_remove_no_op.py index ebb284b80a01..17eff408a505 100644 --- a/tests/python/tir-transform/test_tir_transform_remove_no_op.py +++ b/tests/python/tirx-transform/test_tir_transform_remove_no_op.py @@ -19,57 +19,59 @@ import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T def nop(): - return tvm.tir.Evaluate(1) + return tvm.tirx.Evaluate(1) def test_remove_no_op(): - i = tvm.tir.Var("i", "int32") - j = tvm.tir.Var("j", "int32") - k = tvm.tir.Var("k", "int32") - m = tvm.tir.Var("m", "int32") - n = tvm.tir.Var("n", "int32") + i = tvm.tirx.Var("i", "int32") + j = tvm.tirx.Var("j", "int32") + k = tvm.tirx.Var("k", "int32") + m = tvm.tirx.Var("m", "int32") + n = tvm.tirx.Var("n", "int32") dtype = "int64" - Ab = tvm.tir.decl_buffer((n,), dtype) - stmt = tvm.tir.For( + Ab = tvm.tirx.decl_buffer((n,), dtype) + stmt = tvm.tirx.For( i, 0, 4, - tvm.tir.ForKind.SERIAL, - tvm.tir.For( + tvm.tirx.ForKind.SERIAL, + tvm.tirx.For( j, 0, n, - tvm.tir.ForKind.SERIAL, - tvm.tir.For( + tvm.tirx.ForKind.SERIAL, + tvm.tirx.For( k, 0, m, - tvm.tir.ForKind.SERIAL, - tvm.tir.IfThenElse((i * m + j + k < n), tvm.tir.Evaluate(m), tvm.tir.Evaluate(n)), + tvm.tirx.ForKind.SERIAL, + tvm.tirx.IfThenElse( + (i * m + j + k < n), tvm.tirx.Evaluate(m), tvm.tirx.Evaluate(n) + ), ), ), ) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt)) - ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body + mod = tvm.IRModule.from_expr(tvm.tirx.PrimFunc([Ab], stmt)) + ret = tvm.tirx.transform.RemoveNoOp()(mod)["main"].body - assert isinstance(ret, tvm.tir.Evaluate) - store = tvm.tir.BufferStore(Ab, tvm.tir.BufferLoad(Ab, [i]) + 1, [i + 1]) - stmt2 = tvm.tir.SeqStmt([nop(), tvm.tir.SeqStmt([store, nop()])]) + assert isinstance(ret, tvm.tirx.Evaluate) + store = tvm.tirx.BufferStore(Ab, tvm.tirx.BufferLoad(Ab, [i]) + 1, [i + 1]) + stmt2 = tvm.tirx.SeqStmt([nop(), tvm.tirx.SeqStmt([store, nop()])]) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt2)) - ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body + mod = tvm.IRModule.from_expr(tvm.tirx.PrimFunc([Ab], stmt2)) + ret = tvm.tirx.transform.RemoveNoOp()(mod)["main"].body assert ret == store # remove zero extent loop - stmt3 = tvm.tir.For(i, 0, 0, tvm.tir.ForKind.SERIAL, store) - mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt3)) - ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body - assert isinstance(ret, tvm.tir.Evaluate) + stmt3 = tvm.tirx.For(i, 0, 0, tvm.tirx.ForKind.SERIAL, store) + mod = tvm.IRModule.from_expr(tvm.tirx.PrimFunc([Ab], stmt3)) + ret = tvm.tirx.transform.RemoveNoOp()(mod)["main"].body + assert isinstance(ret, tvm.tirx.Evaluate) def test_remove_no_op_with_invalid_extent(): @@ -80,20 +82,20 @@ def main(A: T.Buffer((16), "int32"), B: T.Buffer((16), "int32")) -> None: B[i] = A[i] + j mod = tvm.ir.module.IRModule.from_expr(main) - ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body - assert isinstance(ret, tvm.tir.Evaluate) + ret = tvm.tirx.transform.RemoveNoOp()(mod)["main"].body + assert isinstance(ret, tvm.tirx.Evaluate) def _apply_remove_no_op(mod, use_dataflow_analysis=False, max_simplification_steps=0): """Helper function to apply RemoveNoOp transform with config.""" config = { - "tir.RemoveNoOp": { + "tirx.RemoveNoOp": { "use_dataflow_analysis": use_dataflow_analysis, "max_simplification_steps": max_simplification_steps, } } with tvm.transform.PassContext(config=config): - mod = tvm.tir.transform.RemoveNoOp()(mod) + mod = tvm.tirx.transform.RemoveNoOp()(mod) return mod diff --git a/tests/python/tir-transform/test_tir_transform_simplify.py b/tests/python/tirx-transform/test_tir_transform_simplify.py similarity index 93% rename from tests/python/tir-transform/test_tir_transform_simplify.py rename to tests/python/tirx-transform/test_tir_transform_simplify.py index 650833dbdf8c..3b28a42bc27c 100644 --- a/tests/python/tir-transform/test_tir_transform_simplify.py +++ b/tests/python/tirx-transform/test_tir_transform_simplify.py @@ -17,7 +17,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_stmt_simplify(): @@ -31,18 +31,18 @@ def func(A: T.handle("float32"), C: T.handle("float32"), n: T.int32): A_ptr[i] = C_ptr[i] mod = tvm.IRModule.from_expr(func) - body = tvm.tir.transform.Simplify()(mod)["main"].body + body = tvm.tirx.transform.Simplify()(mod)["main"].body # Navigate through DeclBuffer nodes to reach the inner body - while isinstance(body, tvm.tir.DeclBuffer): + while isinstance(body, tvm.tirx.DeclBuffer): body = body.body # After simplification, Bind is kept (not inlined) but the if is eliminated # since i < 12 is always true for i in 0..10. # Body is SeqStmt(Bind(n_val, 10), For(i, ...)) - stmts = body if isinstance(body, tvm.tir.SeqStmt) else [body] + stmts = body if isinstance(body, tvm.tirx.SeqStmt) else [body] # Find the For loop in the sequence - for_stmt = [s for s in stmts if isinstance(s, tvm.tir.For)] + for_stmt = [s for s in stmts if isinstance(s, tvm.tirx.For)] assert len(for_stmt) == 1, f"Expected one For loop, got {len(for_stmt)}" - assert isinstance(for_stmt[0].body, tvm.tir.BufferStore) + assert isinstance(for_stmt[0].body, tvm.tirx.BufferStore) def test_thread_extent_simplify(): @@ -57,20 +57,20 @@ def func(A: T.handle("float32"), C: T.handle("float32"), n: T.int32): A_ptr[tx] = C_ptr[tx + ty] mod = tvm.IRModule.from_expr(func) - body = tvm.tir.transform.Simplify()(mod)["main"].body + body = tvm.tirx.transform.Simplify()(mod)["main"].body # Navigate through DeclBuffer nodes to reach the inner body - while isinstance(body, tvm.tir.DeclBuffer): + while isinstance(body, tvm.tirx.DeclBuffer): body = body.body # After simplification: Bind is kept but the if is eliminated # since tx + ty < 12 is always true for tx in 0..10 and ty = 0. - stmts = list(body) if isinstance(body, tvm.tir.SeqStmt) else [body] - for_stmts = [s for s in stmts if isinstance(s, tvm.tir.For)] + stmts = list(body) if isinstance(body, tvm.tirx.SeqStmt) else [body] + for_stmts = [s for s in stmts if isinstance(s, tvm.tirx.For)] assert len(for_stmts) >= 1, f"Expected For loop, got stmts: {[type(s).__name__ for s in stmts]}" # The outermost For is the tx loop tx_loop = for_stmts[0] - assert isinstance(tx_loop, tvm.tir.For) # tx loop - assert isinstance(tx_loop.body, tvm.tir.For) # ty loop - assert isinstance(tx_loop.body.body, tvm.tir.BufferStore) # The if was eliminated + assert isinstance(tx_loop, tvm.tirx.For) # tx loop + assert isinstance(tx_loop.body, tvm.tirx.For) # ty loop + assert isinstance(tx_loop.body.body, tvm.tirx.BufferStore) # The if was eliminated def test_if_likely(): @@ -85,14 +85,14 @@ def func(A: T.handle("float32"), C: T.handle("float32"), n: T.int32): A_ptr[tx] = C_ptr[tx * 32 + ty] mod = tvm.IRModule.from_expr(func) - body = tvm.tir.transform.Simplify()(mod)["main"].body + body = tvm.tirx.transform.Simplify()(mod)["main"].body # With flat semantics, skip DeclBuffer/AllocBuffer siblings to find the For - if isinstance(body, tvm.tir.SeqStmt): - for_stmts = [s for s in body.seq if isinstance(s, tvm.tir.For)] + if isinstance(body, tvm.tirx.SeqStmt): + for_stmts = [s for s in body.seq if isinstance(s, tvm.tirx.For)] body = for_stmts[0] if for_stmts else body # Structure: For(tx) -> For(ty) -> IfThenElse - assert isinstance(body.body.body, tvm.tir.IfThenElse) - assert not isinstance(body.body.body.then_case, tvm.tir.IfThenElse) + assert isinstance(body.body.body, tvm.tirx.IfThenElse) + assert not isinstance(body.body.body.then_case, tvm.tirx.IfThenElse) def _apply_simplify( @@ -105,7 +105,7 @@ def _apply_simplify( ): """Helper to apply simplify transform with config options.""" config = { - "tir.Simplify": { + "tirx.Simplify": { "transitively_prove_inequalities": transitively_prove_inequalities, "convert_boolean_to_and_of_ors": convert_boolean_to_and_of_ors, "apply_constraints_to_boolean_branches": apply_constraints_to_boolean_branches, @@ -115,7 +115,7 @@ def _apply_simplify( } mod = tvm.IRModule.from_expr(func) with tvm.transform.PassContext(config=config): - mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tirx.transform.Simplify()(mod) return mod["main"] @@ -628,60 +628,60 @@ def test_remove_transitively_provable_condition(): For example, the `0 < i` and `i <= j` conditions can be used to prove that `0 < j`. """ - i, j, k = [tvm.tir.Var(name, "int32") for name in "ijk"] - zero = tvm.tir.IntImm("int32", 0) + i, j, k = [tvm.tirx.Var(name, "int32") for name in "ijk"] + zero = tvm.tirx.IntImm("int32", 0) test_cases = [ - (tvm.tir.all(zero < i, i <= j), zero < j, True), + (tvm.tirx.all(zero < i, i <= j), zero < j, True), # Transitive comparisons from LT - (tvm.tir.all(i < j, j < k), i < k, True), - (tvm.tir.all(i < j, j == k), i < k, True), - (tvm.tir.all(i < j, j <= k), i < k, True), - (tvm.tir.all(i < j, j > k), i < k, False), - (tvm.tir.all(i < j, j >= k), i < k, False), - (tvm.tir.all(i < j, j != k), i < k, False), + (tvm.tirx.all(i < j, j < k), i < k, True), + (tvm.tirx.all(i < j, j == k), i < k, True), + (tvm.tirx.all(i < j, j <= k), i < k, True), + (tvm.tirx.all(i < j, j > k), i < k, False), + (tvm.tirx.all(i < j, j >= k), i < k, False), + (tvm.tirx.all(i < j, j != k), i < k, False), # Transitive comparisons from LE - (tvm.tir.all(i <= j, j < k), i < k, True), - (tvm.tir.all(i <= j, j == k), i == k, False), - (tvm.tir.all(i <= j, j == k), i <= k, True), - (tvm.tir.all(i <= j, j <= k), i <= k, True), - (tvm.tir.all(i <= j, j <= k), i < k, False), - (tvm.tir.all(i <= j, j > k), i < k, False), - (tvm.tir.all(i <= j, j >= k), i < k, False), - (tvm.tir.all(i <= j, j != k), i < k, False), + (tvm.tirx.all(i <= j, j < k), i < k, True), + (tvm.tirx.all(i <= j, j == k), i == k, False), + (tvm.tirx.all(i <= j, j == k), i <= k, True), + (tvm.tirx.all(i <= j, j <= k), i <= k, True), + (tvm.tirx.all(i <= j, j <= k), i < k, False), + (tvm.tirx.all(i <= j, j > k), i < k, False), + (tvm.tirx.all(i <= j, j >= k), i < k, False), + (tvm.tirx.all(i <= j, j != k), i < k, False), # Transitive comparisons from GT - (tvm.tir.all(i > j, j > k), i > k, True), - (tvm.tir.all(i > j, j == k), i > k, True), - (tvm.tir.all(i > j, j >= k), i > k, True), - (tvm.tir.all(i > j, j < k), i > k, False), - (tvm.tir.all(i > j, j <= k), i > k, False), - (tvm.tir.all(i > j, j != k), i > k, False), + (tvm.tirx.all(i > j, j > k), i > k, True), + (tvm.tirx.all(i > j, j == k), i > k, True), + (tvm.tirx.all(i > j, j >= k), i > k, True), + (tvm.tirx.all(i > j, j < k), i > k, False), + (tvm.tirx.all(i > j, j <= k), i > k, False), + (tvm.tirx.all(i > j, j != k), i > k, False), # Transitive comparisons from GE - (tvm.tir.all(i >= j, j > k), i > k, True), - (tvm.tir.all(i >= j, j == k), i == k, False), - (tvm.tir.all(i >= j, j == k), i >= k, True), - (tvm.tir.all(i >= j, j >= k), i >= k, True), - (tvm.tir.all(i >= j, j >= k), i > k, False), - (tvm.tir.all(i >= j, j < k), i > k, False), - (tvm.tir.all(i >= j, j <= k), i > k, False), - (tvm.tir.all(i >= j, j != k), i > k, False), + (tvm.tirx.all(i >= j, j > k), i > k, True), + (tvm.tirx.all(i >= j, j == k), i == k, False), + (tvm.tirx.all(i >= j, j == k), i >= k, True), + (tvm.tirx.all(i >= j, j >= k), i >= k, True), + (tvm.tirx.all(i >= j, j >= k), i > k, False), + (tvm.tirx.all(i >= j, j < k), i > k, False), + (tvm.tirx.all(i >= j, j <= k), i > k, False), + (tvm.tirx.all(i >= j, j != k), i > k, False), # GT or LT may be used to prove NE - (tvm.tir.all(i == j, j != k), i != k, True), - (tvm.tir.all(i == j, j < k), i != k, True), - (tvm.tir.all(i == j, j > k), i != k, True), - (tvm.tir.all(i == j, j != k), i < k, False), - (tvm.tir.all(i == j, j != k), i > k, False), + (tvm.tirx.all(i == j, j != k), i != k, True), + (tvm.tirx.all(i == j, j < k), i != k, True), + (tvm.tirx.all(i == j, j > k), i != k, True), + (tvm.tirx.all(i == j, j != k), i < k, False), + (tvm.tirx.all(i == j, j != k), i > k, False), # Because these are integers, x k), i < k, False), - (tvm.tir.all(i <= j - 1, j >= k), i < k, False), - (tvm.tir.all(i <= j - 1, j != k), i < k, False), + (tvm.tirx.all(i <= j - 1, j < k), i < k, True), + (tvm.tirx.all(i <= j - 1, j == k), i < k, True), + (tvm.tirx.all(i <= j - 1, j <= k), i < k, True), + (tvm.tirx.all(i <= j - 1, j > k), i < k, False), + (tvm.tirx.all(i <= j - 1, j >= k), i < k, False), + (tvm.tirx.all(i <= j - 1, j != k), i < k, False), # Either or both inequalities may have an additive offset. - (tvm.tir.all(i <= j + 5, j <= k + 7), i <= k + 12, True), - (tvm.tir.all(i <= j + 5, j <= k + 7), i <= k + 11, False), + (tvm.tirx.all(i <= j + 5, j <= k + 7), i <= k + 12, True), + (tvm.tirx.all(i <= j + 5, j <= k + 7), i <= k + 11, False), # For floats, x < y + c1 and y < z + c2 implies that x < z + (c1 + c2). # Because this simplification applies to integers, transitive # application of LT or GT can give a tighter constraint. @@ -693,8 +693,8 @@ def test_remove_transitively_provable_condition(): # i <= k + c1 + c2 - 2 # i < k + (c1 + c2 - 1) # - (tvm.tir.all(i < j + 5, j < k + 7), i < k + 11, True), - (tvm.tir.all(i < j + 5, j < k + 7), i < k + 10, False), + (tvm.tirx.all(i < j + 5, j < k + 7), i < k + 11, True), + (tvm.tirx.all(i < j + 5, j < k + 7), i < k + 10, False), ] analyzer = tvm.arith.Analyzer() @@ -1117,8 +1117,8 @@ def test_most_restrictive_conditional(): allow for later rewrites. For example, if it is known that `a <= b`, then `a >= b` cannot be proven, but can be reduced to `a == b`. """ - i, j, k = [tvm.tir.Var(name, "int32") for name in "ijk"] - tir_int = tvm.tir.IntImm("int32", 0) + i, j, k = [tvm.tirx.Var(name, "int32") for name in "ijk"] + tir_int = tvm.tirx.IntImm("int32", 0) test_cases = [ (i <= tir_int, tir_int <= i, i == tir_int), @@ -1728,7 +1728,7 @@ def before(A: T.Buffer(24, "int32"), B: T.Buffer(24, "int32"), F: T.Buffer(3, "i B[i] = B[i] + A[i - f] * F[f] # Which means that this loop is unnecessary. It would be - # removed entirely in tir.transform.RemoveNoOp, but here we + # removed entirely in tirx.transform.RemoveNoOp, but here we # want to test that the simplification works as intended. for i in T.serial(24): if i < 3 or 19 <= i: @@ -1781,7 +1781,7 @@ def before(A: T.Buffer(24, "int32"), B: T.Buffer(24, "int32"), F: T.Buffer(3, "i B[i + f] = B[i + f] + A[i] * F[f] # Which means that this loop is unnecessary. It actually gets - # removed in tir.transform.RemoveNoOp, but here we want to + # removed in tirx.transform.RemoveNoOp, but here we want to # test that the simplification works as intended. for i in T.serial(24): if i < 3 or 19 <= i: @@ -1916,7 +1916,7 @@ def before(A_ptr: T.handle("float32"), B_ptr: T.handle("float32"), n: T.int32): B[0] = A[0] after = _apply_simplify(before) - tvm.tir.analysis.verify_well_formed(after) + tvm.tirx.analysis.verify_well_formed(after) def test_buffer_shape_constraint(): @@ -1936,7 +1936,7 @@ def main(a: T.handle): A = T.match_buffer(a, (n * 32,), "float32") A[T.int64(0)] = T.float32(0) - after = tvm.tir.transform.Simplify()(Before) + after = tvm.tirx.transform.Simplify()(Before) tvm.ir.assert_structural_equal(after["main"], Expected["main"]) @@ -1957,7 +1957,7 @@ def main(a: T.handle): A = T.match_buffer(a, (n * 32 + 1 - 2,), "float32") A[T.int64(1)] = T.float32(0) - after = tvm.tir.transform.Simplify()(Before) + after = tvm.tirx.transform.Simplify()(Before) tvm.ir.assert_structural_equal(after["main"], Expected["main"]) diff --git a/tests/python/tir-transform/test_tir_transform_split_host_device.py b/tests/python/tirx-transform/test_tir_transform_split_host_device.py similarity index 88% rename from tests/python/tir-transform/test_tir_transform_split_host_device.py rename to tests/python/tirx-transform/test_tir_transform_split_host_device.py index 6e426d0e58e8..3cf0f1699f73 100644 --- a/tests/python/tir-transform/test_tir_transform_split_host_device.py +++ b/tests/python/tirx-transform/test_tir_transform_split_host_device.py @@ -17,7 +17,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_ssa_across_entire_module(): @@ -39,9 +39,9 @@ def main(): after = tvm.ir.transform.Sequential( [ - tvm.tir.transform.AnnotateDeviceRegions(), - tvm.tir.transform.SplitHostDevice(), - tvm.tir.transform.LowerDeviceKernelLaunch(), + tvm.tirx.transform.AnnotateDeviceRegions(), + tvm.tirx.transform.SplitHostDevice(), + tvm.tirx.transform.LowerDeviceKernelLaunch(), ] )(before) loop_var = after["main"].body.loop_var @@ -73,13 +73,13 @@ def main_kernel(n: T.int32): T.func_attr( { "target": T.target("cuda"), - "tir.noalias": True, - "tir.is_global_func": True, + "tirx.noalias": True, + "tirx.is_global_func": True, } ) T.evaluate(n) - After = tvm.tir.transform.SplitHostDevice()(Before) + After = tvm.tirx.transform.SplitHostDevice()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -107,14 +107,14 @@ def main_kernel(n: T.int32) -> T.int32: T.func_attr( { "target": T.target("llvm"), - "tir.noalias": True, - "tir.is_global_func": True, + "tirx.noalias": True, + "tirx.is_global_func": True, } ) T.evaluate(n) T.ret(0) - After = tvm.tir.transform.SplitHostDevice()(Before) + After = tvm.tirx.transform.SplitHostDevice()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -145,13 +145,13 @@ def main_kernel(n: T.int32): T.func_attr( { "target": T.target("cuda"), - "tir.noalias": True, - "tir.is_global_func": True, + "tirx.noalias": True, + "tirx.is_global_func": True, } ) T.evaluate(n) - After = tvm.tir.transform.SplitHostDevice()(Before) + After = tvm.tirx.transform.SplitHostDevice()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -170,7 +170,7 @@ def Before(): Expected = Before - After = tvm.tir.transform.SplitHostDevice()(tvm.IRModule.from_expr(Before)) + After = tvm.tirx.transform.SplitHostDevice()(tvm.IRModule.from_expr(Before)) tvm.ir.assert_structural_equal(After["Before"], Expected) @@ -207,8 +207,8 @@ def main_kernel_1(n: T.int32): T.func_attr( { "target": T.target("cuda"), - "tir.noalias": True, - "tir.is_global_func": True, + "tirx.noalias": True, + "tirx.is_global_func": True, } ) T.evaluate(n) @@ -218,7 +218,7 @@ def main_kernel(): T.func_attr({"target": T.target("llvm")}) T.evaluate(0) - After = tvm.tir.transform.SplitHostDevice()(Before) + After = tvm.tirx.transform.SplitHostDevice()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -277,8 +277,8 @@ def default_function_kernel( T.func_attr( { "target": T.target("cuda"), - "tir.is_global_func": True, - "tir.noalias": True, + "tirx.is_global_func": True, + "tirx.noalias": True, } ) A = T.decl_buffer(seq_len, "int32", data=A_data) @@ -288,9 +288,9 @@ def default_function_kernel( if blockIdx_x * 128 + threadIdx_x < seq_len: B[blockIdx_x * 128 + threadIdx_x] = A[blockIdx_x * 128 + threadIdx_x] - after = tvm.tir.transform.SplitHostDevice()(before) + after = tvm.tirx.transform.SplitHostDevice()(before) - tvm.tir.analysis.verify_well_formed(after) + tvm.tirx.analysis.verify_well_formed(after) tvm.ir.assert_structural_equal(expected, after) @@ -309,9 +309,9 @@ def main(var_A: T.handle, var_B: T.handle): A_1 = T.decl_buffer((m,), data=A.data) B_1[blockIdx_x] = A_1[blockIdx_x] - after = tvm.tir.transform.SplitHostDevice()(Module) + after = tvm.tirx.transform.SplitHostDevice()(Module) assert len(after["main_kernel"].params) == 3 - assert isinstance(after["main_kernel"].params[2], tvm.tir.SizeVar) + assert isinstance(after["main_kernel"].params[2], tvm.tirx.SizeVar) if __name__ == "__main__": diff --git a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py b/tests/python/tirx-transform/test_tir_transform_storage_rewrite.py similarity index 89% rename from tests/python/tir-transform/test_tir_transform_storage_rewrite.py rename to tests/python/tirx-transform/test_tir_transform_storage_rewrite.py index 2c85240f2064..83a10feb30f3 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py +++ b/tests/python/tirx-transform/test_tir_transform_storage_rewrite.py @@ -22,7 +22,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_alloc_seq(): @@ -39,16 +39,16 @@ def func(n: T.int32): B[j] = T.float32(1.3) mod = tvm.IRModule.from_expr(func) - body = tvm.tir.transform.StorageRewrite()(mod)["func"].body + body = tvm.tirx.transform.StorageRewrite()(mod)["func"].body num_alloc = [0] def verify(n): - if isinstance(n, tvm.tir.AllocBuffer): + if isinstance(n, tvm.tirx.AllocBuffer): num_alloc[0] += 1 assert n.buffer.shape[0].value == 200 - tvm.tir.stmt_functor.post_order_visit(body, verify) + tvm.tirx.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 @@ -99,14 +99,14 @@ def offset_generater(dtype_list, length): def dtype_test(dtype_list, length): def verify(n): - if isinstance(n, tvm.tir.AllocBuffer): + if isinstance(n, tvm.tirx.AllocBuffer): assert n.buffer.shape[0].value == offset mod = make_mod(dtype_list, length) offset = offset_generater(dtype_list, length) - body = tvm.tir.transform.StorageRewrite()(mod)["func"].body - tvm.tir.stmt_functor.post_order_visit(body, verify) + body = tvm.tirx.transform.StorageRewrite()(mod)["func"].body + tvm.tirx.stmt_functor.post_order_visit(body, verify) length = 1024 dtype_list = ["float16", "int32", "uint16", "int8"] @@ -156,17 +156,17 @@ def before(A: T.Buffer(8, "float32"), E: T.Buffer(8, "float32")): ) def verify(n): - if isinstance(n, tvm.tir.AllocBuffer): + if isinstance(n, tvm.tirx.AllocBuffer): total_alloc[0] += n.buffer.shape[0].value total_alloc = [0] mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - tvm.tir.stmt_functor.post_order_visit(mod["main"].body, verify) + tvm.tirx.stmt_functor.post_order_visit(mod["main"].body, verify) assert total_alloc[0] == 24 total_alloc[0] = 0 - mod = tvm.tir.transform.StorageRewrite()(mod) - tvm.tir.stmt_functor.post_order_visit(mod["main"].body, verify) + mod = tvm.tirx.transform.StorageRewrite()(mod) + tvm.tirx.stmt_functor.post_order_visit(mod["main"].body, verify) assert total_alloc[0] == 16 @@ -179,10 +179,10 @@ def func1(n: T.int32): A[j] = A[j] + T.float32(2) mod = tvm.IRModule.from_expr(func1) - body = tvm.tir.transform.StorageRewrite()(mod)["func1"] + body = tvm.tirx.transform.StorageRewrite()(mod)["func1"] # With flat AllocBuffer, the for body is a SeqStmt; first element is AllocBuffer - assert isinstance(body.body.body[0], tvm.tir.AllocBuffer) + assert isinstance(body.body.body[0], tvm.tirx.AllocBuffer) @T.prim_func def func2(n: T.int32): @@ -194,9 +194,9 @@ def func2(n: T.int32): A[j] = A[j] + T.float32(2) mod = tvm.IRModule.from_expr(func2) - body = tvm.tir.transform.StorageRewrite()(mod)["func2"] + body = tvm.tirx.transform.StorageRewrite()(mod)["func2"] - assert isinstance(body.body.body.body.body[0], tvm.tir.AllocBuffer) + assert isinstance(body.body.body.body.body[0], tvm.tirx.AllocBuffer) def test_while_alloc(): @@ -231,7 +231,7 @@ def func_serial(n: T.int32): # j[0] = (j[0] + (j[0] + 1)) # } # } - body = tvm.tir.transform.StorageRewrite()(mod)["func_parallel"] + body = tvm.tirx.transform.StorageRewrite()(mod)["func_parallel"] # Navigate to inside the for loop, then check that allocations exist # The structure with DeclBuffer is: # parallel (i, 0, n) { DeclBuffer(j, DeclBuffer(A, ...)) } @@ -241,16 +241,16 @@ def func_serial(n: T.int32): num_alloc = [0] def count_alloc(n): - if isinstance(n, tvm.tir.AllocBuffer): + if isinstance(n, tvm.tirx.AllocBuffer): num_alloc[0] += 1 - tvm.tir.stmt_functor.post_order_visit(inner, count_alloc) + tvm.tirx.stmt_functor.post_order_visit(inner, count_alloc) assert num_alloc[0] == 2 # j and A allocations mod = tvm.IRModule.from_expr(func_serial) - body = tvm.tir.transform.StorageRewrite()(mod)["func_serial"] + body = tvm.tirx.transform.StorageRewrite()(mod)["func_serial"] num_alloc[0] = 0 - tvm.tir.stmt_functor.post_order_visit(body.body, count_alloc) + tvm.tirx.stmt_functor.post_order_visit(body.body, count_alloc) assert num_alloc[0] == 2 # j and A allocations @@ -273,16 +273,16 @@ def func(n: T.int32): A2[j] = A[j] mod = tvm.IRModule.from_expr(func) - body = tvm.tir.transform.StorageRewrite()(mod)["func"].body + body = tvm.tirx.transform.StorageRewrite()(mod)["func"].body num_alloc = [0] def verify(n): - if isinstance(n, tvm.tir.AllocBuffer): + if isinstance(n, tvm.tirx.AllocBuffer): num_alloc[0] += 1 assert n.buffer.shape[0].value == 500 - tvm.tir.stmt_functor.post_order_visit(body, verify) + tvm.tirx.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 @@ -303,16 +303,16 @@ def func(n: T.int32): C[j] = T.float32(1.2) mod = tvm.IRModule.from_expr(func) - body = tvm.tir.transform.StorageRewrite()(mod)["func"].body + body = tvm.tirx.transform.StorageRewrite()(mod)["func"].body num_alloc = [0] def verify(n): - if isinstance(n, tvm.tir.AllocBuffer): + if isinstance(n, tvm.tirx.AllocBuffer): num_alloc[0] += 1 assert n.buffer.shape[0].value == 200 - tvm.tir.stmt_functor.post_order_visit(body, verify) + tvm.tirx.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 @@ -335,16 +335,16 @@ def func(n: T.int32): E[j] = C[j] mod = tvm.IRModule.from_expr(func) - body = tvm.tir.transform.StorageRewrite()(mod)["func"].body + body = tvm.tirx.transform.StorageRewrite()(mod)["func"].body num_alloc = [0] def verify(n): - if isinstance(n, tvm.tir.AllocBuffer): + if isinstance(n, tvm.tirx.AllocBuffer): num_alloc[0] += 1 assert n.buffer.shape[0].value == 800 - tvm.tir.stmt_functor.post_order_visit(body, verify) + tvm.tirx.stmt_functor.post_order_visit(body, verify) assert num_alloc[0] == 1 @@ -365,7 +365,7 @@ def func_rewritten(A: T.Buffer((8,), "float32")) -> None: x: T.float32 = T.exp(B[0], dtype="float32") A[i] = (x + 1.0) / (x - 1.0) - mod = tvm.tir.transform.StorageRewrite()( + mod = tvm.tirx.transform.StorageRewrite()( tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) ) tvm.ir.assert_structural_equal(mod["main"], func_rewritten.with_attr("global_symbol", "main")) @@ -399,7 +399,7 @@ def main() -> None: A_1 = T.Buffer([1], "int32x8", data=A_data) A_1[0] = T.broadcast(42, 8) - After = tvm.tir.transform.StorageRewrite()(Before) + After = tvm.tirx.transform.StorageRewrite()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -451,7 +451,7 @@ def main(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")): for i, j in T.grid(16, 16): D[i, j] = C[i, j] - After = tvm.tir.transform.StorageRewrite()(Before) + After = tvm.tirx.transform.StorageRewrite()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -491,7 +491,7 @@ def Before(A: T.Buffer((16, 16), "float32"), D: T.Buffer((16, 16), "float32")): Expected = Before - After = tvm.tir.transform.StorageRewrite()(tvm.IRModule.from_expr(Before)) + After = tvm.tirx.transform.StorageRewrite()(tvm.IRModule.from_expr(Before)) tvm.ir.assert_structural_equal(After["Before"], Expected) @@ -530,7 +530,7 @@ def main(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")): for i in range(16): D[i] = C[i] - After = tvm.tir.transform.StorageRewrite()(Before) + After = tvm.tirx.transform.StorageRewrite()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -575,7 +575,7 @@ def main(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")): for i in range(16): D[i] = C[i] - After = tvm.tir.transform.StorageRewrite()(Before) + After = tvm.tirx.transform.StorageRewrite()(Before) tvm.ir.assert_structural_equal(After, Expected) diff --git a/tests/python/tir-transform/test_tir_transform_unroll_loop.py b/tests/python/tirx-transform/test_tir_transform_unroll_loop.py similarity index 72% rename from tests/python/tir-transform/test_tir_transform_unroll_loop.py rename to tests/python/tirx-transform/test_tir_transform_unroll_loop.py index 8a550fe741be..b38da01d5348 100644 --- a/tests/python/tir-transform/test_tir_transform_unroll_loop.py +++ b/tests/python/tirx-transform/test_tir_transform_unroll_loop.py @@ -16,7 +16,7 @@ # under the License. import tvm from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def test_unroll_loop(): @@ -32,22 +32,22 @@ def main(A: T.handle, n: T.int64): mod = Module stmt = mod["main"].body - assert isinstance(stmt, tvm.tir.For) + assert isinstance(stmt, tvm.tirx.For) - with tvm.transform.PassContext(config={"tir.UnrollLoop": {"auto_max_step": 16}}): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert not isinstance(ret, tvm.tir.For) + with tvm.transform.PassContext(config={"tirx.UnrollLoop": {"auto_max_step": 16}}): + ret = tvm.tirx.transform.UnrollLoop()(mod)["main"].body + assert not isinstance(ret, tvm.tirx.For) - with tvm.transform.PassContext(config={"tir.UnrollLoop": {"auto_max_step": 15}}): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert isinstance(ret, tvm.tir.For) + with tvm.transform.PassContext(config={"tirx.UnrollLoop": {"auto_max_step": 15}}): + ret = tvm.tirx.transform.UnrollLoop()(mod)["main"].body + assert isinstance(ret, tvm.tirx.For) with tvm.transform.PassContext( - config={"tir.UnrollLoop": {"auto_max_step": 16, "explicit_unroll": False}} + config={"tirx.UnrollLoop": {"auto_max_step": 16, "explicit_unroll": False}} ): - ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body - assert isinstance(ret, tvm.tir.For) - assert ret.kind == tvm.tir.ForKind.UNROLLED + ret = tvm.tirx.transform.UnrollLoop()(mod)["main"].body + assert isinstance(ret, tvm.tirx.For) + assert ret.kind == tvm.tirx.ForKind.UNROLLED @I.ir_module class ModuleWithPragma: @@ -63,13 +63,13 @@ def main(A: T.handle, n: T.int64): Ab[j + 1] = Ab[i] + T.int64(1) with tvm.transform.PassContext( - config={"tir.UnrollLoop": {"auto_max_depth": 8, "explicit_unroll": False}} + config={"tirx.UnrollLoop": {"auto_max_depth": 8, "explicit_unroll": False}} ): - ret = tvm.tir.transform.UnrollLoop()(ModuleWithPragma)["main"].body - assert isinstance(ret[0], tvm.tir.For) - assert ret[0].kind == tvm.tir.ForKind.UNROLLED - assert isinstance(ret[1], tvm.tir.For) - assert ret[1].kind != tvm.tir.ForKind.UNROLLED + ret = tvm.tirx.transform.UnrollLoop()(ModuleWithPragma)["main"].body + assert isinstance(ret[0], tvm.tirx.For) + assert ret[0].kind == tvm.tirx.ForKind.UNROLLED + assert isinstance(ret[1], tvm.tirx.For) + assert ret[1].kind != tvm.tirx.ForKind.UNROLLED def test_unroll_fake_loop(): @@ -85,11 +85,11 @@ def main(A: T.handle, n: T.int64): with tvm.transform.PassContext( config={ - "tir.UnrollLoop": {"auto_max_depth": 8, "auto_max_extent": 1, "explicit_unroll": False} + "tirx.UnrollLoop": {"auto_max_depth": 8, "auto_max_extent": 1, "explicit_unroll": False} } ): - ret = tvm.tir.transform.UnrollLoop()(Module)["main"].body - assert isinstance(ret[0], tvm.tir.BufferStore) + ret = tvm.tirx.transform.UnrollLoop()(Module)["main"].body + assert isinstance(ret[0], tvm.tirx.BufferStore) def test_unroll_allocations(): @@ -110,7 +110,7 @@ def main(): buf2 = T.alloc_buffer([16], "float32") buf2[0] = 0.0 - after = tvm.tir.transform.UnrollLoop()(Before) + after = tvm.tirx.transform.UnrollLoop()(Before) tvm.ir.assert_structural_equal(after, Expected) @@ -140,7 +140,7 @@ def main(B: T.Buffer((64,), "float32")): with tvm.transform.PassContext( config={ - "tir.UnrollLoop": { + "tirx.UnrollLoop": { "auto_max_depth": 0, "auto_max_extent": 1, "explicit_unroll": True, @@ -148,8 +148,8 @@ def main(B: T.Buffer((64,), "float32")): } } ): - after = tvm.tir.transform.UnrollLoop()(Before) - after = tvm.tir.transform.Simplify()(after) + after = tvm.tirx.transform.UnrollLoop()(Before) + after = tvm.tirx.transform.Simplify()(after) tvm.ir.assert_structural_equal(after, Expected) diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tirx-transform/test_tir_transform_vectorize.py similarity index 87% rename from tests/python/tir-transform/test_tir_transform_vectorize.py rename to tests/python/tirx-transform/test_tir_transform_vectorize.py index 9fcc66717e18..02b82df6e4c5 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tirx-transform/test_tir_transform_vectorize.py @@ -20,7 +20,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T simple_target = tvm.target.Target({"kind": "llvm", "mtriple": "x86_64-linux-gnu"}) sve_target = tvm.target.Target( @@ -49,7 +49,7 @@ def main(A: T.Buffer((16,), "float32")): A[T.Ramp(0, 1, extent)] = T.Broadcast(1, extent) with tvm.target.Target(target): - mod = tvm.tir.transform.VectorizeLoop()(Before) + mod = tvm.tirx.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) @@ -62,14 +62,14 @@ def main(A: T.Buffer((4,), "float32x4"), n: T.int32): for j in T.vectorized(4): A[j] = T.Broadcast(T.float32(1), 4) - mod = tvm.tir.transform.VectorizeLoop()(Module) + mod = tvm.tirx.transform.VectorizeLoop()(Module) stmt = mod["main"].body - assert isinstance(stmt, tvm.tir.For) - assert not isinstance(stmt.body, tvm.tir.For) + assert isinstance(stmt, tvm.tirx.For) + assert not isinstance(stmt.body, tvm.tirx.For) assert len(stmt.body.indices) == 1 - assert isinstance(stmt.body.indices[0], tvm.tir.Ramp) - assert isinstance(stmt.body.value, tvm.tir.Broadcast) + assert isinstance(stmt.body.indices[0], tvm.tirx.Ramp) + assert isinstance(stmt.body.value, tvm.tirx.Broadcast) def test_vectorize_vector_scalable_error(): @@ -83,7 +83,7 @@ def main(A: T.Buffer((25,), "float32")): error_msg = "Creating scalable vectors from existing vectors is not supported." with tvm.target.Target(sve_target): with pytest.raises(tvm.error.InternalError, match=error_msg): - tvm.tir.transform.VectorizeLoop()(Module) + tvm.tirx.transform.VectorizeLoop()(Module) def test_vectorize_vector_scalable_error2(): @@ -96,7 +96,7 @@ def main(A: T.Buffer((25,), "float32xvscalex4")): error_msg = "Vectorizing over scalable buffer elements is not supported in vectorizer." with pytest.raises(tvm.error.InternalError, match=error_msg): - tvm.tir.transform.VectorizeLoop()(Module) + tvm.tirx.transform.VectorizeLoop()(Module) def test_vectorize_vector_scalable_error3(): @@ -112,7 +112,7 @@ def main(A: T.Buffer((25,), "float32")): error_msg = "Vectorizing over existing scalable vectors is not supported." with pytest.raises(tvm.error.InternalError, match=error_msg): with tvm.target.Target(sve_target): - tvm.tir.transform.VectorizeLoop()(Module) + tvm.tirx.transform.VectorizeLoop()(Module) def test_vectorize_vector_scalable_error4(): @@ -128,7 +128,7 @@ def main(A: T.Buffer((25,), "float32")): error_msg = "Creating scalable vectors from existing vectors is not supported." with pytest.raises(tvm.error.InternalError, match=error_msg): with tvm.target.Target(sve_target): - tvm.tir.transform.VectorizeLoop()(Module) + tvm.tirx.transform.VectorizeLoop()(Module) def test_vectorize_with_if(): @@ -162,7 +162,7 @@ def main(a: T.handle, n: T.int32, x: T.int32): A[i_s] = T.float32(2) with tvm.target.Target(target): - mod = tvm.tir.transform.VectorizeLoop()(Before) + mod = tvm.tirx.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) @@ -199,7 +199,7 @@ def main(a: T.handle, n: T.int32, x: T.int32): ) with tvm.target.Target(target): - mod = tvm.tir.transform.VectorizeLoop()(Before) + mod = tvm.tirx.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) @@ -221,7 +221,7 @@ def main(A: T.Buffer((25,), "float32")): A[T.Ramp(0, 1, extent)] = v + T.Broadcast(T.float32(2), extent) with tvm.target.Target(target): - mod = tvm.tir.transform.VectorizeLoop()(Before) + mod = tvm.tirx.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) @@ -236,10 +236,10 @@ def main(A: T.Buffer((16,), "float32"), n: T.int32): A[i] = A[i] + T.float32(1) with tvm.target.Target(target): - stmt = tvm.tir.transform.VectorizeLoop()(Module)["main"].body + stmt = tvm.tirx.transform.VectorizeLoop()(Module)["main"].body # Check that the loop wasn't vectorised - assert isinstance(stmt, tvm.tir.For) + assert isinstance(stmt, tvm.tirx.For) @pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @@ -253,10 +253,10 @@ def main(A: T.Buffer((16,), "float32"), n: T.int32): A[i] = A[i] + T.float32(1) with tvm.target.Target(target): - stmt = tvm.tir.transform.VectorizeLoop()(Module)["main"].body + stmt = tvm.tirx.transform.VectorizeLoop()(Module)["main"].body # Check that the loop wasn't vectorised - assert isinstance(stmt, tvm.tir.For) + assert isinstance(stmt, tvm.tirx.For) @pytest.mark.parametrize("extent, target", [(4, simple_target), (T.vscale() * 4, sve_target)]) @@ -276,7 +276,7 @@ def main(A: T.Buffer((25,), "float32")): A[i_s] = T.if_then_else(i_s > 0, A[i_s] + T.float32(1), A[i_s]) with tvm.target.Target(target): - mod = tvm.tir.transform.VectorizeLoop()(Before) + mod = tvm.tirx.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) @@ -300,7 +300,7 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32): ) with tvm.target.Target(target): - mod = tvm.tir.transform.VectorizeLoop()(Before) + mod = tvm.tirx.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) @@ -323,7 +323,7 @@ def main(): T.evaluate(0) with tvm.target.Target(simple_target): - mod = tvm.tir.transform.VectorizeLoop()(Before) + mod = tvm.tirx.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) @@ -378,7 +378,7 @@ def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): B[T.Ramp(0, 1, extent)] = T.reinterpret(vec_str, A[T.Ramp(0, 1, extent)]) with tvm.target.Target(target): - mod = tvm.tir.transform.VectorizeLoop()(Before) + mod = tvm.tirx.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) @@ -418,7 +418,7 @@ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.float32(3), extent), B[T.Ramp(0, 1, extent)]) with tvm.target.Target(target): - mod = tvm.tir.transform.VectorizeLoop()(Before) + mod = tvm.tirx.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) @@ -439,7 +439,7 @@ def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")): A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.bool(1), extent), B[T.Ramp(0, 1, extent)]) with tvm.target.Target(target): - mod = tvm.tir.transform.VectorizeLoop()(Before) + mod = tvm.tirx.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) @@ -463,7 +463,7 @@ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): ) with tvm.target.Target(target): - mod = tvm.tir.transform.VectorizeLoop()(Before) + mod = tvm.tirx.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) @@ -486,7 +486,7 @@ def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): A[T.Ramp(0, 1, extent)] = T.Cast(vec_str, B[T.Ramp(0, 1, extent)]) with tvm.target.Target(target): - mod = tvm.tir.transform.VectorizeLoop()(Before) + mod = tvm.tirx.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) @@ -501,7 +501,7 @@ def main(A: T.Buffer((25,), "int32")): error_msg = "Failed to vectorize loop with extent n for target \\(nullptr\\)" with pytest.raises(tvm.error.InternalError, match=error_msg): - tvm.tir.transform.VectorizeLoop()(Mod) + tvm.tirx.transform.VectorizeLoop()(Mod) def test_illegal_vscale_in_non_sve_compilation(): @@ -515,7 +515,7 @@ def main(A: T.Buffer((16,), "float32")): msg = "Failed to vectorize loop with extent T.vscale\\(\\) \\* 4 for target" with tvm.target.Target(simple_target): with pytest.raises(tvm.error.InternalError, match=msg): - tvm.tir.transform.VectorizeLoop()(Mod) + tvm.tirx.transform.VectorizeLoop()(Mod) def test_vectorize_and_predicate_all_buffer_loads_stores(): @@ -523,7 +523,7 @@ def test_vectorize_and_predicate_all_buffer_loads_stores(): def before(a: T.handle, b: T.handle): A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i_0 in T.serial(T.ceildiv(14, 4)): for i_1 in T.vectorized(4): if i_0 * 4 + i_1 < 14: @@ -533,7 +533,7 @@ def before(a: T.handle, b: T.handle): def expected(a: T.handle, b: T.handle): A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i_0 in range(4): load_a = T.meta_var( A.vload( @@ -549,8 +549,8 @@ def expected(a: T.handle, b: T.handle): ) mod = tvm.IRModule.from_expr(before) - with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": True}): - after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + with tvm.transform.PassContext(config={"tirx.enable_buffer_level_predication": True}): + after = tvm.tirx.transform.VectorizeLoop()(mod)["main"] tvm.ir.assert_structural_equal(after, expected) @@ -561,7 +561,7 @@ def test_vectorize_and_predicate_some_buffer_loads_stores(): def before(a: T.handle, b: T.handle): A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i_0 in T.serial(T.ceildiv(14, 4)): for i_1 in T.vectorized(4): if i_0 * 4 + i_1 < 14: @@ -571,14 +571,14 @@ def before(a: T.handle, b: T.handle): def expected(a: T.handle, b: T.handle): A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i_0, i_1_s in T.grid(4, 4): if i_0 * 4 + i_1_s < 14: B[i_0 * 4 + i_1_s] = A[i_0] + T.float32(1) mod = tvm.IRModule.from_expr(before) - with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": True}): - after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + with tvm.transform.PassContext(config={"tirx.enable_buffer_level_predication": True}): + after = tvm.tirx.transform.VectorizeLoop()(mod)["main"] tvm.ir.assert_structural_equal(after, expected) @@ -587,7 +587,7 @@ def test_vectorize_and_predicate_multiple_access_statements(): def before(a: T.handle, b: T.handle): A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i_0 in T.serial(T.ceildiv(14, 4)): for i_1 in T.vectorized(4): if i_0 * 4 + i_1 < 14: @@ -598,7 +598,7 @@ def before(a: T.handle, b: T.handle): def expected(a: T.handle, b: T.handle): A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i_0 in range(4): A.vstore( [T.Ramp(i_0 * 4, 1, 4)], @@ -612,8 +612,8 @@ def expected(a: T.handle, b: T.handle): ) before_mod = tvm.IRModule.from_expr(before) - with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": True}): - after = tvm.tir.transform.VectorizeLoop()(before_mod)["main"] + with tvm.transform.PassContext(config={"tirx.enable_buffer_level_predication": True}): + after = tvm.tirx.transform.VectorizeLoop()(before_mod)["main"] tvm.ir.assert_structural_equal(after, expected) @@ -622,7 +622,7 @@ def test_vectorize_and_predicate_invalid_conditions(): def before(a: T.handle, b: T.handle): A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i_0 in T.serial(T.ceildiv(14, 4)): for i_1 in T.vectorized(4): if i_0 * 4 + i_1 > 14: @@ -636,7 +636,7 @@ def before(a: T.handle, b: T.handle): def expected(a: T.handle, b: T.handle): A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i_0 in range(4): for i_1_s in range(4): if i_0 * 4 + i_1_s > 14: @@ -649,8 +649,8 @@ def expected(a: T.handle, b: T.handle): A[i_0 * 4 + i_1_s] = T.float32(2) before_mod = tvm.IRModule.from_expr(before) - with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": True}): - after = tvm.tir.transform.VectorizeLoop()(before_mod)["main"] + with tvm.transform.PassContext(config={"tirx.enable_buffer_level_predication": True}): + after = tvm.tirx.transform.VectorizeLoop()(before_mod)["main"] tvm.ir.assert_structural_equal(after, expected) @@ -662,7 +662,7 @@ def test_vectorize_with_explicitly_disabled_buffer_level_predication(): def before(a: T.handle, b: T.handle): A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i_0 in T.serial(T.ceildiv(14, 4)): for i_1 in T.vectorized(4): if i_0 * 4 + i_1 < 14: @@ -672,15 +672,15 @@ def before(a: T.handle, b: T.handle): def expected(a: T.handle, b: T.handle): A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) for i_0, i_1_s in T.grid(4, 4): if i_0 * 4 + i_1_s < 14: B[i_0 * 4 + i_1_s] = A[i_0 * 4 + i_1_s] + T.float32(1) mod = tvm.IRModule.from_expr(before) - with tvm.transform.PassContext(config={"tir.enable_buffer_level_predication": False}): + with tvm.transform.PassContext(config={"tirx.enable_buffer_level_predication": False}): with tvm.target.Target(sve_target): - after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + after = tvm.tirx.transform.VectorizeLoop()(mod)["main"] tvm.ir.assert_structural_equal(after, expected) @@ -689,7 +689,7 @@ def test_vectorize_and_predicate_buffer_load_stores_with_sve_func_attr_target(): def before(a: T.handle, b: T.handle): A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": sve_target}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True, "target": sve_target}) for i_0 in T.serial(T.ceildiv(14, 4)): for i_1 in T.vectorized(4): if i_0 * 4 + i_1 < 14: @@ -699,7 +699,7 @@ def before(a: T.handle, b: T.handle): def expected(a: T.handle, b: T.handle): A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True, "target": sve_target}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True, "target": sve_target}) for i_0 in range(4): load_a = T.meta_var( A.vload( @@ -715,7 +715,7 @@ def expected(a: T.handle, b: T.handle): ) mod = tvm.IRModule.from_expr(before) - after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + after = tvm.tirx.transform.VectorizeLoop()(mod)["main"] tvm.ir.assert_structural_equal(after, expected) @@ -724,7 +724,7 @@ def test_vectorize_and_predicate_buffer_load_stores_with_sve_attr_scope_target() def before(a: T.handle, b: T.handle): A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.attr(sve_target, "target", 0): for i_0 in T.serial(T.ceildiv(14, 4)): for i_1 in T.vectorized(4): @@ -735,7 +735,7 @@ def before(a: T.handle, b: T.handle): def expected(a: T.handle, b: T.handle): A = T.match_buffer(a, (16,), "float32") B = T.match_buffer(b, (16,), "float32") - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) with T.attr(sve_target, "target", 0): for i_0 in range(4): load_a = T.meta_var( @@ -752,7 +752,7 @@ def expected(a: T.handle, b: T.handle): ) mod = tvm.IRModule.from_expr(before) - after = tvm.tir.transform.VectorizeLoop()(mod)["main"] + after = tvm.tirx.transform.VectorizeLoop()(mod)["main"] tvm.ir.assert_structural_equal(after, expected) @@ -777,7 +777,7 @@ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")): ) with tvm.target.Target(target): - mod = tvm.tir.transform.VectorizeLoop()(Before) + mod = tvm.tirx.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) mod = tvm.compile(mod, target=target) @@ -803,7 +803,7 @@ def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")): ) with tvm.target.Target(target): - mod = tvm.tir.transform.VectorizeLoop()(Before) + mod = tvm.tirx.transform.VectorizeLoop()(Before) tvm.ir.assert_structural_equal(mod, After) with pytest.raises(Exception) as e_info: ex = tvm.compile(mod, target=target) diff --git a/tests/python/tvmscript/test_tvmscript_complete.py b/tests/python/tvmscript/test_tvmscript_complete.py index f1de6bfa7b0d..b23148e45f57 100644 --- a/tests/python/tvmscript/test_tvmscript_complete.py +++ b/tests/python/tvmscript/test_tvmscript_complete.py @@ -17,7 +17,7 @@ import tvm.testing from tvm.ir import Range -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func @@ -112,11 +112,17 @@ def test_complete_matmul(): A, B, C = [func.buffer_map[x] for x in func.params] block = func.body.block.body.body.body.body.block - assert isinstance(block, tvm.tir.SBlock) + assert isinstance(block, tvm.tirx.SBlock) vi, vj, vk = [x.var for x in block.iter_vars] - access_A = tvm.tir.BufferRegion(A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vk, 1)]) - access_B = tvm.tir.BufferRegion(B, [Range.from_min_extent(vj, 1), Range.from_min_extent(vk, 1)]) - access_C = tvm.tir.BufferRegion(C, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)]) + access_A = tvm.tirx.BufferRegion( + A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vk, 1)] + ) + access_B = tvm.tirx.BufferRegion( + B, [Range.from_min_extent(vj, 1), Range.from_min_extent(vk, 1)] + ) + access_C = tvm.tirx.BufferRegion( + C, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)] + ) tvm.ir.assert_structural_equal(block.reads, [access_A, access_B]) tvm.ir.assert_structural_equal(block.writes, [access_C]) @@ -126,24 +132,24 @@ def test_complete_matmul_original(): A, B, C = [func.buffer_map[x] for x in func.params] block1 = func.body.block.body.body.body[0].block - assert isinstance(block1, tvm.tir.SBlock) + assert isinstance(block1, tvm.tirx.SBlock) vi, vj = [x.var for x in block1.iter_vars] - access_C = tvm.tir.BufferRegion( + access_C = tvm.tirx.BufferRegion( C, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vj * 4, 4)] ) tvm.ir.assert_structural_equal(block1.reads, []) tvm.ir.assert_structural_equal(block1.writes, [access_C]) block2 = func.body.block.body.body.body[1].body.block - assert isinstance(block2, tvm.tir.SBlock) + assert isinstance(block2, tvm.tirx.SBlock) vi, vj, vk = [x.var for x in block2.iter_vars] - access_A = tvm.tir.BufferRegion( + access_A = tvm.tirx.BufferRegion( A, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vk * 4, 4)] ) - access_B = tvm.tir.BufferRegion( + access_B = tvm.tirx.BufferRegion( B, [Range.from_min_extent(vj * 4, 4), Range.from_min_extent(vk * 4, 4)] ) - access_C = tvm.tir.BufferRegion( + access_C = tvm.tirx.BufferRegion( C, [Range.from_min_extent(vi * 4, 4), Range.from_min_extent(vj * 4, 4)] ) tvm.ir.assert_structural_equal(block2.reads, [access_C, access_A, access_B]) @@ -158,28 +164,28 @@ def _check_elementwise(func): assert len(root_block.writes) == 0 block1 = func.body.block.body[0].body.body.block - assert isinstance(block1, tvm.tir.SBlock) + assert isinstance(block1, tvm.tirx.SBlock) vi, vj = [x.var for x in block1.iter_vars] tvm.ir.assert_structural_equal( block1.reads, - [tvm.tir.BufferRegion(A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], + [tvm.tirx.BufferRegion(A, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], ) tvm.ir.assert_structural_equal( block1.writes, - [tvm.tir.BufferRegion(B, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], + [tvm.tirx.BufferRegion(B, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], ) block2 = func.body.block.body[1].body.body.block - assert isinstance(block2, tvm.tir.SBlock) + assert isinstance(block2, tvm.tirx.SBlock) vi, vj = [x.var for x in block2.iter_vars] tvm.ir.assert_structural_equal( block2.reads, - [tvm.tir.BufferRegion(B, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], + [tvm.tirx.BufferRegion(B, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], ) tvm.ir.assert_structural_equal( block2.writes, - [tvm.tir.BufferRegion(C, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], + [tvm.tirx.BufferRegion(C, [Range.from_min_extent(vi, 1), Range.from_min_extent(vj, 1)])], ) diff --git a/tests/python/tvmscript/test_tvmscript_error_report.py b/tests/python/tvmscript/test_tvmscript_error_report.py index dab86cb1317f..de6c6d35b9bc 100644 --- a/tests/python/tvmscript/test_tvmscript_error_report.py +++ b/tests/python/tvmscript/test_tvmscript_error_report.py @@ -22,10 +22,10 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.ir.diagnostics import override_renderer from tvm.script import from_source -from tvm.script import tir as T +from tvm.script import tirx as T def check_error(func, rel_lineno): @@ -451,7 +451,7 @@ def test_reorder_fail_block(): with pytest.raises(tvm.s_tir.ScheduleError) as execinfo: sch.reorder(l, i) expected_sub_error_message = ( - " # tir.SBlock#0\n" + " # tirx.SBlock#0\n" ' with T.sblock("B"):\n' " ^^^^^^^^^^^^^^^^^^^\n" ) @@ -466,7 +466,7 @@ def test_reorder_fail_nested_loop_inner(): sch.reorder(k, i) expected_sub_error_message = ( " for i in range(128):\n" - " # tir.For#0\n" + " # tirx.For#0\n" " for j in range(128):\n" " ^^^^^^^^^^^^^^^^^^^^\n" ) @@ -480,7 +480,7 @@ def test_fuse_fail_nested_loop_outer(): with pytest.raises(tvm.s_tir.ScheduleError) as execinfo: sch.fuse(k, i) expected_sub_error_message = ( - " # tir.For#1\n" + " # tirx.For#1\n" " for i in range(128):\n" " ^^^^^^^^^^^^^^^^^^^^\n" " for j in range(128):\n" @@ -494,7 +494,7 @@ def test_report_error_root_block(): with pytest.raises(tvm.s_tir.ScheduleError) as execinfo: sch.compute_inline(root) expected_sub_error_message = ( - ' # tir.SBlock#0\n with T.sblock("root"):\n ^^^^^^^^^^^^^^^^^^^^^^\n' + ' # tirx.SBlock#0\n with T.sblock("root"):\n ^^^^^^^^^^^^^^^^^^^^^^\n' ) assert expected_sub_error_message in str(execinfo.value) diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py index 460457601ae1..af04802dc23d 100644 --- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py +++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name, missing-docstring # ruff: noqa: F401, F841 -"""Unittests for tvm.script.ir_builder.tir""" +"""Unittests for tvm.script.ir_builder.tirx""" import numpy as np import pytest @@ -24,10 +24,10 @@ import tvm import tvm.runtime import tvm.testing -from tvm import tir +from tvm import tirx from tvm.ir.base import assert_structural_equal from tvm.script.ir_builder import IRBuilder -from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import tirx as T def test_ir_builder_tir_primfunc_base(): @@ -39,9 +39,9 @@ def test_ir_builder_tir_primfunc_base(): prim_func_actual = ib.get() # the expected prim_func - prim_func_expected = tir.PrimFunc( + prim_func_expected = tirx.PrimFunc( params=[], - body=tir.Evaluate(0), + body=tirx.Evaluate(0), ret_type=None, buffer_map=None, attrs=None, @@ -69,20 +69,20 @@ def test_ir_builder_tir_primfunc_complete(): # the expected prim_func c_handle, c_buffer = ( - tir.Var("c_handle", "handle"), - tir.decl_buffer((128, 128), "float32", name="c"), + tirx.Var("c_handle", "handle"), + tirx.decl_buffer((128, 128), "float32", name="c"), ) - d_handle, d_buffer = tir.Var("d", "handle"), tir.decl_buffer((64, 64), "int64", name="d") - e_handle, e_buffer = tir.Var("e_handle", "handle"), tir.decl_buffer((1024,), "int8", name="e") - prim_func_expected = tir.PrimFunc( + d_handle, d_buffer = tirx.Var("d", "handle"), tirx.decl_buffer((64, 64), "int64", name="d") + e_handle, e_buffer = tirx.Var("e_handle", "handle"), tirx.decl_buffer((1024,), "int8", name="e") + prim_func_expected = tirx.PrimFunc( params=[ - tir.Var("a", "handle"), - tir.Var("b", "int64"), + tirx.Var("a", "handle"), + tirx.Var("b", "int64"), c_handle, d_handle, e_handle, ], - body=tir.Evaluate(0), + body=tirx.Evaluate(0), ret_type=tvm.ir.PrimType("int64"), buffer_map={c_handle: c_buffer, d_handle: d_buffer, e_handle: e_buffer}, attrs=tvm.ir.make_node("ir.DictAttrs", key="value"), @@ -101,17 +101,17 @@ def test_ir_builder_tir_block_base(): block_realize_actual = ib.get() # the expected block - block_expected = tir.SBlock( + block_expected = tirx.SBlock( iter_vars=[], reads=[], writes=[], name_hint="block", - body=tir.Evaluate(0), + body=tirx.Evaluate(0), alloc_buffers=None, match_buffers=None, - annotations={"tir.script_parsing_detect_access": tir.IntImm("int64", 3)}, + annotations={"tirx.script_parsing_detect_access": tirx.IntImm("int64", 3)}, ) - block_realize_expected = tir.SBlockRealize( + block_realize_expected = tirx.SBlockRealize( iter_values=[], predicate=True, block=block_expected, @@ -143,25 +143,25 @@ def test_ir_builder_tir_block_complete(): block_realize_actual = ib.get() # the expected block - var_a = tir.Var("a", "int64") - buffer_b = tir.decl_buffer((128, 128), "float32", name="b") - buffer_c = tir.decl_buffer((128, 128), "float32", name="c") - var_d = tir.Var("d", "int32") - buffer_e = tir.decl_buffer((128, 128), "float32", name="c") - var_f = tir.Var("f", "int32") - block_expected = tir.SBlock( - iter_vars=[tir.IterVar((0, 128), tir.Var("", "int32"), iter_type=tir.IterVar.DataPar)], + var_a = tirx.Var("a", "int64") + buffer_b = tirx.decl_buffer((128, 128), "float32", name="b") + buffer_c = tirx.decl_buffer((128, 128), "float32", name="c") + var_d = tirx.Var("d", "int32") + buffer_e = tirx.decl_buffer((128, 128), "float32", name="c") + var_f = tirx.Var("f", "int32") + block_expected = tirx.SBlock( + iter_vars=[tirx.IterVar((0, 128), tirx.Var("", "int32"), iter_type=tirx.IterVar.DataPar)], reads=[buffer_b[0:16, 0:16]], writes=[buffer_c[var_d:128, var_d:128]], name_hint="block", - body=tir.Evaluate(0), - alloc_buffers=[tir.decl_buffer((128, 128), "float32")], + body=tirx.Evaluate(0), + alloc_buffers=[tirx.decl_buffer((128, 128), "float32")], match_buffers=[ - tir.MatchBufferRegion(tir.decl_buffer((32, 32), "float32"), buffer_e[0:32, 0:32]) + tirx.MatchBufferRegion(tirx.decl_buffer((32, 32), "float32"), buffer_e[0:32, 0:32]) ], annotations={"key": "value"}, ) - block_realize_expected = tir.SBlockRealize( + block_realize_expected = tirx.SBlockRealize( iter_values=[var_f], predicate=var_a > 1, block=block_expected, @@ -188,24 +188,24 @@ def test_ir_builder_tir_axis(): block_realize_actual = ib.get() # the expected block - var_a = tir.Var("a", "int32") - var_b = tir.Var("b", "int32") - var_c = tir.Var("c", "int32") - var_d = tir.Var("d", "int32") - block_expected = tir.SBlock( + var_a = tirx.Var("a", "int32") + var_b = tirx.Var("b", "int32") + var_c = tirx.Var("c", "int32") + var_d = tirx.Var("d", "int32") + block_expected = tirx.SBlock( iter_vars=[ - tir.IterVar((0, 8), tir.Var("", "int32"), iter_type=tir.IterVar.DataPar), - tir.IterVar((0, 16), tir.Var("", "int32"), iter_type=tir.IterVar.CommReduce), - tir.IterVar((0, 32), tir.Var("", "int32"), iter_type=tir.IterVar.Ordered), - tir.IterVar((0, 64), tir.Var("", "int32"), iter_type=tir.IterVar.Opaque), + tirx.IterVar((0, 8), tirx.Var("", "int32"), iter_type=tirx.IterVar.DataPar), + tirx.IterVar((0, 16), tirx.Var("", "int32"), iter_type=tirx.IterVar.CommReduce), + tirx.IterVar((0, 32), tirx.Var("", "int32"), iter_type=tirx.IterVar.Ordered), + tirx.IterVar((0, 64), tirx.Var("", "int32"), iter_type=tirx.IterVar.Opaque), ], reads=[], writes=[], name_hint="block", - body=tir.Evaluate(0), - annotations={"tir.script_parsing_detect_access": tir.IntImm("int64", 3)}, + body=tirx.Evaluate(0), + annotations={"tirx.script_parsing_detect_access": tirx.IntImm("int64", 3)}, ) - block_realize_expected = tir.SBlockRealize( + block_realize_expected = tirx.SBlockRealize( iter_values=[var_a, var_b, var_c, var_d], predicate=True, block=block_expected, @@ -228,42 +228,42 @@ def test_ir_builder_tir_for(): for_actual = ib.get() # the expected for - thread_binding_expected = tir.For( - loop_var=tir.Var("", "int32"), + thread_binding_expected = tirx.For( + loop_var=tirx.Var("", "int32"), min=0, extent=8, - kind=tir.ForKind.THREAD_BINDING, - body=tir.Evaluate(0), - thread_binding=tir.IterVar( - None, tir.Var("", "int32"), tir.IterVar.ThreadIndex, "threadIdx.x" + kind=tirx.ForKind.THREAD_BINDING, + body=tirx.Evaluate(0), + thread_binding=tirx.IterVar( + None, tirx.Var("", "int32"), tirx.IterVar.ThreadIndex, "threadIdx.x" ), ) - unroll_expected = tir.For( - loop_var=tir.Var("", "int32"), + unroll_expected = tirx.For( + loop_var=tirx.Var("", "int32"), min=0, extent=16, - kind=tir.ForKind.UNROLLED, + kind=tirx.ForKind.UNROLLED, body=thread_binding_expected, ) - vectorized_expected = tir.For( - loop_var=tir.Var("", "int32"), + vectorized_expected = tirx.For( + loop_var=tirx.Var("", "int32"), min=0, extent=32, - kind=tir.ForKind.VECTORIZED, + kind=tirx.ForKind.VECTORIZED, body=unroll_expected, ) - parallel_expected = tir.For( - loop_var=tir.Var("", "int32"), + parallel_expected = tirx.For( + loop_var=tirx.Var("", "int32"), min=0, extent=64, - kind=tir.ForKind.PARALLEL, + kind=tirx.ForKind.PARALLEL, body=vectorized_expected, ) - for_expected = tir.For( - loop_var=tir.Var("", "int32"), + for_expected = tirx.For( + loop_var=tirx.Var("", "int32"), min=0, extent=128, - kind=tir.ForKind.SERIAL, + kind=tirx.ForKind.SERIAL, body=parallel_expected, ) @@ -273,18 +273,18 @@ def test_ir_builder_tir_for(): def test_ir_builder_tir_for_uint(): with IRBuilder() as ib: - with T.serial(tir.const(128, "uint32")) as a: + with T.serial(tirx.const(128, "uint32")) as a: T.evaluate(0) # the for generated by IRBuilder for_actual = ib.get() - for_expected = tir.For( - loop_var=tir.Var("", "uint32"), - min=tir.const(0, "uint32"), - extent=tir.const(128, "uint32"), - kind=tir.ForKind.SERIAL, - body=tir.Evaluate(0), + for_expected = tirx.For( + loop_var=tirx.Var("", "uint32"), + min=tirx.const(0, "uint32"), + extent=tirx.const(128, "uint32"), + kind=tirx.ForKind.SERIAL, + body=tirx.Evaluate(0), ) # Check if the generated ir is expected @@ -299,14 +299,14 @@ def test_ir_builder_tir_assert(): assert_actual = ib.get() # AssertStmt is a leaf. The frame emits the assert and then the body stmts as siblings. - assert_expected = tir.SeqStmt( + assert_expected = tirx.SeqStmt( [ - tir.AssertStmt( + tirx.AssertStmt( T.int32() == 0, - tir.StringImm("RuntimeError"), - [tir.StringImm("a is 0")], + tirx.StringImm("RuntimeError"), + [tirx.StringImm("a is 0")], ), - tir.Evaluate(0), + tirx.Evaluate(0), ] ) @@ -317,17 +317,17 @@ def test_ir_builder_tir_assert(): def test_ir_builder_tir_bind(): # Test that T.bind emits a flat Bind statement and returns the Var. with IRBuilder() as ib: - v = T.bind(tir.IntImm("int32", 2)) + v = T.bind(tirx.IntImm("int32", 2)) # the let binding generated by IRBuilder let_actual = ib.get() # Bind is now flat (no body), so a single Bind stmt is emitted. - let_expected = tir.Bind(T.int32(), tir.IntImm("int32", 2)) + let_expected = tirx.Bind(T.int32(), tirx.IntImm("int32", 2)) # Check if the generated ir is expected assert_structural_equal(let_actual, let_expected, map_free_vars=True) # Check that the returned value is a Var - assert isinstance(v, tir.Var) + assert isinstance(v, tirx.Var) def test_ir_builder_tir_thread(): @@ -341,9 +341,9 @@ def test_ir_builder_tir_thread(): ir_actual = ib.get() # the expected prim_func - iter_var = tir.IterVar((0, 1), "v", iter_type=1, thread_tag="blockIdx.y") - attr_stmt = tir.AttrStmt(iter_var, "thread_extent", 1, tir.Evaluate(0)) - func = tir.PrimFunc([], attr_stmt) + iter_var = tirx.IterVar((0, 1), "v", iter_type=1, thread_tag="blockIdx.y") + attr_stmt = tirx.AttrStmt(iter_var, "thread_extent", 1, tirx.Evaluate(0)) + func = tirx.PrimFunc([], attr_stmt) # Check if the generated ir is expected assert_structural_equal(ir_actual, func, map_free_vars=True) @@ -361,10 +361,10 @@ def test_ir_builder_tir_allocate(): body = ir_actual.body # AllocBuffer is flat: body should be a SeqStmt with [AllocBuffer, Evaluate(1)] - assert isinstance(body, tir.SeqStmt), f"Expected SeqStmt but got {type(body)}" + assert isinstance(body, tirx.SeqStmt), f"Expected SeqStmt but got {type(body)}" assert len(body) == 2 - assert isinstance(body[0], tir.AllocBuffer) - assert isinstance(body[1], tir.Evaluate) + assert isinstance(body[0], tirx.AllocBuffer) + assert isinstance(body[1], tirx.Evaluate) assert body[1].value.value == 1 @@ -377,7 +377,7 @@ def test_ir_builder_tir_while(): ir_actual = ib.get() # the expected while - ir_expected = tir.While(tir.Var("x", "int32") > 0, tir.Evaluate(0)) + ir_expected = tirx.While(tirx.Var("x", "int32") > 0, tirx.Evaluate(0)) # Check if the generated ir is expected assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) @@ -395,10 +395,10 @@ def test_ir_builder_tir_if_then_else(): ir_actual = ib.get() # the expected if_then_else - ir_expected = tir.IfThenElse( - tir.Var("c", "int32") < 12, - tir.Evaluate(tir.IntImm("int32", 0)), - tir.Evaluate(tir.IntImm("int32", 1)), + ir_expected = tirx.IfThenElse( + tirx.Var("c", "int32") < 12, + tirx.Evaluate(tirx.IntImm("int32", 0)), + tirx.Evaluate(tirx.IntImm("int32", 1)), ) # Check if the generated ir is expected @@ -415,7 +415,7 @@ def test_ir_builder_tir_buffer_store(): ir_actual = ib.get() # the expected buffer store - ir_expected = tir.BufferStore(buffer_a, 0.1, [0, i]) + ir_expected = tirx.BufferStore(buffer_a, 0.1, [0, i]) # Check if the generated ir is expected assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) @@ -423,8 +423,8 @@ def test_ir_builder_tir_buffer_store(): def test_ir_builder_tir_buffer_store_scalable_vec(): buffer_a = T.Buffer((30,), "float32") - value = T.broadcast(0.11, 4 * tvm.tir.vscale()) - index = T.ramp(0, 1, 4 * tvm.tir.vscale()) + value = T.broadcast(0.11, 4 * tvm.tirx.vscale()) + index = T.ramp(0, 1, 4 * tvm.tirx.vscale()) with IRBuilder() as ib: T.buffer_store(buffer_a, value, [index]) @@ -433,7 +433,7 @@ def test_ir_builder_tir_buffer_store_scalable_vec(): ir_actual = ib.get() # the expected buffer store - ir_expected = tir.BufferStore(buffer_a, value, [index]) + ir_expected = tirx.BufferStore(buffer_a, value, [index]) # Check if the generated ir is expected assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) @@ -449,7 +449,7 @@ def test_ir_builder_tir_buffer_store_predicate(): T.buffer_store(buffer_a, value, [index], predicate) ir_actual = ib.get() - ir_expected = tir.BufferStore(buffer_a, value, [index], predicate) + ir_expected = tirx.BufferStore(buffer_a, value, [index], predicate) assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) @@ -460,7 +460,7 @@ def test_ir_builder_tir_evaluate(): eval_actual = ib.get() # the expected evaluate - eval_expected = tir.Evaluate(0) + eval_expected = tirx.Evaluate(0) # Check if the generated ir is expected assert_structural_equal(eval_actual, eval_expected, map_free_vars=True) @@ -478,10 +478,10 @@ def test_ir_builder_tir_decl_buffer(): body = ir_actual.body # decl_buffer without data emits AllocBuffer (flat): body should be SeqStmt - assert isinstance(body, tir.SeqStmt), f"Expected SeqStmt but got {type(body)}" + assert isinstance(body, tirx.SeqStmt), f"Expected SeqStmt but got {type(body)}" assert len(body) == 2 - assert isinstance(body[0], tir.AllocBuffer) - assert isinstance(body[1], tir.Evaluate) + assert isinstance(body[0], tirx.AllocBuffer) + assert isinstance(body[1], tirx.Evaluate) assert body[1].value.value == 1 @@ -494,7 +494,7 @@ def test_ir_builder_tir_inline(): eval_actual = ib.get() # the expected evaluate - eval_expected = tir.Evaluate(10) + eval_expected = tirx.Evaluate(10) # Check if the generated ir is expected assert_structural_equal(eval_actual, eval_expected, map_free_vars=True) diff --git a/tests/python/tvmscript/test_tvmscript_meta_programming.py b/tests/python/tvmscript/test_tvmscript_meta_programming.py index 58efac54ed28..10a2c1777062 100644 --- a/tests/python/tvmscript/test_tvmscript_meta_programming.py +++ b/tests/python/tvmscript/test_tvmscript_meta_programming.py @@ -16,7 +16,7 @@ # under the License. import tvm -from tvm.script import tir as T +from tvm.script import tirx as T def test_meta_programming_matmul(): diff --git a/tests/python/tvmscript/test_tvmscript_ops.py b/tests/python/tvmscript/test_tvmscript_ops.py index 7e0de0a151a2..df734f4d042b 100644 --- a/tests/python/tvmscript/test_tvmscript_ops.py +++ b/tests/python/tvmscript/test_tvmscript_ops.py @@ -19,7 +19,7 @@ import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func @@ -109,7 +109,7 @@ def alloc_zero_dim_buffer(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [], dtype="float32") B = T.match_buffer(b, [], dtype="float32") # body - # tir.with block("root") + # tirx.with block("root") C = T.sblock_alloc_buffer([], dtype="float32") A[()] = T.float32(2) C[()] = A[()] + B[()] diff --git a/tests/python/tvmscript/test_tvmscript_parser_source.py b/tests/python/tvmscript/test_tvmscript_parser_source.py index 553f7cfcdc70..e3f12a0e6b00 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_source.py +++ b/tests/python/tvmscript/test_tvmscript_parser_source.py @@ -22,7 +22,7 @@ import pytest import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.script.parser.core import doc_core as doc from tvm.script.parser.core.diagnostics import Source diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index 25cd45bddf81..6a51698f1694 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -14,27 +14,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Unittests for tvm.script.parser.tir""" +"""Unittests for tvm.script.parser.tirx""" import pytest import tvm_ffi import tvm.testing -from tvm import ir, tir -from tvm.script.parser import tir as T +from tvm import ir, tirx +from tvm.script.parser import tirx as T def test_tir_buffer_proxy(): buffer_0 = T.Buffer((128, 128), "float32") assert ( - isinstance(buffer_0, tir.Buffer) + isinstance(buffer_0, tirx.Buffer) and list(buffer_0.shape) == [128, 128] and buffer_0.dtype == "float32" ) buffer_1 = T.Buffer((64, 64, 64), "int32") assert ( - isinstance(buffer_1, tir.Buffer) + isinstance(buffer_1, tirx.Buffer) and list(buffer_1.shape) == [64, 64, 64] and buffer_1.dtype == "int32" ) @@ -43,7 +43,7 @@ def test_tir_buffer_proxy(): def test_tir_ptr_proxy(): ptr_0 = T.handle("int32", "global") assert ( - isinstance(ptr_0, tir.Var) + isinstance(ptr_0, tirx.Var) and ptr_0.dtype == "handle" and isinstance(ptr_0.type_annotation, ir.PointerType) and ptr_0.type_annotation.element_type == ir.PrimType("int32") @@ -52,7 +52,7 @@ def test_tir_ptr_proxy(): ptr_1 = T.handle("float32", "shared") assert ( - isinstance(ptr_1, tir.Var) + isinstance(ptr_1, tirx.Var) and ptr_1.dtype == "handle" and isinstance(ptr_1.type_annotation, ir.PointerType) and ptr_1.type_annotation.element_type == ir.PrimType("float32") @@ -498,8 +498,8 @@ def func(a_handle: T.handle, b_handle: T.handle): for i, j in T.grid(M, N): B[i * N + j] = A[i, j] - M = tvm.tir.Var("M", "int64") - N = tvm.tir.Var("N", "int64") + M = tvm.tirx.Var("M", "int64") + N = tvm.tirx.Var("N", "int64") expected = tvm.relax.FuncStructInfo( [ tvm.relax.TensorStructInfo([M, N], "float32"), diff --git a/tests/python/tvmscript/test_tvmscript_pep563_closure.py b/tests/python/tvmscript/test_tvmscript_pep563_closure.py index a5d26d7f1628..13b85f6014c7 100644 --- a/tests/python/tvmscript/test_tvmscript_pep563_closure.py +++ b/tests/python/tvmscript/test_tvmscript_pep563_closure.py @@ -25,7 +25,7 @@ import tvm import tvm.testing from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def _normalize(func): diff --git a/tests/python/tvmscript/test_tvmscript_printer_annotation.py b/tests/python/tvmscript/test_tvmscript_printer_annotation.py index 13ace54ff7c2..a028ae92134d 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_annotation.py +++ b/tests/python/tvmscript/test_tvmscript_printer_annotation.py @@ -21,7 +21,7 @@ import pytest from tvm_ffi.access_path import AccessPath -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func @@ -47,7 +47,7 @@ def test_annotation_multi_access_paths(): ) assert ( result - == """# from tvm.script import tir as T + == """# from tvm.script import tirx as T @T.prim_func def main(): @@ -73,7 +73,7 @@ def test_annotate_from_multi_obj(): ) assert ( result - == """# from tvm.script import tir as T + == """# from tvm.script import tirx as T @T.prim_func def main(): @@ -104,7 +104,7 @@ def _func(): ) assert ( result - == """# from tvm.script import tir as T + == """# from tvm.script import tirx as T @T.prim_func def main(): diff --git a/tests/python/tvmscript/test_tvmscript_printer_highlight.py b/tests/python/tvmscript/test_tvmscript_printer_highlight.py index 17024c6cc88a..9dcf2aacb05c 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_highlight.py +++ b/tests/python/tvmscript/test_tvmscript_printer_highlight.py @@ -20,7 +20,7 @@ import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.script.highlight import _format, cprint @@ -33,7 +33,7 @@ def main( # type: ignore b: T.handle, c: T.handle, ) -> None: # pylint: disable=no-self-argument - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) A = T.match_buffer(a, [16, 128, 128]) B = T.match_buffer(b, [16, 128, 128]) C = T.match_buffer(c, [16, 128, 128]) diff --git a/tests/python/tvmscript/test_tvmscript_printer_ir.py b/tests/python/tvmscript/test_tvmscript_printer_ir.py index d82e682c2949..def0fccda509 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_ir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_ir.py @@ -21,7 +21,7 @@ from tvm import IRModule, TVMError from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import ir as I -from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import tirx as T def _assert_print(obj, expected): @@ -41,7 +41,7 @@ def test_ir_module(): mod, """ # from tvm.script import ir as I -# from tvm.script import tir as T +# from tvm.script import tirx as T @I.ir_module class Module: diff --git a/tests/python/tvmscript/test_tvmscript_printer_metadata.py b/tests/python/tvmscript/test_tvmscript_printer_metadata.py index c6796796f17e..f0d8d45c0b83 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_metadata.py +++ b/tests/python/tvmscript/test_tvmscript_printer_metadata.py @@ -18,12 +18,12 @@ # ruff: noqa: F841 import tvm.testing from tvm.script.parser import ir as I -from tvm.script.parser import tir as T +from tvm.script.parser import tirx as T def test_str_metadata(): - # This test is to check we reuse the existing metadata element for the same tir.StringImm - # So metadata["tir.StringImm"][0] will occur in the printed script for three times + # This test is to check we reuse the existing metadata element for the same tirx.StringImm + # So metadata["tirx.StringImm"][0] will occur in the printed script for three times str_imm = T.StringImm("aaa\nbbb\n") @I.ir_module @@ -39,8 +39,8 @@ def foo1() -> None: printed_str = Module.script(verbose_expr=True) assert ( - printed_str.count('metadata["tir.StringImm"][0]') == 3 - and printed_str.count('metadata["tir.StringImm"][1]') == 0 + printed_str.count('metadata["tirx.StringImm"][0]') == 3 + and printed_str.count('metadata["tirx.StringImm"][1]') == 0 ) diff --git a/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py b/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py index 6287eb0f2848..1cd24deb8357 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py +++ b/tests/python/tvmscript/test_tvmscript_printer_structural_equal.py @@ -22,7 +22,7 @@ import tvm from tvm.ir import assert_structural_equal from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T def _error_message(exception): diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 406a8c6a79f1..7199e74df493 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -22,10 +22,10 @@ import pytest import tvm.testing -from tvm import ir, tir +from tvm import ir, tirx from tvm.ir import Range from tvm.script.ir_builder import IRBuilder -from tvm.script.ir_builder import tir as T +from tvm.script.ir_builder import tirx as T def _assert_print(obj, expected): @@ -33,21 +33,21 @@ def _assert_print(obj, expected): def test_prim_func(): - a = tir.Var("a", "handle") - b = tir.Var("b", "handle") - func = tir.PrimFunc( + a = tirx.Var("a", "handle") + b = tirx.Var("b", "handle") + func = tirx.PrimFunc( params=[a, b], ret_type=None, buffer_map={ - a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"), - b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"), + a: tirx.decl_buffer(shape=[128, 128], dtype="float32", name="A"), + b: tirx.decl_buffer(shape=[256, 256], dtype="float32", name="B"), }, - body=tir.Evaluate(0), + body=tirx.Evaluate(0), ).with_attr("global_symbol", "main") _assert_print( func, expected=""" -# from tvm.script import tir as T +# from tvm.script import tirx as T @T.prim_func def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): @@ -56,21 +56,21 @@ def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")) def test_prim_func_no_sugar_inlined_buffer(): - a = tir.Var("a", "handle") - b = tir.Var("b", "handle") - func = tir.PrimFunc( + a = tirx.Var("a", "handle") + b = tirx.Var("b", "handle") + func = tirx.PrimFunc( params=[a, b], ret_type=None, buffer_map={ - a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"), - b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"), + a: tirx.decl_buffer(shape=[128, 128], dtype="float32", name="A"), + b: tirx.decl_buffer(shape=[256, 256], dtype="float32", name="B"), }, - body=tir.Evaluate(a), + body=tirx.Evaluate(a), ).with_attr("global_symbol", "main") _assert_print( func, expected=""" -# from tvm.script import tir as T +# from tvm.script import tirx as T @T.prim_func def main(a: T.handle, B: T.Buffer((256, 256), "float32")): @@ -81,22 +81,22 @@ def main(a: T.handle, B: T.Buffer((256, 256), "float32")): def test_prim_func_no_sugar_shared_buffer_data(): - a = tir.Var("a", "handle") - b = tir.Var("b", "handle") - buffer_data = tir.decl_buffer(shape=[128, 128], dtype="float32", name="A").data - func = tir.PrimFunc( + a = tirx.Var("a", "handle") + b = tirx.Var("b", "handle") + buffer_data = tirx.decl_buffer(shape=[128, 128], dtype="float32", name="A").data + func = tirx.PrimFunc( params=[a, b], ret_type=None, buffer_map={ - a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A", data=buffer_data), - b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B", data=buffer_data), + a: tirx.decl_buffer(shape=[128, 128], dtype="float32", name="A", data=buffer_data), + b: tirx.decl_buffer(shape=[256, 256], dtype="float32", name="B", data=buffer_data), }, - body=tir.Evaluate(0), + body=tirx.Evaluate(0), ).with_attr("global_symbol", "main") _assert_print( func, expected=""" -# from tvm.script import tir as T +# from tvm.script import tirx as T @T.prim_func def main(a: T.handle, b: T.handle): @@ -108,9 +108,9 @@ def main(a: T.handle, b: T.handle): def test_block_realize(): - i = tir.Var("i", "int32") - j = tir.Var("j", "int32") - k = tir.Var("k", "int32") + i = tirx.Var("i", "int32") + j = tirx.Var("j", "int32") + k = tirx.Var("k", "int32") with IRBuilder() as ib: with T.sblock(name="block", no_realize=False): vi = ib.name("vi", T.axis.spatial(128, i)) @@ -137,9 +137,9 @@ def test_block_realize(): def test_block(): - i = tir.Var("i", "int32") - j = tir.Var("j", "int32") - k = tir.Var("k", "int32") + i = tirx.Var("i", "int32") + j = tirx.Var("j", "int32") + k = tirx.Var("k", "int32") with IRBuilder() as ib: with T.sblock(name="block", no_realize=False): vi = ib.name("vi", T.axis.spatial(128, i)) @@ -163,11 +163,11 @@ def test_block(): def test_match_buffer_region(): - src = tir.decl_buffer((128, 128), "float32", name="src") - tgt = tir.decl_buffer((64, 64), "float32", name="tgt") - obj = tir.MatchBufferRegion( + src = tirx.decl_buffer((128, 128), "float32", name="src") + tgt = tirx.decl_buffer((64, 64), "float32", name="tgt") + obj = tirx.MatchBufferRegion( tgt, - tir.BufferRegion( + tirx.BufferRegion( src, [ Range(64, 128), @@ -185,7 +185,7 @@ def test_match_buffer_region(): def test_buffer(): - a = tir.decl_buffer((128, 128), "float16", name="A") + a = tirx.decl_buffer((128, 128), "float16", name="A") _assert_print( a, """A = T.Buffer((128, 128), "float16") @@ -194,8 +194,8 @@ def test_buffer(): def test_buffer_region(): - src = tir.decl_buffer((128, 128), "float32", name="src") - obj = tir.BufferRegion( + src = tirx.decl_buffer((128, 128), "float32", name="src") + obj = tirx.BufferRegion( src, [ Range(64, 128), @@ -212,8 +212,8 @@ def test_buffer_region(): def test_buffer_load(): - a = tir.decl_buffer((128, 128), "float16", name="A") - obj = tir.BufferLoad(a, [128, 128]) + a = tirx.decl_buffer((128, 128), "float16", name="A") + obj = tirx.BufferLoad(a, [128, 128]) _assert_print( obj, """ @@ -224,7 +224,7 @@ def test_buffer_load(): def test_buffer_store(): - a = tir.decl_buffer((128, 128), "float16", name="A") + a = tirx.decl_buffer((128, 128), "float16", name="A") with IRBuilder() as ib: T.buffer_store(a, a[128, 128] + 1, [128, 128]) obj = ib.get() @@ -262,7 +262,7 @@ def test_bind(): _assert_print( obj, """ -# from tvm.script import tir as T +# from tvm.script import tirx as T @T.prim_func(private=True) def main(): @@ -452,7 +452,7 @@ def test_evaluate(): def test_var(): - a = tir.Var("a", "float32") + a = tirx.Var("a", "float32") _assert_print( a, """ @@ -462,7 +462,7 @@ def test_var(): def test_size_var(): - a = tir.SizeVar("a", "float32") + a = tirx.SizeVar("a", "float32") _assert_print( a, """ @@ -472,7 +472,7 @@ def test_size_var(): def test_iter_var(): - a = tir.IterVar((0, 8), "a", iter_type=tir.IterVar.DataPar) + a = tirx.IterVar((0, 8), "a", iter_type=tirx.IterVar.DataPar) _assert_print( a, """ @@ -483,12 +483,12 @@ def test_iter_var(): def test_string_imm(): - s = tir.StringImm("str") + s = tirx.StringImm("str") _assert_print(s, '"str"') def test_cast(): - obj = tir.Cast("float64", tir.Var("a", "float32")) + obj = tirx.Cast("float64", tirx.Var("a", "float32")) _assert_print( obj, """ @@ -499,28 +499,28 @@ def test_cast(): def test_llvm_intrin_imm(): - a = tir.call_llvm_intrin("int32x4", "llvm.donothing") + a = tirx.call_llvm_intrin("int32x4", "llvm.donothing") _assert_print(a, 'T.call_llvm_intrin("int32x4", "llvm.donothing")') - a = tir.call_llvm_pure_intrin("int32x4", "llvm.donothing") + a = tirx.call_llvm_pure_intrin("int32x4", "llvm.donothing") _assert_print(a, 'T.call_llvm_pure_intrin("int32x4", "llvm.donothing")') def test_binary_arith(): - a = tir.Var("a", "int32") - b = tir.Var("b", "int32") + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") for op, sign in [ - (tir.Add, "+"), - (tir.Sub, "-"), - (tir.Mul, "*"), - (tir.Mod, "truncmod"), - (tir.FloorDiv, "//"), - (tir.FloorMod, "%"), - (tir.LT, "<"), - (tir.LE, "<="), - (tir.EQ, "=="), - (tir.NE, "!="), - (tir.GT, ">"), - (tir.GE, ">="), + (tirx.Add, "+"), + (tirx.Sub, "-"), + (tirx.Mul, "*"), + (tirx.Mod, "truncmod"), + (tirx.FloorDiv, "//"), + (tirx.FloorMod, "%"), + (tirx.LT, "<"), + (tirx.LE, "<="), + (tirx.EQ, "=="), + (tirx.NE, "!="), + (tirx.GT, ">"), + (tirx.GE, ">="), ]: obj = op(a, b) if sign.isalpha(): @@ -537,22 +537,22 @@ def test_binary_arith(): def test_binary_arith_const(): - a = tir.IntImm("int64", 3) - b = tir.IntImm("int64", 4) + a = tirx.IntImm("int64", 3) + b = tirx.IntImm("int64", 4) for op, name in [ - (tir.Add, "Add"), - (tir.Sub, "Sub"), - (tir.Mul, "Mul"), - (tir.Div, "Div"), - (tir.Mod, "truncmod"), - (tir.FloorDiv, "FloorDiv"), - (tir.FloorMod, "FloorMod"), - (tir.LT, "LT"), - (tir.LE, "LE"), - (tir.EQ, "EQ"), - (tir.NE, "NE"), - (tir.GT, "GT"), - (tir.GE, "GE"), + (tirx.Add, "Add"), + (tirx.Sub, "Sub"), + (tirx.Mul, "Mul"), + (tirx.Div, "Div"), + (tirx.Mod, "truncmod"), + (tirx.FloorDiv, "FloorDiv"), + (tirx.FloorMod, "FloorMod"), + (tirx.LT, "LT"), + (tirx.LE, "LE"), + (tirx.EQ, "EQ"), + (tirx.NE, "NE"), + (tirx.GT, "GT"), + (tirx.GE, "GE"), ]: obj = op(a, b) expected = f""" @@ -561,10 +561,10 @@ def test_binary_arith_const(): def test_int_div(): - a = tir.Var("a", "int32") - b = tir.Var("b", "int32") + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") _assert_print( - tir.Div(a, b), + tirx.Div(a, b), """ a = T.int32() b = T.int32() @@ -574,10 +574,10 @@ def test_int_div(): def test_logical(): - a = tir.Var("a", "bool") - b = tir.Var("b", "bool") + a = tirx.Var("a", "bool") + b = tirx.Var("b", "bool") _assert_print( - tir.And(a, b), + tirx.And(a, b), """ a = T.bool() b = T.bool() @@ -585,7 +585,7 @@ def test_logical(): """, ) _assert_print( - tir.Or(a, b), + tirx.Or(a, b), """ a = T.bool() b = T.bool() @@ -593,7 +593,7 @@ def test_logical(): """, ) _assert_print( - tir.Not(a), + tirx.Not(a), """ a = T.bool() not a @@ -602,7 +602,7 @@ def test_logical(): def test_select(): - obj = tir.Select(True, 0, 2) + obj = tirx.Select(True, 0, 2) _assert_print( obj, """T.Select(T.bool(True), 0, 2) @@ -611,11 +611,11 @@ def test_select(): @pytest.mark.parametrize( - "lanes, scripted_lanes", [(32, "32"), (tvm.tir.vscale() * 8, "T.vscale() * 8")] + "lanes, scripted_lanes", [(32, "32"), (tvm.tirx.vscale() * 8, "T.vscale() * 8")] ) def test_ramp(lanes, scripted_lanes): - a = tir.Var("a", "int32") - obj = tir.Ramp(a, 1, lanes) + a = tirx.Var("a", "int32") + obj = tirx.Ramp(a, 1, lanes) _assert_print( obj, f""" @@ -626,10 +626,10 @@ def test_ramp(lanes, scripted_lanes): @pytest.mark.parametrize( - "lanes, scripted_lanes", [(4, "4"), (tvm.tir.vscale() * 4, "T.vscale() * 4")] + "lanes, scripted_lanes", [(4, "4"), (tvm.tirx.vscale() * 4, "T.vscale() * 4")] ) def test_broadcast(lanes, scripted_lanes): - obj = tir.Broadcast(0, lanes) + obj = tirx.Broadcast(0, lanes) _assert_print( obj, f""" @@ -639,8 +639,8 @@ def test_broadcast(lanes, scripted_lanes): def test_let_expr(): - x = tir.Var("x", "int32") - obj = tir.Let(x, 1, x + 1) + x = tirx.Var("x", "int32") + obj = tirx.Let(x, 1, x + 1) _assert_print( obj, """ @@ -651,7 +651,7 @@ def test_let_expr(): def test_call(): - obj = tir.atan(T.float32(1.0)) + obj = tirx.atan(T.float32(1.0)) _assert_print( obj, """ @@ -716,7 +716,7 @@ def test_tuple_type(): def test_remap(): - from tvm.script import tir as T + from tvm.script import tirx as T @T.prim_func def block_with_remap_implicitly(): @@ -739,7 +739,7 @@ def block_with_remap_explicitly(): v4, v5 = T.axis.remap("RS", [i4, i5]) expected_output = """ -# from tvm.script import tir as T +# from tvm.script import tirx as T @T.prim_func def main(): @@ -758,7 +758,7 @@ def main(): def test_root_block(): - from tvm.script import tir as T + from tvm.script import tirx as T @T.prim_func def root_block_implicitly(): @@ -776,7 +776,7 @@ def root_block_explicitly(): T.evaluate(0) expected_output = """ -# from tvm.script import tir as T +# from tvm.script import tirx as T @T.prim_func def main(): @@ -793,23 +793,23 @@ def main(): def test_private_primfunc(): - from tvm.script import tir as T + from tvm.script import tirx as T - a = tir.Var("a", "handle") - b = tir.Var("b", "handle") - func = tir.PrimFunc( + a = tirx.Var("a", "handle") + b = tirx.Var("b", "handle") + func = tirx.PrimFunc( params=[a, b], ret_type=None, buffer_map={ - a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"), - b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"), + a: tirx.decl_buffer(shape=[128, 128], dtype="float32", name="A"), + b: tirx.decl_buffer(shape=[256, 256], dtype="float32", name="B"), }, - body=tir.Evaluate(0), + body=tirx.Evaluate(0), ) _assert_print( func, expected=""" -# from tvm.script import tir as T +# from tvm.script import tirx as T @T.prim_func(private=True) def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): @@ -818,7 +818,7 @@ def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")) def test_prim_func_different_symbol(): - from tvm.script import tir as T + from tvm.script import tirx as T @T.prim_func def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): @@ -826,7 +826,7 @@ def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")) T.evaluate(0) expected_output = """ -# from tvm.script import tir as T +# from tvm.script import tirx as T @T.prim_func def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): @@ -846,7 +846,7 @@ def test_variable_with_cpp_address(): with "_0x" followed by a hexadecimal number, and that the address is the same for each variable. """ - from tvm.script import tir as T + from tvm.script import tirx as T # The test function has all named objects suffixed with "_name", # to avoid spurious replacement when generating the expected @@ -874,14 +874,14 @@ def func(a_name: T.handle): def test_return_statement(): - from tvm.script import tir as T + from tvm.script import tirx as T @T.prim_func def func(): T.evaluate(T.ret(5)) expected_output = """ -# from tvm.script import tir as T +# from tvm.script import tirx as T @T.prim_func def func(): @@ -910,14 +910,14 @@ def func(): @pytest.mark.parametrize("dtype", CUSTOM_FLOAT_DTYPES) def test_custom_float_types(dtype): - from tvm.script import tir as T + from tvm.script import tirx as T @T.prim_func() def func(): T.evaluate(getattr(T, dtype)(0.0)) expected_output = f""" -# from tvm.script import tir as T +# from tvm.script import tirx as T @T.prim_func def func(): @@ -927,7 +927,7 @@ def func(): def test_predicated_load_store(): - from tvm.script import tir as T + from tvm.script import tirx as T @T.prim_func def main(a: T.handle, b: T.handle): @@ -938,7 +938,7 @@ def main(a: T.handle, b: T.handle): A.vstore([0, T.Ramp(0, 2, 4)], a_load, predicate=T.Broadcast(T.bool(False), 4)) expected_output = """ -# from tvm.script import tir as T +# from tvm.script import tirx as T @T.prim_func def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): @@ -948,24 +948,24 @@ def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")) def test_predicated_buffer_load_store(): - a = tir.Var("a", "handle") - b = tir.Var("b", "handle") + a = tirx.Var("a", "handle") + b = tirx.Var("b", "handle") buffer_map = { - a: tir.decl_buffer(shape=[128, 128], dtype="float32", name="A"), - b: tir.decl_buffer(shape=[256, 256], dtype="float32", name="B"), + a: tirx.decl_buffer(shape=[128, 128], dtype="float32", name="A"), + b: tirx.decl_buffer(shape=[256, 256], dtype="float32", name="B"), } - buffer_load = tir.BufferLoad( + buffer_load = tirx.BufferLoad( buffer=buffer_map[b], - indices=[0, tir.Ramp(0, 4, 4)], - predicate=tir.Broadcast(tir.IntImm("bool", 0), 4), + indices=[0, tirx.Ramp(0, 4, 4)], + predicate=tirx.Broadcast(tirx.IntImm("bool", 0), 4), ) - body = tir.BufferStore( + body = tirx.BufferStore( buffer=buffer_map[a], value=buffer_load, - indices=[0, tir.Ramp(0, 2, 4)], - predicate=tir.Broadcast(tir.IntImm("bool", 0), 4), + indices=[0, tirx.Ramp(0, 2, 4)], + predicate=tirx.Broadcast(tirx.IntImm("bool", 0), 4), ) - func = tir.PrimFunc( + func = tirx.PrimFunc( params=[a, b], ret_type=None, buffer_map=buffer_map, @@ -973,7 +973,7 @@ def test_predicated_buffer_load_store(): ) expected_output = """ -# from tvm.script import tir as T +# from tvm.script import tirx as T @T.prim_func(private=True) def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): @@ -983,7 +983,7 @@ def main(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")) def test_predicated_scalable_load_store(): - from tvm.script import tir as T + from tvm.script import tirx as T @T.prim_func def main(a: T.handle, b: T.handle): @@ -995,7 +995,7 @@ def main(a: T.handle, b: T.handle): A.vstore([0, T.Ramp(0, 2, T.vscale() * 4)], a_load, predicate=mask) expected_output = """ -# from tvm.script import tir as T +# from tvm.script import tirx as T @T.prim_func def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")): @@ -1005,7 +1005,7 @@ def func(A: T.Buffer((128, 128), "float32"), B: T.Buffer((256, 256), "float32")) def test_vload_with_explicit_scalable_data_type(): - from tvm.script import tir as T + from tvm.script import tirx as T @T.prim_func def main(a: T.handle, b: T.handle): @@ -1014,7 +1014,7 @@ def main(a: T.handle, b: T.handle): B[0 : T.vscale() * 4] = A.vload([T.Ramp(0, 1, T.vscale() * 4)], dtype="float32xvscalex4") expected_output = """ -# from tvm.script import tir as T +# from tvm.script import tirx as T @T.prim_func def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): @@ -1024,7 +1024,7 @@ def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): def test_vectorize_llvm_pure_intrin(): - from tvm.script import tir as T + from tvm.script import tirx as T @T.prim_func def main(a: T.handle, b: T.handle): @@ -1033,7 +1033,7 @@ def main(a: T.handle, b: T.handle): A[T.Ramp(0, 1, 4)] = T.call_llvm_pure_intrin("float32x4", "llvm.sqrt", B[T.Ramp(0, 1, 4)]) expected_output = """ -# from tvm.script import tir as T +# from tvm.script import tirx as T @T.prim_func def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): @@ -1043,7 +1043,7 @@ def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): def test_func_with_loop_jumps(): - from tvm.script import tir as T + from tvm.script import tirx as T @T.prim_func def main(a: T.handle, b: T.handle): @@ -1057,7 +1057,7 @@ def main(a: T.handle, b: T.handle): break expected_output = """ -# from tvm.script import tir as T +# from tvm.script import tirx as T @T.prim_func def main(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float32")): diff --git a/tests/python/tvmscript/test_tvmscript_printer_underlining.py b/tests/python/tvmscript/test_tvmscript_printer_underlining.py index f6818b4c935f..7f7510d2d04e 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_underlining.py +++ b/tests/python/tvmscript/test_tvmscript_printer_underlining.py @@ -20,7 +20,7 @@ from tvm_ffi.access_path import AccessPath from tvm.script import ir as I -from tvm.script import tir as T +from tvm.script import tirx as T from tvm.script.printer.doc import ( ExprStmtDoc, IdDoc, @@ -414,7 +414,7 @@ def func(a: T.int32, b: T.int32): result = func.with_attr("global_symbol", "main").script(obj_to_underline=[func.params[0]]) assert result == format_script( """ - # from tvm.script import tir as T + # from tvm.script import tirx as T @T.prim_func def main(a: T.int32, b: T.int32): @@ -453,7 +453,7 @@ def func(): ) assert result == format_script( """ - # from tvm.script import tir as T + # from tvm.script import tirx as T @T.prim_func def main(): @@ -485,7 +485,7 @@ def func(): ) assert result == format_script( """ - # from tvm.script import tir as T + # from tvm.script import tirx as T @T.prim_func ^^^^^^^^^^^^ @@ -512,7 +512,7 @@ def func(): assert result == format_script( """ # from tvm.script import ir as I - # from tvm.script import tir as T + # from tvm.script import tirx as T @I.ir_module class Module: @@ -541,7 +541,7 @@ def func(): assert result == format_script( """ # from tvm.script import ir as I - # from tvm.script import tir as T + # from tvm.script import tirx as T @I.ir_module ^^^^^^^^^^^^ diff --git a/tests/python/tvmscript/test_tvmscript_regression.py b/tests/python/tvmscript/test_tvmscript_regression.py index 6098843ee445..4379cd5447f0 100644 --- a/tests/python/tvmscript/test_tvmscript_regression.py +++ b/tests/python/tvmscript/test_tvmscript_regression.py @@ -19,10 +19,10 @@ import tvm import tvm.testing -from tvm.script import tir as T +from tvm.script import tirx as T # This numpy array is used to test the comparison between the global objects and the -# `tvm.script.tir` submodule. +# `tvm.script.tirx` submodule. np_array = numpy.array([0, 1, 2, 3]) diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index ab64737ce1e0..4f87e434720d 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -23,10 +23,10 @@ import tvm import tvm.testing -from tvm import tir +from tvm import tirx from tvm.script import ir as I from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script import tirx as T def opt_gemm_lower(): @@ -35,7 +35,7 @@ class Module: @T.prim_func def mmult(A: T.handle, B: T.handle, C: T.handle) -> None: # function attr dict - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) A_1 = T.match_buffer(A, [16384], elem_offset=0, align=64, offset_factor=1) B_1 = T.match_buffer(B, [1024, 1024], elem_offset=0, align=64, offset_factor=1) C_1 = T.match_buffer(C, [16384], elem_offset=0, align=64, offset_factor=1) @@ -129,7 +129,7 @@ def func( Conv: T.Buffer((16, 14, 14, 32, 16, 16), "float32"), ) -> None: # function attr dict - T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) + T.func_attr({"global_symbol": "default_function", "tirx.noalias": True}) # body A_1 = T.decl_buffer([12845056], dtype="float16", data=A.data) W_1 = T.decl_buffer([1179648], dtype="float16", data=W.data) @@ -1414,9 +1414,9 @@ def opt_conv_tensorcore_mod_host( # function attr dict T.func_attr( { - "tir.noalias": True, + "tirx.noalias": True, "global_symbol": "default_function", - "tir.is_entry_func": True, + "tirx.is_entry_func": True, "calling_conv": 1, } ) @@ -1762,13 +1762,13 @@ def test_matmul_original(): rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) - assert isinstance(rt_func.body.block, tir.stmt.SBlock) - assert isinstance(rt_func.body.block.body, tir.stmt.For) - assert isinstance(rt_func.body.block.body.body, tir.stmt.For) - assert isinstance(rt_func.body.block.body.body.body, tir.stmt.SeqStmt) - assert isinstance(rt_func.body.block.body.body.body[0].block, tir.stmt.SBlock) - assert isinstance(rt_func.body.block.body.body.body[1], tir.stmt.For) - assert isinstance(rt_func.body.block.body.body.body[1].body.block, tir.stmt.SBlock) + assert isinstance(rt_func.body.block, tirx.stmt.SBlock) + assert isinstance(rt_func.body.block.body, tirx.stmt.For) + assert isinstance(rt_func.body.block.body.body, tirx.stmt.For) + assert isinstance(rt_func.body.block.body.body.body, tirx.stmt.SeqStmt) + assert isinstance(rt_func.body.block.body.body.body[0].block, tirx.stmt.SBlock) + assert isinstance(rt_func.body.block.body.body.body[1], tirx.stmt.For) + assert isinstance(rt_func.body.block.body.body.body[1].body.block, tirx.stmt.SBlock) def test_element_wise(): @@ -1776,15 +1776,15 @@ def test_element_wise(): rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) - assert isinstance(rt_func.body.block, tir.stmt.SBlock) - assert isinstance(rt_func.body.block.body, tir.stmt.SeqStmt) - assert isinstance(rt_func.body.block.body[0], tir.stmt.For) - assert isinstance(rt_func.body.block.body[0].body, tir.stmt.For) - assert isinstance(rt_func.body.block.body[0].body.body.block, tir.stmt.SBlock) + assert isinstance(rt_func.body.block, tirx.stmt.SBlock) + assert isinstance(rt_func.body.block.body, tirx.stmt.SeqStmt) + assert isinstance(rt_func.body.block.body[0], tirx.stmt.For) + assert isinstance(rt_func.body.block.body[0].body, tirx.stmt.For) + assert isinstance(rt_func.body.block.body[0].body.body.block, tirx.stmt.SBlock) - assert isinstance(rt_func.body.block.body[1], tir.stmt.For) - assert isinstance(rt_func.body.block.body[1].body, tir.stmt.For) - assert isinstance(rt_func.body.block.body[1].body.body.block, tir.stmt.SBlock) + assert isinstance(rt_func.body.block.body[1], tirx.stmt.For) + assert isinstance(rt_func.body.block.body[1].body, tirx.stmt.For) + assert isinstance(rt_func.body.block.body[1].body.body.block, tirx.stmt.SBlock) def test_predicate(): @@ -1792,11 +1792,11 @@ def test_predicate(): rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) - assert isinstance(rt_func.body.block, tir.stmt.SBlock) - assert isinstance(rt_func.body.block.body, tir.stmt.For) - assert isinstance(rt_func.body.block.body.body, tir.stmt.For) - assert isinstance(rt_func.body.block.body.body.body, tir.stmt.For) - assert isinstance(rt_func.body.block.body.body.body.body.block, tir.stmt.SBlock) + assert isinstance(rt_func.body.block, tirx.stmt.SBlock) + assert isinstance(rt_func.body.block.body, tirx.stmt.For) + assert isinstance(rt_func.body.block.body.body, tirx.stmt.For) + assert isinstance(rt_func.body.block.body.body.body, tirx.stmt.For) + assert isinstance(rt_func.body.block.body.body.body.body.block, tirx.stmt.SBlock) def for_thread_binding(): @@ -1819,10 +1819,10 @@ def test_for_thread_binding(): rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) - assert isinstance(rt_func.body, tir.stmt.For) + assert isinstance(rt_func.body, tirx.stmt.For) assert rt_func.body.kind == 4 assert rt_func.body.thread_binding.thread_tag == "threadIdx.x" - assert isinstance(rt_func.body.body, tir.stmt.For) + assert isinstance(rt_func.body.body, tirx.stmt.For) assert rt_func.body.body.kind == 4 assert rt_func.body.body.thread_binding.thread_tag == "threadIdx.y" assert rt_func.body.body.annotations["attr_key"] == "attr_value" @@ -1853,19 +1853,19 @@ def test_match_buffer_region(): rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) - assert isinstance(rt_func.body, tir.stmt.SBlockRealize) + assert isinstance(rt_func.body, tirx.stmt.SBlockRealize) root = rt_func.body.block - assert isinstance(root.body, tir.stmt.For) - assert isinstance(root.body.body, tir.stmt.For) - assert isinstance(root.body.body.body, tir.stmt.SBlockRealize) + assert isinstance(root.body, tirx.stmt.For) + assert isinstance(root.body.body, tirx.stmt.For) + assert isinstance(root.body.body.body, tirx.stmt.SBlockRealize) outer_block = root.body.body.body.block assert len(outer_block.match_buffers) == 1 buffer_C = outer_block.match_buffers[0].buffer tvm.ir.assert_structural_equal(buffer_C.shape, [T.int32(16), T.int32(1), T.int32(4)]) - assert isinstance(outer_block.body, tir.stmt.For) - assert isinstance(outer_block.body.body, tir.stmt.SBlockRealize) + assert isinstance(outer_block.body, tirx.stmt.For) + assert isinstance(outer_block.body.body, tirx.stmt.SBlockRealize) inner_block = outer_block.body.body.block assert len(inner_block.match_buffers) == 1 buffer_D = inner_block.match_buffers[0].buffer @@ -1898,12 +1898,12 @@ def test_block_elements(): rt_func = tvm.script.from_source(func.script()) tvm.ir.assert_structural_equal(func, rt_func) - assert isinstance(rt_func.body.block, tir.stmt.SBlock) - assert isinstance(rt_func.body.block.body, tir.stmt.SBlockRealize) - assert isinstance(rt_func.body.block.body.block, tir.stmt.SBlock) + assert isinstance(rt_func.body.block, tirx.stmt.SBlock) + assert isinstance(rt_func.body.block.body, tirx.stmt.SBlockRealize) + assert isinstance(rt_func.body.block.body.block, tirx.stmt.SBlock) block = rt_func.body.block.body.block - assert isinstance(block.body, tir.stmt.BufferStore) - assert isinstance(block.init, tir.stmt.BufferStore) + assert isinstance(block.body, tirx.stmt.BufferStore) + assert isinstance(block.init, tirx.stmt.BufferStore) assert len(block.annotations) == 1 assert block.annotations["attr_key"] == "attr_value" @@ -1935,14 +1935,14 @@ def test_opaque_block(): tvm.ir.assert_structural_equal(func, rt_func) root_block = rt_func.body.block - assert isinstance(root_block, tir.stmt.SBlock) - assert isinstance(root_block.body, tir.stmt.For) - assert isinstance(root_block.body.body[0], tir.stmt.For) - assert isinstance(root_block.body.body[0].body, tir.stmt.SBlockRealize) - assert isinstance(root_block.body.body[0].body.block, tir.stmt.SBlock) + assert isinstance(root_block, tirx.stmt.SBlock) + assert isinstance(root_block.body, tirx.stmt.For) + assert isinstance(root_block.body.body[0], tirx.stmt.For) + assert isinstance(root_block.body.body[0].body, tirx.stmt.SBlockRealize) + assert isinstance(root_block.body.body[0].body.block, tirx.stmt.SBlock) assert len(root_block.body.body[0].body.block.iter_vars) == 0 - assert isinstance(root_block.body.body[1], tir.stmt.SBlockRealize) - assert isinstance(root_block.body.body[1].block, tir.stmt.SBlock) + assert isinstance(root_block.body.body[1], tirx.stmt.SBlockRealize) + assert isinstance(root_block.body.body[1].block, tirx.stmt.SBlock) assert len(root_block.body.body[1].block.iter_vars) == 0 @@ -2077,7 +2077,7 @@ def primfunc_with_allocate_annotations(): @T.prim_func def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.handle) -> None: # function attr dict - T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tirx.noalias": True}) placeholder_29 = T.match_buffer(placeholder_28, [802816], dtype="uint8", elem_offset=0, align=64, offset_factor=1) T_cast_7 = T.match_buffer(T_cast_6, [200704], dtype="int16", elem_offset=0, align=64, offset_factor=1) # body @@ -2100,7 +2100,7 @@ def primfunc_with_allocate_annotations(placeholder_28: T.handle, T_cast_6: T.han def comm_reducer_single_reduce_group(): @T.prim_func def comm_reducer_single_reduce_group(a: T.handle, b: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) threadIdx_x = T.env_thread("threadIdx.x") A = T.match_buffer(a, [16384], dtype="float32") for i in T.serial(0, 128): @@ -2115,7 +2115,7 @@ def comm_reducer_single_reduce_group(a: T.handle, b: T.handle) -> None: def comm_reducer_multiple_reduce_groups(): @T.prim_func def comm_reducer_multiple_reduce_groups(a: T.handle, b: T.handle) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) threadIdx_x = T.env_thread("threadIdx.x") A = T.match_buffer(a, [16384], dtype="float32") for i in T.serial(0, 128): @@ -2166,9 +2166,9 @@ def test_div_mod(): rt_func = tvm.script.from_source(func.script(), check_well_formed=False) tvm.ir.assert_structural_equal(func, rt_func, True) - assert isinstance(func.body[0].value, tvm.tir.FloorDiv) - assert isinstance(func.body[1].value, tvm.tir.FloorMod) - assert isinstance(func.body[2].value, tvm.tir.Mod) + assert isinstance(func.body[0].value, tvm.tirx.FloorDiv) + assert isinstance(func.body[1].value, tvm.tirx.FloorMod) + assert isinstance(func.body[2].value, tvm.tirx.Mod) def loop_extent_dependent(): @@ -2444,7 +2444,7 @@ def scalable_vectors(): @T.prim_func def func(a: T.handle): A = T.match_buffer(a, (200,), "float32") - A[T.Ramp(11, 2, 4 * tir.vscale())] = T.Broadcast(125, 4 * tir.vscale()) + A[T.Ramp(11, 2, 4 * tirx.vscale())] = T.Broadcast(125, 4 * tirx.vscale()) return func @@ -2548,7 +2548,7 @@ def func( placeholder: T.Buffer((1, 512, 768), "float32"), T_isinf: T.Buffer((1, 512, 768), "bool") ) -> None: # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) # body # with T.sblock("root") for i0, i1, i2 in T.grid(1, 512, 768): @@ -2674,16 +2674,20 @@ def func(i: T.int32) -> None: def nested_boolean_expressions(): expressions = { - "and_lhs_and": lambda i, j, k: tir.all(tir.all(i, j), k), - "and_rhs_and": lambda i, j, k: tir.all(i, tir.all(j, k)), - "and_lhs_or": lambda i, j, k: tir.all(tir.any(i, j), k), - "and_rhs_or": lambda i, j, k: tir.all(i, tir.any(j, k)), - "or_lhs_and": lambda i, j, k: tir.any(tir.all(i, j), k), - "or_rhs_and": lambda i, j, k: tir.any(i, tir.all(j, k)), - "or_lhs_or": lambda i, j, k: tir.any(tir.any(i, j), k), - "or_rhs_or": lambda i, j, k: tir.any(i, tir.any(j, k)), - "and_of_ors": lambda i, j, k: tir.all(tir.any(i, j), tir.any(j, k), tir.any(i, k), i, j, k), - "or_of_ands": lambda i, j, k: tir.any(tir.all(i, j), tir.all(j, k), tir.all(i, k), i, j, k), + "and_lhs_and": lambda i, j, k: tirx.all(tirx.all(i, j), k), + "and_rhs_and": lambda i, j, k: tirx.all(i, tirx.all(j, k)), + "and_lhs_or": lambda i, j, k: tirx.all(tirx.any(i, j), k), + "and_rhs_or": lambda i, j, k: tirx.all(i, tirx.any(j, k)), + "or_lhs_and": lambda i, j, k: tirx.any(tirx.all(i, j), k), + "or_rhs_and": lambda i, j, k: tirx.any(i, tirx.all(j, k)), + "or_lhs_or": lambda i, j, k: tirx.any(tirx.any(i, j), k), + "or_rhs_or": lambda i, j, k: tirx.any(i, tirx.any(j, k)), + "and_of_ors": lambda i, j, k: tirx.all( + tirx.any(i, j), tirx.any(j, k), tirx.any(i, k), i, j, k + ), + "or_of_ands": lambda i, j, k: tirx.any( + tirx.all(i, j), tirx.all(j, k), tirx.all(i, k), i, j, k + ), } def make_ir_generator(name, expression): @@ -2740,7 +2744,7 @@ def func(): def string_stride(): @T.prim_func def main(a: T.handle, b: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) n = T.int32() A = T.match_buffer(a, (n,), strides=("A_s0",), buffer_type="auto") B = T.match_buffer(b, (n,), strides=("B_s0",), buffer_type="auto") @@ -2759,7 +2763,7 @@ def main(a: T.handle, b: T.handle): def string_stride_int64(): @T.prim_func def main(a: T.handle, b: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) n = T.int64() A_s0 = T.int64() B_s0 = T.int64() @@ -2776,7 +2780,7 @@ def merge_shape_var_def(): @T.prim_func(check_well_formed=False) def main(A: T.handle, B: T.handle): # fmt: off - T.func_attr({"global_symbol": "main", "tir.noalias": True}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True}) m, n = T.int32(), T.int32() A_1 = T.match_buffer(A, (m, n), strides=("A_1_s0", "A_1_s1"), buffer_type="auto") B_1 = T.match_buffer(B, (m, n), strides=("B_1_s0", "B_1_s1"), buffer_type="auto") @@ -2881,7 +2885,7 @@ def func(A: T.Buffer(64, "float32")): T.evaluate(A[bx]) mod = tvm.IRModule.from_expr(func) - return tvm.tir.transform.MakePackedAPI()(mod) + return tvm.tirx.transform.MakePackedAPI()(mod) def tvm_struct_set_generated_in_cpp(): @@ -2914,7 +2918,7 @@ def tir_packed_call(A: T.Buffer(16)): ) ) - return tvm.tir.transform.LowerTVMBuiltin()(Module) + return tvm.tirx.transform.LowerTVMBuiltin()(Module) def ir_module_with_attrs(): @@ -2938,15 +2942,15 @@ def nested_seqstmt(): cause failures to round-trip through TVMScript, including erroneous use of TVMScript's concise-scoping rules. This was resolved by normalizing nested SeqStmt in TIR, such that the use - of `tir.SeqStmt` below results in a single flat `tir.SeqStmt` - containing the three `tir.Evaluate` calls. + of `tirx.SeqStmt` below results in a single flat `tirx.SeqStmt` + containing the three `tirx.Evaluate` calls. """ - func = tvm.tir.PrimFunc( + func = tvm.tirx.PrimFunc( params=[], - body=tvm.tir.SeqStmt( + body=tvm.tirx.SeqStmt( [ - tvm.tir.SeqStmt([tvm.tir.Evaluate(0), tvm.tir.Evaluate(1)]), - tvm.tir.Evaluate(2), + tvm.tirx.SeqStmt([tvm.tirx.Evaluate(0), tvm.tirx.Evaluate(1)]), + tvm.tirx.Evaluate(2), ] ), ) @@ -3047,7 +3051,7 @@ def main(): # Should be equivalent to the bare "mod.subroutine()", but # that relies on `GlobalVar.__call__` returning the # correct IR type. - tir.call_tir(mod.subroutine) + tirx.call_tir(mod.subroutine) @T.prim_func def subroutine(): @@ -3088,7 +3092,7 @@ def func( B: T.Buffer((128, 128), "float32"), D: T.Buffer((128, 128), "float32"), ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + T.func_attr({"global_symbol": "main", "tirx.noalias": True, "layout_free_buffers": [1]}) C = T.sblock_alloc_buffer([128, 128], dtype="float32") for i0, i1, i2 in T.grid(128, 128, 128): with T.sblock("C"): @@ -3249,7 +3253,7 @@ def func(A: R.Object): def relax_symbolic_size_var(): """Relax symbolic variables may be SizeVar""" - N = tvm.tir.SizeVar("N", "int64") + N = tvm.tirx.SizeVar("N", "int64") @R.function def func(A: R.Tensor([N], "float16")): diff --git a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py index cc707a8ccf5c..5a5b603a5415 100644 --- a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py +++ b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py @@ -24,7 +24,7 @@ import tvm.testing from tvm.s_tir.schedule.testing import assert_structural_equal_ignore_global_symbol from tvm.script import from_source -from tvm.script import tir as T +from tvm.script import tirx as T @T.prim_func @@ -447,7 +447,7 @@ def func(i: T.int32): def test_preserve_variable_name(): - """Use variable name when generating tir::Bind""" + """Use variable name when generating tirx::Bind""" @T.prim_func def func(): diff --git a/tests/python/tvmscript/test_tvmscript_type.py b/tests/python/tvmscript/test_tvmscript_type.py index 4129911b396e..11401863a072 100644 --- a/tests/python/tvmscript/test_tvmscript_type.py +++ b/tests/python/tvmscript/test_tvmscript_type.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement -from tvm.script import tir as T +from tvm.script import tirx as T """ This prim func include necessary buffer types that need to be checked diff --git a/tests/scripts/release/make_notes.py b/tests/scripts/release/make_notes.py index 82e5a4372b0a..abee4e85db67 100644 --- a/tests/scripts/release/make_notes.py +++ b/tests/scripts/release/make_notes.py @@ -53,7 +53,7 @@ "byoc": "BYOC", "community": "Community", "tensorir": "TIR", - "tir": "TIR", + "tirx": "TIR", "tensorflow": "Frontend", "tflite": "Frontend", "pytorch": "Frontend", diff --git a/tests/scripts/task_python_unittest.sh b/tests/scripts/task_python_unittest.sh index ee1b31a6e124..ca092d10d657 100755 --- a/tests/scripts/task_python_unittest.sh +++ b/tests/scripts/task_python_unittest.sh @@ -51,9 +51,9 @@ TEST_FILES=( "s_tir/schedule" "s_tir/dlight" "s_tir/analysis" - "tir-analysis" - "tir-base" - "tir-transform" + "tirx-analysis" + "tirx-base" + "tirx-transform" "tvmscript" "relax" )