Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 49 additions & 2 deletions python/tvm/relax/frontend/nn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,25 @@ def from_scalar(data: int | float, dtype: str) -> "Tensor":

@staticmethod
def from_struct_info(struct_info: rx.TensorStructInfo, name: str = "tensor") -> "Tensor":
"""Construct a nn.Tensor from relax TensorStructInfo"""
"""Construct a nn.Tensor from a Relax TensorStructInfo.

TensorStructInfo is the Relax type-level description of a tensor, carrying its shape
and dtype without holding actual data. This factory creates an unbound placeholder
``nn.Tensor`` that can be used as a symbolic input when tracing an ``nn.Module``.

Parameters
----------
struct_info : rx.TensorStructInfo
The struct info describing the tensor's shape and dtype.

name : str
Name hint for the underlying Relax variable.

Returns
-------
tensor : Tensor
A symbolic ``nn.Tensor`` backed by a ``relax.Var`` with the given struct info.
"""
return Tensor(
_expr=rx.Var(
name_hint=name,
Expand Down Expand Up @@ -492,7 +510,36 @@ def jit( # pylint: disable=too-many-arguments
out_format: str = "torch",
debug: bool = False,
) -> Any:
"""Just-in-time compilation of a nn.model to an executable"""
"""Just-in-time compile an ``nn.Module`` into a callable executable.

The method exports the module to a Relax IRModule, applies the given compilation
pipeline, builds a Relax VM executable, and wraps the result so it can be called
directly (e.g. with PyTorch tensors when ``out_format="torch"``).

Parameters
----------
spec : _spec.ModuleSpec
A specification mapping each module input to its shape and dtype.

device : Union[str, Device]
The device to compile and run on (e.g. ``"cpu"``, ``"cuda"``).

pipeline : Union[None, str, Pass]
The Relax compilation pipeline to apply. ``"default_build"`` uses the standard
optimization pipeline; ``None`` skips pipeline passes.

out_format : str
Output wrapper format. ``"torch"`` returns a ``TorchModule`` whose ``forward``
accepts and returns PyTorch tensors.

debug : bool
If ``True``, enable effect-based debugging (e.g. printing) in the compiled graph.

Returns
-------
module : Any
A callable wrapper (type depends on *out_format*) around the compiled VM.
"""

def _compile(spec, device, pipeline, debug):
# pylint: disable=import-outside-toplevel
Expand Down
Loading
Loading