From 12d24980c321e0b05e265a8e2be11930cc3a4edd Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Fri, 20 Mar 2026 16:57:03 +0900 Subject: [PATCH] [Refactor] Update type references from tir to tirx in PyTorch ExportedProgram frontend Follow up for #18913 and #18917 --- .../relax/frontend/torch/exported_program_translator.py | 8 ++++---- .../test_transform_legalize_ops_index_linear_algebra.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index fd03f67332cd..67e0e45da0a3 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -29,7 +29,7 @@ from torch import fx import tvm -from tvm import relax, tir +from tvm import relax from .base_fx_graph_translator import BaseFXGraphImporter @@ -968,11 +968,11 @@ def _slice(self, node: fx.Node) -> relax.Var: # tensor's own dimension size (common with dynamic shapes). if isinstance(start, int) and start == 0 and isinstance(step, int) and step == 1: in_shape = self.shape_of(x) - if in_shape is not None and isinstance(end_val, tir.PrimExpr): + if in_shape is not None and isinstance(end_val, tvm.tirx.PrimExpr): actual_dim = dim if dim >= 0 else len(in_shape) + dim dim_expr = in_shape[actual_dim] - if isinstance(dim_expr, tir.PrimExpr): - if tir.analysis.expr_deep_equal(end_val, dim_expr): + if isinstance(dim_expr, tvm.tirx.PrimExpr): + if tvm.tirx.analysis.expr_deep_equal(end_val, dim_expr): return x axes = [dim] diff --git a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py index b8dbe1934bb5..9f45c7031f6c 100644 --- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py +++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py @@ -277,7 +277,7 @@ def main(x: R.Tensor((8, 9, 10), dtype="float32")) -> R.Tensor((8, 9, 3), dtype= @T.prim_func(private=True) def strided_slice(rxplaceholder: T.Buffer((T.int64(8), T.int64(9), T.int64(10)), "float32"), T_strided_slice_with_axes: T.Buffer((T.int64(8), T.int64(9), T.int64(3)), "float32")): - T.func_attr({"tir.noalias": True}) + T.func_attr({"tirx.noalias": True}) for ax0, ax1, ax2 in T.grid(T.int64(8), T.int64(9), T.int64(3)): with T.sblock("T_strided_slice_with_axes"): v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])