From adced7d1edd4f89e26fb5fec4e882dd90150f109 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 8 Mar 2026 13:15:22 +0000 Subject: [PATCH] [TVMScript] Normalize T.Bind to T.bind for statement builder convention --- python/tvm/script/ir_builder/tir/ir.py | 8 +++---- python/tvm/script/parser/tir/parser.py | 6 ++--- .../test_s_tir_transform_thread_sync.py | 16 ++++++------- .../test_tir_analysis_verify_well_formed.py | 24 +++++++++---------- .../test_tir_inline_private_functions.py | 4 ++-- .../test_tir_transform_convert_ssa.py | 6 ++--- .../test_tvmscript_ir_builder_tir.py | 4 ++-- .../tvmscript/test_tvmscript_printer_tir.py | 2 +- .../tvmscript/test_tvmscript_roundtrip.py | 4 ++-- .../tvmscript/test_tvmscript_syntax_sugar.py | 4 ++-- 10 files changed, 39 insertions(+), 39 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index ccdfe3fd6783..ccc730f805b7 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -981,7 +981,7 @@ def Assert(condition: PrimExpr, message, error_kind: str = "RuntimeError") -> fr return _ffi_api.Assert(condition, error_kind, message) # type: ignore[attr-defined] # pylint: disable=no-member -def Bind( # pylint: disable=invalid-name +def bind( value: PrimExpr, type_annotation: Type | None = None, # pylint: disable=redefined-outer-name *, @@ -1024,7 +1024,7 @@ def Let( # pylint: disable=invalid-name return tir.Let(var, value, expr) -bind = Bind +Bind = bind # backward-compat alias def let( @@ -1055,9 +1055,9 @@ def let( def let_expr(v: Var, value: PrimExpr, body: PrimExpr) -> PrimExpr: return tir.Let(v, value, body) - @deprecated("T.let", "T.Bind") + @deprecated("T.let", "T.bind") def let_stmt(v: Var, value: PrimExpr) -> Var: - return Bind(value, var=v) + return bind(value, var=v) if body is None: return let_stmt(v, value) diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 660085ba3cc5..b4d6f88edd00 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -145,7 +145,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - return value else: value = tvm.runtime.convert(value) - var = T.Bind(value) + var = T.bind(value) IRBuilder.name(var_name, var) return var @@ -349,7 +349,7 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: if not isinstance(ann_var, Var): self.report_error(node.annotation, "Annotation should be Var") self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value) - T.Bind(rhs, var=ann_var) + T.bind(rhs, var=ann_var) @dispatch.register(token="tir", type_name="With") @@ -467,7 +467,7 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: res.add_callback(partial(res.__exit__, None, None, None)) res.__enter__() elif isinstance(res, Var): - # Standalone Var expression (e.g. from T.Bind(value, var=v)) -- + # Standalone Var expression (e.g. from T.bind(value, var=v)) -- # the Bind statement was already emitted to the parent frame by the FFI call, # so just discard the returned Var. pass diff --git a/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py index ec4b5afe0cf4..08a51d265556 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py @@ -113,13 +113,13 @@ def func(A: T.Buffer((16 * 512), "float32")): A_shared_1[ax0] = A[blockIdx_x * 512 + ax0] in_thread_A_temp_1 = T.decl_buffer((1,), data=in_thread_A_temp.data, scope="local") in_thread_A_temp_1[0] = T.float32(0) - A_temp_1 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) + A_temp_1 = T.bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x]) in_thread_A_temp_1[0] = A_temp_1 - A_temp_2 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 128]) + A_temp_2 = T.bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 128]) in_thread_A_temp_1[0] = A_temp_2 - A_temp_3 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 256]) + A_temp_3 = T.bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 256]) in_thread_A_temp_1[0] = A_temp_3 - A_temp_4 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 384]) + A_temp_4 = T.bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 384]) in_thread_A_temp_1[0] = A_temp_4 cross_thread_A_temp_1 = T.decl_buffer((1,), data=cross_thread_A_temp.data, scope="local") with T.attr( @@ -148,13 +148,13 @@ def expected(A: T.Buffer((8192,), "float32")): in_thread_A_temp_1_1 = T.decl_buffer((1,), data=in_thread_A_temp_1.data, scope="local") in_thread_A_temp_1_1[0] = T.float32(0) T.tvm_storage_sync("shared") - A_temp_1 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x]) + A_temp_1 = T.bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x]) in_thread_A_temp_1_1[0] = A_temp_1 - A_temp_2 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 128]) + A_temp_2 = T.bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 128]) in_thread_A_temp_1_1[0] = A_temp_2 - A_temp_3 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 256]) + A_temp_3 = T.bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 256]) in_thread_A_temp_1_1[0] = A_temp_3 - A_temp_4 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 384]) + A_temp_4 = T.bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 384]) in_thread_A_temp_1_1[0] = A_temp_4 cross_thread_A_temp_1_1 = T.decl_buffer( (1,), data=cross_thread_A_temp_1.data, scope="local" diff --git a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py index b89a1cb9c739..d6c1dae3b64c 100644 --- a/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py +++ b/tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py @@ -92,8 +92,8 @@ def test_error_for_nested_rebind_usage(): @T.prim_func(check_well_formed=False) def func(): i = T.int32() - T.Bind(42, var=i) - T.Bind(42, var=i) + T.bind(42, var=i) + T.bind(42, var=i) T.evaluate(i) with pytest.raises( @@ -113,9 +113,9 @@ def test_error_for_repeated_binding(): @T.prim_func(check_well_formed=False) def func(): i = T.int32() - T.Bind(42, var=i) + T.bind(42, var=i) T.evaluate(i) - T.Bind(17, var=i) + T.bind(17, var=i) T.evaluate(i) with pytest.raises(ValueError, match="multiple nested definitions of variable i"): @@ -131,12 +131,12 @@ def test_error_for_cross_function_reuse(): class mod: @T.prim_func def func1(): - T.Bind(42, var=i) + T.bind(42, var=i) T.evaluate(i) @T.prim_func def func2(): - T.Bind(42, var=i) + T.bind(42, var=i) T.evaluate(i) with pytest.raises(ValueError, match="multiple definitions of variable i"): @@ -295,10 +295,10 @@ def test_error_message_without_previous_definition_location(): def func(): x = T.int32() - T.Bind(42, var=x) + T.bind(42, var=x) T.evaluate(x) - T.Bind(99, var=x) # This should trigger the error + T.bind(99, var=x) # This should trigger the error T.evaluate(x) with pytest.raises(ValueError) as exc_info: @@ -322,8 +322,8 @@ def test_error_message_with_previous_definition_location(): def func(): x = T.int32() - T.Bind(42, var=x) - T.Bind(99, var=x) # This should trigger the error + T.bind(42, var=x) + T.bind(99, var=x) # This should trigger the error T.evaluate(x) with pytest.raises(ValueError) as exc_info: @@ -351,10 +351,10 @@ def test_sequential_redefinition_with_location(): def func(): x = T.int32() - T.Bind(1, var=x) + T.bind(1, var=x) T.evaluate(x) - T.Bind(2, var=x) # This should trigger the error + T.bind(2, var=x) # This should trigger the error T.evaluate(x) with pytest.raises(ValueError) as exc_info: diff --git a/tests/python/tir-transform/test_tir_inline_private_functions.py b/tests/python/tir-transform/test_tir_inline_private_functions.py index e681073fa6f4..e2f41fda16a1 100644 --- a/tests/python/tir-transform/test_tir_inline_private_functions.py +++ b/tests/python/tir-transform/test_tir_inline_private_functions.py @@ -150,7 +150,7 @@ def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")): class Expected: @T.prim_func def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): - A_data_1 = T.Bind(T.address_of(A[0, 0]), T.handle("float32")) + A_data_1 = T.bind(T.address_of(A[0, 0]), T.handle("float32")) A_1 = T.decl_buffer(16, "float32", data=A_data_1) B_data_1: T.handle("float32") = T.address_of(B[0, 0]) B_1 = T.decl_buffer(16, "float32", data=B_data_1) @@ -158,7 +158,7 @@ def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")): with T.sblock("scalar_mul_1"): B_1[i] = A_1[i] * 2.0 - A_data_2 = T.Bind(T.address_of(A[1, 0]), T.handle("float32")) + A_data_2 = T.bind(T.address_of(A[1, 0]), T.handle("float32")) A_2 = T.decl_buffer(16, "float32", data=A_data_2) B_data_2: T.handle("float32") = T.address_of(B[1, 0]) B_2 = T.decl_buffer(16, "float32", data=B_data_2) diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py b/tests/python/tir-transform/test_tir_transform_convert_ssa.py index df69bc384de1..625001bf9f8f 100644 --- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py +++ b/tests/python/tir-transform/test_tir_transform_convert_ssa.py @@ -42,9 +42,9 @@ def test_reuse_in_sequential_bind(): @T.prim_func(private=True) def expected(): - var1 = T.Bind(T.int32(16)) + var1 = T.bind(T.int32(16)) T.evaluate(var1) - var2 = T.Bind(T.int32(32)) + var2 = T.bind(T.int32(32)) T.evaluate(var2) mod = tvm.IRModule.from_expr(before) @@ -108,7 +108,7 @@ def test_reused_var_across_module(): @T.prim_func(private=True) def func(): - var = T.Bind(10) + var = T.bind(10) T.evaluate(var) before = tvm.IRModule( diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py index ee45aebedaca..460457601ae1 100644 --- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py +++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py @@ -315,9 +315,9 @@ def test_ir_builder_tir_assert(): def test_ir_builder_tir_bind(): - # Test that T.Bind emits a flat Bind statement and returns the Var. + # Test that T.bind emits a flat Bind statement and returns the Var. with IRBuilder() as ib: - v = T.Bind(tir.IntImm("int32", 2)) + v = T.bind(tir.IntImm("int32", 2)) # the let binding generated by IRBuilder let_actual = ib.get() diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 7bf0c9f1d02f..406a8c6a79f1 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -255,7 +255,7 @@ def test_for(): def test_bind(): with IRBuilder() as ib: with T.prim_func(): - v = T.Bind(T.float32(10)) + v = T.bind(T.float32(10)) ib.name("v", v) T.evaluate(1) obj = ib.get() diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 040932720adc..ab64737ce1e0 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -2729,8 +2729,8 @@ def func(): def bind_var(): @T.prim_func def func(): - x = T.Bind(0) - y = T.Bind(0) + x = T.bind(0) + y = T.bind(0) T.evaluate(0) T.evaluate(0) diff --git a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py index f3d19f8ebad2..cc707a8ccf5c 100644 --- a/tests/python/tvmscript/test_tvmscript_syntax_sugar.py +++ b/tests/python/tvmscript/test_tvmscript_syntax_sugar.py @@ -410,7 +410,7 @@ def test_preserve_trivial_let_binding(): @T.prim_func def explicit(i: T.int32): j = T.int32() - T.Bind(i, var=j) + T.bind(i, var=j) T.evaluate(j) @T.prim_func @@ -425,7 +425,7 @@ def test_preserve_trivial_let_binding_of_value(): @T.prim_func def explicit(i: T.int32): j = T.int32() - T.Bind(42, var=j) + T.bind(42, var=j) T.evaluate(j) @T.prim_func