Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,4 @@ jobs:
with:
fetch-depth: 0
fetch-tags: true
- name: Set up uv
uses: astral-sh/setup-uv@b75a909f75acd358c2196fb9a5f1299a9a8868a4 # v6.7.0
- name: Set up Python environment
run: uv sync --group lint --no-install-project
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
1 change: 0 additions & 1 deletion docs/arch/pass_infra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -667,4 +667,3 @@ new ``PassInstrument`` are called.
.. _src/tirx/transform/unroll_loop.cc: https://github.com/apache/tvm/blob/main/src/tirx/transform/unroll_loop.cc

.. _use pass infra: https://github.com/apache/tvm/blob/main/docs/how_to/tutorials/customize_opt.py

4 changes: 1 addition & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,9 +724,7 @@ def _dedup_find_obj(self, env, modname, classname, name, objtype, searchmode=0):
return context_matches

# Fall back to the unique match that best shares the current module prefix.
match_scores = {
match[0]: _common_prefix_len(modname, match[0]) for match in matches
}
match_scores = {match[0]: _common_prefix_len(modname, match[0]) for match in matches}
best_score = max(match_scores.values())
if best_score > 1:
best_matches = [match for match in matches if match_scores[match[0]] == best_score]
Expand Down
35 changes: 15 additions & 20 deletions docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# ruff: noqa: E402

"""
.. _mix_python_and_tvm:
Expand Down Expand Up @@ -163,8 +162,10 @@ def forward(self, x, weights):
logits = self._convert_tvm_to_pytorch(out)

# Inspect intermediate value — impossible with a compiled-only workflow
print(f" [DEBUG] logits shape: {logits.shape}, "
f"min: {logits.min():.4f}, max: {logits.max():.4f}")
print(
f" [DEBUG] logits shape: {logits.shape}, "
f"min: {logits.min():.4f}, max: {logits.max():.4f}"
)

result = F.softmax(logits, dim=-1)

Expand Down Expand Up @@ -198,12 +199,10 @@ def forward(self, x, weights):
# — for example, CUBLAS or cuDNN bindings that TVM wraps as packed functions.

if RUN_EXAMPLE:

# Register a packed function (simulating an external library binding)
@tvm.register_global_func("my_bias_add", override=True)
def my_bias_add(x, bias, out):
"""Packed function: adds bias to each row of x."""
import numpy as np

x_np = x.numpy()
b_np = bias.numpy()
Expand All @@ -230,14 +229,16 @@ def forward(self, x, weights, bias):
x_tvm = self._convert_pytorch_to_tvm(x)
w_tvm = self._convert_pytorch_to_tvm(weights)
h = self.call_tir(
self.matmul_tir, [x_tvm, w_tvm],
self.matmul_tir,
[x_tvm, w_tvm],
out_sinfo=R.Tensor((2, 3), "float32"),
)
h_pt = self._convert_tvm_to_pytorch(h)

# 2. Packed function for bias add (simulating an external library)
h_biased = self.call_dps_packed(
"my_bias_add", [h_pt, bias],
"my_bias_add",
[h_pt, bias],
out_sinfo=R.Tensor((2, 3), "float32"),
)

Expand Down Expand Up @@ -291,7 +292,8 @@ def main(
h = R.matmul(x, w)
cls = DenseLayer
h_bias = R.call_tir(
cls.bias_add_tir, (h, b),
cls.bias_add_tir,
(h, b),
out_sinfo=R.Tensor((2, 4), "float32"),
)
return R.nn.relu(h_bias)
Expand Down Expand Up @@ -324,8 +326,7 @@ def main(

print("\nAfter CanonicalizeBindings pass:")
print(" Converted result:", py_result_late)
print(" Still matches: ",
torch.allclose(py_result_late, expected, atol=1e-5))
print(" Still matches: ", torch.allclose(py_result_late, expected, atol=1e-5))
assert torch.allclose(py_result_late, expected, atol=1e-5)


Expand Down Expand Up @@ -363,12 +364,8 @@ def main(
x: R.Tensor((4, 8), "float32"),
) -> R.Tensor((4, 8), "float32"):
# The VM calls back into Python for these two ops
h = R.call_py_func(
"layer_norm", (x,), out_sinfo=R.Tensor((4, 8), "float32")
)
out = R.call_py_func(
"silu", (h,), out_sinfo=R.Tensor((4, 8), "float32")
)
h = R.call_py_func("layer_norm", (x,), out_sinfo=R.Tensor((4, 8), "float32"))
out = R.call_py_func("silu", (h,), out_sinfo=R.Tensor((4, 8), "float32"))
return out

mod = HybridVMModule(device=tvm.cpu(0))
Expand All @@ -390,7 +387,7 @@ def main(
# ``BasePyModule`` is designed for **cross-level interoperability**: Python functions can call
# TIR and Relax functions, and Relax functions can call Python functions. We have already seen:
#
# - Python → TIR via ``call_tir`` (Steps 13)
# - Python → TIR via ``call_tir`` (Steps 1-3)
# - Python → packed function via ``call_dps_packed`` (Step 3)
# - Relax → Python via ``R.call_py_func`` (Step 5)
#
Expand Down Expand Up @@ -441,9 +438,7 @@ def add_relax(
# Python → TIR with symbolic output shape
n = T.int64()
x7 = torch.randn(7)
scaled = mod.call_tir(
"scale_tir", [x7], relax.TensorStructInfo((n,), "float32")
)
scaled = mod.call_tir("scale_tir", [x7], relax.TensorStructInfo((n,), "float32"))
print("scale_tir(len=7):", scaled)
assert torch.allclose(torch.tensor(scaled.numpy()), x7 * 2.0, atol=1e-5)

Expand Down
11 changes: 6 additions & 5 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,12 @@ struct Conv3DTransposeAttrs : public AttrsNodeReflAdapter<Conv3DTransposeAttrs>
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Convolution is applied on the 'D', 'H', and"
"'W' dimensions.")
.def_ro("kernel_layout", &Conv3DTransposeAttrs::kernel_layout,
"Dimension ordering of weight. Can be 'IODHW', etc."
"'I', 'O', 'D', 'H', 'W' stands for input_channel, output_channel, depth, height, and "
"width"
"dimensions respectively.")
.def_ro(
"kernel_layout", &Conv3DTransposeAttrs::kernel_layout,
"Dimension ordering of weight. Can be 'IODHW', etc."
"'I', 'O', 'D', 'H', 'W' stands for input_channel, output_channel, depth, height, and "
"width"
"dimensions respectively.")
.def_ro("out_layout", &Conv3DTransposeAttrs::out_layout,
"Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ class Node(Object):
"""Base class of all IR Nodes."""

def __repr__(self) -> str:
from tvm.runtime.script_printer import _script # noqa: PLC0415
from tvm.runtime.script_printer import _script

try:
return _script(self, None)
except Exception: # noqa: BLE001
except Exception:
return super().__repr__()


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/backend/contrib/example_npu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@
constraints, making them available for graph partitioning.
"""

from . import patterns # noqa: F401
from . import patterns

__all__ = ["patterns"]
14 changes: 10 additions & 4 deletions python/tvm/relax/frontend/nn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,9 @@ def __setitem__(self, key: str, param: Parameter) -> None:
if not isinstance(key, str):
raise TypeError(f"ParameterDict keys must be strings, but got {type(key).__name__}")
if not isinstance(param, Parameter):
raise TypeError(f"ParameterDict values must be nn.Parameter, but got {type(param).__name__}")
raise TypeError(
f"ParameterDict values must be nn.Parameter, but got {type(param).__name__}"
)
self.params[key] = param

def __len__(self) -> int:
Expand Down Expand Up @@ -731,16 +733,20 @@ def __getitem__(self, idx: int) -> Parameter:

def __setitem__(self, idx: int, param: Parameter) -> None:
if not isinstance(param, Parameter):
raise TypeError(f"ParameterList elements must be nn.Parameter, but got {type(param).__name__}")
raise TypeError(
f"ParameterList elements must be nn.Parameter, but got {type(param).__name__}"
)
self.params[idx] = param

def __len__(self) -> int:
return len(self.params)

def append(self, param: Parameter) -> None:
"""Add a parameter to the end of the ParameterList"""
if not isinstance(param, Parameter):
raise TypeError(f"ParameterList elements must be nn.Parameter, but got {type(param).__name__}")
if not isinstance(param, Parameter):
raise TypeError(
f"ParameterList elements must be nn.Parameter, but got {type(param).__name__}"
)
self.params.append(param)

def extend(self, params: list[Parameter]) -> None:
Expand Down
36 changes: 15 additions & 21 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,9 +792,7 @@ def _legacy_softmax_prepare(
return flattened, tuple(original_shape)


def _get_axis_extent(
data: relax.Expr, axis: int, op_name: str
) -> tuple[int, int | tirx.PrimExpr]:
def _get_axis_extent(data: relax.Expr, axis: int, op_name: str) -> tuple[int, int | tirx.PrimExpr]:
"""Return normalized axis and axis extent when rank/shape are known."""

rank = _get_known_tensor_rank(data)
Expand All @@ -803,7 +801,9 @@ def _get_axis_extent(

normalized_axis = _normalize_constant_axes([axis], rank, op_name)[0]
struct_info = data.struct_info
if isinstance(struct_info, relax.TensorStructInfo) and isinstance(struct_info.shape, relax.ShapeExpr):
if isinstance(struct_info, relax.TensorStructInfo) and isinstance(
struct_info.shape, relax.ShapeExpr
):
axis_extent = struct_info.shape.values[normalized_axis]
if isinstance(axis_extent, tirx.IntImm):
axis_extent = int(axis_extent.value)
Expand Down Expand Up @@ -881,9 +881,7 @@ def _hardmax_impl(cls, *args):
bb = None
data, axis = args
else:
raise TypeError(
"Hardmax._hardmax_impl expects (bb, data, axis) or (data, axis)."
)
raise TypeError("Hardmax._hardmax_impl expects (bb, data, axis) or (data, axis).")

if bb is not None:
data = bb.normalize(data)
Expand Down Expand Up @@ -1130,7 +1128,7 @@ def _impl_v13(cls, bb, inputs, attr, params):
relax.op.take(data_shape_tensor, relax.const(axis, "int64"), axis=0, mode="wrap")
)

if indices_dtype !="int64":
if indices_dtype != "int64":
axis_extent = bb.normalize(relax.op.astype(axis_extent, indices_dtype))

indices = bb.normalize(
Expand Down Expand Up @@ -1182,9 +1180,7 @@ def _get_onnx_reduction(attr, valid_reductions: list[str]):
reduction = reduction.decode("utf-8")
reduction = "update" if reduction == "none" else reduction
if reduction not in valid_reductions:
raise ValueError(
f"Only {valid_reductions} reductions are supported, but got {reduction}"
)
raise ValueError(f"Only {valid_reductions} reductions are supported, but got {reduction}")

return reduction

Expand Down Expand Up @@ -1775,10 +1771,7 @@ def _impl_v1(cls, bb, inputs, attr, params):
pads_end: list[int] = []
for i in range(spatial_dims):
total_pad = (
(kernel_shape[i] - 1) * dilations[i]
+ 1
+ output_padding[i]
- strides[i]
(kernel_shape[i] - 1) * dilations[i] + 1 + output_padding[i] - strides[i]
)
total_pad = max(total_pad, 0)
if auto_pad == "SAME_UPPER":
Expand Down Expand Up @@ -1844,18 +1837,20 @@ def _impl_v14(cls, bb, inputs, attr, params):
else:
raise ValueError(
"CumSum axis input must be a scalar (0-D) or a single-element 1-D tensor, "
"got shape {}".format(axis_data.shape)
f"got shape {axis_data.shape}"
)
elif isinstance(axis_input, relax.Var):
axis_shape = axis_input.struct_info.shape if hasattr(axis_input.struct_info, "shape") else None
axis_shape = (
axis_input.struct_info.shape if hasattr(axis_input.struct_info, "shape") else None
)
raise ValueError(
"CumSum with non-constant axis input is not supported yet. "
"ONNX permits runtime axis tensors, but Relax/TE currently requires a compile-time "
"constant axis for cumsum/flip. Got axis shape {}".format(axis_shape)
f"constant axis for cumsum/flip. Got axis shape {axis_shape}"
)
else:
raise TypeError("CumSum axis input must be a Constant or Var")

if attr.get("reverse", 0) != 0:
data = bb.emit_te(topi.flip, data, axis=axis)

Expand Down Expand Up @@ -4694,7 +4689,6 @@ def _impl_v11(cls, bb, inputs, attr, params):

input_tensor = inputs[0]
input_shape = input_tensor.struct_info.shape
split_is_scalar = False

if len(inputs) == 1:
split = _np.array(1)
Expand All @@ -4711,7 +4705,7 @@ def _impl_v11(cls, bb, inputs, attr, params):
chunk_size = int(split)
dim_size = input_shape[axis]

if isinstance(dim_size, (int, tirx.IntImm)):
if isinstance(dim_size, int | tirx.IntImm):
dim_size_int = int(dim_size)
split = math.ceil(dim_size_int / chunk_size)
else:
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relax/frontend/tflite/tflite_flexbuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def decode_vector(self, end, size, byte_width):
value_type = FlexBufferType(value_type_packed >> 2)
value_bit_width = BitWidth(value_type_packed & 3)
value_byte_width = 1 << value_bit_width
value_bytes = self.buffer[end + i * byte_width : end + i * byte_width + value_byte_width]
value_bytes = self.buffer[
end + i * byte_width : end + i * byte_width + value_byte_width
]
if value_type == FlexBufferType.FBT_BOOL:
value = bool(value_bytes[0])
elif value_type == FlexBufferType.FBT_INT:
Expand Down
Loading
Loading