Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from tvm.runtime import vm
from tvm.runtime.vm import VirtualMachine, VMInstrumentReturnKind

from .type_converter import args_converter

# Expr
from .expr import (
Expr,
Expand Down
25 changes: 12 additions & 13 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ...ir import PrimExpr
from ..expr import Call, Expr, ExternFunc, GlobalVar, ShapeExpr, StringImm, Var
from ..struct_info import StructInfo, TensorStructInfo
from ..utils import args_converter
from ..utils import convert_to_expr
from . import _ffi_api

py_print = print # pylint: disable=invalid-name
Expand Down Expand Up @@ -76,7 +76,9 @@ def _wrap_inline_arg_tuple(args) -> Expr:
in-line relax Tuple.

"""
if (
if isinstance(args, tuple | list):
return tvm.relax.Tuple([convert_to_expr(a) for a in args])
elif (
isinstance(args, Expr)
and not isinstance(args, tvm.relax.Tuple)
and (
Expand All @@ -89,7 +91,6 @@ def _wrap_inline_arg_tuple(args) -> Expr:
return args


@args_converter.auto
def call_tir(
gvar: GlobalVar,
args: Expr,
Expand Down Expand Up @@ -131,7 +132,6 @@ def call_tir(
return _ffi_api.call_tir(gvar, args, out_sinfo, tir_vars) # type: ignore


@args_converter.auto
def call_tir_with_grad(
gvar: GlobalVar,
args: Expr,
Expand Down Expand Up @@ -190,7 +190,6 @@ def call_tir_with_grad(
)


@args_converter.auto
def call_tir_inplace(
gvar: GlobalVar,
args: Expr,
Expand Down Expand Up @@ -261,7 +260,6 @@ def call_tir_inplace(
)


@args_converter.auto
def call_dps_packed(
func: str | Expr,
args: Expr,
Expand Down Expand Up @@ -303,7 +301,6 @@ def call_dps_packed(
return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore


@args_converter.auto
def call_py_func(
func_name: str,
args: Expr,
Expand Down Expand Up @@ -339,7 +336,6 @@ def call_py_func(
return _ffi_api.call_py_func(func_name, args, out_sinfo) # type: ignore


@args_converter.auto
def call_builtin_with_ctx(
func: str | Expr,
args: Expr,
Expand Down Expand Up @@ -367,6 +363,8 @@ def call_builtin_with_ctx(
if isinstance(func, str):
func = ExternFunc(func)

args = _wrap_inline_arg_tuple(args)

if sinfo_args is not None and not isinstance(sinfo_args, list | tuple):
sinfo_args = [sinfo_args]

Expand All @@ -377,7 +375,6 @@ def call_builtin_with_ctx(
)


@args_converter.auto
def make_closure(
func: Expr,
args: Expr,
Expand All @@ -400,10 +397,11 @@ def make_closure(
The VMClosure.
"""

args = _wrap_inline_arg_tuple(args)

return _ffi_api.make_closure(func, args) # type: ignore


@args_converter.auto
def invoke_closure(
closure: Expr,
args: Expr,
Expand All @@ -428,6 +426,7 @@ def invoke_closure(
ret: Call
A call to `invoke_closure`.
"""
args = _wrap_inline_arg_tuple(args)

if not isinstance(sinfo_args, list | tuple):
sinfo_args = [sinfo_args]
Expand Down Expand Up @@ -677,7 +676,6 @@ def shape_to_tensor(expr: Expr) -> Expr:
return _ffi_api.shape_to_tensor(expr) # type: ignore # pylint: disable=no-member


@args_converter.auto
def call_inplace_packed(
func: str | ExternFunc | GlobalVar,
*args: Expr,
Expand Down Expand Up @@ -731,6 +729,7 @@ def call_inplace_packed(
func = func.global_symbol

op = ExternFunc(func)
args = tuple(convert_to_expr(a) for a in args)
if sinfo_args is None:
raise ValueError("R.call_pure_packed is required to have type_args")
if isinstance(sinfo_args, tuple): # type: ignore
Expand All @@ -743,7 +742,6 @@ def call_inplace_packed(
return _ffi_api.call_inplace_packed(op, args, inplace_indices, sinfo_args) # type: ignore # pylint: disable=no-member


@args_converter.auto
def call_pure_packed(
func: str | ExternFunc | GlobalVar,
*args: Expr,
Expand Down Expand Up @@ -782,6 +780,7 @@ def call_pure_packed(
func = func.global_symbol

op = ExternFunc(func)
args = tuple(convert_to_expr(a) for a in args)

if sinfo_args is None:
raise ValueError("R.call_pure_packed is required to have type_args")
Expand All @@ -807,7 +806,6 @@ def call_pure_packed(
return _ffi_api.call_pure_packed(op, args, None, sinfo_args) # type: ignore # pylint: disable=no-member


@args_converter.auto
def invoke_pure_closure(
closure: Expr,
args: Expr,
Expand Down Expand Up @@ -838,6 +836,7 @@ def invoke_pure_closure(
ret: Call
A call to `invoke_pure_closure`.
"""
args = _wrap_inline_arg_tuple(args)

if not isinstance(sinfo_args, list | tuple):
sinfo_args = [sinfo_args]
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relax/op/builtin/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
"""The builtin Relax operators."""

from ...expr import Call, DataTypeImm, Expr, PrimValue, StringImm
from ...utils import args_converter
from ...utils import convert_to_expr
from . import _ffi_api


@args_converter.auto
def alloc_tensor(
shape: Expr,
dtype: str | Expr,
Expand Down Expand Up @@ -49,6 +48,8 @@ def alloc_tensor(
result : Call
A relax Call, which gets the allocated tensor.
"""
if not isinstance(shape, Expr):
shape = convert_to_expr(shape)
if isinstance(dtype, str):
dtype = DataTypeImm(dtype)
if isinstance(runtime_device_index, int):
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/relax/op/distributed/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from tvm.ir import PrimExpr
from tvm.relax.distributed import DTensorStructInfo
from tvm.relax.distributed.struct_info import DeviceMesh, Placement
from tvm.relax.utils import args_converter

from ...expr import Call, Expr, GlobalVar, ShapeExpr
from ...expr import Tuple as RxTuple
from ...utils import convert_to_expr
from . import _ffi_api


Expand Down Expand Up @@ -66,7 +66,6 @@ def redistribute(input: Expr, device_mesh: DeviceMesh, placement: Placement) ->
return _ffi_api.redistribute(input, device_mesh, placement) # type: ignore


@args_converter.auto
def call_tir_local_view(
gvar: GlobalVar,
args: Expr,
Expand Down Expand Up @@ -99,7 +98,9 @@ def call_tir_local_view(
ret: Call
A call node for the call_tir_local_view operator.
"""
if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
if isinstance(args, tuple | list):
args = RxTuple([convert_to_expr(a) for a in args])
elif isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))

if not isinstance(out_sinfo, list):
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/relax/op/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

from tvm.ir.expr import PrimExpr

from .. import args_converter
from ..expr import Expr
from ..utils import convert_to_expr
from . import _ffi_api

PrimExprLike = int | PrimExpr
Expand Down Expand Up @@ -58,7 +58,6 @@ def take(x: Expr, indices: Expr, axis: int | None = None, mode: str = "fast") ->
return _ffi_api.take(x, indices, axis, mode) # type: ignore


@args_converter.auto
def strided_slice(
x: Expr,
axes: Expr,
Expand Down Expand Up @@ -101,6 +100,11 @@ def strided_slice(
strided_slice require the input `begin`, `end` and `strides` to have the
same length as `axes`.
"""
axes = convert_to_expr(axes)
begin = convert_to_expr(begin)
end = convert_to_expr(end)
if strides is not None:
strides = convert_to_expr(strides)
return _ffi_api.strided_slice(x, axes, begin, end, strides, assume_inbound) # type: ignore


Expand Down
8 changes: 3 additions & 5 deletions python/tvm/relax/op/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
"""Relax memory primitives."""

from ...expr import Call, DataTypeImm, Expr, PrimValue, StringImm
from ...utils import args_converter
from ...utils import convert_to_expr
from . import _ffi_api


@args_converter.auto
def alloc_storage(
size: Expr,
virtual_device_index: int | Expr,
Expand Down Expand Up @@ -50,6 +49,7 @@ def alloc_storage(
result : Call
A relax Call, which gets the allocated storage.
"""
size = convert_to_expr(size)
if isinstance(dtype, str):
dtype = DataTypeImm(dtype)
if isinstance(storage_scope, str):
Expand All @@ -59,7 +59,6 @@ def alloc_storage(
return _ffi_api.alloc_storage(size, virtual_device_index, storage_scope, dtype) # type: ignore


@args_converter.auto
def alloc_tensor(
storage: Expr,
offset: int | Expr,
Expand Down Expand Up @@ -94,12 +93,12 @@ def alloc_tensor(
"""
if isinstance(offset, int):
offset = PrimValue(offset)
shape = convert_to_expr(shape)
if isinstance(dtype, str):
dtype = DataTypeImm(dtype)
return _ffi_api.alloc_tensor(storage, offset, shape, dtype, runtime_device_ind) # type: ignore


@args_converter.auto
def kill_storage(storage: Expr) -> Call:
"""Construct a Call to kill a storage.

Expand All @@ -116,7 +115,6 @@ def kill_storage(storage: Expr) -> Call:
return _ffi_api.kill_storage(storage) # type: ignore


@args_converter.auto
def kill_tensor(tensor: Expr) -> Call:
"""Construct a Call to kill a tensor.

Expand Down
2 changes: 0 additions & 2 deletions python/tvm/relax/op/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
# under the License.
"""Sampling operators."""

from .. import args_converter
from ..expr import Expr
from . import _ffi_api


@args_converter.auto
def multinomial_from_uniform(
prob: Expr,
uniform_sample: Expr,
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relax/op/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Relax unary arithmetic operators."""

from ..expr import Expr
from ..utils import args_converter
from ..utils import convert_to_expr
from . import _ffi_api

###################### Arithmetic operators ######################
Expand Down Expand Up @@ -526,7 +526,6 @@ def trunc(x: Expr) -> Expr:
return _ffi_api.trunc(x) # type: ignore


@args_converter.auto
def clip(x: Expr, min: Expr, max: Expr) -> Expr:
"""Clips tensor values to a specified min and max.

Expand All @@ -546,6 +545,8 @@ def clip(x: Expr, min: Expr, max: Expr) -> Expr:
result : relax.Expr
The computed result.
"""
min = convert_to_expr(min)
max = convert_to_expr(max)
return _ffi_api.clip(x, min, max) # type: ignore


Expand Down
8 changes: 4 additions & 4 deletions python/tvm/relax/op/vm/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
"""Relax vm primitives."""

from ...expr import Call, DataTypeImm, Expr, PrimValue, StringImm, Tuple
from ...utils import args_converter
from ...utils import convert_to_expr
from . import _ffi_api


@args_converter.auto
def alloc_storage(
shape: Expr,
runtime_device_index: int | Expr,
Expand Down Expand Up @@ -50,6 +49,7 @@ def alloc_storage(
result : Call
A relax Call, which gets the allocated storage.
"""
shape = convert_to_expr(shape)
if isinstance(dtype, str):
dtype = DataTypeImm(dtype)
if isinstance(storage_scope, str):
Expand All @@ -59,7 +59,6 @@ def alloc_storage(
return _ffi_api.alloc_storage(shape, runtime_device_index, dtype, storage_scope) # type: ignore


@args_converter.auto
def alloc_tensor(
storage: Expr,
offset: int | Expr,
Expand Down Expand Up @@ -94,6 +93,7 @@ def alloc_tensor(
"""
if isinstance(offset, int):
offset = PrimValue(offset)
shape = convert_to_expr(shape)
if isinstance(dtype, str):
dtype = DataTypeImm(dtype)
return _ffi_api.alloc_tensor(storage, offset, shape, dtype, runtime_device_ind) # type: ignore
Expand All @@ -116,7 +116,6 @@ def kill_object(obj: Expr) -> Call:
return _ffi_api.kill_object(obj) # type: ignore


@args_converter.auto
def call_tir_dyn(func: Expr, args: Tuple) -> Call:
"""Construct a Call to call_tir_dyn (invoke the given TIR PrimFunc)
consisting of the input tensors and the shape of the result.
Expand All @@ -134,6 +133,7 @@ def call_tir_dyn(func: Expr, args: Tuple) -> Call:
result : Call
A relax Call to call_tir_dyn.
"""
func = convert_to_expr(func)
if isinstance(args, list | tuple):
args = Tuple(args)

Expand Down
Loading
Loading