From e241e6fcd95aff96047f8044e9316383319b5964 Mon Sep 17 00:00:00 2001 From: tlopex <820958424@qq.com> Date: Fri, 10 Apr 2026 22:08:34 -0400 Subject: [PATCH] finsh1 --- python/tvm/relax/frontend/nn/core.py | 51 +++- python/tvm/relax/frontend/nn/modules.py | 297 ++++++++++++++++++++++-- 2 files changed, 322 insertions(+), 26 deletions(-) diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 40659e1623d7..f3886e94cbcf 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -123,7 +123,25 @@ def from_scalar(data: int | float, dtype: str) -> "Tensor": @staticmethod def from_struct_info(struct_info: rx.TensorStructInfo, name: str = "tensor") -> "Tensor": - """Construct a nn.Tensor from relax TensorStructInfo""" + """Construct a nn.Tensor from a Relax TensorStructInfo. + + TensorStructInfo is the Relax type-level description of a tensor, carrying its shape + and dtype without holding actual data. This factory creates an unbound placeholder + ``nn.Tensor`` that can be used as a symbolic input when tracing an ``nn.Module``. + + Parameters + ---------- + struct_info : rx.TensorStructInfo + The struct info describing the tensor's shape and dtype. + + name : str + Name hint for the underlying Relax variable. + + Returns + ------- + tensor : Tensor + A symbolic ``nn.Tensor`` backed by a ``relax.Var`` with the given struct info. + """ return Tensor( _expr=rx.Var( name_hint=name, @@ -492,7 +510,36 @@ def jit( # pylint: disable=too-many-arguments out_format: str = "torch", debug: bool = False, ) -> Any: - """Just-in-time compilation of a nn.model to an executable""" + """Just-in-time compile an ``nn.Module`` into a callable executable. + + The method exports the module to a Relax IRModule, applies the given compilation + pipeline, builds a Relax VM executable, and wraps the result so it can be called + directly (e.g. with PyTorch tensors when ``out_format="torch"``). + + Parameters + ---------- + spec : _spec.ModuleSpec + A specification mapping each module input to its shape and dtype. + + device : Union[str, Device] + The device to compile and run on (e.g. ``"cpu"``, ``"cuda"``). + + pipeline : Union[None, str, Pass] + The Relax compilation pipeline to apply. ``"default_build"`` uses the standard + optimization pipeline; ``None`` skips pipeline passes. + + out_format : str + Output wrapper format. ``"torch"`` returns a ``TorchModule`` whose ``forward`` + accepts and returns PyTorch tensors. + + debug : bool + If ``True``, enable effect-based debugging (e.g. printing) in the compiled graph. + + Returns + ------- + module : Any + A callable wrapper (type depends on *out_format*) around the compiled VM. + """ def _compile(spec, device, pipeline, debug): # pylint: disable=import-outside-toplevel diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index cf6c827b9ef2..8753d37bace1 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -95,8 +95,27 @@ def forward(self, x: Tensor): class Linear(Module): - """ - Module for linear layer. + """Applies a linear transformation :math:`y = xW^T + b`. + + Parameters + ---------- + in_features : Union[int, str, tirx.PrimExpr] + Size of each input sample. Can be symbolic. + + out_features : Union[int, str, tirx.PrimExpr] + Size of each output sample. Can be symbolic. + + bias : bool + If ``True``, adds a learnable bias. Default: ``True``. + + dtype : Optional[str] + Data type for weight (and bias when *out_dtype* is ``None``). + ``None`` uses the default dtype. + + out_dtype : Optional[str] + If set, the matmul accumulates in this dtype and the bias is stored in this dtype + instead of *dtype*. Useful for mixed-precision (e.g. ``float32`` accumulation with + ``float16`` weights). """ def __init__( @@ -154,8 +173,36 @@ def to(self, dtype: str | None = None) -> None: class Conv1D(Module): - """ - Module for conv1d layer. + """Applies a 1D convolution over an input signal. + + Parameters + ---------- + in_channels : int + Number of channels in the input. + + out_channels : int + Number of channels produced by the convolution. + + kernel_size : int + Size of the convolving kernel. + + stride : int + Stride of the convolution. Default: 1. + + padding : int + Zero-padding added to both sides of the input. Default: 0. + + dilation : int + Spacing between kernel elements. Default: 1. + + groups : int + Number of blocked connections from input to output channels. Default: 1. + + bias : bool + If ``True``, adds a learnable bias. Default: ``True``. + + dtype : Optional[str] + Data type for weight and bias. ``None`` uses the default dtype. """ def __init__( @@ -212,8 +259,39 @@ def forward(self, x: Tensor) -> Tensor: class Conv2D(Module): - """ - Module for conv2d layer. + """Applies a 2D convolution over an input image. + + Parameters + ---------- + in_channels : int + Number of channels in the input image. + + out_channels : int + Number of channels produced by the convolution. + + kernel_size : Union[List[int], int] + Size of the convolving kernel. An int is expanded to a 2-element list. + + stride : int + Stride of the convolution. Default: 1. + + padding : int + Zero-padding added to both sides of the input. Default: 0. + + dilation : int + Spacing between kernel elements. Default: 1. + + groups : int + Number of blocked connections from input to output channels. Default: 1. + + bias : bool + If ``True``, adds a learnable bias. Default: ``True``. + + dtype : Optional[str] + Data type for weight and bias. ``None`` uses the default dtype. + + data_layout : str + Layout of the input data, e.g. ``"NCHW"`` or ``"NHWC"``. Default: ``"NCHW"``. """ def __init__( # pylint: disable=too-many-arguments @@ -286,8 +364,39 @@ def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name class Conv3D(Module): - """ - Module for conv3d layer. + """Applies a 3D convolution over an input volume. + + Parameters + ---------- + in_channels : int + Number of channels in the input volume. + + out_channels : int + Number of channels produced by the convolution. + + kernel_size : Union[List[int], int] + Size of the convolving kernel. An int is expanded to a 3-element list. + + stride : Union[List[int], int] + Stride of the convolution. Default: 1. + + padding : Union[List[int], int] + Zero-padding added to each side of the input. Default: 0. + + dilation : int + Spacing between kernel elements. Default: 1. + + groups : int + Number of blocked connections from input to output channels. Default: 1. + + bias : bool + If ``True``, adds a learnable bias. Default: ``True``. + + dtype : Optional[str] + Data type for weight and bias. ``None`` uses the default dtype. + + data_layout : str + Layout of the input data, e.g. ``"NCDHW"``. Default: ``"NCDHW"``. """ def __init__( # pylint: disable=too-many-arguments @@ -360,8 +469,39 @@ def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name class ConvTranspose1D(Module): - """ - Module for ConvTranspose1D layer. + """Applies a 1D transposed convolution (fractionally-strided convolution). + + Parameters + ---------- + in_channels : int + Number of channels in the input. + + out_channels : int + Number of channels produced by the transposed convolution. + + kernel_size : int + Size of the convolving kernel. + + stride : int + Stride of the convolution. Default: 1. + + padding : int + Zero-padding added to both sides of the input. Default: 0. + + output_padding : int + Additional size added to one side of the output shape. Default: 0. + + dilation : int + Spacing between kernel elements. Default: 1. + + groups : int + Number of blocked connections from input to output channels. Default: 1. + + bias : bool + If ``True``, adds a learnable bias. Default: ``True``. + + dtype : Optional[str] + Data type for weight and bias. ``None`` uses the default dtype. """ def __init__( @@ -427,8 +567,22 @@ def forward(self, x: Tensor) -> Tensor: class LayerNorm(Module): - """ - Module for Layer Normalization + """Applies Layer Normalization over the last dimension. + + Parameters + ---------- + normalized_shape : int + Size of the last dimension to normalize over. + + eps : Optional[float] + Value added to the denominator for numerical stability. Default: ``1e-5``. + + elementwise_affine : bool + If ``True``, learnable affine parameters (weight and bias) are added. + Default: ``True``. + + dtype : Optional[str] + Data type for the affine parameters. ``None`` uses the default dtype. """ def __init__( @@ -473,8 +627,24 @@ def forward(self, x: Tensor) -> Tensor: class RMSNorm(Module): - """ - Module for rms norm layer. + """Applies Root Mean Square Layer Normalization. + + Parameters + ---------- + hidden_size : int + Size of the weight parameter. + + axes : Union[int, List[int]] + The axes over which to compute the RMS norm. + + epsilon : float + Value added to the denominator for numerical stability. Default: ``1e-5``. + + bias : bool + If ``True``, adds a learnable bias after normalization. Default: ``True``. + + dtype : Optional[str] + Data type for the parameters. ``None`` uses the default dtype. """ def __init__( @@ -515,8 +685,24 @@ def forward(self, x: Tensor): class GroupNorm(Module): - """ - Module for group norm layer. + """Applies Group Normalization. + + Parameters + ---------- + num_groups : int + Number of groups to separate the channels into. + + num_channels : int + Number of channels in the input, must be divisible by *num_groups*. + + eps : float + Value added to the denominator for numerical stability. Default: ``1e-5``. + + affine : bool + If ``True``, learnable per-channel affine parameters are added. Default: ``True``. + + dtype : Optional[str] + Data type for the affine parameters. ``None`` uses the default dtype. """ def __init__( @@ -563,8 +749,28 @@ def forward(self, x: Tensor, channel_axis: int = 1, axes: list[int] | None = Non class KVCache(Effect): - """ - Effect to implement KVCache. + """Managed key-value cache for autoregressive decoding. + + ``KVCache`` is a TVM-specific ``Effect`` that allocates and maintains a runtime cache + for storing past key/value tensors in transformer models. Unlike regular ``Module`` + parameters, effects are registered with the Relax VM and carry mutable state across + calls (append, reset) without being passed as explicit function arguments. + + The cache is pre-allocated with shape ``[init_seq_len, *unit_shape]`` and grows via + the ``append`` method at runtime. Use ``init_seq_len`` to control the initial + allocation size. + + Parameters + ---------- + init_seq_len : int + Initial sequence-length capacity of the cache allocation. + + unit_shape : Sequence[int] + Shape of a single cache entry excluding the sequence dimension. + For multi-head attention this is typically ``[num_heads, head_dim]``. + + dtype : Optional[str] + Data type of the cache tensor. ``None`` uses the default dtype. """ init_seq_len: int @@ -708,8 +914,18 @@ def append(self, new_element: Tensor) -> None: class Embedding(Module): - """ - Module for embedding layer. + """A lookup table that retrieves embeddings by index. + + Parameters + ---------- + num : Union[int, str, tirx.PrimExpr] + Size of the embedding dictionary (vocabulary size). Can be symbolic. + + dim : Union[int, str, tirx.PrimExpr] + Size of each embedding vector. Can be symbolic. + + dtype : Optional[str] + Data type of the embedding weight. ``None`` uses the default dtype. """ def __init__( @@ -749,8 +965,31 @@ def forward(self, x: Tensor): class TimestepEmbedding(Module): - """ - Module for HF TimestepEmbedding layer. + """MLP that projects timestep embeddings, following the HuggingFace diffusers convention. + + Consists of two linear layers with an activation in between, and an optional + conditional projection and post-activation. + + Parameters + ---------- + in_channels : int + Dimensionality of the input timestep embedding. + + time_embed_dim : int + Dimensionality of the intermediate (hidden) projection. + + act_fn : str + Activation function name. Currently only ``"silu"`` is supported. + + out_dim : Optional[int] + Dimensionality of the output. If ``None``, defaults to *time_embed_dim*. + + post_act_fn : Optional[str] + Optional post-activation applied after the second linear layer. + + cond_proj_dim : Optional[int] + If set, adds a linear projection from a conditioning signal of this + dimensionality to *in_channels*, which is added to the input sample. """ def __init__( @@ -816,8 +1055,18 @@ def forward(self, sample: Tensor, condition: Tensor | None = None): class Timesteps(Module): - """ - Module for HF timesteps layer. + """Sinusoidal positional embedding for diffusion timesteps (HuggingFace convention). + + Parameters + ---------- + num_channels : int + Dimensionality of the embedding (number of sinusoidal channels). + + flip_sin_to_cos : bool + If ``True``, swap sin and cos components. Default: ``False``. + + downscale_freq_shift : float + Shift applied to the frequency denominator. Default: ``1``. """ def __init__(