From 4e7227668eb1b5e985c44ececf93992b8725a00c Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 10 Nov 2022 23:50:08 -0800 Subject: [PATCH] [IRBuilder][Minor] Add intrinsics like `T.int32x4` This PR adds all common TIR intrinsics like `T.int32x4`, `T.floatx4`. Co-authored-by: Yaxing Cai --- include/tvm/script/ir_builder/tir/frame.h | 16 +- include/tvm/script/ir_builder/tir/ir.h | 46 +- python/tvm/script/ir_builder/tir/frame.py | 4 +- python/tvm/script/ir_builder/tir/ir.py | 473 +++++++----------- python/tvm/tir/op.py | 57 ++- src/script/ir_builder/tir/frame.cc | 14 +- src/script/ir_builder/tir/ir.cc | 56 ++- .../unittest/test_tvmscript_error_report.py | 2 +- .../unittest/test_tvmscript_ir_builder_tir.py | 21 +- 9 files changed, 348 insertions(+), 341 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index aa2386e7f1e4..b95d575360e6 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -453,8 +453,8 @@ class AllocateFrameNode : public TIRFrameNode { PrimExpr condition; /*! \brief Additional annotation hints. */ Map annotations; - /*! \brief The buffer. */ - tvm::tir::Buffer buffer; + /*! \brief The buffer var. */ + tvm::tir::Var buffer_var; void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); @@ -463,7 +463,7 @@ class AllocateFrameNode : public TIRFrameNode { v->Visit("storage_scope", &storage_scope); v->Visit("condition", &condition); v->Visit("annotations", &annotations); - v->Visit("buffer", &buffer); + v->Visit("buffer_var", &buffer_var); } static constexpr const char* _type_key = "script.ir_builder.tir.AllocateFrame"; @@ -500,8 +500,8 @@ class AllocateConstFrameNode : public TIRFrameNode { Array extents; /*! \brief The data associated with the constant. */ tvm::runtime::NDArray data; - /*! \brief The buffer */ - tvm::tir::Buffer buffer; + /*! \brief The buffer var */ + tvm::tir::Var buffer_var; /*! \brief Additional annotations about the allocation. */ Map annotations; @@ -510,7 +510,7 @@ class AllocateConstFrameNode : public TIRFrameNode { v->Visit("dtype", &dtype); v->Visit("extents", &extents); v->Visit("data", &data); - v->Visit("buffer", &buffer); + v->Visit("buffer_var", &buffer_var); v->Visit("annotations", &annotations); } @@ -723,11 +723,15 @@ class ElseFrame : public TIRFrame { class DeclBufferFrameNode : public TIRFrameNode { public: + /*! \brief The declared buffer. */ tvm::tir::Buffer buffer; + /*! \brief The buffer allocated or not. */ + bool allocated; void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); v->Visit("buffer", &buffer); + v->Visit("allocated", &allocated); } static constexpr const char* _type_key = "script.ir_builder.tir.DeclBufferFrame"; diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 7460099f9448..d9e1a1b49063 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -339,9 +339,8 @@ AllocateFrame Allocate(Array extents, DataType dtype, String storage_s * \param annotations Additional annotation hints. * \return The created AllocateConstFrame. */ -AllocateConstFrame AllocateConst( - NDArray data, DataType dtype, Array extents, - Map annotations = NullValue>()); +AllocateConstFrame AllocateConst(NDArray data, DataType dtype, Array extents, + Optional> annotations = NullOpt); /*! * \brief Create an attribute. @@ -449,21 +448,32 @@ PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global"); return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \ } -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int8, DataType::Int(8)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int16, DataType::Int(16)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32, DataType::Int(32)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int64, DataType::Int(64)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt8, DataType::UInt(8)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt16, DataType::UInt(16)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt32, DataType::UInt(32)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt64, DataType::UInt(64)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float8, DataType::Float(8)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float16, DataType::Float(16)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float32, DataType::Float(32)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float64, DataType::Float(64)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x4, DataType::Int(32, 4)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x8, DataType::Int(32, 8)); -TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x16, DataType::Int(32, 16)); +#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##8, FDType(8)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##16, FDType(16)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64)); + +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Float, DataType::Float); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, DataType::UInt); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int); + +#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(FuncName, FDType, Size) \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x4, FDType(Size, 4)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x8, FDType(Size, 8)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x16, FDType(Size, 16)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x32, FDType(Size, 32)); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x64, FDType(Size, 64)); + +#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, FDType) \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##8, FDType, 8); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##16, FDType, 16); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32); \ + TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64); + +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt); +TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool()); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle()); TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void()); diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index b9b50dfa9876..a57c878bd929 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -69,14 +69,14 @@ class RealizeFrame(TIRFrame): class AllocateFrame(TIRFrame): def __enter__(self) -> Buffer: super().__enter__() - return self.buffer + return self.buffer_var @_register_object("script.ir_builder.tir.AllocateConstFrame") class AllocateConstFrame(TIRFrame): def __enter__(self) -> Buffer: super().__enter__() - return self.buffer + return self.buffer_var @_register_object("script.ir_builder.tir.AttrFrame") diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 4ec1511f2907..bd9e4e1db522 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -14,41 +14,75 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-docstring """IRBuilder for TIR""" -import inspect import functools +import inspect from numbers import Integral -from typing import Any, Callable, Dict, List, Optional, Union, Tuple -import numpy as np # type: ignore +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +# isort: off +from typing_extensions import Literal +# isort: on + +import numpy as np # type: ignore from tvm.ir import Range, Type from tvm.runtime import convert, ndarray +from tvm.target import Target + +# pylint: disable=unused-import from tvm.target.codegen import llvm_lookup_intrinsic_id -from tvm.tir import ( - Buffer, +from tvm.tir import Buffer, BufferRegion, PrimExpr +from tvm.tir import op as _tir_op +from tvm.tir import type_annotation + +# import tir.expr for direct ir construction to pass structural_equal comparison +from tvm.tir.expr import ( + EQ, + GE, + GT, + LE, + LT, + NE, + Add, + And, + Broadcast, BufferLoad, - BufferRegion, + Call, + CallEffectKind, Cast, CommReducer, + Div, + FloatImm, + FloorDiv, + FloorMod, IntImm, IterVar, Let, - PrimExpr, + Load, + Max, + Min, + Mod, + Mul, + Not, + Or, + ProducerLoad, + Ramp, + Reduce, Select, Shuffle, + SizeVar, StringImm, - type_annotation, + Sub, Var, ) -from tvm.tir import Broadcast as broadcast -from tvm.tir import Ramp as ramp -from tvm.tir import op as _tir_op from tvm.tir.generic import cast from . import _ffi_api, frame +# pylint: enable=unused-import + def buffer_decl( shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], @@ -56,7 +90,7 @@ def buffer_decl( data: Var = None, strides: List[PrimExpr] = None, elem_offset: PrimExpr = None, - scope: str = "", + scope: str = "global", align: int = 0, offset_factor: int = 0, buffer_type: str = "", @@ -187,7 +221,7 @@ def func_ret(ret_type: Type) -> Type: def match_buffer( param: Union[Var, BufferLoad, BufferRegion], - shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral], + shape: Union[List[PrimExpr], Tuple[PrimExpr], PrimExpr, Integral] = None, dtype: str = "float32", data: Var = None, strides: List[PrimExpr] = None, @@ -256,6 +290,12 @@ def match_buffer( res : Buffer The matched buffer. """ + if shape is None: + if isinstance(param, BufferRegion): + dtype = param.buffer.dtype + shape = [region.extent for region in param.region] + else: + raise ValueError("Shape must be specified when binding input param") shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape if strides is None: strides = [] @@ -447,7 +487,7 @@ def alloc_buffer( data: Var = None, strides: List[PrimExpr] = None, elem_offset: PrimExpr = None, - scope: str = "", + scope: str = "global", align: int = -1, offset_factor: int = 0, buffer_type: str = "default", @@ -526,10 +566,14 @@ def _as_range(dom: Union[Range, List[PrimExpr]]) -> Range: return dom if isinstance(dom, (list, tuple)): return Range(dom[0], dom[1]) + if hasattr(dom, "dtype"): + return Range(IntImm(dom.dtype, 0), dom) return Range(0, dom) class axis: # pylint: disable=invalid-name + """The axis class""" + @staticmethod def spatial( dom: Union[Range, List[PrimExpr], Tuple[PrimExpr]], binding: PrimExpr, dtype: str = "int32" @@ -686,7 +730,10 @@ def serial( """ if stop is None: stop = start - start = 0 + if hasattr(start, "dtype"): + start = IntImm(start.dtype, 0) + else: + start = 0 return _ffi_api.Serial(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member @@ -713,7 +760,10 @@ def parallel( """ if stop is None: stop = start - start = 0 + if hasattr(start, "dtype"): + start = IntImm(start.dtype, 0) + else: + start = 0 return _ffi_api.Parallel(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member @@ -740,7 +790,10 @@ def vectorized( """ if stop is None: stop = start - start = 0 + if hasattr(start, "dtype"): + start = IntImm(start.dtype, 0) + else: + start = 0 return _ffi_api.Vectorized(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member @@ -767,7 +820,10 @@ def unroll( """ if stop is None: stop = start - start = 0 + if hasattr(start, "dtype"): + start = IntImm(start.dtype, 0) + else: + start = 0 return _ffi_api.Unroll(start, stop, annotations) # type: ignore[attr-defined] # pylint: disable=no-member @@ -804,10 +860,16 @@ def thread_binding( raise ValueError("Thread cannot be None for thread_binding") thread = stop stop = start - start = 0 + if hasattr(start, "dtype"): + start = IntImm(start.dtype, 0) + else: + start = 0 elif stop is None: stop = start - start = 0 + if hasattr(start, "dtype"): + start = IntImm(start.dtype, 0) + else: + start = 0 return _ffi_api.ThreadBinding( # type: ignore[attr-defined] # pylint: disable=no-member start, stop, thread, annotations ) @@ -907,7 +969,7 @@ def realize( def allocate( extents: List[PrimExpr], dtype: str, - scope: str = "", + scope: str = "global", condition: PrimExpr = None, annotations=None, ) -> frame.AllocateFrame: @@ -959,9 +1021,18 @@ def allocate_const( annotations : Optional[Map] Additional annotations about the allocation. """ + np_data = np.asarray(data, dtype=dtype) + prod_extent = 1 + for extent in extents: + prod_extent *= extent + prod_shape = 1 + for shape in np_data.shape: + prod_shape *= shape + if prod_extent == prod_shape: + np_data = np_data.reshape(extents) return _ffi_api.AllocateConst( # type: ignore[attr-defined] # pylint: disable=no-member - ndarray.array(np.asarray(data, dtype)), dtype, extents, annotations + ndarray.array(np_data), dtype, extents, annotations ) @@ -1054,7 +1125,7 @@ def decl_buffer( data=None, strides=None, elem_offset=None, - scope="", + scope="global", align=0, offset_factor=0, buffer_type="", @@ -1221,247 +1292,41 @@ def evaluate(value: PrimExpr) -> None: """ if isinstance(value, str): value = StringImm(value) + if isinstance(value, bool): + value = cast(value, "bool") return _ffi_api.Evaluate(value) # type: ignore[attr-defined] # pylint: disable=no-member -def int8(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type int8 or cast expression to type int8. +__all__ = [] +for _dtype in ["Float", "UInt", "Int"]: + for _size in ["8", "16", "32", "64"]: + for _lanes in ["", "x4", "x8", "x16", "x32", "x64"]: + _name = _dtype + _size + _lanes # pylint: disable=invalid-name - Parameters - ---------- - expr: PrimExpr - The expression to be cast. + def func_gen(name: str): + """Generate a function for each PrimExpr dtype. - Returns - ------- - res : PrimExpr - The new tir.Var with type int8 or casted expression with type int8. - """ - return _ffi_api.Int8(expr) # type: ignore[attr-defined] # pylint: disable=no-member - - -def int16(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type int16 or cast expression to type int16. + Parameters + ---------- + name: str + The ffi function name to call. + """ - Parameters - ---------- - expr: PrimExpr - The expression to be cast. + def func( + expr: Union[ + None, + PrimExpr, + Literal["inf", "-inf", "nan"], + ] = None + ) -> PrimExpr: + if isinstance(expr, str): + expr = float(expr) + return getattr(_ffi_api, name)(expr) - Returns - ------- - res : PrimExpr - The new tir.Var with type int16 or casted expression with type int16. - """ - return _ffi_api.Int16(expr) # type: ignore[attr-defined] # pylint: disable=no-member + return func - -def int32(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type int32 or cast expression to type int32. - - Parameters - ---------- - expr: PrimExpr - The expression to be cast. - - Returns - ------- - res : PrimExpr - The new tir.Var with type int32 or casted expression with type int32. - """ - return _ffi_api.Int32(expr) # type: ignore[attr-defined] # pylint: disable=no-member - - -def int64(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type int64 or cast expression to type int64. - - Parameters - ---------- - expr: PrimExpr - The expression to be cast. - - Returns - ------- - res : PrimExpr - The new tir.Var with type int64 or casted expression with type int64. - """ - return _ffi_api.Int64(expr) # type: ignore[attr-defined] # pylint: disable=no-member - - -def uint8(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type uint8 or cast expression to type uint8. - - Parameters - ---------- - expr: PrimExpr - The expression to be cast. - - Returns - ------- - res : PrimExpr - The new tir.Var with type uint8 or casted expression with type uint8. - """ - return _ffi_api.UInt8(expr) # type: ignore[attr-defined] # pylint: disable=no-member - - -def uint16(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type uint16 or cast expression to type uint16. - - Parameters - ---------- - expr: PrimExpr - The expression to be cast. - - Returns - ------- - res : PrimExpr - The new tir.Var with type uint16 or casted expression with type uint16. - """ - return _ffi_api.UInt16(expr) # type: ignore[attr-defined] # pylint: disable=no-member - - -def uint32(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type uint32 or cast expression to type uint32. - - Parameters - ---------- - expr: PrimExpr - The expression to be cast. - - Returns - ------- - res : PrimExpr - The new tir.Var with type uint32 or casted expression with type uint32. - """ - return _ffi_api.UInt32(expr) # type: ignore[attr-defined] # pylint: disable=no-member - - -def uint64(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type uint64 or cast expression to type uint64. - - Parameters - ---------- - expr: PrimExpr - The expression to be cast. - - Returns - ------- - res : PrimExpr - The new tir.Var with type uint64 or casted expression with type uint64. - """ - return _ffi_api.UInt64(expr) # type: ignore[attr-defined] # pylint: disable=no-member - - -def float8(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type float8 or cast expression to type float8. - - Parameters - ---------- - expr: PrimExpr - The expression to be cast. - - Returns - ------- - res : PrimExpr - The new tir.Var with type float8 or casted expression with type float8. - """ - return _ffi_api.Float8(expr) # type: ignore[attr-defined] # pylint: disable=no-member - - -def float16(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type float16 or cast expression to type float16. - - Parameters - ---------- - expr: PrimExpr - The expression to be cast. - - Returns - ------- - res : PrimExpr - The new tir.Var with type float16 or casted expression with type float16. - """ - return _ffi_api.Float16(expr) # type: ignore[attr-defined] # pylint: disable=no-member - - -def float32(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type float32 or cast expression to type float32. - - Parameters - ---------- - expr: PrimExpr - The expression to be cast. - - Returns - ------- - res : PrimExpr - The new tir.Var with type float32 or casted expression with type float32. - """ - return _ffi_api.Float32(expr) # type: ignore[attr-defined] # pylint: disable=no-member - - -def float64(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type float64 or cast expression to type float64. - - Parameters - ---------- - expr: PrimExpr - The expression to be cast. - - Returns - ------- - res : PrimExpr - The new tir.Var with type float64 or casted expression with type float64. - """ - return _ffi_api.Float64(expr) # type: ignore[attr-defined] # pylint: disable=no-member - - -def int32x4(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type int32x4 or cast expression to type int32x4. - - Parameters - ---------- - expr: PrimExpr - The expression to be cast. - - Returns - ------- - res : PrimExpr - The new tir.Var with type int32x4 or casted expression with type int32x4. - """ - return _ffi_api.Int32x4(expr) # type: ignore[attr-defined] # pylint: disable=no-member - - -def int32x8(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type int32x8 or cast expression to type int32x8. - - Parameters - ---------- - expr: PrimExpr - The expression to be cast. - - Returns - ------- - res : PrimExpr - The new tir.Var with type int32x8 or casted expression with type int32x8. - """ - return _ffi_api.Int32x8(expr) # type: ignore[attr-defined] # pylint: disable=no-member - - -def int32x16(expr: Optional[PrimExpr] = None) -> PrimExpr: - """Construct a new tir.Var with type int32x16 or cast expression to type int32x16. - - Parameters - ---------- - expr: PrimExpr - The expression to be cast. - - Returns - ------- - res : PrimExpr - The new tir.Var with type int32x16 or casted expression with type int32x16. - """ - return _ffi_api.Int32x16(expr) # type: ignore[attr-defined] # pylint: disable=no-member + globals()[_name.lower()] = func_gen(_name) + __all__.append(_name.lower()) def boolean(expr: Optional[PrimExpr] = None) -> PrimExpr: @@ -1645,6 +1510,27 @@ def comm_reducer(combiner: Callable, identity: List[PrimExpr]) -> CommReducer: return CommReducer(args[: num_args // 2], args[num_args // 2 :], res, identity) +def target(target_config: Union[Dict, str]) -> Target: + """ + Create a target + + Parameters + ---------- + target_config : Union[Dict, str] + The target configuration. + + Returns + ------- + res : Target + The target. + """ + if not isinstance(target_config, (str, dict)): + raise ValueError( + f"T.target expected a config dict or string, but got {type(target_config)}" + ) + return Target(target_config) + + def _op_wrapper(func): @functools.wraps(func) def wrapped(*args, **kwargs): @@ -1667,6 +1553,9 @@ def wrapped(*args, **kwargs): # pylint: disable=invalid-name +broadcast = Broadcast +ramp = Ramp + buffer_var = ptr abs = _op_wrapper(_tir_op.abs) # pylint: disable=redefined-builtin fabs = abs @@ -1713,6 +1602,7 @@ def wrapped(*args, **kwargs): popcount = _op_wrapper(_tir_op.popcount) power = _op_wrapper(_tir_op.power) q_multiply_shift = _op_wrapper(_tir_op.q_multiply_shift) +q_multiply_shift_per_axis = _op_wrapper(_tir_op.q_multiply_shift_per_axis) ret = _op_wrapper(_tir_op.ret) reinterpret = _dtype_forward(_tir_op.reinterpret) round = _op_wrapper(_tir_op.round) # pylint: disable=redefined-builtin @@ -1733,6 +1623,7 @@ def wrapped(*args, **kwargs): tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca) tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape) tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array) +tvm_check_return = _op_wrapper(_tir_op.tvm_check_return) call_packed = _op_wrapper(_tir_op.call_packed) call_cpacked = _op_wrapper(_tir_op.call_cpacked) call_packed_lowered = _op_wrapper(_tir_op.call_packed_lowered) @@ -1742,7 +1633,6 @@ def wrapped(*args, **kwargs): call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin) call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin) call_pure_extern = _dtype_forward(_tir_op.call_pure_extern) -tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr) tvm_tuple = _op_wrapper(_tir_op.tvm_tuple) tvm_struct_set = _op_wrapper(_tir_op.tvm_struct_set) tvm_struct_get = _tir_op.tvm_struct_get @@ -1771,6 +1661,8 @@ def wrapped(*args, **kwargs): tvm_call_cpacked_lowered = call_cpacked_lowered TVMBackendAllocWorkspace = _op_wrapper(_tir_op.TVMBackendAllocWorkspace) TVMBackendFreeWorkspace = _op_wrapper(_tir_op.TVMBackendFreeWorkspace) +start_profile_intrinsic = _op_wrapper(_tir_op.start_profile_intrinsic) +end_profile_intrinsic = _op_wrapper(_tir_op.end_profile_intrinsic) class inline: @@ -1796,7 +1688,7 @@ def f(): # pylint: enable=invalid-name -__all__ = [ +__all__ += [ "buffer_decl", "prim_func", "arg", @@ -1835,21 +1727,6 @@ def f(): "buffer_store", "prefetch", "evaluate", - "int8", - "int16", - "int32", - "int64", - "uint8", - "uint16", - "uint32", - "uint64", - "float8", - "float16", - "float32", - "float64", - "int32x4", - "int32x8", - "int32x16", "boolean", "handle", "void", @@ -1859,6 +1736,7 @@ def f(): "max", "iter_var", "comm_reducer", + "target", "buffer_var", "abs", "fabs", @@ -1905,6 +1783,7 @@ def f(): "popcount", "power", "q_multiply_shift", + "q_multiply_shift_per_axis", "ret", "reinterpret", "round", @@ -1925,6 +1804,7 @@ def f(): "tvm_stack_alloca", "tvm_stack_make_shape", "tvm_stack_make_array", + "tvm_check_return", "call_packed", "call_cpacked", "call_packed_lowered", @@ -1934,7 +1814,6 @@ def f(): "call_llvm_intrin", "call_llvm_pure_intrin", "call_pure_extern", - "tvm_access_ptr", "tvm_tuple", "tvm_struct_set", "tvm_struct_get", @@ -1963,14 +1842,50 @@ def f(): "tvm_call_cpacked_lowered", "TVMBackendAllocWorkspace", "TVMBackendFreeWorkspace", + "start_profile_intrinsic", + "end_profile_intrinsic", "inline", "llvm_lookup_intrinsic_id", - "Cast", - "Let", - "Select", - "Shuffle", "type_annotation", "broadcast", "ramp", "cast", + # tvm.tir.expr + "Var", + "SizeVar", + "Reduce", + "FloatImm", + "IntImm", + "StringImm", + "Cast", + "Add", + "Sub", + "Mul", + "Div", + "Mod", + "FloorDiv", + "FloorMod", + "Min", + "Max", + "EQ", + "NE", + "LT", + "LE", + "GT", + "GE", + "And", + "Or", + "Not", + "Select", + "BufferLoad", + "ProducerLoad", + "Load", + "Ramp", + "Broadcast", + "Shuffle", + "Call", + "CallEffectKind", + "Let", + "IterVar", + "CommReducer", ] diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 588b40ae4033..e1adc0a6bbd7 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -18,14 +18,15 @@ """Operators used in TIR expression.""" import warnings from typing import Any, Optional + import tvm._ffi -from tvm.ir.base import Span -from tvm.runtime import convert, const from tvm.ir import Array, Op, PrimExpr +from tvm.ir.base import Span +from tvm.runtime import const, convert -from .buffer import Buffer -from .expr import Call, PrimExprWithOp, StringImm, Var, CommReducer, IntImm from . import _ffi_api +from .buffer import Buffer +from .expr import Call, CommReducer, IntImm, PrimExprWithOp, StringImm, Var def _pack_buffer(buf, span=None): @@ -322,6 +323,24 @@ def call_llvm_pure_intrin(dtype, name, *args, span=None): ) +def tvm_check_return(expected, return_unexpected, nested_call): + """Return new on stack dtype[num] + Parameters + ---------- + expected : int + The expected return code. + return_unexpected : int + The unexpected return code. + nested_call : PrimExpr + The call expression to check return. + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("int32", "tir.tvm_check_return", expected, return_unexpected, nested_call) + + def tvm_stack_alloca(dtype_str, num): """Return new on stack dtype[num] @@ -403,7 +422,7 @@ def assume(cond=None): call : PrimExpr The call expression. """ - return call_intrin("int32", "tir.assume", cond) + return call_intrin("bool", "tir.assume", cond) def undef(): @@ -417,6 +436,34 @@ def undef(): return call_intrin("int32", "tir.undef") +def start_profile_intrinsic(id): + """Start profile intrinsic. + Parameters + ---------- + id : int + The intrinsic id. + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.start_profile_intrinsic", id) + + +def end_profile_intrinsic(id): + """End profile intrinsic. + Parameters + ---------- + id : int + The intrinsic id. + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("handle", "tir.end_profile_intrinsic", id) + + def tvm_tuple(*value): """Create a tuple structure in value field of AttrStmt diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index aa9efa653f71..f48ee52506b4 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -117,14 +117,14 @@ void LaunchThreadFrameNode::ExitWithScope() { void AllocateFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); - AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape, condition, - AsStmt(stmts), annotations)); + AddToParent( + tvm::tir::Allocate(buffer_var, dtype, extents, condition, AsStmt(stmts), annotations)); } void AllocateConstFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); AddToParent( - tvm::tir::AllocateConst(buffer->data, dtype, extents, data, AsStmt(stmts), annotations)); + tvm::tir::AllocateConst(buffer_var, dtype, extents, data, AsStmt(stmts), annotations)); } void AttrFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); @@ -182,7 +182,13 @@ void ElseFrameNode::ExitWithScope() { void DeclBufferFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); - AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts))); + if (allocated) { + AddToParent(tvm::tir::DeclBuffer(buffer, AsStmt(stmts))); + } else { + AddToParent(tvm::tir::Allocate(buffer->data, buffer->dtype, buffer->shape, + tvm::IntImm(DataType::Bool(), 1), + tvm::tir::DeclBuffer(buffer, AsStmt(stmts)))); + } } TVM_REGISTER_NODE_TYPE(TIRFrameNode); diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 6be6e2619fea..78107136d492 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -452,20 +452,19 @@ AllocateFrame Allocate(Array extents, DataType dtype, String storage_s n->storage_scope = storage_scope; n->condition = condition.value_or(tvm::Bool(true)); n->annotations = annotations.value_or(Map()); - n->buffer = BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, storage_scope, 0, 0, - "default", NullOpt); + n->buffer_var = Var("", tvm::PointerType(tvm::PrimType(dtype), storage_scope)); return AllocateFrame(n); } AllocateConstFrame AllocateConst(tvm::runtime::NDArray data, DataType dtype, - Array extents, Map annotations) { + Array extents, + Optional> annotations) { ObjectPtr n = make_object(); n->dtype = dtype; n->extents = extents; n->data = data; - n->annotations = annotations; - n->buffer = - BufferDecl(extents, dtype, "", NullOpt, NullOpt, NullOpt, "", 0, 0, "default", NullOpt); + n->annotations = annotations.value_or(Map()); + n->buffer_var = Var("", tvm::PointerType(tvm::PrimType(dtype))); return AllocateConstFrame(n); } @@ -529,6 +528,7 @@ DeclBufferFrame DeclBuffer(Array shape, DataType dtype, String buffer_ ObjectPtr n = make_object(); n->buffer = BufferDecl(shape, dtype, buffer_name, data, strides, elem_offset, storage_scope, align, offset_factor, buffer_type, axis_separators); + n->allocated = data.defined(); return DeclBufferFrame(n); } @@ -638,21 +638,35 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Ptr").set_body_typed(Ptr); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int8").set_body_typed(Int8); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int16").set_body_typed(Int16); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32").set_body_typed(Int32); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int64").set_body_typed(Int64); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt8").set_body_typed(UInt8); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt16").set_body_typed(UInt16); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt32").set_body_typed(UInt32); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.UInt64").set_body_typed(UInt64); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8").set_body_typed(Float8); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float16").set_body_typed(Float16); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float32").set_body_typed(Float32); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float64").set_body_typed(Float64); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x4").set_body_typed(Int32x4); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x8").set_body_typed(Int32x8); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Int32x16").set_body_typed(Int32x16); +#define TVM_TMP_STR(x) #x + +#define TVM_REGISTER_GLOBAL_SIZE(Prefix, DType) \ + TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(8)).set_body_typed(DType##8); \ + TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(16)).set_body_typed(DType##16); \ + TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(32)).set_body_typed(DType##32); \ + TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(64)).set_body_typed(DType##64); + +TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Float", Float); +TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.UInt", UInt); +TVM_REGISTER_GLOBAL_SIZE("script.ir_builder.tir.Int", Int); + +#define TVM_REGISTER_GLOBAL_LANES(Prefix, Func) \ + TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x4)).set_body_typed(Func##x4); \ + TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x8)).set_body_typed(Func##x8); \ + TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x16)).set_body_typed(Func##x16); \ + TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x32)).set_body_typed(Func##x32); \ + TVM_REGISTER_GLOBAL(Prefix TVM_TMP_STR(x64)).set_body_typed(Func##x64); + +#define TVM_REGISTER_GLOBAL_SIZES_LANES(Prefix, DType) \ + TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(8), DType##8); \ + TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(16), DType##16); \ + TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(32), DType##32); \ + TVM_REGISTER_GLOBAL_LANES(Prefix TVM_TMP_STR(64), DType##64); + +TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float); +TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt); +TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int); + TVM_REGISTER_GLOBAL("script.ir_builder.tir.Boolean").set_body_typed(Boolean); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Handle").set_body_typed(Handle); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Void").set_body_typed(Void); diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index 36de35fa928b..2ec52bfbfe41 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -52,7 +52,7 @@ def render(e): return error = errors[0] assert ( - error.span.line - 1 == rel_lineno + error.span.line - 1 == rel_lineno or error.span.line == rel_lineno ), f"Expected error to be on line {rel_lineno}, but it was on {error.span.line - 1}" error_line = source_code.split("\n")[rel_lineno] diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index dbc9b594fb87..a3df5a183bab 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -16,15 +16,15 @@ # under the License. # pylint: disable=invalid-name, missing-docstring """Unittests for tvm.script.ir_builder.tir""" -import pytest import numpy as np +import pytest import tvm import tvm.testing from tvm import tir +from tvm.ir.base import assert_structural_equal from tvm.runtime import ndarray -from tvm.script.ir_builder import tir as T from tvm.script.ir_builder import IRBuilder -from tvm.ir.base import assert_structural_equal +from tvm.script.ir_builder import tir as T def test_ir_builder_tir_primfunc_base(): @@ -372,7 +372,12 @@ def test_ir_builder_tir_allocate_const(): # the expected allocate const buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("int32"))) ir_expected = tir.AllocateConst( - buffer_var, "int32", [10], ndarray.array(np.asarray(data, "int32")), tir.Evaluate(1) + buffer_var, + "int32", + [10], + ndarray.array(np.asarray(data, "int32")), + tir.Evaluate(1), + annotations={}, ) # Check if the generated ir is expected @@ -470,7 +475,13 @@ def test_ir_builder_tir_decl_buffer(): # the expected decl_buffer buffer = T.buffer_decl((128, 128), "float32") - ir_expected = tir.DeclBuffer(buffer, tir.Evaluate(0)) + ir_expected = tir.Allocate( + buffer.data, + "float32", + (128, 128), + tir.IntImm("bool", True), + tir.DeclBuffer(buffer, tir.Evaluate(0)), + ) # Check if the generated ir is expected assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)