diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 313048e4d4ca..65c44db3ac58 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -21,8 +21,6 @@ from tvm.runtime import vm from tvm.runtime.vm import VirtualMachine, VMInstrumentReturnKind -from .type_converter import args_converter - # Expr from .expr import ( Expr, diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 10a4eb76b310..28a7aa897a50 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -27,7 +27,7 @@ from ...ir import PrimExpr from ..expr import Call, Expr, ExternFunc, GlobalVar, ShapeExpr, StringImm, Var from ..struct_info import StructInfo, TensorStructInfo -from ..utils import args_converter +from ..utils import convert_to_expr from . import _ffi_api py_print = print # pylint: disable=invalid-name @@ -76,7 +76,9 @@ def _wrap_inline_arg_tuple(args) -> Expr: in-line relax Tuple. """ - if ( + if isinstance(args, tuple | list): + return tvm.relax.Tuple([convert_to_expr(a) for a in args]) + elif ( isinstance(args, Expr) and not isinstance(args, tvm.relax.Tuple) and ( @@ -89,7 +91,6 @@ def _wrap_inline_arg_tuple(args) -> Expr: return args -@args_converter.auto def call_tir( gvar: GlobalVar, args: Expr, @@ -131,7 +132,6 @@ def call_tir( return _ffi_api.call_tir(gvar, args, out_sinfo, tir_vars) # type: ignore -@args_converter.auto def call_tir_with_grad( gvar: GlobalVar, args: Expr, @@ -190,7 +190,6 @@ def call_tir_with_grad( ) -@args_converter.auto def call_tir_inplace( gvar: GlobalVar, args: Expr, @@ -261,7 +260,6 @@ def call_tir_inplace( ) -@args_converter.auto def call_dps_packed( func: str | Expr, args: Expr, @@ -303,7 +301,6 @@ def call_dps_packed( return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore -@args_converter.auto def call_py_func( func_name: str, args: Expr, @@ -339,7 +336,6 @@ def call_py_func( return _ffi_api.call_py_func(func_name, args, out_sinfo) # type: ignore -@args_converter.auto def call_builtin_with_ctx( func: str | Expr, args: Expr, @@ -367,6 +363,8 @@ def call_builtin_with_ctx( if isinstance(func, str): func = ExternFunc(func) + args = _wrap_inline_arg_tuple(args) + if sinfo_args is not None and not isinstance(sinfo_args, list | tuple): sinfo_args = [sinfo_args] @@ -377,7 +375,6 @@ def call_builtin_with_ctx( ) -@args_converter.auto def make_closure( func: Expr, args: Expr, @@ -400,10 +397,11 @@ def make_closure( The VMClosure. """ + args = _wrap_inline_arg_tuple(args) + return _ffi_api.make_closure(func, args) # type: ignore -@args_converter.auto def invoke_closure( closure: Expr, args: Expr, @@ -428,6 +426,7 @@ def invoke_closure( ret: Call A call to `invoke_closure`. """ + args = _wrap_inline_arg_tuple(args) if not isinstance(sinfo_args, list | tuple): sinfo_args = [sinfo_args] @@ -677,7 +676,6 @@ def shape_to_tensor(expr: Expr) -> Expr: return _ffi_api.shape_to_tensor(expr) # type: ignore # pylint: disable=no-member -@args_converter.auto def call_inplace_packed( func: str | ExternFunc | GlobalVar, *args: Expr, @@ -731,6 +729,7 @@ def call_inplace_packed( func = func.global_symbol op = ExternFunc(func) + args = tuple(convert_to_expr(a) for a in args) if sinfo_args is None: raise ValueError("R.call_pure_packed is required to have type_args") if isinstance(sinfo_args, tuple): # type: ignore @@ -743,7 +742,6 @@ def call_inplace_packed( return _ffi_api.call_inplace_packed(op, args, inplace_indices, sinfo_args) # type: ignore # pylint: disable=no-member -@args_converter.auto def call_pure_packed( func: str | ExternFunc | GlobalVar, *args: Expr, @@ -782,6 +780,7 @@ def call_pure_packed( func = func.global_symbol op = ExternFunc(func) + args = tuple(convert_to_expr(a) for a in args) if sinfo_args is None: raise ValueError("R.call_pure_packed is required to have type_args") @@ -807,7 +806,6 @@ def call_pure_packed( return _ffi_api.call_pure_packed(op, args, None, sinfo_args) # type: ignore # pylint: disable=no-member -@args_converter.auto def invoke_pure_closure( closure: Expr, args: Expr, @@ -838,6 +836,7 @@ def invoke_pure_closure( ret: Call A call to `invoke_pure_closure`. """ + args = _wrap_inline_arg_tuple(args) if not isinstance(sinfo_args, list | tuple): sinfo_args = [sinfo_args] diff --git a/python/tvm/relax/op/builtin/builtin.py b/python/tvm/relax/op/builtin/builtin.py index 47fab9c7a5c9..328d63f3cc1d 100644 --- a/python/tvm/relax/op/builtin/builtin.py +++ b/python/tvm/relax/op/builtin/builtin.py @@ -16,11 +16,10 @@ """The builtin Relax operators.""" from ...expr import Call, DataTypeImm, Expr, PrimValue, StringImm -from ...utils import args_converter +from ...utils import convert_to_expr from . import _ffi_api -@args_converter.auto def alloc_tensor( shape: Expr, dtype: str | Expr, @@ -49,6 +48,8 @@ def alloc_tensor( result : Call A relax Call, which gets the allocated tensor. """ + if not isinstance(shape, Expr): + shape = convert_to_expr(shape) if isinstance(dtype, str): dtype = DataTypeImm(dtype) if isinstance(runtime_device_index, int): diff --git a/python/tvm/relax/op/distributed/distributed.py b/python/tvm/relax/op/distributed/distributed.py index b17708223d8b..b09f8686ac48 100644 --- a/python/tvm/relax/op/distributed/distributed.py +++ b/python/tvm/relax/op/distributed/distributed.py @@ -20,10 +20,10 @@ from tvm.ir import PrimExpr from tvm.relax.distributed import DTensorStructInfo from tvm.relax.distributed.struct_info import DeviceMesh, Placement -from tvm.relax.utils import args_converter from ...expr import Call, Expr, GlobalVar, ShapeExpr from ...expr import Tuple as RxTuple +from ...utils import convert_to_expr from . import _ffi_api @@ -66,7 +66,6 @@ def redistribute(input: Expr, device_mesh: DeviceMesh, placement: Placement) -> return _ffi_api.redistribute(input, device_mesh, placement) # type: ignore -@args_converter.auto def call_tir_local_view( gvar: GlobalVar, args: Expr, @@ -99,7 +98,9 @@ def call_tir_local_view( ret: Call A call node for the call_tir_local_view operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore + if isinstance(args, tuple | list): + args = RxTuple([convert_to_expr(a) for a in args]) + elif isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore args = RxTuple((args,)) if not isinstance(out_sinfo, list): diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py index 1edb807db840..71cae89c3ccf 100644 --- a/python/tvm/relax/op/index.py +++ b/python/tvm/relax/op/index.py @@ -18,8 +18,8 @@ from tvm.ir.expr import PrimExpr -from .. import args_converter from ..expr import Expr +from ..utils import convert_to_expr from . import _ffi_api PrimExprLike = int | PrimExpr @@ -58,7 +58,6 @@ def take(x: Expr, indices: Expr, axis: int | None = None, mode: str = "fast") -> return _ffi_api.take(x, indices, axis, mode) # type: ignore -@args_converter.auto def strided_slice( x: Expr, axes: Expr, @@ -101,6 +100,11 @@ def strided_slice( strided_slice require the input `begin`, `end` and `strides` to have the same length as `axes`. """ + axes = convert_to_expr(axes) + begin = convert_to_expr(begin) + end = convert_to_expr(end) + if strides is not None: + strides = convert_to_expr(strides) return _ffi_api.strided_slice(x, axes, begin, end, strides, assume_inbound) # type: ignore diff --git a/python/tvm/relax/op/memory/memory.py b/python/tvm/relax/op/memory/memory.py index ac8343c310e0..ba54eea9ef00 100644 --- a/python/tvm/relax/op/memory/memory.py +++ b/python/tvm/relax/op/memory/memory.py @@ -16,11 +16,10 @@ """Relax memory primitives.""" from ...expr import Call, DataTypeImm, Expr, PrimValue, StringImm -from ...utils import args_converter +from ...utils import convert_to_expr from . import _ffi_api -@args_converter.auto def alloc_storage( size: Expr, virtual_device_index: int | Expr, @@ -50,6 +49,7 @@ def alloc_storage( result : Call A relax Call, which gets the allocated storage. """ + size = convert_to_expr(size) if isinstance(dtype, str): dtype = DataTypeImm(dtype) if isinstance(storage_scope, str): @@ -59,7 +59,6 @@ def alloc_storage( return _ffi_api.alloc_storage(size, virtual_device_index, storage_scope, dtype) # type: ignore -@args_converter.auto def alloc_tensor( storage: Expr, offset: int | Expr, @@ -94,12 +93,12 @@ def alloc_tensor( """ if isinstance(offset, int): offset = PrimValue(offset) + shape = convert_to_expr(shape) if isinstance(dtype, str): dtype = DataTypeImm(dtype) return _ffi_api.alloc_tensor(storage, offset, shape, dtype, runtime_device_ind) # type: ignore -@args_converter.auto def kill_storage(storage: Expr) -> Call: """Construct a Call to kill a storage. @@ -116,7 +115,6 @@ def kill_storage(storage: Expr) -> Call: return _ffi_api.kill_storage(storage) # type: ignore -@args_converter.auto def kill_tensor(tensor: Expr) -> Call: """Construct a Call to kill a tensor. diff --git a/python/tvm/relax/op/sampling.py b/python/tvm/relax/op/sampling.py index bcd43a392247..cd4cad925341 100644 --- a/python/tvm/relax/op/sampling.py +++ b/python/tvm/relax/op/sampling.py @@ -16,12 +16,10 @@ # under the License. """Sampling operators.""" -from .. import args_converter from ..expr import Expr from . import _ffi_api -@args_converter.auto def multinomial_from_uniform( prob: Expr, uniform_sample: Expr, diff --git a/python/tvm/relax/op/unary.py b/python/tvm/relax/op/unary.py index c77d8311d53b..e7b800b5c9af 100644 --- a/python/tvm/relax/op/unary.py +++ b/python/tvm/relax/op/unary.py @@ -18,7 +18,7 @@ """Relax unary arithmetic operators.""" from ..expr import Expr -from ..utils import args_converter +from ..utils import convert_to_expr from . import _ffi_api ###################### Arithmetic operators ###################### @@ -526,7 +526,6 @@ def trunc(x: Expr) -> Expr: return _ffi_api.trunc(x) # type: ignore -@args_converter.auto def clip(x: Expr, min: Expr, max: Expr) -> Expr: """Clips tensor values to a specified min and max. @@ -546,6 +545,8 @@ def clip(x: Expr, min: Expr, max: Expr) -> Expr: result : relax.Expr The computed result. """ + min = convert_to_expr(min) + max = convert_to_expr(max) return _ffi_api.clip(x, min, max) # type: ignore diff --git a/python/tvm/relax/op/vm/vm.py b/python/tvm/relax/op/vm/vm.py index de1015d08d8b..868ac922952c 100644 --- a/python/tvm/relax/op/vm/vm.py +++ b/python/tvm/relax/op/vm/vm.py @@ -16,11 +16,10 @@ """Relax vm primitives.""" from ...expr import Call, DataTypeImm, Expr, PrimValue, StringImm, Tuple -from ...utils import args_converter +from ...utils import convert_to_expr from . import _ffi_api -@args_converter.auto def alloc_storage( shape: Expr, runtime_device_index: int | Expr, @@ -50,6 +49,7 @@ def alloc_storage( result : Call A relax Call, which gets the allocated storage. """ + shape = convert_to_expr(shape) if isinstance(dtype, str): dtype = DataTypeImm(dtype) if isinstance(storage_scope, str): @@ -59,7 +59,6 @@ def alloc_storage( return _ffi_api.alloc_storage(shape, runtime_device_index, dtype, storage_scope) # type: ignore -@args_converter.auto def alloc_tensor( storage: Expr, offset: int | Expr, @@ -94,6 +93,7 @@ def alloc_tensor( """ if isinstance(offset, int): offset = PrimValue(offset) + shape = convert_to_expr(shape) if isinstance(dtype, str): dtype = DataTypeImm(dtype) return _ffi_api.alloc_tensor(storage, offset, shape, dtype, runtime_device_ind) # type: ignore @@ -116,7 +116,6 @@ def kill_object(obj: Expr) -> Call: return _ffi_api.kill_object(obj) # type: ignore -@args_converter.auto def call_tir_dyn(func: Expr, args: Tuple) -> Call: """Construct a Call to call_tir_dyn (invoke the given TIR PrimFunc) consisting of the input tensors and the shape of the result. @@ -134,6 +133,7 @@ def call_tir_dyn(func: Expr, args: Tuple) -> Call: result : Call A relax Call to call_tir_dyn. """ + func = convert_to_expr(func) if isinstance(args, list | tuple): args = Tuple(args) diff --git a/python/tvm/relax/type_converter.py b/python/tvm/relax/type_converter.py deleted file mode 100644 index 50014cba5a9f..000000000000 --- a/python/tvm/relax/type_converter.py +++ /dev/null @@ -1,179 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=invalid-name,too-many-locals - -"""Argument converter utility for Relax - -This utility is used to decorate constructors of `tvm.relax.Expr`, and -must be able to be imported before `tvm.relax.Expr` or its subtypes -have been defined. Neither the class definitions nor any type -signature in this file may reference relax types. All references must -be exclusively in function bodies to avoid having a circular reference -during module imports. -""" - -import functools -import inspect -from collections.abc import Callable -from typing import Any, TypeVar - -import tvm - -FType = TypeVar("FType", bound=Callable[..., "tvm.relax.Expr"]) - - -class _ArgsConverter: - """A helper class to convert the arguments to Expr.""" - - @staticmethod - def convert(args_to_expr: list[str], args_to_list_expr: list[str]): - """Convert the arguments to Expr. - - Parameters - ---------- - args_to_expr : List[str] - The argument names to be converted to Expr. - - args_to_list_expr : List[str] - The argument names to be converted to List[Expr]. - - Returns - ------- - output : Callable[[FType], FType] - The decorator. - """ - - if any([x in args_to_list_expr for x in args_to_expr]): - raise ValueError("`args_to_expr` and `args_to_list_expr` should be disjoint.") - - def _convert(name: str, value: Any) -> Any: - if value is None: - return value - if name in args_to_expr: - try: - return tvm.relax.utils.convert_to_expr(value) - except Exception as err: - raise TypeError( - f"Argument `{name}` is expected to be converted to `Expr`, " - f"but failed with input value: {value}" - ) from err - elif name in args_to_list_expr: - try: - return [tvm.relax.utils.convert_to_expr(x) for x in value] - except Exception as err: - raise TypeError( - f"Argument `{name}` is expected to be converted to `List[Expr]`, " - f"but failed with input value: {value}" - ) from err - else: - return value - - def inner(func: FType) -> FType: - sig = inspect.signature(func) - param_names = list(sig.parameters.keys()) - for name in args_to_expr + args_to_list_expr: - if name not in param_names: - raise ValueError(f"Argument `{name}` is not found in function signature.") - - @functools.wraps(func) - def wrapper(*args, **kwargs): - bound = sig.bind(*args, **kwargs) - bound.apply_defaults() - for param in sig.parameters.values(): - if param.kind == param.VAR_POSITIONAL: - # *args case - values = [_convert(param.name, x) for x in bound.arguments[param.name]] - bound.arguments[param.name] = tuple(values) - elif param.kind == param.VAR_KEYWORD: - # **kwargs case - key_value = { - key: _convert(param.name, value) - for key, value in bound.arguments[param.name].items() - } - bound.arguments[param.name] = key_value - else: - bound.arguments[param.name] = _convert( - param.name, bound.arguments[param.name] - ) - return func(*bound.args, **bound.kwargs) - - return wrapper # type: ignore - - return inner - - @staticmethod - def to_expr(*arg_names: str) -> Callable: - """Convert the arguments to Expr. - - Parameters - ---------- - *arg_names: str - The list of argument names that need to be converted to Expr. - - Returns - ------- - output: Callable - The decorator. - """ - - return _ArgsConverter.convert(args_to_expr=list(arg_names), args_to_list_expr=[]) - - @staticmethod - def to_list_expr(*arg_names: str) -> Callable: - """Convert the arguments to List of Expr. - - Parameters - ---------- - *arg_names: str - The list of argument names that need to be converted to List of Expr. - - Returns - ------- - output: Callable - The decorator. - """ - - return _ArgsConverter.convert(args_to_expr=[], args_to_list_expr=list(arg_names)) - - @staticmethod - def auto(func: FType) -> FType: - """Decorator for automatically convert the arguments to Expr according to type annotation. - Only two patterns are supported: - - 1. The argument is Expr or Expr | None. - - 2. The argument is List[Expr] or Optional[List[Expr]]. - - """ - sig = inspect.signature(func) - args_to_expr = [] - args_to_list_expr = [] - - from . import Expr # pylint: disable=import-outside-toplevel - - for param in sig.parameters.values(): - anno = param.annotation - if anno in (Expr, Expr | None): - args_to_expr.append(param.name) - if anno in (list[Expr], list[Expr] | None): - args_to_list_expr.append(param.name) - - return _ArgsConverter.convert(args_to_expr, args_to_list_expr)(func) - - -args_converter = _ArgsConverter() # pylint: disable=invalid-name diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index f76459294aba..75f62c525d4b 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -39,9 +39,6 @@ from .expr import Tuple as rx_Tuple from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo -# Re-export `args_converter` here for backwards compatibility -from .type_converter import args_converter # pylint: disable=unused-import - def metadata_partitioner(rx_txt: str) -> list[str]: """Extract Relax program and metadata section. diff --git a/python/tvm/s_tir/schedule/_type_checker.py b/python/tvm/s_tir/schedule/_type_checker.py index 57e612688bb0..5ad4f29c55ae 100644 --- a/python/tvm/s_tir/schedule/_type_checker.py +++ b/python/tvm/s_tir/schedule/_type_checker.py @@ -345,17 +345,22 @@ def _type_check(v: Any, name: str, type_: Any) -> str | None: def type_checked(func: FType) -> FType: """Type check the input arguments of a function.""" sig = inspect.signature(func) + try: + hints = typing.get_type_hints(func) + except Exception: + hints = {} @functools.wraps(func) def wrap(*args, **kwargs): bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() for param in sig.parameters.values(): - if param.annotation != inspect.Signature.empty: + type_hint = hints.get(param.name, inspect.Parameter.empty) + if type_hint != inspect.Parameter.empty: error_msg = _type_check( bound_args.arguments[param.name], param.name, - param.annotation, + type_hint, ) if error_msg is not None: error_msg = f'In "{func.__qualname__}", {error_msg}' diff --git a/python/tvm/script/ir_builder/relax/distributed/ir.py b/python/tvm/script/ir_builder/relax/distributed/ir.py index 49bfb6139a97..485e91e3e53e 100644 --- a/python/tvm/script/ir_builder/relax/distributed/ir.py +++ b/python/tvm/script/ir_builder/relax/distributed/ir.py @@ -40,7 +40,7 @@ from tvm.relax.op.distributed import ( redistribute as _redistribute, ) -from tvm.relax.utils import args_converter +from tvm.relax.utils import convert_to_expr from tvm.runtime import _tensor from ... import IRBuilder @@ -49,7 +49,6 @@ from . import _ffi_api -@args_converter.auto def call_tir( func: str | Expr, args: Expr, @@ -82,7 +81,9 @@ def call_tir( if isinstance(func, str): func = ExternFunc(func) - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore + if isinstance(args, tuple | list): + args = RxTuple([convert_to_expr(a) for a in args]) + elif isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore args = RxTuple((args,)) if not isinstance(out_sinfo, list): diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index a7428c32a30b..a40517992e83 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -198,7 +198,7 @@ ) from tvm.relax.op.builtin import stop_lift_params from tvm.relax.struct_info import StructInfo -from tvm.relax.utils import args_converter, gen_call_tir_inputs +from tvm.relax.utils import convert_to_expr, gen_call_tir_inputs from tvm.runtime import Object as tvm_Object from tvm.runtime import ObjectConvertible from tvm.runtime._tensor import ( @@ -403,7 +403,6 @@ def output(*vars: tuple[Var]) -> None: ################################## Ops ################################# -@args_converter.auto def call_packed( func: py_str, *args: Expr, @@ -428,6 +427,7 @@ def call_packed( The created Relax Call """ op = ExternFunc(func) + args = py_tuple(convert_to_expr(a) for a in args) if sinfo_args is None: sinfo_args = [] if isinstance(sinfo_args, py_tuple): # type: ignore @@ -460,7 +460,6 @@ def call_packed( return Call(op, args, attrs=attrs, sinfo_args=sinfo_args) -@args_converter.auto def call_py_func( py_func_name: py_str, *args: Expr, @@ -485,6 +484,7 @@ def call_py_func( call: Call The created Relax Call for call_py_func operator. """ + args = py_tuple(convert_to_expr(a) for a in args) if isinstance(out_sinfo, py_tuple): # type: ignore out_sinfo = list(out_sinfo) elif not isinstance(out_sinfo, list): diff --git a/tests/python/relax/test_expr_args_converter.py b/tests/python/relax/test_expr_args_converter.py deleted file mode 100644 index d156245452ec..000000000000 --- a/tests/python/relax/test_expr_args_converter.py +++ /dev/null @@ -1,147 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E731 - -from collections.abc import Callable -from typing import Any - -import pytest - -import tvm -import tvm.testing -from tvm import relax -from tvm.relax import Expr -from tvm.relax.utils import args_converter - - -def _test_base(f_checker: Callable, arg: Any, *args: Any, **kwargs: Any) -> None: - # Test converting to `Expr` - assert f_checker(arg) - # Test converting `*args` - assert isinstance(args, tuple) - assert all([f_checker(arg) for arg in args]) - # Test converting `**kwargs` - assert isinstance(kwargs, dict) - assert all([f_checker(arg) for arg in kwargs.values()]) - - -def _test_expr(arg: Expr, *args: Expr, **kwargs: Expr) -> None: - f_checker = lambda x: isinstance(x, Expr) - _test_base(f_checker, arg, *args, **kwargs) - - -def _test_optional_expr(arg: Expr | None, *args: Expr | None, **kwargs: Expr | None) -> None: - f_checker = lambda x: x is None or isinstance(x, Expr) - _test_base(f_checker, arg, *args, **kwargs) - - -def _test_list_expr(arg: list[Expr], *args: list[Expr], **kwargs: list[Expr]) -> None: - f_checker = lambda x: isinstance(x, list) and all([isinstance(arg, Expr) for arg in x]) - _test_base(f_checker, arg, *args, **kwargs) - - -def _test_optional_list_expr( - arg: list[Expr] | None, *args: list[Expr] | None, **kwargs: list[Expr] | None -) -> None: - f_checker = lambda x: ( - x is None or (isinstance(x, list) and all([isinstance(arg, Expr) for arg in x])) - ) - _test_base(f_checker, arg, *args, **kwargs) - - -prim_value = 1 -str_value = "value_to_be_convert" -shape_value = (1, 1) -tuple_value = (relax.const(1), (1, 1)) -placeholder = relax.const(0) - -test_cases = [prim_value, str_value, shape_value, tuple_value, placeholder] - - -def test_args_to_expr(): - for _f in [_test_expr, _test_optional_expr]: - f = args_converter.to_expr("arg", "args", "kwargs")(_f) - for x in test_cases: - f( - x, - x, # the first argument in *args - x, # the second argument in *args - test_kwargs=x, - ) - - if _f == _test_optional_expr: - f(None, None, x, test_kwargs=None) - - -def test_args_to_list_expr(): - for _f in [_test_list_expr, _test_optional_list_expr]: - f = args_converter.to_list_expr("arg", "args", "kwargs")(_f) - for x in test_cases: - f( - [x], - [x], # the first argument in *args - [x, x], # the second argument in *args - test_kwargs=[x, (x,)], - ) - - if _f == _test_optional_list_expr: - f(None, None, [x], test_kwargs=None) - - -def test_error(): - f = args_converter.to_list_expr("arg", "args", "kwargs")(_test_list_expr) - with pytest.raises(TypeError): - f(prim_value) # fail to convert prim_value to `List[Expr]` - - -def test_auto_convert(): - for _f in [_test_expr, _test_optional_expr]: - f = args_converter.auto(_f) - for x in test_cases: - f(x, (x,), test_kwargs=x) - - if _f == _test_optional_expr: - f(None, x, test_kwargs=None) - - for _f in [_test_list_expr, _test_optional_list_expr]: - f = args_converter.auto(_f) - for x in test_cases: - f([x], [x, x], test_kwargs=[x, (x,)]) - - if _f == _test_optional_list_expr: - f(None, None, [x], test_kwargs=None) - - -def test_auto_convert_skip(): - def _test_expr_skip(arg: int, *args: str | Expr, **kwargs: list[Expr | None]) -> None: - f_checker = lambda x: not isinstance(x, Expr) - _test_base(f_checker, arg, *args, **kwargs) - - f = args_converter.auto(_test_expr_skip) - f(1, "str", test_kwargs=[None]) - - -def test_empty_tuple(): - def _test(arg: Expr): - assert isinstance(arg, relax.Tuple) - - f = args_converter.auto(_test) - f(()) - - -if __name__ == "__main__": - tvm.testing.main()