From 6e4a4da40a4b36694fe35bcc3765984f0f6aeab0 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Thu, 11 Jun 2026 00:48:36 -0400 Subject: [PATCH 1/4] [Tests][MetaSchedule] Update s_tir sketch tests for current defaults Fix the s_tir MetaSchedule sketch tests that no longer matched the design spaces generated by current TVM: * test_meta_schedule_schedule_rule_add_rfactor.py::test_cpu_argmax The argmax workload and its expected sketches used the legacy `v: T.int32 = ...` annotated-assignment syntax. The TIRx parser now lowers that form to a mutable local-scalar buffer plus a store, which the rfactor/cross-thread-reduction reducer matching correctly rejects (reduction combiner temporaries must be immutable binds). Switch the temporaries to `v: T.let[T.int32] = ...`, producing Bind nodes - the same canonical form te.create_prim_func emits for comm_reducer based reductions - so AddRFactor generates the three expected sketches again. * test_meta_schedule_space_cuda.py (cap, dil, gmm, t2d, nrm, sfm, cbr, tbg) and test_meta_schedule_space_cuda_async.py (c2d) Commit b46564618a (#18927) expanded DefaultCUDA unroll_max_steps from {0, 16, 64, 512, 1024} to {0, 16, 32, 64, 128, 256, 512, 1024} without updating the expected SampleCategorical decisions, so the recorded indices selected different unroll values than the expected modules encode. Remap the decision indices (2->3, 3->6, 4->7) so each test keeps sampling the same unroll value. The expected modules and all other decisions are unchanged; every sketch was re-verified by replaying the trace and structurally comparing against the expected module. --- ...meta_schedule_schedule_rule_add_rfactor.py | 28 +++++++++++-------- .../test_meta_schedule_space_cuda.py | 18 ++++++------ .../test_meta_schedule_space_cuda_async.py | 2 +- 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py index fc6043526d76..550cc35e9401 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_add_rfactor.py @@ -138,8 +138,10 @@ def argmax( with T.init(): argmax_v0[i] = -1 argmax_v1[i] = T.min_value("float32") - v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) - v_argmax_v1: T.float32 = T.Select( + v_argmax_v0: T.let[T.int32] = T.Select( + argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k] + ) + v_argmax_v1: T.let[T.float32] = T.Select( argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k] ) argmax_v0[i] = v_argmax_v0 @@ -160,8 +162,10 @@ def argmax_0( with T.init(): argmax_v0[i] = -1 argmax_v1[i] = T.float32(-3.4028234663852886e38) - v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) - v_argmax_v1: T.float32 = T.Select( + v_argmax_v0: T.let[T.int32] = T.Select( + argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k] + ) + v_argmax_v1: T.let[T.float32] = T.Select( argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k] ) argmax_v0[i] = v_argmax_v0 @@ -184,12 +188,12 @@ def argmax_1( with T.init(): argmax_v0_rf[i, vi1_1] = -1 argmax_v1_rf[i, vi1_1] = T.float32(-3.4028234663852886e38) - v_argmax_v0_rf: T.int32 = T.Select( + v_argmax_v0_rf: T.let[T.int32] = T.Select( argmax_v1_rf[i, vi1_1] >= val[i, vi1_0 * 16 + vi1_1], argmax_v0_rf[i, vi1_1], idx[i, vi1_0 * 16 + vi1_1], ) - v_argmax_v1_rf: T.float32 = T.Select( + v_argmax_v1_rf: T.let[T.float32] = T.Select( argmax_v1_rf[i, vi1_1] >= val[i, vi1_0 * 16 + vi1_1], argmax_v1_rf[i, vi1_1], val[i, vi1_0 * 16 + vi1_1], @@ -205,10 +209,10 @@ def argmax_1( with T.init(): argmax_v0[i] = -1 argmax_v1[i] = T.float32(-3.4028234663852886e38) - v_argmax_v0: T.int32 = T.Select( + v_argmax_v0: T.let[T.int32] = T.Select( argmax_v1[i] >= argmax_v1_rf[i, vi1_1], argmax_v0[i], argmax_v0_rf[i, vi1_1] ) - v_argmax_v1: T.float32 = T.Select( + v_argmax_v1: T.let[T.float32] = T.Select( argmax_v1[i] >= argmax_v1_rf[i, vi1_1], argmax_v1[i], argmax_v1_rf[i, vi1_1] ) argmax_v0[i] = v_argmax_v0 @@ -233,12 +237,12 @@ def argmax_2( with T.init(): argmax_v0_rf[i, vi1_0] = -1 argmax_v1_rf[i, vi1_0] = T.float32(-3.4028234663852886e38) - v_argmax_v0_rf: T.int32 = T.Select( + v_argmax_v0_rf: T.let[T.int32] = T.Select( argmax_v1_rf[i, vi1_0] >= val[i, vi1_0 * 16 + vi1_1], argmax_v0_rf[i, vi1_0], idx[i, vi1_0 * 16 + vi1_1], ) - v_argmax_v1_rf: T.float32 = T.Select( + v_argmax_v1_rf: T.let[T.float32] = T.Select( argmax_v1_rf[i, vi1_0] >= val[i, vi1_0 * 16 + vi1_1], argmax_v1_rf[i, vi1_0], val[i, vi1_0 * 16 + vi1_1], @@ -254,10 +258,10 @@ def argmax_2( with T.init(): argmax_v0[i] = -1 argmax_v1[i] = T.float32(-3.4028234663852886e38) - v_argmax_v0: T.int32 = T.Select( + v_argmax_v0: T.let[T.int32] = T.Select( argmax_v1[i] >= argmax_v1_rf[i, vi1_0], argmax_v0[i], argmax_v0_rf[i, vi1_0] ) - v_argmax_v1: T.float32 = T.Select( + v_argmax_v1: T.let[T.float32] = T.Select( argmax_v1[i] >= argmax_v1_rf[i, vi1_0], argmax_v1[i], argmax_v1_rf[i, vi1_0] ) argmax_v0[i] = v_argmax_v0 diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py index ba9ac778a581..d748d074edc6 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py @@ -375,7 +375,7 @@ def cap_0(inputs: T.Buffer((1, 16, 16, 4, 4, 32), "float32"), weight: T.Buffer(( ("SamplePerfectTile", [8, 4, 1]), ("SampleCategorical", 1), ("SampleCategorical", 3), - ("SampleCategorical", 2), + ("SampleCategorical", 3), ] mod = create_te_workload("CAP", 0) actual = _design_space(mod) @@ -537,7 +537,7 @@ def dil_0(inputs: T.Buffer((1, 224, 224, 3), "float32"), weight: T.Buffer((7, 7, ("SamplePerfectTile", [3, 1, 1]), ("SampleCategorical", 1), ("SampleCategorical", 3), - ("SampleCategorical", 3), + ("SampleCategorical", 6), ] mod = create_te_workload("DIL", 0) actual = _design_space(mod) @@ -611,7 +611,7 @@ def gmm_0(X: T.Buffer((1, 128, 128), "float32"), Y: T.Buffer((1, 128, 128), "flo ("SamplePerfectTile", [1, 32, 4]), ("SampleCategorical", 1), ("SampleCategorical", 0), - ("SampleCategorical", 4), + ("SampleCategorical", 7), ] mod = create_te_workload("GMM", 0) actual = _design_space(mod) @@ -776,7 +776,7 @@ def t2d_0(inputs: T.Buffer((1, 4, 4, 512), "float32"), weight: T.Buffer((4, 4, 5 ("SamplePerfectTile", [16, 4, 8]), ("SampleCategorical", 1), ("SampleCategorical", 3), - ("SampleCategorical", 2), + ("SampleCategorical", 3), ] mod = create_te_workload("T2D", 0) actual = _design_space(mod) @@ -846,11 +846,11 @@ def nrm_1(A: T.Buffer((1, 256, 256), "float32"), D: T.Buffer(1, "float32")) -> N D[v_b] = T.sqrt(C_shared[v_b]) # fmt: on decision_0 = [ - ("SampleCategorical", 3), + ("SampleCategorical", 6), ] decision_1 = [ ("SampleCategorical", 5), - ("SampleCategorical", 4), + ("SampleCategorical", 7), ] mod = create_te_workload("NRM", 0) actual = _design_space(mod) @@ -1043,7 +1043,7 @@ def sfm_3(A: T.Buffer((256, 256), "float32"), T_softmax_norm: T.Buffer((256, 256 ] decision_2 = [ ("SampleCategorical", 7), - ("SampleCategorical", 3), + ("SampleCategorical", 6), ("SampleCategorical", 0), ] decision_3 = [ @@ -1132,7 +1132,7 @@ def cbr_0(data: T.Buffer((1, 224, 224, 3), "float32"), kernel: T.Buffer((7, 7, 3 ("SamplePerfectTile", [3, 1, 1]), ("SampleCategorical", 0), ("SampleCategorical", 0), - ("SampleCategorical", 3), + ("SampleCategorical", 6), ] mod = create_te_workload("CBR", 0) actual = _design_space(mod) @@ -1211,7 +1211,7 @@ def tbg_0(query: T.Buffer((1, 128, 12, 64), "float32"), value: T.Buffer((1, 128, ("SamplePerfectTile", [8, 4, 2]), ("SampleCategorical", 2), ("SampleCategorical", 3), - ("SampleCategorical", 4), + ("SampleCategorical", 7), ] mod = create_te_workload("TBG", 0) actual = _design_space(mod) diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda_async.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda_async.py index 4c44feae910b..1907502b4392 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda_async.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda_async.py @@ -178,7 +178,7 @@ def test_cuda_c2d(): ("SamplePerfectTile", [3, 1, 1]), ("SampleCategorical", 3), ("SampleCategorical", 2), - ("SampleCategorical", 4), + ("SampleCategorical", 7), ] mod = create_te_workload("C2D", 0) From a33b2ba10317e85f761ad8d0ca155299448b6662 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Thu, 11 Jun 2026 00:34:47 -0400 Subject: [PATCH 2/4] [Tests][S-TIR] Use T.let binds in cross-thread tuple reduction tests The TVMScript TIRX parser now treats `v: T.int32 = expr` as a mutable local scalar buffer (AllocBuffer + BufferStore) rather than an immutable Bind. The tuple-style argmax/argmin/layer-norm reduction tests in test_s_tir_transform_lower_cross_thread_reduction.py still used the old spelling, so their reduction blocks no longer matched the reduction-block pattern required by LowerCrossThreadReduction (condition #3: the number of consecutive Binds in the block body must equal the number of BufferStores in the block init), and the pass rejected them. Switch the reduction update bindings to `v: T.let[dtype] = expr`, which produces the Bind nodes the pass expects, matching the spelling already used by the s_tir rfactor schedule tests. No pass behavior changes. --- ..._transform_lower_cross_thread_reduction.py | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py b/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py index 34e08718f578..5e83bd8f7fb3 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_cross_thread_reduction.py @@ -1246,8 +1246,10 @@ def argmax_split( with T.init(): argmax_v0[i] = -1 argmax_v1[i] = T.float32(-3.4028234663852886e38) - v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) - v_argmax_v1: T.float32 = T.Select( + v_argmax_v0: T.let[T.int32] = T.Select( + argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k] + ) + v_argmax_v1: T.let[T.float32] = T.Select( argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k] ) argmax_v0[i] = v_argmax_v0 @@ -1278,10 +1280,10 @@ def lowered_argmax_split( k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) T.writes(in_thread_argmax_v0[0], in_thread_argmax_v1[0]) - v_argmax_v0: T.int32 = T.Select( + v_argmax_v0: T.let[T.int32] = T.Select( in_thread_argmax_v1[0] >= val[i, k], in_thread_argmax_v0[0], idx[i, k] ) - v_argmax_v1: T.float32 = T.Select( + v_argmax_v1: T.let[T.float32] = T.Select( in_thread_argmax_v1[0] >= val[i, k], in_thread_argmax_v1[0], val[i, k] ) in_thread_argmax_v0[0] = v_argmax_v0 @@ -1338,8 +1340,10 @@ def argmin_split_init_update_reordered( with T.init(): argmin_v1[i] = T.float32(3.4028234663852886e38) argmin_v0[i] = -1 - v_argmin_v0: T.int32 = T.Select(argmin_v1[i] <= val[i, k], argmin_v0[i], idx[i, k]) - v_argmin_v1: T.float32 = T.Select( + v_argmin_v0: T.let[T.int32] = T.Select( + argmin_v1[i] <= val[i, k], argmin_v0[i], idx[i, k] + ) + v_argmin_v1: T.let[T.float32] = T.Select( argmin_v1[i] <= val[i, k], argmin_v1[i], val[i, k] ) argmin_v1[i] = v_argmin_v1 @@ -1370,10 +1374,10 @@ def lowered_argmin_split_init_update_reordered( k = T.axis.reduce(128, i1_0 * 32 + i1_1) T.reads(idx[i, k], val[i, k]) T.writes(in_thread_argmin_v0[0], in_thread_argmin_v1[0]) - v_argmin_v0: T.int32 = T.Select( + v_argmin_v0: T.let[T.int32] = T.Select( in_thread_argmin_v1[0] <= val[i, k], in_thread_argmin_v0[0], idx[i, k] ) - v_argmin_v1: T.float32 = T.Select( + v_argmin_v1: T.let[T.float32] = T.Select( in_thread_argmin_v1[0] <= val[i, k], in_thread_argmin_v1[0], val[i, k] ) in_thread_argmin_v1[0] = v_argmin_v1 @@ -1433,8 +1437,8 @@ def layer_norm_tuple_sum( with T.init(): data_red_temp_v0[ax0] = T.float32(0) data_red_temp_v1[ax0] = T.float32(0) - v_data_red_temp_v0: T.float32 = data_red_temp_v0[ax0] + data[ax0, k1] - v_data_red_temp_v1: T.float32 = ( + v_data_red_temp_v0: T.let[T.float32] = data_red_temp_v0[ax0] + data[ax0, k1] + v_data_red_temp_v1: T.let[T.float32] = ( data_red_temp_v1[ax0] + data[ax0, k1] * data[ax0, k1] ) data_red_temp_v0[ax0] = v_data_red_temp_v0 @@ -1499,8 +1503,10 @@ def lowered_layer_norm_tuple_sum( k1 = T.axis.reduce(768, i1_0 * 32 + i1_1) T.reads(data[ax0, k1]) T.writes(in_thread_data_red_temp_v0[0], in_thread_data_red_temp_v1[0]) - v_data_red_temp_v0: T.float32 = in_thread_data_red_temp_v0[0] + data[ax0, k1] - v_data_red_temp_v1: T.float32 = ( + v_data_red_temp_v0: T.let[T.float32] = ( + in_thread_data_red_temp_v0[0] + data[ax0, k1] + ) + v_data_red_temp_v1: T.let[T.float32] = ( in_thread_data_red_temp_v1[0] + data[ax0, k1] * data[ax0, k1] ) in_thread_data_red_temp_v0[0] = v_data_red_temp_v0 From f7aa998d40cabcb75c21516dfd6a0cb077338503 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Thu, 11 Jun 2026 01:17:00 -0400 Subject: [PATCH 3/4] [Tests][S-TIR] Migrate remaining let-binding tests to T.let spelling Sweep the s_tir test tree for other tests broken by the same TIRx parser semantics change: plain `x = expr` and `x: T.int32 = expr` now create mutable local-scalar buffers instead of immutable binds, so tests whose intent is a Bind (LetStmt) must spell it `x: T.let[dtype] = expr`. * test_s_tir_transform_compact_buffer_region.py TestLetBinding: index vars rii/rjj are meant to be binds the pass can analyze through; the scalar-buffer form made the compaction result diverge from expected. TestNonIndexLetBinding: plain assignments of call_extern results (incl. handle and void dtypes) crashed CompactBufferAllocation when parsed as scalar buffers. * test_s_tir_transform_hoist_expression.py test_hoist_with_let / test_hoist_disable_let / test_hoist_let_expr: the hoisted condition and Let-expr bindings must be Bind nodes for HoistExpression to hoist (or deliberately not hoist) them. * test_s_tir_transform_remove_undef.py test_remove_let_undef / test_raise_error_for_undef_as_store_indices: binding T.undef() through a mutable scalar hid the undef from RemoveStoreUndef, leaving a stray allocation in one test and swallowing the expected error in the other. Verified by running the full tests/python/s_tir and tests/python/tirx trees: the only remaining failures are unrelated (nvcc too old for compute_120a on the local RTX 5090, buffer_data_alignment annotation mismatches in lower_opaque_block, SBlockRealize well-formedness in default_gpu_schedule, and one cross-file test-isolation flake in test_parser_printer), none caused by bind spelling. --- ...st_s_tir_transform_compact_buffer_region.py | 18 +++++++++--------- .../test_s_tir_transform_hoist_expression.py | 10 +++++----- .../test_s_tir_transform_remove_undef.py | 4 ++-- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py b/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py index e398876f35d6..3185147f8e14 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_compact_buffer_region.py @@ -755,8 +755,8 @@ def before(): for rii, rjj in T.grid(8, 8): C[rii, rjj] = T.float32(0) for riijj in T.serial(8 * 8): - rii: T.int32 = riijj // 8 - rjj: T.int32 = riijj % 8 + rii: T.let[T.int32] = riijj // 8 + rjj: T.let[T.int32] = riijj % 8 C[rii, rjj] += A[rk, rii] * B[rk, rjj] expected = before @@ -766,13 +766,13 @@ class TestNonIndexLetBinding(BaseCompactTest): @T.prim_func(s_tir=True) def before(): A = T.sblock_alloc_buffer((64), "float32") - x1 = T.call_extern("get", dtype="float16") - x2 = T.call_extern("get", dtype="float32") - x3 = T.call_extern("get", dtype="float64") - x4 = T.call_extern("get", dtype="uint8") - x5 = T.call_extern("get", dtype="int32x16") - x6 = T.call_extern("get", dtype="handle") - x7 = T.call_extern("get", dtype="") + x1: T.let[T.float16] = T.call_extern("get", dtype="float16") + x2: T.let[T.float32] = T.call_extern("get", dtype="float32") + x3: T.let[T.float64] = T.call_extern("get", dtype="float64") + x4: T.let[T.uint8] = T.call_extern("get", dtype="uint8") + x5: T.let[T.int32x16] = T.call_extern("get", dtype="int32x16") + x6: T.let[T.handle] = T.call_extern("get", dtype="handle") + x7: T.let = T.call_extern("get", dtype="") for rk in range(64): A[rk] = T.call_extern("load_ptr", x1, x2, x3, x4, x5, x6, x7, dtype="float32") diff --git a/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py index b4c52d283187..da86edb310b0 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_hoist_expression.py @@ -221,14 +221,14 @@ def test_hoist_with_let(): def before(A: T.Buffer((4, 4), "float32")): for i in T.serial(4): for j in T.serial(4): - condition = i < 3 + condition: T.let[T.bool] = i < 3 if condition: A[i, j] = 0.0 @T.prim_func(private=True, s_tir=True) def expected(A: T.Buffer((4, 4), "float32")): for i in T.serial(4): - condition: T.bool = i < 3 # noqa: F841 + condition: T.let[T.bool] = i < 3 # noqa: F841 if i < 3: for j in T.serial(4): A[i, j] = T.float32(0.0) @@ -250,14 +250,14 @@ def test_hoist_disable_let(): def before(A: T.Buffer((4, 4), "float32")): for i in T.serial(4): for j in T.serial(4): - condition = i < 3 + condition: T.let[T.bool] = i < 3 if condition: A[i, j] = 0.0 @T.prim_func(private=True, s_tir=True) def expected(A: T.Buffer((4, 4), "float32")): for i, j in T.grid(4, 4): - condition: T.bool = i < 3 # noqa: F841 + condition: T.let[T.bool] = i < 3 # noqa: F841 if i < 3: A[i, j] = T.float32(0.0) @@ -519,7 +519,7 @@ def before(A: T.Buffer((4, 4), "float32")): @T.prim_func(private=True, s_tir=True) def expected(A: T.Buffer((4, 4), "float32")): for i in T.serial(4): - x: T.float32 = T.cast(i + 1, "float32") # noqa: F841 + x: T.let[T.float32] = T.cast(i + 1, "float32") # noqa: F841 for j in T.serial(4): A[i, j] = T.float32(5.0) * T.cast(i + 1, "float32") + T.cast(j, "float32") diff --git a/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py b/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py index cdc39c443a74..a1dc101c74df 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_remove_undef.py @@ -84,7 +84,7 @@ def test_remove_let_undef(): class Before: @T.prim_func(s_tir=True) def main(A: T.Buffer(1, "int32")): - val = T.undef(dtype="int32") + val: T.let[T.int32] = T.undef(dtype="int32") A[0] = val @I.ir_module @@ -104,7 +104,7 @@ def test_raise_error_for_undef_as_store_indices(): class Before: @T.prim_func(s_tir=True) def main(A: T.Buffer(1, "int32")): - val = T.undef(dtype="int32") + val: T.let[T.int32] = T.undef(dtype="int32") A[val] = 5 with pytest.raises(TVMError): From 85ffba19d82b0cba0bed6e53a7a40c9d097004c1 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Thu, 11 Jun 2026 01:35:33 -0400 Subject: [PATCH 4/4] [Tests][S-TIR] Use T.let for reduction temporaries in green tests Sweep follow-up: convert reduction-combiner temporaries (v_*_red_temp_*, v_argmax_*) from the legacy `v: T.dtype = expr` spelling to `v: T.let[dtype] = expr` in tests that still pass but feed schedule rules / passes a non-canonical mutable-scalar form. Real lowered workloads (te.create_prim_func of comm_reducer reductions) produce Bind nodes, so these hand-written mimics should too; with the mutable-scalar spelling the reducer pattern matching in rfactor / cross-thread reduction would reject these blocks at lowering time even though the tests themselves stayed green. Deliberately left unchanged: tvmscript_printer_annotation (tests the scalar-assignment sugar itself), non-reduction scalar temporaries in schedule-error / plan-update / trace-apply tests (value semantics are equivalent and no pattern matching depends on them), and hardware-gated hexagon / nvshmem files that cannot be verified locally. --- .../test_transform_rewrite_cuda_graph.py | 4 +-- .../dlight/test_gpu_general_reduction.py | 16 ++++----- ...tproc_rewrite_parallel_vectorize_unroll.py | 8 ++--- ...le_schedule_rule_cross_thread_reduction.py | 36 ++++++++++++------- .../test_tir_schedule_compute_inline.py | 8 ++--- .../schedule/test_tir_schedule_utilities.py | 4 +-- 6 files changed, 44 insertions(+), 32 deletions(-) diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py b/tests/python/relax/test_transform_rewrite_cuda_graph.py index 80637edcc07e..3897e444bc93 100644 --- a/tests/python/relax/test_transform_rewrite_cuda_graph.py +++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py @@ -475,10 +475,10 @@ def layer_norm( with T.init(): A_red_temp_v0[v_ax0, v_ax1, v_ax2] = T.float32(0) A_red_temp_v1[v_ax0, v_ax1, v_ax2] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1, v_ax2] + T.Cast( + v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0[v_ax0, v_ax1, v_ax2] + T.Cast( "float32", A[v_ax0, v_ax1, v_ax2, v_k3] ) - v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1, v_ax2] + T.Cast( + v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1[v_ax0, v_ax1, v_ax2] + T.Cast( "float32", A[v_ax0, v_ax1, v_ax2, v_k3] ) * T.Cast("float32", A[v_ax0, v_ax1, v_ax2, v_k3]) A_red_temp_v0[v_ax0, v_ax1, v_ax2] = v_A_red_temp_v0 diff --git a/tests/python/s_tir/dlight/test_gpu_general_reduction.py b/tests/python/s_tir/dlight/test_gpu_general_reduction.py index 7022cef9f20d..6ddc4671b9b2 100644 --- a/tests/python/s_tir/dlight/test_gpu_general_reduction.py +++ b/tests/python/s_tir/dlight/test_gpu_general_reduction.py @@ -337,8 +337,8 @@ def main(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: with T.init(): A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] - v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2] + v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2] A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 for ax0, ax1, ax2 in T.grid(T.int64(1), n, T.int64(2560)): @@ -377,8 +377,8 @@ def main(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: with T.init(): A_red_temp_v0_shared[T.int64(0), v0] = T.float32(0) A_red_temp_v1_shared[T.int64(0), v0] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0_shared[T.int64(0), v0] + lv6[T.int64(0), v0, v1] - v_A_red_temp_v1: T.float32 = A_red_temp_v1_shared[T.int64(0), v0] + lv6[T.int64(0), v0, v1] * lv6[T.int64(0), v0, v1] + v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0_shared[T.int64(0), v0] + lv6[T.int64(0), v0, v1] + v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1_shared[T.int64(0), v0] + lv6[T.int64(0), v0, v1] * lv6[T.int64(0), v0, v1] A_red_temp_v0_shared[T.int64(0), v0] = v_A_red_temp_v0 A_red_temp_v1_shared[T.int64(0), v0] = v_A_red_temp_v1 for ax1_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"): @@ -481,8 +481,8 @@ def main(A: T.Buffer((1, 2048), "float32"), B: T.Buffer((2048,), "float32"), C: with T.init(): A_red_temp_v0[v_ax0, v_ax1] = T.float32(0) A_red_temp_v1[v_ax0, v_ax1] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2] - v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2] * T_reshape_1[v_ax0, v_ax1, v_k2] + v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1[v_ax0, v_ax1] + T_reshape_1[v_ax0, v_ax1, v_k2] * T_reshape_1[v_ax0, v_ax1, v_k2] A_red_temp_v0[v_ax0, v_ax1] = v_A_red_temp_v0 A_red_temp_v1[v_ax0, v_ax1] = v_A_red_temp_v1 for ax0, ax1 in T.grid(32, 64): @@ -531,8 +531,8 @@ def main(A: T.Buffer((1, 2048), "float32"), B: T.Buffer((2048,), "float32"), C: with T.init(): A_red_temp_v0_shared[0, v0] = T.float32(0) A_red_temp_v1_shared[0, v0] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0_shared[0, v0] + A[0, v0 * 64 + v1] - v_A_red_temp_v1: T.float32 = A_red_temp_v1_shared[0, v0] + A[0, v0 * 64 + v1] * A[0, v0 * 64 + v1] + v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0_shared[0, v0] + A[0, v0 * 64 + v1] + v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1_shared[0, v0] + A[0, v0 * 64 + v1] * A[0, v0 * 64 + v1] A_red_temp_v0_shared[0, v0] = v_A_red_temp_v0 A_red_temp_v1_shared[0, v0] = v_A_red_temp_v1 for ax1_1 in T.thread_binding(256, thread="threadIdx.x"): diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py index 84d667f47d8e..5152e2fa15f0 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -241,8 +241,8 @@ def layer_norm(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "f with T.init(): A_red_temp_v0[v_ax0] = T.float32(0) A_red_temp_v1[v_ax0] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0] + A[v_ax0, v_k1, v_k2, v_k3] - v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0] + A[v_ax0, v_k1, v_k2, v_k3] * A[v_ax0, v_k1, v_k2, v_k3] + v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0[v_ax0] + A[v_ax0, v_k1, v_k2, v_k3] + v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1[v_ax0] + A[v_ax0, v_k1, v_k2, v_k3] * A[v_ax0, v_k1, v_k2, v_k3] A_red_temp_v0[v_ax0] = v_A_red_temp_v0 A_red_temp_v1[v_ax0] = v_A_red_temp_v1 for ax0, ax1, ax2, ax3 in T.grid(1, 4, 4, 32): @@ -267,8 +267,8 @@ def expected(A: T.Buffer((1, 4, 4, 32), "float32"), B: T.Buffer((4, 4, 32), "flo with T.init(): A_red_temp_v0[0] = T.float32(0) A_red_temp_v1[0] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0[0] + A[0, v_k1, v_k2, v_k3] - v_A_red_temp_v1: T.float32 = A_red_temp_v1[0] + A[0, v_k1, v_k2, v_k3] * A[0, v_k1, v_k2, v_k3] + v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0[0] + A[0, v_k1, v_k2, v_k3] + v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1[0] + A[0, v_k1, v_k2, v_k3] * A[0, v_k1, v_k2, v_k3] A_red_temp_v0[0] = v_A_red_temp_v0 A_red_temp_v1[0] = v_A_red_temp_v1 for ax0, ax1, ax2, ax3 in T.grid(1, 4, 4, 32): diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py index eaecaa0fb598..0c6cc9bb2459 100644 --- a/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py +++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_schedule_rule_cross_thread_reduction.py @@ -582,8 +582,12 @@ def argmax( with T.init(): argmax_v0[i] = -1 argmax_v1[i] = T.min_value("float32") - v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) - v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + v_argmax_v0: T.let[T.int32] = T.Select( + argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k] + ) + v_argmax_v1: T.let[T.float32] = T.Select( + argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k] + ) argmax_v0[i] = v_argmax_v0 argmax_v1[i] = v_argmax_v1 @@ -604,8 +608,12 @@ def argmax_32( with T.init(): argmax_v0[i] = -1 argmax_v1[i] = T.min_value("float32") - v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) - v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + v_argmax_v0: T.let[T.int32] = T.Select( + argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k] + ) + v_argmax_v1: T.let[T.float32] = T.Select( + argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k] + ) argmax_v0[i] = v_argmax_v0 argmax_v1[i] = v_argmax_v1 @@ -628,8 +636,10 @@ def argmax_0( with T.init(): argmax_v0[i] = -1 argmax_v1[i] = T.float32(-3.4028234663852886e38) - v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) - v_argmax_v1: T.float32 = T.Select( + v_argmax_v0: T.let[T.int32] = T.Select( + argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k] + ) + v_argmax_v1: T.let[T.float32] = T.Select( argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k] ) argmax_v0[i] = v_argmax_v0 @@ -654,10 +664,10 @@ def argmax_1( with T.init(): argmax_v0[i] = -1 argmax_v1[i] = T.float32(-3.4028234663852886e38) - v_argmax_v0: T.int32 = T.Select( + v_argmax_v0: T.let[T.int32] = T.Select( argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k] ) - v_argmax_v1: T.float32 = T.Select( + v_argmax_v1: T.let[T.float32] = T.Select( argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k] ) argmax_v0[i] = v_argmax_v0 @@ -701,8 +711,10 @@ def argmax_0( with T.init(): argmax_v0[i] = -1 argmax_v1[i] = T.float32(-3.4028234663852886e38) - v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) - v_argmax_v1: T.float32 = T.Select( + v_argmax_v0: T.let[T.int32] = T.Select( + argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k] + ) + v_argmax_v1: T.let[T.float32] = T.Select( argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k] ) argmax_v0[i] = v_argmax_v0 @@ -728,10 +740,10 @@ def argmax_1( with T.init(): argmax_v0[i] = -1 argmax_v1[i] = T.float32(-3.4028234663852886e38) - v_argmax_v0: T.int32 = T.Select( + v_argmax_v0: T.let[T.int32] = T.Select( argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k] ) - v_argmax_v1: T.float32 = T.Select( + v_argmax_v1: T.let[T.float32] = T.Select( argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k] ) argmax_v0[i] = v_argmax_v0 diff --git a/tests/python/s_tir/schedule/test_tir_schedule_compute_inline.py b/tests/python/s_tir/schedule/test_tir_schedule_compute_inline.py index df0e4963b9c3..d05d67b4a822 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_compute_inline.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_compute_inline.py @@ -1424,8 +1424,8 @@ def before(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias with T.init(): A_red_temp_v0_shared[v_ax0, v_ax1] = T.float32(0) A_red_temp_v1_shared[v_ax0, v_ax1] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] - v_A_red_temp_v1: T.float32 = A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2] + v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2] A_red_temp_v0_shared[v_ax0, v_ax1] = v_A_red_temp_v0 A_red_temp_v1_shared[v_ax0, v_ax1] = v_A_red_temp_v1 for ax2_0 in range(T.int64(10)): @@ -1465,8 +1465,8 @@ def after(p_lv6: T.handle, weight1: T.Buffer((T.int64(2560),), "float32"), bias: with T.init(): A_red_temp_v0_shared[v_ax0, v_ax1] = T.float32(0) A_red_temp_v1_shared[v_ax0, v_ax1] = T.float32(0) - v_A_red_temp_v0: T.float32 = A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] - v_A_red_temp_v1: T.float32 = A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2] + v_A_red_temp_v0: T.let[T.float32] = A_red_temp_v0_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] + v_A_red_temp_v1: T.let[T.float32] = A_red_temp_v1_shared[v_ax0, v_ax1] + lv6[v_ax0, v_ax1, v_k2] * lv6[v_ax0, v_ax1, v_k2] A_red_temp_v0_shared[v_ax0, v_ax1] = v_A_red_temp_v0 A_red_temp_v1_shared[v_ax0, v_ax1] = v_A_red_temp_v1 for ax2_0 in range(T.int64(10)): diff --git a/tests/python/s_tir/schedule/test_tir_schedule_utilities.py b/tests/python/s_tir/schedule/test_tir_schedule_utilities.py index dcd3b7b5a296..a2aa9c699db8 100644 --- a/tests/python/s_tir/schedule/test_tir_schedule_utilities.py +++ b/tests/python/s_tir/schedule/test_tir_schedule_utilities.py @@ -147,8 +147,8 @@ def tuple_reduction(data: T.Buffer((4, 32), "float32"), T_add: T.Buffer((4,), "f with T.init(): data_red_temp_v0[ax0] = T.float32(0) data_red_temp_v1[ax0] = T.float32(0) - v_data_red_temp_v0: T.float32 = data_red_temp_v0[ax0] + data[ax0, k1] - v_data_red_temp_v1: T.float32 = ( + v_data_red_temp_v0: T.let[T.float32] = data_red_temp_v0[ax0] + data[ax0, k1] + v_data_red_temp_v1: T.let[T.float32] = ( data_red_temp_v1[ax0] + data[ax0, k1] * data[ax0, k1] ) data_red_temp_v0[ax0] = v_data_red_temp_v0