From ec46a3fb241077c211a2a55780b2e61a58f22823 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 28 Feb 2026 03:36:48 +0000 Subject: [PATCH 1/3] [PYTHON] Use typing.get_type_hints in @type_checked for PEP 563 compat This PR resolves string annotations from PEP 563 (from __future__ import annotations) at decoration time using typing.get_type_hints(), preventing TypeError crashes when @type_checked passes stringified annotations to isinstance(). --- python/tvm/s_tir/schedule/_type_checker.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/tvm/s_tir/schedule/_type_checker.py b/python/tvm/s_tir/schedule/_type_checker.py index 57e612688bb0..5ad4f29c55ae 100644 --- a/python/tvm/s_tir/schedule/_type_checker.py +++ b/python/tvm/s_tir/schedule/_type_checker.py @@ -345,17 +345,22 @@ def _type_check(v: Any, name: str, type_: Any) -> str | None: def type_checked(func: FType) -> FType: """Type check the input arguments of a function.""" sig = inspect.signature(func) + try: + hints = typing.get_type_hints(func) + except Exception: + hints = {} @functools.wraps(func) def wrap(*args, **kwargs): bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() for param in sig.parameters.values(): - if param.annotation != inspect.Signature.empty: + type_hint = hints.get(param.name, inspect.Parameter.empty) + if type_hint != inspect.Parameter.empty: error_msg = _type_check( bound_args.arguments[param.name], param.name, - param.annotation, + type_hint, ) if error_msg is not None: error_msg = f'In "{func.__qualname__}", {error_msg}' From 273bf4d11fb73fdc8fa1da7199c348dbb24762c5 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 28 Feb 2026 05:29:23 +0000 Subject: [PATCH 2/3] [REFACTOR][PYTHON] Remove args_converter auto-conversion decorator This PR removes the @args_converter.auto decorator that auto-converted Python types to Relax Expr by inspecting annotations at decoration time. This was PEP 563-incompatible (annotations become strings). The FFI/C++ layer handles type conversion directly, with minimal explicit conversion added at call sites for Python tuples/lists and primitive types. --- python/tvm/relax/__init__.py | 2 - python/tvm/relax/op/base.py | 31 +-- python/tvm/relax/op/builtin/builtin.py | 5 +- .../tvm/relax/op/distributed/distributed.py | 7 +- python/tvm/relax/op/index.py | 8 +- python/tvm/relax/op/memory/memory.py | 10 +- python/tvm/relax/op/sampling.py | 2 - python/tvm/relax/op/unary.py | 7 +- python/tvm/relax/op/vm/vm.py | 11 +- python/tvm/relax/type_converter.py | 179 ------------------ python/tvm/relax/utils.py | 3 - .../script/ir_builder/relax/distributed/ir.py | 6 +- python/tvm/script/ir_builder/relax/ir.py | 12 +- .../python/relax/test_expr_args_converter.py | 147 -------------- 14 files changed, 59 insertions(+), 371 deletions(-) delete mode 100644 python/tvm/relax/type_converter.py delete mode 100644 tests/python/relax/test_expr_args_converter.py diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index 313048e4d4ca..65c44db3ac58 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -21,8 +21,6 @@ from tvm.runtime import vm from tvm.runtime.vm import VirtualMachine, VMInstrumentReturnKind -from .type_converter import args_converter - # Expr from .expr import ( Expr, diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 10a4eb76b310..3c91a34cac4e 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -27,7 +27,7 @@ from ...ir import PrimExpr from ..expr import Call, Expr, ExternFunc, GlobalVar, ShapeExpr, StringImm, Var from ..struct_info import StructInfo, TensorStructInfo -from ..utils import args_converter +from ..utils import convert_to_expr from . import _ffi_api py_print = print # pylint: disable=invalid-name @@ -76,7 +76,9 @@ def _wrap_inline_arg_tuple(args) -> Expr: in-line relax Tuple. """ - if ( + if isinstance(args, tuple | list): + return tvm.relax.Tuple([convert_to_expr(a) if not isinstance(a, Expr) else a for a in args]) + elif ( isinstance(args, Expr) and not isinstance(args, tvm.relax.Tuple) and ( @@ -89,7 +91,6 @@ def _wrap_inline_arg_tuple(args) -> Expr: return args -@args_converter.auto def call_tir( gvar: GlobalVar, args: Expr, @@ -131,7 +132,6 @@ def call_tir( return _ffi_api.call_tir(gvar, args, out_sinfo, tir_vars) # type: ignore -@args_converter.auto def call_tir_with_grad( gvar: GlobalVar, args: Expr, @@ -190,7 +190,6 @@ def call_tir_with_grad( ) -@args_converter.auto def call_tir_inplace( gvar: GlobalVar, args: Expr, @@ -261,7 +260,6 @@ def call_tir_inplace( ) -@args_converter.auto def call_dps_packed( func: str | Expr, args: Expr, @@ -303,7 +301,6 @@ def call_dps_packed( return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore -@args_converter.auto def call_py_func( func_name: str, args: Expr, @@ -339,7 +336,6 @@ def call_py_func( return _ffi_api.call_py_func(func_name, args, out_sinfo) # type: ignore -@args_converter.auto def call_builtin_with_ctx( func: str | Expr, args: Expr, @@ -367,6 +363,8 @@ def call_builtin_with_ctx( if isinstance(func, str): func = ExternFunc(func) + args = _wrap_inline_arg_tuple(args) + if sinfo_args is not None and not isinstance(sinfo_args, list | tuple): sinfo_args = [sinfo_args] @@ -377,7 +375,6 @@ def call_builtin_with_ctx( ) -@args_converter.auto def make_closure( func: Expr, args: Expr, @@ -400,10 +397,11 @@ def make_closure( The VMClosure. """ + args = _wrap_inline_arg_tuple(args) + return _ffi_api.make_closure(func, args) # type: ignore -@args_converter.auto def invoke_closure( closure: Expr, args: Expr, @@ -428,6 +426,7 @@ def invoke_closure( ret: Call A call to `invoke_closure`. """ + args = _wrap_inline_arg_tuple(args) if not isinstance(sinfo_args, list | tuple): sinfo_args = [sinfo_args] @@ -677,7 +676,6 @@ def shape_to_tensor(expr: Expr) -> Expr: return _ffi_api.shape_to_tensor(expr) # type: ignore # pylint: disable=no-member -@args_converter.auto def call_inplace_packed( func: str | ExternFunc | GlobalVar, *args: Expr, @@ -731,6 +729,10 @@ def call_inplace_packed( func = func.global_symbol op = ExternFunc(func) + args = tuple( + convert_to_expr(a) if isinstance(a, int | float | str | tuple | list) else a + for a in args + ) if sinfo_args is None: raise ValueError("R.call_pure_packed is required to have type_args") if isinstance(sinfo_args, tuple): # type: ignore @@ -743,7 +745,6 @@ def call_inplace_packed( return _ffi_api.call_inplace_packed(op, args, inplace_indices, sinfo_args) # type: ignore # pylint: disable=no-member -@args_converter.auto def call_pure_packed( func: str | ExternFunc | GlobalVar, *args: Expr, @@ -782,6 +783,10 @@ def call_pure_packed( func = func.global_symbol op = ExternFunc(func) + args = tuple( + convert_to_expr(a) if isinstance(a, int | float | str | tuple | list) else a + for a in args + ) if sinfo_args is None: raise ValueError("R.call_pure_packed is required to have type_args") @@ -807,7 +812,6 @@ def call_pure_packed( return _ffi_api.call_pure_packed(op, args, None, sinfo_args) # type: ignore # pylint: disable=no-member -@args_converter.auto def invoke_pure_closure( closure: Expr, args: Expr, @@ -838,6 +842,7 @@ def invoke_pure_closure( ret: Call A call to `invoke_pure_closure`. """ + args = _wrap_inline_arg_tuple(args) if not isinstance(sinfo_args, list | tuple): sinfo_args = [sinfo_args] diff --git a/python/tvm/relax/op/builtin/builtin.py b/python/tvm/relax/op/builtin/builtin.py index 47fab9c7a5c9..328d63f3cc1d 100644 --- a/python/tvm/relax/op/builtin/builtin.py +++ b/python/tvm/relax/op/builtin/builtin.py @@ -16,11 +16,10 @@ """The builtin Relax operators.""" from ...expr import Call, DataTypeImm, Expr, PrimValue, StringImm -from ...utils import args_converter +from ...utils import convert_to_expr from . import _ffi_api -@args_converter.auto def alloc_tensor( shape: Expr, dtype: str | Expr, @@ -49,6 +48,8 @@ def alloc_tensor( result : Call A relax Call, which gets the allocated tensor. """ + if not isinstance(shape, Expr): + shape = convert_to_expr(shape) if isinstance(dtype, str): dtype = DataTypeImm(dtype) if isinstance(runtime_device_index, int): diff --git a/python/tvm/relax/op/distributed/distributed.py b/python/tvm/relax/op/distributed/distributed.py index b17708223d8b..12b7da047aa4 100644 --- a/python/tvm/relax/op/distributed/distributed.py +++ b/python/tvm/relax/op/distributed/distributed.py @@ -20,8 +20,6 @@ from tvm.ir import PrimExpr from tvm.relax.distributed import DTensorStructInfo from tvm.relax.distributed.struct_info import DeviceMesh, Placement -from tvm.relax.utils import args_converter - from ...expr import Call, Expr, GlobalVar, ShapeExpr from ...expr import Tuple as RxTuple from . import _ffi_api @@ -66,7 +64,6 @@ def redistribute(input: Expr, device_mesh: DeviceMesh, placement: Placement) -> return _ffi_api.redistribute(input, device_mesh, placement) # type: ignore -@args_converter.auto def call_tir_local_view( gvar: GlobalVar, args: Expr, @@ -99,7 +96,9 @@ def call_tir_local_view( ret: Call A call node for the call_tir_local_view operator. """ - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore + if isinstance(args, tuple | list): + args = RxTuple(list(args)) + elif isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore args = RxTuple((args,)) if not isinstance(out_sinfo, list): diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py index 1edb807db840..793d30499305 100644 --- a/python/tvm/relax/op/index.py +++ b/python/tvm/relax/op/index.py @@ -18,8 +18,8 @@ from tvm.ir.expr import PrimExpr -from .. import args_converter from ..expr import Expr +from ..utils import convert_to_expr from . import _ffi_api PrimExprLike = int | PrimExpr @@ -58,7 +58,6 @@ def take(x: Expr, indices: Expr, axis: int | None = None, mode: str = "fast") -> return _ffi_api.take(x, indices, axis, mode) # type: ignore -@args_converter.auto def strided_slice( x: Expr, axes: Expr, @@ -101,6 +100,11 @@ def strided_slice( strided_slice require the input `begin`, `end` and `strides` to have the same length as `axes`. """ + axes = convert_to_expr(axes) if not isinstance(axes, Expr) else axes + begin = convert_to_expr(begin) if not isinstance(begin, Expr) else begin + end = convert_to_expr(end) if not isinstance(end, Expr) else end + if strides is not None and not isinstance(strides, Expr): + strides = convert_to_expr(strides) return _ffi_api.strided_slice(x, axes, begin, end, strides, assume_inbound) # type: ignore diff --git a/python/tvm/relax/op/memory/memory.py b/python/tvm/relax/op/memory/memory.py index ac8343c310e0..e65877d08738 100644 --- a/python/tvm/relax/op/memory/memory.py +++ b/python/tvm/relax/op/memory/memory.py @@ -16,11 +16,10 @@ """Relax memory primitives.""" from ...expr import Call, DataTypeImm, Expr, PrimValue, StringImm -from ...utils import args_converter +from ...utils import convert_to_expr from . import _ffi_api -@args_converter.auto def alloc_storage( size: Expr, virtual_device_index: int | Expr, @@ -50,6 +49,8 @@ def alloc_storage( result : Call A relax Call, which gets the allocated storage. """ + if not isinstance(size, Expr): + size = convert_to_expr(size) if isinstance(dtype, str): dtype = DataTypeImm(dtype) if isinstance(storage_scope, str): @@ -59,7 +60,6 @@ def alloc_storage( return _ffi_api.alloc_storage(size, virtual_device_index, storage_scope, dtype) # type: ignore -@args_converter.auto def alloc_tensor( storage: Expr, offset: int | Expr, @@ -94,12 +94,13 @@ def alloc_tensor( """ if isinstance(offset, int): offset = PrimValue(offset) + if not isinstance(shape, Expr): + shape = convert_to_expr(shape) if isinstance(dtype, str): dtype = DataTypeImm(dtype) return _ffi_api.alloc_tensor(storage, offset, shape, dtype, runtime_device_ind) # type: ignore -@args_converter.auto def kill_storage(storage: Expr) -> Call: """Construct a Call to kill a storage. @@ -116,7 +117,6 @@ def kill_storage(storage: Expr) -> Call: return _ffi_api.kill_storage(storage) # type: ignore -@args_converter.auto def kill_tensor(tensor: Expr) -> Call: """Construct a Call to kill a tensor. diff --git a/python/tvm/relax/op/sampling.py b/python/tvm/relax/op/sampling.py index bcd43a392247..cd4cad925341 100644 --- a/python/tvm/relax/op/sampling.py +++ b/python/tvm/relax/op/sampling.py @@ -16,12 +16,10 @@ # under the License. """Sampling operators.""" -from .. import args_converter from ..expr import Expr from . import _ffi_api -@args_converter.auto def multinomial_from_uniform( prob: Expr, uniform_sample: Expr, diff --git a/python/tvm/relax/op/unary.py b/python/tvm/relax/op/unary.py index c77d8311d53b..4d8815f9ce5f 100644 --- a/python/tvm/relax/op/unary.py +++ b/python/tvm/relax/op/unary.py @@ -18,7 +18,7 @@ """Relax unary arithmetic operators.""" from ..expr import Expr -from ..utils import args_converter +from ..utils import convert_to_expr from . import _ffi_api ###################### Arithmetic operators ###################### @@ -526,7 +526,6 @@ def trunc(x: Expr) -> Expr: return _ffi_api.trunc(x) # type: ignore -@args_converter.auto def clip(x: Expr, min: Expr, max: Expr) -> Expr: """Clips tensor values to a specified min and max. @@ -546,6 +545,10 @@ def clip(x: Expr, min: Expr, max: Expr) -> Expr: result : relax.Expr The computed result. """ + if not isinstance(min, Expr): + min = convert_to_expr(min) + if not isinstance(max, Expr): + max = convert_to_expr(max) return _ffi_api.clip(x, min, max) # type: ignore diff --git a/python/tvm/relax/op/vm/vm.py b/python/tvm/relax/op/vm/vm.py index de1015d08d8b..eed1b768a40d 100644 --- a/python/tvm/relax/op/vm/vm.py +++ b/python/tvm/relax/op/vm/vm.py @@ -16,11 +16,10 @@ """Relax vm primitives.""" from ...expr import Call, DataTypeImm, Expr, PrimValue, StringImm, Tuple -from ...utils import args_converter +from ...utils import convert_to_expr from . import _ffi_api -@args_converter.auto def alloc_storage( shape: Expr, runtime_device_index: int | Expr, @@ -50,6 +49,8 @@ def alloc_storage( result : Call A relax Call, which gets the allocated storage. """ + if not isinstance(shape, Expr): + shape = convert_to_expr(shape) if isinstance(dtype, str): dtype = DataTypeImm(dtype) if isinstance(storage_scope, str): @@ -59,7 +60,6 @@ def alloc_storage( return _ffi_api.alloc_storage(shape, runtime_device_index, dtype, storage_scope) # type: ignore -@args_converter.auto def alloc_tensor( storage: Expr, offset: int | Expr, @@ -94,6 +94,8 @@ def alloc_tensor( """ if isinstance(offset, int): offset = PrimValue(offset) + if not isinstance(shape, Expr): + shape = convert_to_expr(shape) if isinstance(dtype, str): dtype = DataTypeImm(dtype) return _ffi_api.alloc_tensor(storage, offset, shape, dtype, runtime_device_ind) # type: ignore @@ -116,7 +118,6 @@ def kill_object(obj: Expr) -> Call: return _ffi_api.kill_object(obj) # type: ignore -@args_converter.auto def call_tir_dyn(func: Expr, args: Tuple) -> Call: """Construct a Call to call_tir_dyn (invoke the given TIR PrimFunc) consisting of the input tensors and the shape of the result. @@ -134,6 +135,8 @@ def call_tir_dyn(func: Expr, args: Tuple) -> Call: result : Call A relax Call to call_tir_dyn. """ + if not isinstance(func, Expr): + func = convert_to_expr(func) if isinstance(args, list | tuple): args = Tuple(args) diff --git a/python/tvm/relax/type_converter.py b/python/tvm/relax/type_converter.py deleted file mode 100644 index 50014cba5a9f..000000000000 --- a/python/tvm/relax/type_converter.py +++ /dev/null @@ -1,179 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# pylint: disable=invalid-name,too-many-locals - -"""Argument converter utility for Relax - -This utility is used to decorate constructors of `tvm.relax.Expr`, and -must be able to be imported before `tvm.relax.Expr` or its subtypes -have been defined. Neither the class definitions nor any type -signature in this file may reference relax types. All references must -be exclusively in function bodies to avoid having a circular reference -during module imports. -""" - -import functools -import inspect -from collections.abc import Callable -from typing import Any, TypeVar - -import tvm - -FType = TypeVar("FType", bound=Callable[..., "tvm.relax.Expr"]) - - -class _ArgsConverter: - """A helper class to convert the arguments to Expr.""" - - @staticmethod - def convert(args_to_expr: list[str], args_to_list_expr: list[str]): - """Convert the arguments to Expr. - - Parameters - ---------- - args_to_expr : List[str] - The argument names to be converted to Expr. - - args_to_list_expr : List[str] - The argument names to be converted to List[Expr]. - - Returns - ------- - output : Callable[[FType], FType] - The decorator. - """ - - if any([x in args_to_list_expr for x in args_to_expr]): - raise ValueError("`args_to_expr` and `args_to_list_expr` should be disjoint.") - - def _convert(name: str, value: Any) -> Any: - if value is None: - return value - if name in args_to_expr: - try: - return tvm.relax.utils.convert_to_expr(value) - except Exception as err: - raise TypeError( - f"Argument `{name}` is expected to be converted to `Expr`, " - f"but failed with input value: {value}" - ) from err - elif name in args_to_list_expr: - try: - return [tvm.relax.utils.convert_to_expr(x) for x in value] - except Exception as err: - raise TypeError( - f"Argument `{name}` is expected to be converted to `List[Expr]`, " - f"but failed with input value: {value}" - ) from err - else: - return value - - def inner(func: FType) -> FType: - sig = inspect.signature(func) - param_names = list(sig.parameters.keys()) - for name in args_to_expr + args_to_list_expr: - if name not in param_names: - raise ValueError(f"Argument `{name}` is not found in function signature.") - - @functools.wraps(func) - def wrapper(*args, **kwargs): - bound = sig.bind(*args, **kwargs) - bound.apply_defaults() - for param in sig.parameters.values(): - if param.kind == param.VAR_POSITIONAL: - # *args case - values = [_convert(param.name, x) for x in bound.arguments[param.name]] - bound.arguments[param.name] = tuple(values) - elif param.kind == param.VAR_KEYWORD: - # **kwargs case - key_value = { - key: _convert(param.name, value) - for key, value in bound.arguments[param.name].items() - } - bound.arguments[param.name] = key_value - else: - bound.arguments[param.name] = _convert( - param.name, bound.arguments[param.name] - ) - return func(*bound.args, **bound.kwargs) - - return wrapper # type: ignore - - return inner - - @staticmethod - def to_expr(*arg_names: str) -> Callable: - """Convert the arguments to Expr. - - Parameters - ---------- - *arg_names: str - The list of argument names that need to be converted to Expr. - - Returns - ------- - output: Callable - The decorator. - """ - - return _ArgsConverter.convert(args_to_expr=list(arg_names), args_to_list_expr=[]) - - @staticmethod - def to_list_expr(*arg_names: str) -> Callable: - """Convert the arguments to List of Expr. - - Parameters - ---------- - *arg_names: str - The list of argument names that need to be converted to List of Expr. - - Returns - ------- - output: Callable - The decorator. - """ - - return _ArgsConverter.convert(args_to_expr=[], args_to_list_expr=list(arg_names)) - - @staticmethod - def auto(func: FType) -> FType: - """Decorator for automatically convert the arguments to Expr according to type annotation. - Only two patterns are supported: - - 1. The argument is Expr or Expr | None. - - 2. The argument is List[Expr] or Optional[List[Expr]]. - - """ - sig = inspect.signature(func) - args_to_expr = [] - args_to_list_expr = [] - - from . import Expr # pylint: disable=import-outside-toplevel - - for param in sig.parameters.values(): - anno = param.annotation - if anno in (Expr, Expr | None): - args_to_expr.append(param.name) - if anno in (list[Expr], list[Expr] | None): - args_to_list_expr.append(param.name) - - return _ArgsConverter.convert(args_to_expr, args_to_list_expr)(func) - - -args_converter = _ArgsConverter() # pylint: disable=invalid-name diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index f76459294aba..75f62c525d4b 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -39,9 +39,6 @@ from .expr import Tuple as rx_Tuple from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo -# Re-export `args_converter` here for backwards compatibility -from .type_converter import args_converter # pylint: disable=unused-import - def metadata_partitioner(rx_txt: str) -> list[str]: """Extract Relax program and metadata section. diff --git a/python/tvm/script/ir_builder/relax/distributed/ir.py b/python/tvm/script/ir_builder/relax/distributed/ir.py index 49bfb6139a97..674b110c3ef2 100644 --- a/python/tvm/script/ir_builder/relax/distributed/ir.py +++ b/python/tvm/script/ir_builder/relax/distributed/ir.py @@ -40,7 +40,6 @@ from tvm.relax.op.distributed import ( redistribute as _redistribute, ) -from tvm.relax.utils import args_converter from tvm.runtime import _tensor from ... import IRBuilder @@ -49,7 +48,6 @@ from . import _ffi_api -@args_converter.auto def call_tir( func: str | Expr, args: Expr, @@ -82,7 +80,9 @@ def call_tir( if isinstance(func, str): func = ExternFunc(func) - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore + if isinstance(args, tuple | list): + args = RxTuple(list(args)) + elif isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore args = RxTuple((args,)) if not isinstance(out_sinfo, list): diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index a7428c32a30b..bcbcf5353632 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -198,7 +198,7 @@ ) from tvm.relax.op.builtin import stop_lift_params from tvm.relax.struct_info import StructInfo -from tvm.relax.utils import args_converter, gen_call_tir_inputs +from tvm.relax.utils import convert_to_expr, gen_call_tir_inputs from tvm.runtime import Object as tvm_Object from tvm.runtime import ObjectConvertible from tvm.runtime._tensor import ( @@ -224,6 +224,8 @@ py_tuple = tuple # pylint: disable=used-before-assignment py_str = str # pylint: disable=used-before-assignment +_CONVERTIBLE_TYPES = (int, float, str, tuple, list) + ################################ Device ################################ @@ -403,7 +405,6 @@ def output(*vars: tuple[Var]) -> None: ################################## Ops ################################# -@args_converter.auto def call_packed( func: py_str, *args: Expr, @@ -428,6 +429,9 @@ def call_packed( The created Relax Call """ op = ExternFunc(func) + args = py_tuple( + convert_to_expr(a) if isinstance(a, _CONVERTIBLE_TYPES) else a for a in args + ) if sinfo_args is None: sinfo_args = [] if isinstance(sinfo_args, py_tuple): # type: ignore @@ -460,7 +464,6 @@ def call_packed( return Call(op, args, attrs=attrs, sinfo_args=sinfo_args) -@args_converter.auto def call_py_func( py_func_name: py_str, *args: Expr, @@ -485,6 +488,9 @@ def call_py_func( call: Call The created Relax Call for call_py_func operator. """ + args = py_tuple( + convert_to_expr(a) if isinstance(a, _CONVERTIBLE_TYPES) else a for a in args + ) if isinstance(out_sinfo, py_tuple): # type: ignore out_sinfo = list(out_sinfo) elif not isinstance(out_sinfo, list): diff --git a/tests/python/relax/test_expr_args_converter.py b/tests/python/relax/test_expr_args_converter.py deleted file mode 100644 index d156245452ec..000000000000 --- a/tests/python/relax/test_expr_args_converter.py +++ /dev/null @@ -1,147 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: E731 - -from collections.abc import Callable -from typing import Any - -import pytest - -import tvm -import tvm.testing -from tvm import relax -from tvm.relax import Expr -from tvm.relax.utils import args_converter - - -def _test_base(f_checker: Callable, arg: Any, *args: Any, **kwargs: Any) -> None: - # Test converting to `Expr` - assert f_checker(arg) - # Test converting `*args` - assert isinstance(args, tuple) - assert all([f_checker(arg) for arg in args]) - # Test converting `**kwargs` - assert isinstance(kwargs, dict) - assert all([f_checker(arg) for arg in kwargs.values()]) - - -def _test_expr(arg: Expr, *args: Expr, **kwargs: Expr) -> None: - f_checker = lambda x: isinstance(x, Expr) - _test_base(f_checker, arg, *args, **kwargs) - - -def _test_optional_expr(arg: Expr | None, *args: Expr | None, **kwargs: Expr | None) -> None: - f_checker = lambda x: x is None or isinstance(x, Expr) - _test_base(f_checker, arg, *args, **kwargs) - - -def _test_list_expr(arg: list[Expr], *args: list[Expr], **kwargs: list[Expr]) -> None: - f_checker = lambda x: isinstance(x, list) and all([isinstance(arg, Expr) for arg in x]) - _test_base(f_checker, arg, *args, **kwargs) - - -def _test_optional_list_expr( - arg: list[Expr] | None, *args: list[Expr] | None, **kwargs: list[Expr] | None -) -> None: - f_checker = lambda x: ( - x is None or (isinstance(x, list) and all([isinstance(arg, Expr) for arg in x])) - ) - _test_base(f_checker, arg, *args, **kwargs) - - -prim_value = 1 -str_value = "value_to_be_convert" -shape_value = (1, 1) -tuple_value = (relax.const(1), (1, 1)) -placeholder = relax.const(0) - -test_cases = [prim_value, str_value, shape_value, tuple_value, placeholder] - - -def test_args_to_expr(): - for _f in [_test_expr, _test_optional_expr]: - f = args_converter.to_expr("arg", "args", "kwargs")(_f) - for x in test_cases: - f( - x, - x, # the first argument in *args - x, # the second argument in *args - test_kwargs=x, - ) - - if _f == _test_optional_expr: - f(None, None, x, test_kwargs=None) - - -def test_args_to_list_expr(): - for _f in [_test_list_expr, _test_optional_list_expr]: - f = args_converter.to_list_expr("arg", "args", "kwargs")(_f) - for x in test_cases: - f( - [x], - [x], # the first argument in *args - [x, x], # the second argument in *args - test_kwargs=[x, (x,)], - ) - - if _f == _test_optional_list_expr: - f(None, None, [x], test_kwargs=None) - - -def test_error(): - f = args_converter.to_list_expr("arg", "args", "kwargs")(_test_list_expr) - with pytest.raises(TypeError): - f(prim_value) # fail to convert prim_value to `List[Expr]` - - -def test_auto_convert(): - for _f in [_test_expr, _test_optional_expr]: - f = args_converter.auto(_f) - for x in test_cases: - f(x, (x,), test_kwargs=x) - - if _f == _test_optional_expr: - f(None, x, test_kwargs=None) - - for _f in [_test_list_expr, _test_optional_list_expr]: - f = args_converter.auto(_f) - for x in test_cases: - f([x], [x, x], test_kwargs=[x, (x,)]) - - if _f == _test_optional_list_expr: - f(None, None, [x], test_kwargs=None) - - -def test_auto_convert_skip(): - def _test_expr_skip(arg: int, *args: str | Expr, **kwargs: list[Expr | None]) -> None: - f_checker = lambda x: not isinstance(x, Expr) - _test_base(f_checker, arg, *args, **kwargs) - - f = args_converter.auto(_test_expr_skip) - f(1, "str", test_kwargs=[None]) - - -def test_empty_tuple(): - def _test(arg: Expr): - assert isinstance(arg, relax.Tuple) - - f = args_converter.auto(_test) - f(()) - - -if __name__ == "__main__": - tvm.testing.main() From 8aba2c4e2029eabd4399d96a1e08358ed136a860 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 28 Feb 2026 13:08:03 +0000 Subject: [PATCH 3/3] [FIX] Simplify convert_to_expr calls and fix distributed op conversion - Fix distributed ops (distributed.py, distributed/ir.py): apply convert_to_expr element-wise when args is a list/tuple before constructing RxTuple, so non-Expr elements are properly converted. - Simplify redundant isinstance(x, Expr) guards before convert_to_expr calls (convert_to_expr is idempotent on Expr inputs) in: memory/memory.py, op/base.py, op/index.py, op/unary.py, op/vm/vm.py, script/ir_builder/relax/ir.py. --- python/tvm/relax/op/base.py | 12 +++--------- python/tvm/relax/op/distributed/distributed.py | 4 +++- python/tvm/relax/op/index.py | 8 ++++---- python/tvm/relax/op/memory/memory.py | 6 ++---- python/tvm/relax/op/unary.py | 6 ++---- python/tvm/relax/op/vm/vm.py | 9 +++------ python/tvm/script/ir_builder/relax/distributed/ir.py | 3 ++- python/tvm/script/ir_builder/relax/ir.py | 10 ++-------- 8 files changed, 21 insertions(+), 37 deletions(-) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 3c91a34cac4e..28a7aa897a50 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -77,7 +77,7 @@ def _wrap_inline_arg_tuple(args) -> Expr: """ if isinstance(args, tuple | list): - return tvm.relax.Tuple([convert_to_expr(a) if not isinstance(a, Expr) else a for a in args]) + return tvm.relax.Tuple([convert_to_expr(a) for a in args]) elif ( isinstance(args, Expr) and not isinstance(args, tvm.relax.Tuple) @@ -729,10 +729,7 @@ def call_inplace_packed( func = func.global_symbol op = ExternFunc(func) - args = tuple( - convert_to_expr(a) if isinstance(a, int | float | str | tuple | list) else a - for a in args - ) + args = tuple(convert_to_expr(a) for a in args) if sinfo_args is None: raise ValueError("R.call_pure_packed is required to have type_args") if isinstance(sinfo_args, tuple): # type: ignore @@ -783,10 +780,7 @@ def call_pure_packed( func = func.global_symbol op = ExternFunc(func) - args = tuple( - convert_to_expr(a) if isinstance(a, int | float | str | tuple | list) else a - for a in args - ) + args = tuple(convert_to_expr(a) for a in args) if sinfo_args is None: raise ValueError("R.call_pure_packed is required to have type_args") diff --git a/python/tvm/relax/op/distributed/distributed.py b/python/tvm/relax/op/distributed/distributed.py index 12b7da047aa4..b09f8686ac48 100644 --- a/python/tvm/relax/op/distributed/distributed.py +++ b/python/tvm/relax/op/distributed/distributed.py @@ -20,8 +20,10 @@ from tvm.ir import PrimExpr from tvm.relax.distributed import DTensorStructInfo from tvm.relax.distributed.struct_info import DeviceMesh, Placement + from ...expr import Call, Expr, GlobalVar, ShapeExpr from ...expr import Tuple as RxTuple +from ...utils import convert_to_expr from . import _ffi_api @@ -97,7 +99,7 @@ def call_tir_local_view( A call node for the call_tir_local_view operator. """ if isinstance(args, tuple | list): - args = RxTuple(list(args)) + args = RxTuple([convert_to_expr(a) for a in args]) elif isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore args = RxTuple((args,)) diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py index 793d30499305..71cae89c3ccf 100644 --- a/python/tvm/relax/op/index.py +++ b/python/tvm/relax/op/index.py @@ -100,10 +100,10 @@ def strided_slice( strided_slice require the input `begin`, `end` and `strides` to have the same length as `axes`. """ - axes = convert_to_expr(axes) if not isinstance(axes, Expr) else axes - begin = convert_to_expr(begin) if not isinstance(begin, Expr) else begin - end = convert_to_expr(end) if not isinstance(end, Expr) else end - if strides is not None and not isinstance(strides, Expr): + axes = convert_to_expr(axes) + begin = convert_to_expr(begin) + end = convert_to_expr(end) + if strides is not None: strides = convert_to_expr(strides) return _ffi_api.strided_slice(x, axes, begin, end, strides, assume_inbound) # type: ignore diff --git a/python/tvm/relax/op/memory/memory.py b/python/tvm/relax/op/memory/memory.py index e65877d08738..ba54eea9ef00 100644 --- a/python/tvm/relax/op/memory/memory.py +++ b/python/tvm/relax/op/memory/memory.py @@ -49,8 +49,7 @@ def alloc_storage( result : Call A relax Call, which gets the allocated storage. """ - if not isinstance(size, Expr): - size = convert_to_expr(size) + size = convert_to_expr(size) if isinstance(dtype, str): dtype = DataTypeImm(dtype) if isinstance(storage_scope, str): @@ -94,8 +93,7 @@ def alloc_tensor( """ if isinstance(offset, int): offset = PrimValue(offset) - if not isinstance(shape, Expr): - shape = convert_to_expr(shape) + shape = convert_to_expr(shape) if isinstance(dtype, str): dtype = DataTypeImm(dtype) return _ffi_api.alloc_tensor(storage, offset, shape, dtype, runtime_device_ind) # type: ignore diff --git a/python/tvm/relax/op/unary.py b/python/tvm/relax/op/unary.py index 4d8815f9ce5f..e7b800b5c9af 100644 --- a/python/tvm/relax/op/unary.py +++ b/python/tvm/relax/op/unary.py @@ -545,10 +545,8 @@ def clip(x: Expr, min: Expr, max: Expr) -> Expr: result : relax.Expr The computed result. """ - if not isinstance(min, Expr): - min = convert_to_expr(min) - if not isinstance(max, Expr): - max = convert_to_expr(max) + min = convert_to_expr(min) + max = convert_to_expr(max) return _ffi_api.clip(x, min, max) # type: ignore diff --git a/python/tvm/relax/op/vm/vm.py b/python/tvm/relax/op/vm/vm.py index eed1b768a40d..868ac922952c 100644 --- a/python/tvm/relax/op/vm/vm.py +++ b/python/tvm/relax/op/vm/vm.py @@ -49,8 +49,7 @@ def alloc_storage( result : Call A relax Call, which gets the allocated storage. """ - if not isinstance(shape, Expr): - shape = convert_to_expr(shape) + shape = convert_to_expr(shape) if isinstance(dtype, str): dtype = DataTypeImm(dtype) if isinstance(storage_scope, str): @@ -94,8 +93,7 @@ def alloc_tensor( """ if isinstance(offset, int): offset = PrimValue(offset) - if not isinstance(shape, Expr): - shape = convert_to_expr(shape) + shape = convert_to_expr(shape) if isinstance(dtype, str): dtype = DataTypeImm(dtype) return _ffi_api.alloc_tensor(storage, offset, shape, dtype, runtime_device_ind) # type: ignore @@ -135,8 +133,7 @@ def call_tir_dyn(func: Expr, args: Tuple) -> Call: result : Call A relax Call to call_tir_dyn. """ - if not isinstance(func, Expr): - func = convert_to_expr(func) + func = convert_to_expr(func) if isinstance(args, list | tuple): args = Tuple(args) diff --git a/python/tvm/script/ir_builder/relax/distributed/ir.py b/python/tvm/script/ir_builder/relax/distributed/ir.py index 674b110c3ef2..485e91e3e53e 100644 --- a/python/tvm/script/ir_builder/relax/distributed/ir.py +++ b/python/tvm/script/ir_builder/relax/distributed/ir.py @@ -40,6 +40,7 @@ from tvm.relax.op.distributed import ( redistribute as _redistribute, ) +from tvm.relax.utils import convert_to_expr from tvm.runtime import _tensor from ... import IRBuilder @@ -81,7 +82,7 @@ def call_tir( func = ExternFunc(func) if isinstance(args, tuple | list): - args = RxTuple(list(args)) + args = RxTuple([convert_to_expr(a) for a in args]) elif isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore args = RxTuple((args,)) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index bcbcf5353632..a40517992e83 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -224,8 +224,6 @@ py_tuple = tuple # pylint: disable=used-before-assignment py_str = str # pylint: disable=used-before-assignment -_CONVERTIBLE_TYPES = (int, float, str, tuple, list) - ################################ Device ################################ @@ -429,9 +427,7 @@ def call_packed( The created Relax Call """ op = ExternFunc(func) - args = py_tuple( - convert_to_expr(a) if isinstance(a, _CONVERTIBLE_TYPES) else a for a in args - ) + args = py_tuple(convert_to_expr(a) for a in args) if sinfo_args is None: sinfo_args = [] if isinstance(sinfo_args, py_tuple): # type: ignore @@ -488,9 +484,7 @@ def call_py_func( call: Call The created Relax Call for call_py_func operator. """ - args = py_tuple( - convert_to_expr(a) if isinstance(a, _CONVERTIBLE_TYPES) else a for a in args - ) + args = py_tuple(convert_to_expr(a) for a in args) if isinstance(out_sinfo, py_tuple): # type: ignore out_sinfo = list(out_sinfo) elif not isinstance(out_sinfo, list):