From 666c016ee84771feb90acc893529598004e25346 Mon Sep 17 00:00:00 2001 From: shingjan Date: Wed, 1 Dec 2021 23:22:32 -0800 Subject: [PATCH 1/8] add test file --- .../unittest/test_tvmscript_syntax_sugar.py | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 tests/python/unittest/test_tvmscript_syntax_sugar.py diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py new file mode 100644 index 000000000000..1d4b916e9d4a --- /dev/null +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring,invalid-name,pointless-string-statement +import sys + +import pytest +from tvm.ir import assert_structural_equal +from tvm.script import tir as T + + +@T.prim_func +def transformed_matmul_no_syntax_sugar(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): + with T.block("update"): + vi, vj = T.axis.remap("SS", [i0, i1]) + vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner) + T.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) + T.writes([C[vi, vj], A[vi, vk]]) + with T.init(): + C[vi, vj] = 0.0 + A[vi, vk] = A[vi, vk] + B[vj, vk] + C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) + + +@T.prim_func +def transformed_matmul_syntax_sugar(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): + with T.block("update"): + vi, vj = T.axis.remap("SS", [i0, i1]) + vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner) + T.reads(C[vi, vj], A[vi, vk], B[vj, vk]) + T.writes(C[vi, vj], A[vi, vk]) + with T.init(): + C[vi, vj] = 0.0 + A[vi, vk] = A[vi, vk] + B[vj, vk] + C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) + + +def test_reads_writes_syntax_sugar(): + assert_structural_equal(transformed_matmul_no_syntax_sugar, transformed_matmul_syntax_sugar) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 9f333321d7fee983f7ed161a7dab975b6445f7ae Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 2 Dec 2021 13:05:43 -0800 Subject: [PATCH 2/8] add syntax sugar support --- python/tvm/script/tir/special_stmt.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 5f1b37dd4731..62d83f7a8687 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -17,7 +17,7 @@ """TVM Script Parser Special Stmt Classes""" # pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements # pylint: disable=relative-beyond-top-level -from typing import Callable, List, Optional, Tuple, Any, Mapping, Union +from typing import Callable, List, Optional, Tuple, Any, Mapping, Union, overload import synr from synr import ast @@ -310,7 +310,11 @@ class BlockReads(SpecialStmt): """ def __init__(self): - def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span: Span = None): + def reads( + read_regions: Union[BufferSlice, List[BufferSlice]], + *other_regions: BufferSlice, + span: Span = None, + ): assert self.context, "call 'exit_scope' before 'enter_scope'" block_scope = self.context.current_block_scope() if block_scope is None: @@ -327,6 +331,8 @@ def reads(read_regions: Union[BufferSlice, List[BufferSlice]], span: Span = None ) if isinstance(read_regions, BufferSlice): read_regions = [read_regions] + for region in other_regions: + read_regions.append(region) if not isinstance(read_regions, list): self.context.report_error( "Incorrect input type. " @@ -350,7 +356,11 @@ class BlockWrites(SpecialStmt): """ def __init__(self): - def writes(write_region: Union[BufferSlice, List[BufferSlice]], span: Span = None): + def writes( + write_region: Union[BufferSlice, List[BufferSlice]], + *other_region: BufferSlice, + span: Span = None, + ): assert self.context, "call 'exit_scope' before 'enter_scope'" block_scope = self.context.current_block_scope() if block_scope is None: @@ -369,6 +379,8 @@ def writes(write_region: Union[BufferSlice, List[BufferSlice]], span: Span = Non pass elif isinstance(write_region, BufferSlice): write_region = [write_region] + for region in other_region: + write_region.append(region) else: self.context.report_error( "Incorrect input type. " From 0d69de008e9059e65eed82f25e0493ca736c03ea Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 2 Dec 2021 13:13:58 -0800 Subject: [PATCH 3/8] add comments --- python/tvm/script/tir/special_stmt.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 62d83f7a8687..f56bd28dc4d3 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -300,7 +300,12 @@ def alloc_buffer( @register class BlockReads(SpecialStmt): - """Special function reads([read_buffer_regions]) + """Special function reads([read_regions], *other_regions) + + Note + ---- + *other_region is an unpackable list of BufferSlice to support + reads syntax sugar like reads(BufferRegion1, BufferRegion2, ...) Example ------- @@ -346,7 +351,12 @@ def reads( @register class BlockWrites(SpecialStmt): - """Special function writes([write_buffer_regions]) + """Special function writes([write_regions], *other_regions) + + Note + ---- + *other_region is an unpackable list of BufferSlice to support + writes syntax sugar like writes(BufferRegion1, BufferRegion2, ...) Example ------- From d0e32633cb8a1704d432002da8c90d07c9ad21ee Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 2 Dec 2021 13:15:40 -0800 Subject: [PATCH 4/8] cleanup --- python/tvm/script/tir/special_stmt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index f56bd28dc4d3..16de060e6783 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -17,7 +17,7 @@ """TVM Script Parser Special Stmt Classes""" # pylint: disable=unused-argument, no-self-argument, inconsistent-return-statements # pylint: disable=relative-beyond-top-level -from typing import Callable, List, Optional, Tuple, Any, Mapping, Union, overload +from typing import Callable, List, Optional, Tuple, Any, Mapping, Union import synr from synr import ast From f841ee7b82d734cdf1dd9f7d07a4a9401a93e9a6 Mon Sep 17 00:00:00 2001 From: shingjan Date: Thu, 2 Dec 2021 23:09:59 -0800 Subject: [PATCH 5/8] update stub --- python/tvm/script/tir/__init__.pyi | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index ad0a2507c709..aaecc2233197 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -211,8 +211,12 @@ def alloc_buffer( special_stmt - Reads/Writes """ -def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ... -def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ... +def reads( + read_regions: Union[BufferSlice, List[BufferSlice]], *other_regions: BufferSlice +) -> None: ... +def writes( + write_region: Union[BufferSlice, List[BufferSlice]], *other_regions: BufferSlice +) -> None: ... def block_attr(attrs: Mapping[str, Object]) -> None: ... """ From 9d76e582656109a3359a802769d735901b0b95a7 Mon Sep 17 00:00:00 2001 From: shingjan Date: Fri, 3 Dec 2021 16:30:26 -0800 Subject: [PATCH 6/8] remove failed tests --- tests/python/unittest/test_tvmscript_error_report.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 11b360287cb7..158996515bec 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -411,17 +411,6 @@ def test_error_index_with_stop_slice(): check_error(error_bufferslice_index_with_stop, 8) -def mismatch_args() -> None: - A = T.alloc_buffer((128, 128), "float32") - with T.block(): - T.reads(A[0, 0], A[1, 1]) # error - T.evaluate(1.0) - - -def test_mismatch_args(): - check_error(mismatch_args, 4) - - def special_stmt_except() -> None: A = T.alloc_buffer("(128, 128)", "float32") # error T.evaluate(1.0) From 3f47bb72bd44c69780a8bc8d9df855233bff67ef Mon Sep 17 00:00:00 2001 From: shingjan Date: Mon, 6 Dec 2021 10:37:29 -0800 Subject: [PATCH 7/8] update stub with overload --- python/tvm/script/tir/__init__.pyi | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index aaecc2233197..c47072590f8b 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -210,13 +210,14 @@ def alloc_buffer( """ special_stmt - Reads/Writes """ - -def reads( - read_regions: Union[BufferSlice, List[BufferSlice]], *other_regions: BufferSlice -) -> None: ... -def writes( - write_region: Union[BufferSlice, List[BufferSlice]], *other_regions: BufferSlice -) -> None: ... +@overload +def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ... +@overload +def reads(read_regions: BufferSlice, *other_regions: BufferSlice) -> None: ... +@overload +def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ... +@overload +def writes(write_region: BufferSlice, *other_regions: BufferSlice) -> None: ... def block_attr(attrs: Mapping[str, Object]) -> None: ... """ From ef411bc0c0cb028665971a75029babef18085069 Mon Sep 17 00:00:00 2001 From: shingjan Date: Mon, 6 Dec 2021 10:51:06 -0800 Subject: [PATCH 8/8] address comments --- python/tvm/script/tir/__init__.pyi | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index c47072590f8b..a197d7467a20 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -211,13 +211,13 @@ def alloc_buffer( special_stmt - Reads/Writes """ @overload -def reads(read_regions: Union[BufferSlice, List[BufferSlice]]) -> None: ... +def reads(read_regions: List[BufferSlice]) -> None: ... @overload -def reads(read_regions: BufferSlice, *other_regions: BufferSlice) -> None: ... +def reads(*read_regions: BufferSlice) -> None: ... @overload -def writes(write_region: Union[BufferSlice, List[BufferSlice]]) -> None: ... +def writes(write_region: List[BufferSlice]) -> None: ... @overload -def writes(write_region: BufferSlice, *other_regions: BufferSlice) -> None: ... +def writes(*write_region: BufferSlice) -> None: ... def block_attr(attrs: Mapping[str, Object]) -> None: ... """