diff --git a/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py index f39ccb6fde1f..b719416e6290 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py @@ -23,6 +23,18 @@ from tvm.script import tirx as T +def _has_volatile_alloc_buffer(mod): + has_volatile_alloc = False + + def visit(node): + nonlocal has_volatile_alloc + if isinstance(node, tvm.tirx.AllocBuffer) and "tirx.volatile" in node.annotations: + has_volatile_alloc = has_volatile_alloc or node.annotations["tirx.volatile"] is True + + tvm.tirx.stmt_functor.post_order_visit(mod["main"].body, visit) + return has_volatile_alloc + + def test_basic(): transform = tvm.s_tir.transform.LowerThreadAllreduce() @@ -503,7 +515,7 @@ def main(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32 After_script = After.script() assert "tvm_warp_shuffle_down" in After_script assert "tvm_storage_sync" in After_script - assert '"tirx.volatile": T.bool(True)' in After_script + assert _has_volatile_alloc_buffer(After) assert "T.uint32(" not in After_script