diff --git a/tests/python/tirx/transform/test_transform_lower_tirx.py b/tests/python/tirx/transform/test_transform_lower_tirx.py index 3e20d61f8059..80e68243d0b3 100644 --- a/tests/python/tirx/transform/test_transform_lower_tirx.py +++ b/tests/python/tirx/transform/test_transform_lower_tirx.py @@ -24,7 +24,7 @@ from tvm.tirx.layout import laneid, warpid, wg_local_layout from tvm.tirx.stmt import ExecScopeStmt from tvm.tirx.stmt_functor import post_order_visit -from tvm.tirx.transform import LowerTIRx, Simplify +from tvm.tirx.transform import LowerTIRx, StmtSimplify def _contains_exec_scope(mod): @@ -1000,7 +1000,7 @@ def before(A_ptr: Tx.handle): with tvm.target.Target("cuda"): lowered = LowerTIRx()(tvm.IRModule({"main": before})) - simplified = Simplify()(lowered) + simplified = StmtSimplify()(lowered) script = simplified.script(extra_config={"tirx.prefix": "Tx"}) assert "if warp_id_in_cta // 4 == 0:" in script