Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,7 @@ def Assert(condition: PrimExpr, message, error_kind: str = "RuntimeError") -> fr
return _ffi_api.Assert(condition, error_kind, message) # type: ignore[attr-defined] # pylint: disable=no-member


def Bind( # pylint: disable=invalid-name
def bind(
value: PrimExpr,
type_annotation: Type | None = None, # pylint: disable=redefined-outer-name
*,
Expand Down Expand Up @@ -1024,7 +1024,7 @@ def Let( # pylint: disable=invalid-name
return tir.Let(var, value, expr)


bind = Bind
Bind = bind # backward-compat alias


def let(
Expand Down Expand Up @@ -1055,9 +1055,9 @@ def let(
def let_expr(v: Var, value: PrimExpr, body: PrimExpr) -> PrimExpr:
return tir.Let(v, value, body)

@deprecated("T.let", "T.Bind")
@deprecated("T.let", "T.bind")
def let_stmt(v: Var, value: PrimExpr) -> Var:
return Bind(value, var=v)
return bind(value, var=v)

if body is None:
return let_stmt(v, value)
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -
return value
else:
value = tvm.runtime.convert(value)
var = T.Bind(value)
var = T.bind(value)
IRBuilder.name(var_name, var)
return var

Expand Down Expand Up @@ -349,7 +349,7 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None:
if not isinstance(ann_var, Var):
self.report_error(node.annotation, "Annotation should be Var")
self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value)
T.Bind(rhs, var=ann_var)
T.bind(rhs, var=ann_var)


@dispatch.register(token="tir", type_name="With")
Expand Down Expand Up @@ -467,7 +467,7 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
res.add_callback(partial(res.__exit__, None, None, None))
res.__enter__()
elif isinstance(res, Var):
# Standalone Var expression (e.g. from T.Bind(value, var=v)) --
# Standalone Var expression (e.g. from T.bind(value, var=v)) --
# the Bind statement was already emitted to the parent frame by the FFI call,
# so just discard the returned Var.
pass
Expand Down
16 changes: 8 additions & 8 deletions tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,13 @@ def func(A: T.Buffer((16 * 512), "float32")):
A_shared_1[ax0] = A[blockIdx_x * 512 + ax0]
in_thread_A_temp_1 = T.decl_buffer((1,), data=in_thread_A_temp.data, scope="local")
in_thread_A_temp_1[0] = T.float32(0)
A_temp_1 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x])
A_temp_1 = T.bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x])
in_thread_A_temp_1[0] = A_temp_1
A_temp_2 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 128])
A_temp_2 = T.bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 128])
in_thread_A_temp_1[0] = A_temp_2
A_temp_3 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 256])
A_temp_3 = T.bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 256])
in_thread_A_temp_1[0] = A_temp_3
A_temp_4 = T.Bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 384])
A_temp_4 = T.bind(in_thread_A_temp_1[0] + A_shared_1[threadIdx_x + 384])
in_thread_A_temp_1[0] = A_temp_4
cross_thread_A_temp_1 = T.decl_buffer((1,), data=cross_thread_A_temp.data, scope="local")
with T.attr(
Expand Down Expand Up @@ -148,13 +148,13 @@ def expected(A: T.Buffer((8192,), "float32")):
in_thread_A_temp_1_1 = T.decl_buffer((1,), data=in_thread_A_temp_1.data, scope="local")
in_thread_A_temp_1_1[0] = T.float32(0)
T.tvm_storage_sync("shared")
A_temp_1 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x])
A_temp_1 = T.bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x])
in_thread_A_temp_1_1[0] = A_temp_1
A_temp_2 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 128])
A_temp_2 = T.bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 128])
in_thread_A_temp_1_1[0] = A_temp_2
A_temp_3 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 256])
A_temp_3 = T.bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 256])
in_thread_A_temp_1_1[0] = A_temp_3
A_temp_4 = T.Bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 384])
A_temp_4 = T.bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 384])
in_thread_A_temp_1_1[0] = A_temp_4
cross_thread_A_temp_1_1 = T.decl_buffer(
(1,), data=cross_thread_A_temp_1.data, scope="local"
Expand Down
24 changes: 12 additions & 12 deletions tests/python/tir-analysis/test_tir_analysis_verify_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def test_error_for_nested_rebind_usage():
@T.prim_func(check_well_formed=False)
def func():
i = T.int32()
T.Bind(42, var=i)
T.Bind(42, var=i)
T.bind(42, var=i)
T.bind(42, var=i)
T.evaluate(i)

with pytest.raises(
Expand All @@ -113,9 +113,9 @@ def test_error_for_repeated_binding():
@T.prim_func(check_well_formed=False)
def func():
i = T.int32()
T.Bind(42, var=i)
T.bind(42, var=i)
T.evaluate(i)
T.Bind(17, var=i)
T.bind(17, var=i)
T.evaluate(i)

with pytest.raises(ValueError, match="multiple nested definitions of variable i"):
Expand All @@ -131,12 +131,12 @@ def test_error_for_cross_function_reuse():
class mod:
@T.prim_func
def func1():
T.Bind(42, var=i)
T.bind(42, var=i)
T.evaluate(i)

@T.prim_func
def func2():
T.Bind(42, var=i)
T.bind(42, var=i)
T.evaluate(i)

with pytest.raises(ValueError, match="multiple definitions of variable i"):
Expand Down Expand Up @@ -295,10 +295,10 @@ def test_error_message_without_previous_definition_location():
def func():
x = T.int32()

T.Bind(42, var=x)
T.bind(42, var=x)
T.evaluate(x)

T.Bind(99, var=x) # This should trigger the error
T.bind(99, var=x) # This should trigger the error
T.evaluate(x)

with pytest.raises(ValueError) as exc_info:
Expand All @@ -322,8 +322,8 @@ def test_error_message_with_previous_definition_location():
def func():
x = T.int32()

T.Bind(42, var=x)
T.Bind(99, var=x) # This should trigger the error
T.bind(42, var=x)
T.bind(99, var=x) # This should trigger the error
T.evaluate(x)

with pytest.raises(ValueError) as exc_info:
Expand Down Expand Up @@ -351,10 +351,10 @@ def test_sequential_redefinition_with_location():
def func():
x = T.int32()

T.Bind(1, var=x)
T.bind(1, var=x)
T.evaluate(x)

T.Bind(2, var=x) # This should trigger the error
T.bind(2, var=x) # This should trigger the error
T.evaluate(x)

with pytest.raises(ValueError) as exc_info:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,15 @@ def subroutine(A_data: T.handle("float32"), B_data: T.handle("float32")):
class Expected:
@T.prim_func
def main(A: T.Buffer([80, 16], "float32"), B: T.Buffer([64, 16], "float32")):
A_data_1 = T.Bind(T.address_of(A[0, 0]), T.handle("float32"))
A_data_1 = T.bind(T.address_of(A[0, 0]), T.handle("float32"))
A_1 = T.decl_buffer(16, "float32", data=A_data_1)
B_data_1: T.handle("float32") = T.address_of(B[0, 0])
B_1 = T.decl_buffer(16, "float32", data=B_data_1)
for i in range(16):
with T.sblock("scalar_mul_1"):
B_1[i] = A_1[i] * 2.0

A_data_2 = T.Bind(T.address_of(A[1, 0]), T.handle("float32"))
A_data_2 = T.bind(T.address_of(A[1, 0]), T.handle("float32"))
A_2 = T.decl_buffer(16, "float32", data=A_data_2)
B_data_2: T.handle("float32") = T.address_of(B[1, 0])
B_2 = T.decl_buffer(16, "float32", data=B_data_2)
Expand Down
6 changes: 3 additions & 3 deletions tests/python/tir-transform/test_tir_transform_convert_ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def test_reuse_in_sequential_bind():

@T.prim_func(private=True)
def expected():
var1 = T.Bind(T.int32(16))
var1 = T.bind(T.int32(16))
T.evaluate(var1)
var2 = T.Bind(T.int32(32))
var2 = T.bind(T.int32(32))
T.evaluate(var2)

mod = tvm.IRModule.from_expr(before)
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_reused_var_across_module():

@T.prim_func(private=True)
def func():
var = T.Bind(10)
var = T.bind(10)
T.evaluate(var)

before = tvm.IRModule(
Expand Down
4 changes: 2 additions & 2 deletions tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,9 @@ def test_ir_builder_tir_assert():


def test_ir_builder_tir_bind():
# Test that T.Bind emits a flat Bind statement and returns the Var.
# Test that T.bind emits a flat Bind statement and returns the Var.
with IRBuilder() as ib:
v = T.Bind(tir.IntImm("int32", 2))
v = T.bind(tir.IntImm("int32", 2))
# the let binding generated by IRBuilder
let_actual = ib.get()

Expand Down
2 changes: 1 addition & 1 deletion tests/python/tvmscript/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def test_for():
def test_bind():
with IRBuilder() as ib:
with T.prim_func():
v = T.Bind(T.float32(10))
v = T.bind(T.float32(10))
ib.name("v", v)
T.evaluate(1)
obj = ib.get()
Expand Down
4 changes: 2 additions & 2 deletions tests/python/tvmscript/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2729,8 +2729,8 @@ def func():
def bind_var():
@T.prim_func
def func():
x = T.Bind(0)
y = T.Bind(0)
x = T.bind(0)
y = T.bind(0)
T.evaluate(0)
T.evaluate(0)

Expand Down
4 changes: 2 additions & 2 deletions tests/python/tvmscript/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def test_preserve_trivial_let_binding():
@T.prim_func
def explicit(i: T.int32):
j = T.int32()
T.Bind(i, var=j)
T.bind(i, var=j)
T.evaluate(j)

@T.prim_func
Expand All @@ -425,7 +425,7 @@ def test_preserve_trivial_let_binding_of_value():
@T.prim_func
def explicit(i: T.int32):
j = T.int32()
T.Bind(42, var=j)
T.bind(42, var=j)
T.evaluate(j)

@T.prim_func
Expand Down
Loading