From 9f4f6f34101a2eb35a5c1c122b0cecc5fc1f25ca Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 23 Sep 2022 14:57:33 -0500 Subject: [PATCH 1/2] [TVMScript] Infer T.match_buffer parameters for region When using `T.match_buffer` to define a view into another buffer, default shape and dtype parameters can be inferred. --- python/tvm/script/tir/special_stmt.py | 68 ++++++++++++++----- .../unittest/test_tvmscript_syntax_sugar.py | 25 +++++++ 2 files changed, 77 insertions(+), 16 deletions(-) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 15502055b7fc..7cbf47441053 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -121,8 +121,8 @@ class MatchBuffer(SpecialStmt): def __init__(self): def match_buffer( param, - shape, - dtype="float32", + shape=None, + dtype=None, data=None, strides=None, elem_offset=None, @@ -146,28 +146,64 @@ def match_buffer( offset_factor, "offset_factor", self.context.report_error, self.node.span ) buffer_name: str = self.node.lhs[0].id.name - buffer = tvm.tir.decl_buffer( - shape, - dtype, - buffer_name, - data, - strides, - elem_offset, - scope, - align, - offset_factor, - buffer_type, - axis_separators, - span=span, - ) + if isinstance(param, tvm.tir.Var): + if shape is None: + self.context.report_error( + "Shape must be specified when binding input param", + self.node.rhs.span, + ) + + if dtype is None: + dtype = "float32" + + buffer = tvm.tir.decl_buffer( + shape, + dtype, + buffer_name, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + span=span, + ) if param not in self.context.func_params: self.context.report_error( "Can not bind non-input param to buffer", self.node.rhs.params[0].span ) self.context.func_buffer_map[param] = buffer + elif isinstance(param, BufferSlice): buffer_region = param.as_buffer_region() + + if shape is None: + shape = [dim.extent for dim in buffer_region.region] + + if dtype is None: + dtype = buffer_region.buffer.dtype + + if elem_offset is None and offset_factor == 0: + offset_factor = 1 + + buffer = tvm.tir.decl_buffer( + shape, + dtype, + buffer_name, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + axis_separators, + span=span, + ) + self.context.current_block_scope().match_buffers.append( tvm.tir.MatchBufferRegion(buffer, buffer_region) ) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index d955ec0a8c80..2a2f7354d7cd 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -251,6 +251,31 @@ def test_match_buffer_int64(): assert_structural_equal(original, after_roundtrip, True) +def test_match_buffer_region_has_implicit_shape_dtype(): + @T.prim_func + def explicit_shape_dtype(A: T.Buffer[(16, 64), "int32"]): + with T.block(): + B = T.match_buffer(A[8:16, 32:64], shape=(8, 32), dtype="int32") + T.evaluate(0) + + @T.prim_func + def implicit_shape_dtype(A: T.Buffer[(16, 64), "int32"]): + with T.block(): + B = T.match_buffer(A[8:16, 32:64]) + T.evaluate(0) + + assert_structural_equal(explicit_shape_dtype, implicit_shape_dtype) + + +def test_match_buffer_input_requires_shape_arg(): + with pytest.raises(tvm.error.DiagnosticError): + + @T.prim_func + def func(a: T.handle): + A = T.match_buffer(a, dtype="int32") + T.evaluate(0) + + def test_letstmt_bufferload_without_type_annotation(): # Variable assignment of PrimExpr types uses the dtype of the # PrimExpr to determine the variable's dtype. Parsing of From 515330e3893a64b8a2a0ded490c231a0fa6168c0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 26 Sep 2022 10:49:24 -0500 Subject: [PATCH 2/2] Updated unit test for new behavior The test intentionally triggers a failed match based on mismatched `elem_offset`. Therefore, the test now needs to explicitly pass an `elem_offset` to trigger the failure, as this now defaults to having a `Var` for `match_buffer` calls that represent views. --- tests/python/unittest/test_tir_lower_match_buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index 93b7caf9cdde..6120cf2b673c 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -464,7 +464,7 @@ def fail_match_load(a: T.handle) -> None: with T.block(): T.reads(A[i, j]) T.writes([]) - sub_A = T.match_buffer(A[i, j], ()) + sub_A = T.match_buffer(A[i, j], (), elem_offset=0) T.evaluate(sub_A[()]) @@ -475,7 +475,7 @@ def fail_match_store(a: T.handle) -> None: with T.block(): T.reads([]) T.writes(A[i, j]) - sub_A = T.match_buffer(A[i, j], ()) + sub_A = T.match_buffer(A[i, j], (), elem_offset=0) sub_A[()] = 1