diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 17fa38339f81..3936789a91f4 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -35,8 +35,4 @@ jobs: with: fetch-depth: 0 fetch-tags: true - - name: Set up uv - uses: astral-sh/setup-uv@b75a909f75acd358c2196fb9a5f1299a9a8868a4 # v6.7.0 - - name: Set up Python environment - run: uv sync --group lint --no-install-project - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 diff --git a/docs/arch/pass_infra.rst b/docs/arch/pass_infra.rst index b04868e2c6ca..1fb78bcd5254 100644 --- a/docs/arch/pass_infra.rst +++ b/docs/arch/pass_infra.rst @@ -667,4 +667,3 @@ new ``PassInstrument`` are called. .. _src/tirx/transform/unroll_loop.cc: https://github.com/apache/tvm/blob/main/src/tirx/transform/unroll_loop.cc .. _use pass infra: https://github.com/apache/tvm/blob/main/docs/how_to/tutorials/customize_opt.py - diff --git a/docs/conf.py b/docs/conf.py index 74e4b881814a..6bcd1fbbc8a2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -724,9 +724,7 @@ def _dedup_find_obj(self, env, modname, classname, name, objtype, searchmode=0): return context_matches # Fall back to the unique match that best shares the current module prefix. - match_scores = { - match[0]: _common_prefix_len(modname, match[0]) for match in matches - } + match_scores = {match[0]: _common_prefix_len(modname, match[0]) for match in matches} best_score = max(match_scores.values()) if best_score > 1: best_matches = [match for match in matches if match_scores[match[0]] == best_score] diff --git a/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py index 91d1cb9c2633..c3bc95dcc854 100644 --- a/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py +++ b/docs/how_to/tutorials/mix_python_and_tvm_with_pymodule.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# ruff: noqa: E402 """ .. _mix_python_and_tvm: @@ -163,8 +162,10 @@ def forward(self, x, weights): logits = self._convert_tvm_to_pytorch(out) # Inspect intermediate value — impossible with a compiled-only workflow - print(f" [DEBUG] logits shape: {logits.shape}, " - f"min: {logits.min():.4f}, max: {logits.max():.4f}") + print( + f" [DEBUG] logits shape: {logits.shape}, " + f"min: {logits.min():.4f}, max: {logits.max():.4f}" + ) result = F.softmax(logits, dim=-1) @@ -198,12 +199,10 @@ def forward(self, x, weights): # — for example, CUBLAS or cuDNN bindings that TVM wraps as packed functions. if RUN_EXAMPLE: - # Register a packed function (simulating an external library binding) @tvm.register_global_func("my_bias_add", override=True) def my_bias_add(x, bias, out): """Packed function: adds bias to each row of x.""" - import numpy as np x_np = x.numpy() b_np = bias.numpy() @@ -230,14 +229,16 @@ def forward(self, x, weights, bias): x_tvm = self._convert_pytorch_to_tvm(x) w_tvm = self._convert_pytorch_to_tvm(weights) h = self.call_tir( - self.matmul_tir, [x_tvm, w_tvm], + self.matmul_tir, + [x_tvm, w_tvm], out_sinfo=R.Tensor((2, 3), "float32"), ) h_pt = self._convert_tvm_to_pytorch(h) # 2. Packed function for bias add (simulating an external library) h_biased = self.call_dps_packed( - "my_bias_add", [h_pt, bias], + "my_bias_add", + [h_pt, bias], out_sinfo=R.Tensor((2, 3), "float32"), ) @@ -291,7 +292,8 @@ def main( h = R.matmul(x, w) cls = DenseLayer h_bias = R.call_tir( - cls.bias_add_tir, (h, b), + cls.bias_add_tir, + (h, b), out_sinfo=R.Tensor((2, 4), "float32"), ) return R.nn.relu(h_bias) @@ -324,8 +326,7 @@ def main( print("\nAfter CanonicalizeBindings pass:") print(" Converted result:", py_result_late) - print(" Still matches: ", - torch.allclose(py_result_late, expected, atol=1e-5)) + print(" Still matches: ", torch.allclose(py_result_late, expected, atol=1e-5)) assert torch.allclose(py_result_late, expected, atol=1e-5) @@ -363,12 +364,8 @@ def main( x: R.Tensor((4, 8), "float32"), ) -> R.Tensor((4, 8), "float32"): # The VM calls back into Python for these two ops - h = R.call_py_func( - "layer_norm", (x,), out_sinfo=R.Tensor((4, 8), "float32") - ) - out = R.call_py_func( - "silu", (h,), out_sinfo=R.Tensor((4, 8), "float32") - ) + h = R.call_py_func("layer_norm", (x,), out_sinfo=R.Tensor((4, 8), "float32")) + out = R.call_py_func("silu", (h,), out_sinfo=R.Tensor((4, 8), "float32")) return out mod = HybridVMModule(device=tvm.cpu(0)) @@ -390,7 +387,7 @@ def main( # ``BasePyModule`` is designed for **cross-level interoperability**: Python functions can call # TIR and Relax functions, and Relax functions can call Python functions. We have already seen: # -# - Python → TIR via ``call_tir`` (Steps 1–3) +# - Python → TIR via ``call_tir`` (Steps 1-3) # - Python → packed function via ``call_dps_packed`` (Step 3) # - Relax → Python via ``R.call_py_func`` (Step 5) # @@ -441,9 +438,7 @@ def add_relax( # Python → TIR with symbolic output shape n = T.int64() x7 = torch.randn(7) - scaled = mod.call_tir( - "scale_tir", [x7], relax.TensorStructInfo((n,), "float32") - ) + scaled = mod.call_tir("scale_tir", [x7], relax.TensorStructInfo((n,), "float32")) print("scale_tir(len=7):", scaled) assert torch.allclose(torch.tensor(scaled.numpy()), x7 * 2.0, atol=1e-5) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 5c1931c3ee23..45abeb9d5b7e 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -303,11 +303,12 @@ struct Conv3DTransposeAttrs : public AttrsNodeReflAdapter "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" "dimensions respectively. Convolution is applied on the 'D', 'H', and" "'W' dimensions.") - .def_ro("kernel_layout", &Conv3DTransposeAttrs::kernel_layout, - "Dimension ordering of weight. Can be 'IODHW', etc." - "'I', 'O', 'D', 'H', 'W' stands for input_channel, output_channel, depth, height, and " - "width" - "dimensions respectively.") + .def_ro( + "kernel_layout", &Conv3DTransposeAttrs::kernel_layout, + "Dimension ordering of weight. Can be 'IODHW', etc." + "'I', 'O', 'D', 'H', 'W' stands for input_channel, output_channel, depth, height, and " + "width" + "dimensions respectively.") .def_ro("out_layout", &Conv3DTransposeAttrs::out_layout, "Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 15629dcbe61f..b65a241450bf 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -28,11 +28,11 @@ class Node(Object): """Base class of all IR Nodes.""" def __repr__(self) -> str: - from tvm.runtime.script_printer import _script # noqa: PLC0415 + from tvm.runtime.script_printer import _script try: return _script(self, None) - except Exception: # noqa: BLE001 + except Exception: return super().__repr__() diff --git a/python/tvm/relax/backend/contrib/example_npu/__init__.py b/python/tvm/relax/backend/contrib/example_npu/__init__.py index 018997f3228a..a1d484c0fcee 100644 --- a/python/tvm/relax/backend/contrib/example_npu/__init__.py +++ b/python/tvm/relax/backend/contrib/example_npu/__init__.py @@ -26,6 +26,6 @@ constraints, making them available for graph partitioning. """ -from . import patterns # noqa: F401 +from . import patterns __all__ = ["patterns"] diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 3725a84d61f8..1301fa471ff6 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -646,7 +646,9 @@ def __setitem__(self, key: str, param: Parameter) -> None: if not isinstance(key, str): raise TypeError(f"ParameterDict keys must be strings, but got {type(key).__name__}") if not isinstance(param, Parameter): - raise TypeError(f"ParameterDict values must be nn.Parameter, but got {type(param).__name__}") + raise TypeError( + f"ParameterDict values must be nn.Parameter, but got {type(param).__name__}" + ) self.params[key] = param def __len__(self) -> int: @@ -731,7 +733,9 @@ def __getitem__(self, idx: int) -> Parameter: def __setitem__(self, idx: int, param: Parameter) -> None: if not isinstance(param, Parameter): - raise TypeError(f"ParameterList elements must be nn.Parameter, but got {type(param).__name__}") + raise TypeError( + f"ParameterList elements must be nn.Parameter, but got {type(param).__name__}" + ) self.params[idx] = param def __len__(self) -> int: @@ -739,8 +743,10 @@ def __len__(self) -> int: def append(self, param: Parameter) -> None: """Add a parameter to the end of the ParameterList""" - if not isinstance(param, Parameter): - raise TypeError(f"ParameterList elements must be nn.Parameter, but got {type(param).__name__}") + if not isinstance(param, Parameter): + raise TypeError( + f"ParameterList elements must be nn.Parameter, but got {type(param).__name__}" + ) self.params.append(param) def extend(self, params: list[Parameter]) -> None: diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 878f976c9504..560b644de8cc 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -792,9 +792,7 @@ def _legacy_softmax_prepare( return flattened, tuple(original_shape) -def _get_axis_extent( - data: relax.Expr, axis: int, op_name: str -) -> tuple[int, int | tirx.PrimExpr]: +def _get_axis_extent(data: relax.Expr, axis: int, op_name: str) -> tuple[int, int | tirx.PrimExpr]: """Return normalized axis and axis extent when rank/shape are known.""" rank = _get_known_tensor_rank(data) @@ -803,7 +801,9 @@ def _get_axis_extent( normalized_axis = _normalize_constant_axes([axis], rank, op_name)[0] struct_info = data.struct_info - if isinstance(struct_info, relax.TensorStructInfo) and isinstance(struct_info.shape, relax.ShapeExpr): + if isinstance(struct_info, relax.TensorStructInfo) and isinstance( + struct_info.shape, relax.ShapeExpr + ): axis_extent = struct_info.shape.values[normalized_axis] if isinstance(axis_extent, tirx.IntImm): axis_extent = int(axis_extent.value) @@ -881,9 +881,7 @@ def _hardmax_impl(cls, *args): bb = None data, axis = args else: - raise TypeError( - "Hardmax._hardmax_impl expects (bb, data, axis) or (data, axis)." - ) + raise TypeError("Hardmax._hardmax_impl expects (bb, data, axis) or (data, axis).") if bb is not None: data = bb.normalize(data) @@ -1130,7 +1128,7 @@ def _impl_v13(cls, bb, inputs, attr, params): relax.op.take(data_shape_tensor, relax.const(axis, "int64"), axis=0, mode="wrap") ) - if indices_dtype !="int64": + if indices_dtype != "int64": axis_extent = bb.normalize(relax.op.astype(axis_extent, indices_dtype)) indices = bb.normalize( @@ -1182,9 +1180,7 @@ def _get_onnx_reduction(attr, valid_reductions: list[str]): reduction = reduction.decode("utf-8") reduction = "update" if reduction == "none" else reduction if reduction not in valid_reductions: - raise ValueError( - f"Only {valid_reductions} reductions are supported, but got {reduction}" - ) + raise ValueError(f"Only {valid_reductions} reductions are supported, but got {reduction}") return reduction @@ -1775,10 +1771,7 @@ def _impl_v1(cls, bb, inputs, attr, params): pads_end: list[int] = [] for i in range(spatial_dims): total_pad = ( - (kernel_shape[i] - 1) * dilations[i] - + 1 - + output_padding[i] - - strides[i] + (kernel_shape[i] - 1) * dilations[i] + 1 + output_padding[i] - strides[i] ) total_pad = max(total_pad, 0) if auto_pad == "SAME_UPPER": @@ -1844,18 +1837,20 @@ def _impl_v14(cls, bb, inputs, attr, params): else: raise ValueError( "CumSum axis input must be a scalar (0-D) or a single-element 1-D tensor, " - "got shape {}".format(axis_data.shape) + f"got shape {axis_data.shape}" ) elif isinstance(axis_input, relax.Var): - axis_shape = axis_input.struct_info.shape if hasattr(axis_input.struct_info, "shape") else None + axis_shape = ( + axis_input.struct_info.shape if hasattr(axis_input.struct_info, "shape") else None + ) raise ValueError( "CumSum with non-constant axis input is not supported yet. " "ONNX permits runtime axis tensors, but Relax/TE currently requires a compile-time " - "constant axis for cumsum/flip. Got axis shape {}".format(axis_shape) + f"constant axis for cumsum/flip. Got axis shape {axis_shape}" ) else: raise TypeError("CumSum axis input must be a Constant or Var") - + if attr.get("reverse", 0) != 0: data = bb.emit_te(topi.flip, data, axis=axis) @@ -4694,7 +4689,6 @@ def _impl_v11(cls, bb, inputs, attr, params): input_tensor = inputs[0] input_shape = input_tensor.struct_info.shape - split_is_scalar = False if len(inputs) == 1: split = _np.array(1) @@ -4711,7 +4705,7 @@ def _impl_v11(cls, bb, inputs, attr, params): chunk_size = int(split) dim_size = input_shape[axis] - if isinstance(dim_size, (int, tirx.IntImm)): + if isinstance(dim_size, int | tirx.IntImm): dim_size_int = int(dim_size) split = math.ceil(dim_size_int / chunk_size) else: diff --git a/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py b/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py index 5152b6996ecf..148f404f25db 100644 --- a/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py +++ b/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py @@ -110,7 +110,9 @@ def decode_vector(self, end, size, byte_width): value_type = FlexBufferType(value_type_packed >> 2) value_bit_width = BitWidth(value_type_packed & 3) value_byte_width = 1 << value_bit_width - value_bytes = self.buffer[end + i * byte_width : end + i * byte_width + value_byte_width] + value_bytes = self.buffer[ + end + i * byte_width : end + i * byte_width + value_byte_width + ] if value_type == FlexBufferType.FBT_BOOL: value = bool(value_bytes[0]) elif value_type == FlexBufferType.FBT_INT: diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 0b71990c90a7..145e953394cd 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -240,21 +240,15 @@ def __init__(self, model, subgraph, exp_tab, ctx): "SQRT": functools.partial(self._convert_unary_elemwise, relax_op=_op.sqrt), "SQUARE": self.convert_square, "SQUARED_DIFFERENCE": self.convert_squared_difference, - "STABLEHLO_ABS": functools.partial( - self._convert_stablehlo_unary, relax_op=_op.abs - ), - "STABLEHLO_ADD": functools.partial( - self._convert_stablehlo_binary, relax_op=_op.add - ), + "STABLEHLO_ABS": functools.partial(self._convert_stablehlo_unary, relax_op=_op.abs), + "STABLEHLO_ADD": functools.partial(self._convert_stablehlo_binary, relax_op=_op.add), "STABLEHLO_AND": self._convert_stablehlo_and, "STABLEHLO_BROADCAST_IN_DIM": self._convert_stablehlo_broadcast_in_dim, "STABLEHLO_CLAMP": self._convert_stablehlo_clamp, "STABLEHLO_COMPARE": self._convert_stablehlo_compare, "STABLEHLO_CONCATENATE": self._convert_stablehlo_concatenate, "STABLEHLO_CONVERT": self._convert_stablehlo_convert, - "STABLEHLO_COSINE": functools.partial( - self._convert_stablehlo_unary, relax_op=_op.cos - ), + "STABLEHLO_COSINE": functools.partial(self._convert_stablehlo_unary, relax_op=_op.cos), "STABLEHLO_DIVIDE": functools.partial( self._convert_stablehlo_binary, relax_op=_op.divide ), @@ -262,14 +256,10 @@ def __init__(self, model, subgraph, exp_tab, ctx): "STABLEHLO_EXPONENTIAL": functools.partial( self._convert_stablehlo_unary, relax_op=_op.exp ), - "STABLEHLO_FLOOR": functools.partial( - self._convert_stablehlo_unary, relax_op=_op.floor - ), + "STABLEHLO_FLOOR": functools.partial(self._convert_stablehlo_unary, relax_op=_op.floor), "STABLEHLO_GATHER": self._convert_stablehlo_gather, "STABLEHLO_IOTA": self._convert_stablehlo_iota, - "STABLEHLO_LOG": functools.partial( - self._convert_stablehlo_unary, relax_op=_op.log - ), + "STABLEHLO_LOG": functools.partial(self._convert_stablehlo_unary, relax_op=_op.log), "STABLEHLO_LOGISTIC": functools.partial( self._convert_stablehlo_unary, relax_op=_op.sigmoid ), @@ -290,9 +280,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): "STABLEHLO_POWER": functools.partial( self._convert_stablehlo_binary, relax_op=_op.power ), - "STABLEHLO_RSQRT": functools.partial( - self._convert_stablehlo_unary, relax_op=_op.rsqrt - ), + "STABLEHLO_RSQRT": functools.partial(self._convert_stablehlo_unary, relax_op=_op.rsqrt), "STABLEHLO_SELECT": functools.partial( self._convert_stablehlo_ternary, relax_op=_op.where ), @@ -302,9 +290,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): "STABLEHLO_SUBTRACT": functools.partial( self._convert_stablehlo_binary, relax_op=_op.subtract ), - "STABLEHLO_TANH": functools.partial( - self._convert_stablehlo_unary, relax_op=_op.tanh - ), + "STABLEHLO_TANH": functools.partial(self._convert_stablehlo_unary, relax_op=_op.tanh), "SQUEEZE": self.convert_squeeze, "STRIDED_SLICE": self.convert_strided_slice, "SUB": functools.partial(self._convert_elemwise, relax_op=_op.subtract), @@ -631,7 +617,9 @@ def _get_shape_expr_from_tensor(self, shape_tensor, prefix): dims_expr = self.get_expr(shape_tensor.tensor_idx) dims_ndim = int(self.get_tensor_shape(shape_tensor)[0]) dims_dtype = self.get_tensor_type_str(shape_tensor.tensor.Type()) - dims_expr = self.bb.match_cast(dims_expr, relax.TensorStructInfo([dims_ndim], dims_dtype)) + dims_expr = self.bb.match_cast( + dims_expr, relax.TensorStructInfo([dims_ndim], dims_dtype) + ) dims_expr = self.bb.normalize(relax.op.astype(dims_expr, "int64")) shape_dataflow_var = self.bb.emit(relax.op.tensor_to_shape(dims_expr)) shape_vars = [tirx.Var(f"{prefix}_{i}", "int64") for i in range(dims_ndim)] @@ -969,7 +957,9 @@ def convert_lrn(self, op): ) pooled = self.bb.normalize(_op.reshape(pooled, data_shape)) denom = relax.op.power( - relax.op.add(relax.const(bias, in_type), relax.op.multiply(relax.const(alpha, in_type), pooled)), + relax.op.add( + relax.const(bias, in_type), relax.op.multiply(relax.const(alpha, in_type), pooled) + ), relax.const(beta, in_type), ) out = relax.op.divide(in_expr, denom) @@ -1062,7 +1052,8 @@ def get_scalar_value(tensor): # relax.op.arange currently expects scalar-like values here. # Keep dynamic scalar RANGE explicit until frontend support is added. raise tvm.error.OpNotImplemented( - "TFLite RANGE with dynamic scalar inputs is not supported in Relax frontend yet." + "TFLite RANGE with dynamic scalar inputs is not supported in" + "Relax frontend yet." ) else: value = self.get_tensor_value(tensor) @@ -1074,7 +1065,7 @@ def get_scalar_value(tensor): start_value = get_scalar_value(start) limit_value = get_scalar_value(limit) delta_value = get_scalar_value(delta) - + # out type inference if delta.tensor.Type() == TensorType.FLOAT32: out_type = self.get_tensor_type_str(delta.tensor.Type()) @@ -1434,9 +1425,7 @@ def _convert_stablehlo_and(self, op): elif dtype.startswith(("int", "uint")): op_fn = _op.bitwise_and else: - raise tvm.error.OpNotImplemented( - f"STABLEHLO_AND with dtype {dtype} is not supported" - ) + raise tvm.error.OpNotImplemented(f"STABLEHLO_AND with dtype {dtype} is not supported") return self.bb.normalize(op_fn(lhs, rhs)) def _convert_stablehlo_or(self, op): @@ -1454,9 +1443,7 @@ def _convert_stablehlo_or(self, op): elif dtype.startswith(("int", "uint")): op_fn = _op.bitwise_or else: - raise tvm.error.OpNotImplemented( - f"STABLEHLO_OR with dtype {dtype} is not supported" - ) + raise tvm.error.OpNotImplemented(f"STABLEHLO_OR with dtype {dtype} is not supported") return self.bb.normalize(op_fn(lhs, rhs)) def _convert_stablehlo_ternary(self, op, relax_op): @@ -1681,9 +1668,7 @@ def _convert_stablehlo_pad(self, op): for lo, hi in zip(edge_low, edge_high): pad_width.extend([lo, hi]) - return self.bb.normalize( - relax.op.nn.pad(operand, pad_width=pad_width, pad_value=pad_val) - ) + return self.bb.normalize(relax.op.nn.pad(operand, pad_width=pad_width, pad_value=pad_val)) def _convert_stablehlo_dynamic_slice(self, op): """Convert STABLEHLO_DYNAMIC_SLICE to Relax (dynamic_strided_slice). @@ -1732,10 +1717,7 @@ def _const_1d(values, dtype="int64"): end = _const_1d(end_vals) strides = _const_1d(stride_vals) - return self.bb.normalize( - relax.op.dynamic_strided_slice(operand, begin, end, strides) - ) - + return self.bb.normalize(relax.op.dynamic_strided_slice(operand, begin, end, strides)) def _convert_stablehlo_gather(self, op): """Convert STABLEHLO_GATHER to Relax (take-equivalent subset only). @@ -1775,9 +1757,7 @@ def _convert_stablehlo_gather(self, op): "STABLEHLO_GATHER only supports collapsed_slice_dims matching the gather axis" ) if len(slice_sizes) != len(data_shape): - raise tvm.error.OpNotImplemented( - "STABLEHLO_GATHER slice_sizes must match operand rank" - ) + raise tvm.error.OpNotImplemented("STABLEHLO_GATHER slice_sizes must match operand rank") for i, (size, dim) in enumerate(zip(slice_sizes, data_shape)): expected = 1 if i == axis else dim if size != expected: @@ -1789,9 +1769,7 @@ def _convert_stablehlo_gather(self, op): "STABLEHLO_GATHER only supports trailing index_vector_dim" ) if not indices_shape or indices_shape[index_vector_dim] != 1: - raise tvm.error.OpNotImplemented( - "STABLEHLO_GATHER only supports index vector size 1" - ) + raise tvm.error.OpNotImplemented("STABLEHLO_GATHER only supports index vector size 1") indices_batch_shape = indices_shape[:index_vector_dim] expected_offset_dims = list(range(axis)) + list( @@ -1802,9 +1780,7 @@ def _convert_stablehlo_gather(self, op): "STABLEHLO_GATHER offset_dims do not match Relax take output layout" ) - expected_output_shape = ( - data_shape[:axis] + indices_batch_shape + data_shape[axis + 1 :] - ) + expected_output_shape = data_shape[:axis] + indices_batch_shape + data_shape[axis + 1 :] if output_shape != expected_output_shape: raise tvm.error.OpNotImplemented( "STABLEHLO_GATHER output shape does not match Relax take semantics" @@ -1815,7 +1791,6 @@ def _convert_stablehlo_gather(self, op): indices = self.bb.normalize(relax.op.reshape(indices, indices_batch_shape)) return self.bb.normalize(relax.op.take(data, indices, axis=axis, mode="fast")) - def convert_elu(self, op): """Convert TFLite ELU""" input_tensors = self.get_input_tensors(op) @@ -1959,7 +1934,7 @@ def convert_add_n(self, op): rhs_expr = self.get_tensor_expr(rhs_tensor) lhs_expr = relax.op.add(lhs_expr, rhs_expr) return lhs_expr - + def convert_cumsum(self, op): """Convert TFLite CUMSUM""" if self.is_quantized(op): @@ -1972,7 +1947,7 @@ def convert_cumsum(self, op): input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" - + input_expr = self.get_tensor_expr(input_tensors[0]) if self.has_expr(input_tensors[1].tensor_idx): @@ -1993,7 +1968,7 @@ def convert_cumsum(self, op): raise tvm.error.OpNotImplemented( "The TFLite to Relax converter does not support reverse CUMSUM operator yet." ) - + output_tensors = self.get_output_tensors(op) assert len(output_tensors) == 1, "output tensors length should be 1" @@ -2954,7 +2929,7 @@ def convert_conv3d(self, op): input_tensors = self.get_input_tensors(op) assert len(input_tensors) >= 2, "input tensors length should be >= 2" - + input_tensor = input_tensors[0] input_tensor_idx = input_tensor.tensor_idx weight_tensor = input_tensors[1] @@ -3023,8 +2998,7 @@ def convert_conv3d(self, op): weight_value = self.get_tensor_value(weight_tensor) weight_expr = self.exp_tab.new_const( - weight_value, dtype=weight_tensor_type_str, - source_name=weight_tensor.tensor.Name() + weight_value, dtype=weight_tensor_type_str, source_name=weight_tensor.tensor.Name() ) if padding == Padding.VALID: @@ -3035,9 +3009,12 @@ def convert_conv3d(self, op): pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w) do_pad = not ( - pad_front == 0 and pad_back == 0 - and pad_top == 0 and pad_bottom == 0 - and pad_left == 0 and pad_right == 0 + pad_front == 0 + and pad_back == 0 + and pad_top == 0 + and pad_bottom == 0 + and pad_left == 0 + and pad_right == 0 ) if do_pad: params["padding"] = [pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right] @@ -3163,8 +3140,7 @@ def convert_conv3d_transpose(self, op): weight_value = self.get_tensor_value(weight_tensor) weight_expr = self.exp_tab.new_const( - weight_value, dtype=weight_tensor_type_str, - source_name=weight_tensor.tensor.Name() + weight_value, dtype=weight_tensor_type_str, source_name=weight_tensor.tensor.Name() ) if padding == Padding.VALID: @@ -3297,9 +3273,7 @@ def convert_split_v(self, op): outputs = [] for i in range(num_splits): - start_val = relax.op.strided_slice( - padded_cumsum, axes=[0], begin=[i], end=[i + 1] - ) + start_val = relax.op.strided_slice(padded_cumsum, axes=[0], begin=[i], end=[i + 1]) end_val = relax.op.strided_slice( padded_cumsum, axes=[0], begin=[i + 1], end=[i + 2] ) @@ -3403,7 +3377,7 @@ def _get_segment_num_segments(self, op_name, input_tensors): raise tvm.error.OpNotImplemented( "TFLite SEGMENT_SUM with runtime segment_ids is not supported, " "because TFLite does not encode a reliable output segment count." - ) + ) segment_ids = self.get_tensor_value(segment_ids_tensor) if np.any(segment_ids < 0): raise tvm.error.OpNotImplemented( @@ -4563,7 +4537,7 @@ def convert_dilate(self, op): dilations_tensor = input_tensors[1] padding_expr = self.get_tensor_expr(input_tensors[2]) - # Runtime dilations bind tensor values to TIR Vars for symbolic + # Runtime dilations bind tensor values to TIR Vars for symbolic # per-axis math. if self.has_expr(dilations_tensor.tensor_idx): dilations_expr = self.get_expr(dilations_tensor.tensor_idx) @@ -4980,9 +4954,7 @@ def convert_nms_v5(self, op): if soft_nms_sigma > 0.0: # Extract decayed scores from the processed data (score_index=0) - selected_scores = relax.op.strided_slice( - processed_data, axes=[1], begin=[0], end=[1] - ) + selected_scores = relax.op.strided_slice(processed_data, axes=[1], begin=[0], end=[1]) selected_scores = relax.op.squeeze(selected_scores, axis=[1]) selected_scores = relax.op.strided_slice( selected_scores, axes=[0], begin=[0], end=[max_output_size] @@ -5126,11 +5098,20 @@ def convert_matrix_set_diag(self, op): output_shape = to_int_list(self.get_tensor_shape(output_tensor)) output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type()) - # topi.matrix_set_diag(input, diagonal, k1, k2, super_diag_right_align, sub_diag_right_align) + # topi.matrix_set_diag( + # input, diagonal, k1, k2, super_diag_right_align, sub_diag_right_align + # ) # TFLite MATRIX_SET_DIAG only sets the main diagonal, so k1=0, k2=0 out = relax.op.call_dps_packed( "topi.matrix_set_diag", - (input_expr, diagonal_expr, relax.const(0), relax.const(0), relax.const(False), relax.const(False)), + ( + input_expr, + diagonal_expr, + relax.const(0), + relax.const(0), + relax.const(False), + relax.const(False), + ), out_sinfo=relax.TensorStructInfo(output_shape, output_dtype), ) return out @@ -5158,11 +5139,20 @@ def convert_matrix_diag(self, op): diagonal_expr = self.get_tensor_expr(diagonal) zeros_expr = relax.op.zeros(output_shape, output_dtype) - # topi.matrix_set_diag(input, diagonal, k1, k2, super_diag_right_align, sub_diag_right_align) + # topi.matrix_set_diag( + # input, diagonal, k1, k2, super_diag_right_align, sub_diag_right_align + # ) # TFLite MATRIX_DIAG only sets the main diagonal, so k1=0, k2=0 out = relax.op.call_dps_packed( "topi.matrix_set_diag", - (zeros_expr, diagonal_expr, relax.const(0), relax.const(0), relax.const(False), relax.const(False)), + ( + zeros_expr, + diagonal_expr, + relax.const(0), + relax.const(0), + relax.const(False), + relax.const(False), + ), out_sinfo=relax.TensorStructInfo(output_shape, output_dtype), ) return out @@ -5271,9 +5261,7 @@ def get_tensor_expr(self, tensor, is_sparse=False): type_str = self.get_tensor_type_str(tensor.tensor.Type()) value = self.get_tensor_value_or_prefetched(tensor, is_sparse) - return self.exp_tab.new_const( - value, dtype=type_str, source_name=tensor.tensor.Name() - ) + return self.exp_tab.new_const(value, dtype=type_str, source_name=tensor.tensor.Name()) def get_tensor_shape(self, tensor_wrapper): """Returns tensor shape. Infers shape if the shape is empty.""" diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 89c91e37735d..e9bddc4500bb 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1650,9 +1650,10 @@ def _var(self, node: fx.Node) -> relax.Var: # Legacy fx `tensor.var(...)` calls go through the original path # below to keep this fix narrowly scoped. target = node.target - if getattr(target, "_overloadname", None) == "correction" or getattr( - target, "overload_name", None - ) == "correction": + if ( + getattr(target, "_overloadname", None) == "correction" + or getattr(target, "overload_name", None) == "correction" + ): return self._var_correction(node) args = self.retrieve_args(node) x = args[0] @@ -1674,8 +1675,7 @@ def _var_correction(self, node: fx.Node) -> relax.Var: n = self._reduction_size(x, dim) if n is None: raise NotImplementedError( - "var/std with non-zero correction requires statically known " - "reduction-axis sizes." + "var/std with non-zero correction requires statically known reduction-axis sizes." ) # PyTorch returns NaN (with a warning) when `n - correction <= 0`; # mirror that semantics rather than failing the import. @@ -1698,7 +1698,7 @@ def _reduction_size(x: relax.Expr, dim) -> int | None: axes = list(range(rank)) elif isinstance(dim, int): axes = [dim] - elif isinstance(dim, (list, tuple)) and all(isinstance(a, int) for a in dim): + elif isinstance(dim, list | tuple) and all(isinstance(a, int) for a in dim): axes = list(dim) else: return None diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 5bd2c785f205..596dc60f555e 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1140,9 +1140,7 @@ def _affine_grid_generator(self, node: fx.Node) -> relax.Var: target_w = size[3] # Relax affine_grid outputs [N, 2, H, W] - grid = self.block_builder.emit( - relax.op.image.affine_grid(theta, (target_h, target_w)) - ) + grid = self.block_builder.emit(relax.op.image.affine_grid(theta, (target_h, target_w))) # Permute to PyTorch convention [N, H, W, 2] return self.block_builder.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 1])) @@ -1361,10 +1359,7 @@ def _visit(value): # Preserve explicit None outputs as Relax null objects. flattened.append(relax.op.null_value()) else: - raise ValueError( - "Unsupported output type in exported graph output: " - f"{type(value)}" - ) + raise ValueError(f"Unsupported output type in exported graph output: {type(value)}") _visit(output_args) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index c81768f6d946..d4dd6902ae54 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -564,7 +564,7 @@ def _interpolate(self, node: fx.Node) -> relax.Var: layout_3d = "NDHWC" else: layout_3d = "NCDHW" - + return self.block_builder.emit( relax.op.image.resize3d( data, diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index c31a9744022a..6755782fdab4 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -587,7 +587,8 @@ def conv3d_transpose( See Also -------- conv3d : Forward 3D convolution (default ``OIDHW`` weights vs. ``IODHW`` here). - conv2d_transpose : 2D analogue; legalization supports the same TOPI subset (canonical layout, dilation 1). + conv2d_transpose : 2D analogue; legalization supports the same TOPI subset + (canonical layout, dilation 1). Returns ------- diff --git a/python/tvm/relax/transform/legalize_ops/image.py b/python/tvm/relax/transform/legalize_ops/image.py index 19431a2731aa..42cb72f7c988 100644 --- a/python/tvm/relax/transform/legalize_ops/image.py +++ b/python/tvm/relax/transform/legalize_ops/image.py @@ -57,10 +57,9 @@ def _image_grid_sample(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.image.affine_grid") def _image_affine_grid(bb: BlockBuilder, call: Call) -> Expr: for v in call.args[1].values: - if not isinstance(v, (int, tirx.IntImm)): + if not isinstance(v, int | tirx.IntImm): raise ValueError( - "affine_grid legalization requires static target_shape, " - f"got symbolic value: {v}" + f"affine_grid legalization requires static target_shape, got symbolic value: {v}" ) target_shape = [int(v) for v in call.args[1].values] return bb.call_te( diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 157ec8b148cf..c0b7b166d1e3 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -222,7 +222,8 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr: @register_legalize("relax.nn.conv3d_transpose") def _nn_conv3d_transpose(bb: BlockBuilder, call: Call) -> Expr: - # Keep policy in sync with _nn_conv2d_transpose: only lower when TOPI supports the layout/dilation. + # Keep policy in sync with _nn_conv2d_transpose: only lower when TOPI supports + # the layout/dilation. if call.attrs.out_layout != call.attrs.data_layout: logging.info( "TOPI conv3d_transpose does not support different input-output " diff --git a/python/tvm/relax/transform/legalize_ops/qdq.py b/python/tvm/relax/transform/legalize_ops/qdq.py index 5e28d1b29105..aa86f6fca2c3 100644 --- a/python/tvm/relax/transform/legalize_ops/qdq.py +++ b/python/tvm/relax/transform/legalize_ops/qdq.py @@ -17,7 +17,6 @@ # pylint: disable=invalid-name """Default legalization function for quantize/dequantize operators.""" -from typing import Union import tvm from tvm import te, tirx @@ -59,8 +58,8 @@ def _quantize(bb: BlockBuilder, call: Call) -> Expr: def te_quantize( data: te.Tensor, - scale: Union[te.Tensor, tirx.IntImm, tirx.FloatImm], - zp: Union[te.Tensor, tirx.IntImm, tirx.FloatImm], + scale: te.Tensor | tirx.IntImm | tirx.FloatImm, + zp: te.Tensor | tirx.IntImm | tirx.FloatImm, ): scale_singleton = _is_singleton_qparam(scale) if isinstance(scale, te.Tensor) else False zp_singleton = _is_singleton_qparam(zp) if isinstance(zp, te.Tensor) else False @@ -121,8 +120,8 @@ def _dequantize(bb: BlockBuilder, call: Call) -> Expr: def te_dequantize( data: te.Tensor, - scale: Union[te.Tensor, tirx.IntImm, tirx.FloatImm], - zp: Union[te.Tensor, tirx.IntImm, tirx.FloatImm], + scale: te.Tensor | tirx.IntImm | tirx.FloatImm, + zp: te.Tensor | tirx.IntImm | tirx.FloatImm, ): scale_singleton = _is_singleton_qparam(scale) if isinstance(scale, te.Tensor) else False zp_singleton = _is_singleton_qparam(zp) if isinstance(zp, te.Tensor) else False diff --git a/python/tvm/s_tir/dlight/analysis/common_analysis.py b/python/tvm/s_tir/dlight/analysis/common_analysis.py index ec7a025c54ec..5a05f46a08a8 100644 --- a/python/tvm/s_tir/dlight/analysis/common_analysis.py +++ b/python/tvm/s_tir/dlight/analysis/common_analysis.py @@ -20,7 +20,6 @@ """Analysis on TIR blocks, loops and functions.""" import logging - from collections import namedtuple from typing import Literal @@ -34,6 +33,7 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name + class IterInfo: """Information about a loop/iter var.""" @@ -373,6 +373,7 @@ def get_max_threads_per_block(target: Target) -> int: "vulkan": 16384, } + def get_max_shared_memory_per_block(target: Target) -> int: _assert_gpu_target(target) max_shared_memory_per_block = target.attrs.get("max_shared_memory_per_block", None) diff --git a/python/tvm/s_tir/meta_schedule/relax_integration.py b/python/tvm/s_tir/meta_schedule/relax_integration.py index 0cd19b0aad8a..c8a2e0e248f8 100644 --- a/python/tvm/s_tir/meta_schedule/relax_integration.py +++ b/python/tvm/s_tir/meta_schedule/relax_integration.py @@ -463,7 +463,8 @@ def compile_relax( @tvm.transform.module_pass(opt_level=3) def _ms_pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: - fuse_seq = dispatch_passes + [ + fuse_seq = [ + *dispatch_passes, relax.transform.LegalizeOps(enable_warning=enable_warning), relax.transform.AnnotateTIROpPattern(), relax.transform.FoldConstant(), diff --git a/python/tvm/topi/testing/get_valid_counts_python.py b/python/tvm/topi/testing/get_valid_counts_python.py index 2caab6babc9d..7f0591901edd 100644 --- a/python/tvm/topi/testing/get_valid_counts_python.py +++ b/python/tvm/topi/testing/get_valid_counts_python.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Numpy reference implementation for get_valid_counts.""" + import numpy as np diff --git a/python/tvm/topi/testing/nms_python.py b/python/tvm/topi/testing/nms_python.py index c8711c70dde2..1b0d613f47a9 100644 --- a/python/tvm/topi/testing/nms_python.py +++ b/python/tvm/topi/testing/nms_python.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Numpy reference implementation for classic non_max_suppression.""" + import numpy as np diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index 23ead47ae670..7dc416b272d2 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -188,8 +188,8 @@ def get_const_tuple(in_tuple): """ if isinstance(in_tuple, te.tensor.Tensor): raise TypeError( - f"get_const_tuple expects a tuple-like shape (e.g., tensor.shape), " - f"but got a te.Tensor. Did you mean get_const_tuple(tensor.shape)?" + "get_const_tuple expects a tuple-like shape (e.g., tensor.shape), " + "but got a te.Tensor. Did you mean get_const_tuple(tensor.shape)?" ) ret = [] ana = None diff --git a/python/tvm/topi/vision/multibox_transform_loc.py b/python/tvm/topi/vision/multibox_transform_loc.py index ab965e798141..e6816d8eec08 100644 --- a/python/tvm/topi/vision/multibox_transform_loc.py +++ b/python/tvm/topi/vision/multibox_transform_loc.py @@ -74,14 +74,14 @@ def multibox_transform_loc( th = tvm.tirx.const(float(threshold), dtype) def decode_bbox(b, a, k): - l = anchor[0, a, 0] - t = anchor[0, a, 1] - r = anchor[0, a, 2] - br = anchor[0, a, 3] - ay = (t + br) * half - ax = (l + r) * half - ah = br - t - aw = r - l + left = anchor[0, a, 0] + top = anchor[0, a, 1] + right = anchor[0, a, 2] + bottom = anchor[0, a, 3] + ay = (top + bottom) * half + ax = (left + right) * half + ah = bottom - top + aw = right - left ex = loc_reshaped[b, a, 0] ey = loc_reshaped[b, a, 1] ew = loc_reshaped[b, a, 2] diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index ad548978a186..9ac20869bde0 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -123,9 +123,7 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): out_tensor_buf = tvm.tirx.decl_buffer( (batch_size, num_anchors, box_data_length), data.dtype, "out_tensor" ) - out_indices_buf = tvm.tirx.decl_buffer( - (batch_size, num_anchors), "int32", "out_indices" - ) + out_indices_buf = tvm.tirx.decl_buffer((batch_size, num_anchors), "int32", "out_indices") if is_score_threshold_tensor: score_thresh_buf = tvm.tirx.decl_buffer( @@ -135,8 +133,13 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): [(batch_size,), (batch_size, num_anchors, box_data_length), (batch_size, num_anchors)], [data, score_threshold], lambda ins, outs: _get_valid_counts_ir( - ins[0], ins[1], id_index_const, score_index_const, - outs[0], outs[1], outs[2], + ins[0], + ins[1], + id_index_const, + score_index_const, + outs[0], + outs[1], + outs[2], ), dtype=["int32", data.dtype, "int32"], out_buffers=[valid_count_buf, out_tensor_buf, out_indices_buf], @@ -151,8 +154,13 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): # score_threshold is a TIR constant, not a tensor def _ir_with_const_threshold(ins, outs): return _get_valid_counts_ir( - ins[0], score_threshold, id_index_const, score_index_const, - outs[0], outs[1], outs[2], + ins[0], + score_threshold, + id_index_const, + score_index_const, + outs[0], + outs[1], + outs[2], ) valid_count, out_tensor, out_indices = te.extern( @@ -318,9 +326,9 @@ def compute_iou(lhs_idx, rhs_idx): with T.If(best_idx[0] != num_valid_boxes[0]): with T.Then(): tmp_idx[0] = out_box_indices[i, num_valid_boxes[0]] - out_box_indices[ - i, num_valid_boxes[0] - ] = out_box_indices[i, best_idx[0]] + out_box_indices[i, num_valid_boxes[0]] = ( + out_box_indices[i, best_idx[0]] + ) out_box_indices[i, best_idx[0]] = tmp_idx[0] with T.serial(0, box_data_length) as k: @@ -362,9 +370,7 @@ def compute_iou(lhs_idx, rhs_idx): out_data[i, j, score_index] = ( out_data[i, j, score_index] * tvm.tirx.exp( - soft_nms_scale - * iou - * iou + soft_nms_scale * iou * iou ) ) with T.If( @@ -372,9 +378,9 @@ def compute_iou(lhs_idx, rhs_idx): <= thresh ): with T.Then(): - out_box_indices[ - i, j - ] = T.int32(-1) + out_box_indices[i, j] = ( + T.int32(-1) + ) num_valid_boxes[0] = num_valid_boxes[0] + 1 @@ -389,9 +395,7 @@ def compute_iou(lhs_idx, rhs_idx): with T.If(j >= num_valid_boxes[0]): with T.Then(): with T.serial(0, box_data_length) as k: - out_data[i, j, k] = tvm.tirx.Cast( - data.dtype, T.float32(-1.0) - ) + out_data[i, j, k] = tvm.tirx.Cast(data.dtype, T.float32(-1.0)) out_box_indices[i, j] = T.int32(-1) else: with T.serial(0, num_anchors) as j: @@ -552,7 +556,7 @@ def non_max_suppression( if isinstance(max_output_size, int): max_output_size = tvm.tirx.const(max_output_size, dtype="int32") - if isinstance(iou_threshold, (float, int)): + if isinstance(iou_threshold, float | int): iou_threshold = tvm.tirx.const(iou_threshold, dtype=data.dtype) # Sort by score @@ -581,14 +585,26 @@ def non_max_suppression( [data.shape, (batch_size, num_anchors), (batch_size, 1)], [data, sort_tensor, valid_count, indices], lambda ins, outs: _classic_nms_ir( - ins[0], ins[1], ins[2], ins[3], - batch_size, num_anchors, box_data_length, - max_output_size, iou_threshold, - force_suppress, top_k, - coord_start, score_index, id_index, + ins[0], + ins[1], + ins[2], + ins[3], + batch_size, + num_anchors, + box_data_length, + max_output_size, + iou_threshold, + force_suppress, + top_k, + coord_start, + score_index, + id_index, return_indices, - outs[0], outs[1], outs[2], - soft_nms_sigma, score_threshold, + outs[0], + outs[1], + outs[2], + soft_nms_sigma, + score_threshold, ), dtype=[data.dtype, "int32", "int32"], out_buffers=[out_data_buf, out_box_indices_buf, out_valid_box_count_buf], @@ -604,14 +620,26 @@ def non_max_suppression( [data.shape, (batch_size, num_anchors)], [data, sort_tensor, valid_count, indices], lambda ins, outs: _classic_nms_ir( - ins[0], ins[1], ins[2], ins[3], - batch_size, num_anchors, box_data_length, - max_output_size, iou_threshold, - force_suppress, top_k, - coord_start, score_index, id_index, + ins[0], + ins[1], + ins[2], + ins[3], + batch_size, + num_anchors, + box_data_length, + max_output_size, + iou_threshold, + force_suppress, + top_k, + coord_start, + score_index, + id_index, return_indices, - outs[0], outs[1], None, - soft_nms_sigma, score_threshold, + outs[0], + outs[1], + None, + soft_nms_sigma, + score_threshold, ), dtype=[data.dtype, "int32"], out_buffers=[out_data_buf, out_box_indices_buf], @@ -644,9 +672,7 @@ def _rearrange_ir(ins, outs): valid_idx[0] = T.int32(0) with T.serial(0, num_anchors) as j: - with T.If( - data[i, j, score_index] >= tvm.tirx.Cast(data.dtype, T.float32(0.0)) - ): + with T.If(data[i, j, score_index] >= tvm.tirx.Cast(data.dtype, T.float32(0.0))): with T.Then(): with T.serial(0, box_data_length) as k: out[i, valid_idx[0], k] = data[i, j, k] diff --git a/python/tvm/topi/vision/nms_util.py b/python/tvm/topi/vision/nms_util.py index f9bb460bc840..b9f02ab982b1 100644 --- a/python/tvm/topi/vision/nms_util.py +++ b/python/tvm/topi/vision/nms_util.py @@ -308,7 +308,7 @@ def _all_class_nms_ir( if selected_scores is not None: selected_scores = T.buffer_proxy(selected_scores) - if isinstance(iou_threshold, (float, int)): + if isinstance(iou_threshold, float | int): iou_threshold = tvm.tirx.FloatImm("float32", float(iou_threshold)) elif isinstance(iou_threshold, te.Tensor): if len(iou_threshold.shape) == 0: diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h index 31b4cc292762..c7bd5061217b 100644 --- a/src/relax/ir/emit_te.h +++ b/src/relax/ir/emit_te.h @@ -43,8 +43,7 @@ class RXPlaceholderOpNode : public te::PlaceholderOpNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("value", &RXPlaceholderOpNode::value); + refl::ObjectDef().def_ro("value", &RXPlaceholderOpNode::value); } // FFI system configuration for structural equality and hashing diff --git a/src/relax/op/vision/multibox_transform_loc.cc b/src/relax/op/vision/multibox_transform_loc.cc index e01e569b78f0..13855cbd6625 100644 --- a/src/relax/op/vision/multibox_transform_loc.cc +++ b/src/relax/op/vision/multibox_transform_loc.cc @@ -188,9 +188,10 @@ StructInfo InferStructInfoMultiboxTransformLoc(const Call& call, const BlockBuil } TVM_REGISTER_OP("relax.vision.multibox_transform_loc") - .describe("Decode SSD/TFLite-style priors and offsets into boxes and softmax scores. If " - "cls_pred shape is unknown, N-based loc/anchor shape checks are skipped in " - "inference. Very large variances (w,h) can overflow exp in half box sizes.") + .describe( + "Decode SSD/TFLite-style priors and offsets into boxes and softmax scores. If " + "cls_pred shape is unknown, N-based loc/anchor shape checks are skipped in " + "inference. Very large variances (w,h) can overflow exp in half box sizes.") .set_attrs_type() .set_num_inputs(3) .add_argument("cls_pred", "Tensor", "[B,C,N] class logits (pre-softmax).") diff --git a/src/runtime/hexagon/hexagon_common.h b/src/runtime/hexagon/hexagon_common.h index 7ffc4457192a..acd9b6b3b70f 100644 --- a/src/runtime/hexagon/hexagon_common.h +++ b/src/runtime/hexagon/hexagon_common.h @@ -24,9 +24,9 @@ #define TVM_RUNTIME_HEXAGON_HEXAGON_COMMON_H_ #include +#include #include #include -#include #if defined(__hexagon__) #include diff --git a/src/runtime/hexagon/hexagon_thread_manager.cc b/src/runtime/hexagon/hexagon_thread_manager.cc index 76e57c67e8a1..8c82325661f1 100644 --- a/src/runtime/hexagon/hexagon_thread_manager.cc +++ b/src/runtime/hexagon/hexagon_thread_manager.cc @@ -18,6 +18,7 @@ */ #include "hexagon_thread_manager.h" + #include namespace tvm { diff --git a/src/runtime/hexagon/hexagon_thread_manager.h b/src/runtime/hexagon/hexagon_thread_manager.h index c02e23f29c34..09d13c2b5f78 100644 --- a/src/runtime/hexagon/hexagon_thread_manager.h +++ b/src/runtime/hexagon/hexagon_thread_manager.h @@ -20,9 +20,9 @@ #ifndef TVM_RUNTIME_HEXAGON_HEXAGON_THREAD_MANAGER_H_ #define TVM_RUNTIME_HEXAGON_HEXAGON_THREAD_MANAGER_H_ +#include #include #include -#include #include #include diff --git a/src/runtime/hexagon/hexagon_vtcm_pool.cc b/src/runtime/hexagon/hexagon_vtcm_pool.cc index f96ba975da0d..5516c7825640 100644 --- a/src/runtime/hexagon/hexagon_vtcm_pool.cc +++ b/src/runtime/hexagon/hexagon_vtcm_pool.cc @@ -17,6 +17,7 @@ * under the License. */ #include "hexagon_vtcm_pool.h" + #include #include "HAP_compute_res.h" diff --git a/src/runtime/hexagon/hexagon_vtcm_pool.h b/src/runtime/hexagon/hexagon_vtcm_pool.h index 5159c458c8d6..cef9cbcaad12 100644 --- a/src/runtime/hexagon/hexagon_vtcm_pool.h +++ b/src/runtime/hexagon/hexagon_vtcm_pool.h @@ -20,10 +20,10 @@ #ifndef TVM_RUNTIME_HEXAGON_HEXAGON_VTCM_POOL_H_ #define TVM_RUNTIME_HEXAGON_HEXAGON_VTCM_POOL_H_ +#include #include #include #include -#include #include #include diff --git a/src/runtime/memory/memory_manager.cc b/src/runtime/memory/memory_manager.cc index 626222e6c87f..2c50af475bfe 100644 --- a/src/runtime/memory/memory_manager.cc +++ b/src/runtime/memory/memory_manager.cc @@ -24,8 +24,8 @@ #include #include #include -#include #include +#include #include #include diff --git a/src/runtime/metadata.h b/src/runtime/metadata.h index c034041ce4a4..0a8d96261e26 100644 --- a/src/runtime/metadata.h +++ b/src/runtime/metadata.h @@ -96,7 +96,8 @@ class FunctionInfoObj : public ffi::Object { auto sarg_types_arr = src.at("arg_types").cast(); arg_types = ffi::Array(); for (size_t i = 0; i < sarg_types_arr.size(); ++i) { - arg_types.push_back(ffi::StringToDLDataType(std::string(sarg_types_arr[i].cast()))); + arg_types.push_back( + ffi::StringToDLDataType(std::string(sarg_types_arr[i].cast()))); } auto lt = src.find("launch_param_tags"); if (lt != src.end()) { diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index ebbbcde071b4..184cbca05fb8 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -30,10 +30,10 @@ #import #import #import +#include #include #include #include -#include #include #include diff --git a/src/runtime/minrpc/minrpc_server.h b/src/runtime/minrpc/minrpc_server.h index 84cf45a6ca92..059c2fc2e402 100644 --- a/src/runtime/minrpc/minrpc_server.h +++ b/src/runtime/minrpc/minrpc_server.h @@ -29,9 +29,9 @@ #define TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_ #include +#include #include #include -#include #include #include diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index d80a52e5e705..df2f370fd038 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -24,10 +24,10 @@ #ifndef TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_ #define TVM_RUNTIME_OPENCL_OPENCL_COMMON_H_ +#include #include #include #include -#include #include #include #include diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 952a9b67141c..14823f18b3cb 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -22,8 +22,8 @@ */ #include #include -#include #include +#include #include diff --git a/src/runtime/rpc/rpc_device_api.cc b/src/runtime/rpc/rpc_device_api.cc index 6e0dd162b3ba..e828f752d9b8 100644 --- a/src/runtime/rpc/rpc_device_api.cc +++ b/src/runtime/rpc/rpc_device_api.cc @@ -20,10 +20,10 @@ /*! * \file rpc_device_api.cc */ +#include #include #include #include -#include #include diff --git a/src/runtime/static_library.cc b/src/runtime/static_library.cc index d3ea7b345838..f288a843d8f9 100644 --- a/src/runtime/static_library.cc +++ b/src/runtime/static_library.cc @@ -29,9 +29,9 @@ #include #include #include +#include #include #include -#include #include diff --git a/src/runtime/static_library.h b/src/runtime/static_library.h index 0ce4d9e003c6..9baa5c6fb39f 100644 --- a/src/runtime/static_library.h +++ b/src/runtime/static_library.h @@ -27,8 +27,8 @@ #define TVM_RUNTIME_STATIC_LIBRARY_H_ #include -#include #include +#include #include #include diff --git a/src/runtime/tensor.cc b/src/runtime/tensor.cc index 61f037caec55..bb526ef8437c 100644 --- a/src/runtime/tensor.cc +++ b/src/runtime/tensor.cc @@ -21,11 +21,11 @@ * \file tensor.cc * \brief Tensor container infratructure. */ +#include #include #include #include #include -#include #include #include "tvm/runtime/data_type.h" diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index ba2b89770bd7..c7e0b9979e1d 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -22,11 +22,11 @@ * \brief Threadpool for multi-threading runtime. */ #include +#include #include #include #include #include -#include #include "threading_backend.h" #if TVM_THREADPOOL_USE_OPENMP diff --git a/src/runtime/vm/attn_backend.h b/src/runtime/vm/attn_backend.h index ae88843667c3..8d523e4e0506 100644 --- a/src/runtime/vm/attn_backend.h +++ b/src/runtime/vm/attn_backend.h @@ -26,9 +26,9 @@ #define TVM_RUNTIME_VM_ATTN_BACKEND_H_ #include +#include #include #include -#include #include #include diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index f5485e7a3326..322a0a137c17 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -22,12 +22,12 @@ #include #include #include +#include #include #include #include #include #include -#include #include #include #include @@ -611,7 +611,8 @@ bool ReadIfCond(ffi::AnyView cond) { break; } default: - TVM_FFI_THROW(InternalError) << "Unknown scalar int type: " << ffi::DLDataTypeToString(arr->dtype); + TVM_FFI_THROW(InternalError) + << "Unknown scalar int type: " << ffi::DLDataTypeToString(arr->dtype); throw; } return result != 0; diff --git a/src/runtime/vm/kv_state.h b/src/runtime/vm/kv_state.h index fd001f8048a2..198bd18d979d 100644 --- a/src/runtime/vm/kv_state.h +++ b/src/runtime/vm/kv_state.h @@ -20,10 +20,10 @@ #define TVM_RUNTIME_VM_KV_STATE_H_ #include #include +#include #include #include #include -#include #include namespace tvm { diff --git a/src/runtime/vm/lm_support.cc b/src/runtime/vm/lm_support.cc index d07f84be1647..51b271441a27 100644 --- a/src/runtime/vm/lm_support.cc +++ b/src/runtime/vm/lm_support.cc @@ -37,10 +37,10 @@ */ #include #include +#include #include #include #include -#include #include #include #include diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index bb3aee7e340b..d4bc3f874e2c 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -20,11 +20,11 @@ * \file src/runtime/vm/paged_kv_cache.cc * \brief Runtime paged KV cache object for language models. */ +#include #include #include #include #include -#include #include #include diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index b7e29710aff9..d6ffab9be018 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -23,10 +23,10 @@ #include #include #include +#include #include #include #include -#include #include diff --git a/src/runtime/vulkan/spirv_shader.h b/src/runtime/vulkan/spirv_shader.h index e9575defd110..40b3fd70904c 100644 --- a/src/runtime/vulkan/spirv_shader.h +++ b/src/runtime/vulkan/spirv_shader.h @@ -20,10 +20,10 @@ #ifndef TVM_RUNTIME_VULKAN_SPIRV_SHADER_H_ #define TVM_RUNTIME_VULKAN_SPIRV_SHADER_H_ +#include #include #include #include -#include #include #include diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index 826048d8578d..2372c02f366a 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -20,10 +20,10 @@ #ifndef TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_ #define TVM_RUNTIME_VULKAN_VULKAN_COMMON_H_ +#include #include #include #include -#include #include #include diff --git a/src/runtime/vulkan/vulkan_instance.cc b/src/runtime/vulkan/vulkan_instance.cc index fc88db7644cd..92ee82fe1f8a 100644 --- a/src/runtime/vulkan/vulkan_instance.cc +++ b/src/runtime/vulkan/vulkan_instance.cc @@ -18,6 +18,7 @@ */ #include "vulkan_instance.h" + #include #include diff --git a/src/s_tir/analysis/is_pure_function.cc b/src/s_tir/analysis/is_pure_function.cc index 2ca557b171d1..1c4981c90814 100644 --- a/src/s_tir/analysis/is_pure_function.cc +++ b/src/s_tir/analysis/is_pure_function.cc @@ -33,7 +33,6 @@ namespace tvm { namespace s_tir { using namespace tvm::tirx; - namespace { class PurityChecker : TIRVisitorWithPath { public: diff --git a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc index 27c3ded758ad..b4e89b6bb79e 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -17,8 +17,8 @@ * under the License. */ #include -#include #include +#include #include "../utils.h" diff --git a/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc b/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc index 01d619302a5a..a0030ee28bee 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_tensorize.cc @@ -17,9 +17,9 @@ * under the License. */ #include +#include #include #include -#include #include diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc index 87244c8809e4..09d787689d90 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -19,9 +19,9 @@ #include "./multi_level_tiling.h" #include +#include #include #include -#include #include #include diff --git a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 2dc9de361e8f..039754b04fee 100644 --- a/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/s_tir/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -820,7 +820,8 @@ ffi::Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( rhs_to_index_map_tgt[mapping_info->rhs_iters[i - offset]->var] = index_map->final_indices[i]; } - auto f_get_sub_index_map = [&](const tirx::Buffer& lhs_buffer, const ffi::Array& lhs_region) { + auto f_get_sub_index_map = [&](const tirx::Buffer& lhs_buffer, + const ffi::Array& lhs_region) { std::vector sub_index_map_src; std::vector sub_index_map_tgt; const tirx::Buffer& rhs_buffer = mapping_info->lhs_buffer_map[lhs_buffer]; diff --git a/src/s_tir/meta_schedule/utils.h b/src/s_tir/meta_schedule/utils.h index 5dc99d744c28..2dfba623a067 100644 --- a/src/s_tir/meta_schedule/utils.h +++ b/src/s_tir/meta_schedule/utils.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -43,7 +44,6 @@ #include #include #include -#include #include #include diff --git a/src/s_tir/support/parallel_for.h b/src/s_tir/support/parallel_for.h index 1b2c5fa18fbb..22b2073e56ba 100644 --- a/src/s_tir/support/parallel_for.h +++ b/src/s_tir/support/parallel_for.h @@ -24,8 +24,8 @@ #ifndef TVM_S_TIR_SUPPORT_PARALLEL_FOR_H_ #define TVM_S_TIR_SUPPORT_PARALLEL_FOR_H_ -#include #include +#include #include #include diff --git a/src/s_tir/transform/inject_double_buffer.cc b/src/s_tir/transform/inject_double_buffer.cc index 9c5e9bf0b8b5..b476f0dca6ad 100644 --- a/src/s_tir/transform/inject_double_buffer.cc +++ b/src/s_tir/transform/inject_double_buffer.cc @@ -24,11 +24,11 @@ #include #include #include +#include #include #include #include #include -#include #include "../../tirx/transform/ir_utils.h" diff --git a/src/s_tir/transform/loop_partition.cc b/src/s_tir/transform/loop_partition.cc index d47c861873a7..e5f03c29f57d 100644 --- a/src/s_tir/transform/loop_partition.cc +++ b/src/s_tir/transform/loop_partition.cc @@ -25,13 +25,13 @@ #include #include #include +#include #include #include #include #include #include #include -#include #include #include diff --git a/src/s_tir/transform/lower_async_dma.cc b/src/s_tir/transform/lower_async_dma.cc index 756461b0dd08..218de17c11a5 100644 --- a/src/s_tir/transform/lower_async_dma.cc +++ b/src/s_tir/transform/lower_async_dma.cc @@ -26,13 +26,13 @@ #include #include #include +#include #include #include #include #include #include #include -#include #include #include diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index 683806768dc2..347461bd1a06 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -19,8 +19,8 @@ #include #include #include -#include #include +#include #include "./utils.h" diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index 78b9b9fa986f..957421c0bc29 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -16,9 +16,9 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include -#include #include #include diff --git a/src/target/hexagon/llvm/codegen_hexagon.cc b/src/target/hexagon/llvm/codegen_hexagon.cc index c83af58c4ce7..e0beb0262752 100644 --- a/src/target/hexagon/llvm/codegen_hexagon.cc +++ b/src/target/hexagon/llvm/codegen_hexagon.cc @@ -43,9 +43,9 @@ #include #include #include +#include #include #include -#include #include #include diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 9e1a8ce068cc..8d35ef87238f 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -23,10 +23,10 @@ */ #include "intrin_rule.h" +#include #include #include #include -#include namespace tvm { namespace codegen { diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 09308a6ebbfd..10a129eca74f 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -50,8 +50,8 @@ #include #include #include -#include #include +#include #include #include diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index a0e237500c19..4dad2fc4b3ec 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -80,8 +80,8 @@ #include #include #include -#include #include +#include #include #include @@ -2266,8 +2266,9 @@ llvm::DIType* CodeGenLLVM::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) if (dtype.is_scalable_vector()) return nullptr; - return dbg_info_->di_builder_->createBasicType(ffi::DLDataTypeToString(dtype).operator std::string(), - dtype.bits() * dtype.lanes(), dwarf_type); + return dbg_info_->di_builder_->createBasicType( + ffi::DLDataTypeToString(dtype).operator std::string(), dtype.bits() * dtype.lanes(), + dwarf_type); } else { std::string type_str; diff --git a/src/target/metal/codegen_metal.cc b/src/target/metal/codegen_metal.cc index c84df824a14f..986bda6c66b2 100644 --- a/src/target/metal/codegen_metal.cc +++ b/src/target/metal/codegen_metal.cc @@ -25,8 +25,8 @@ #include #include #include -#include #include +#include #include #include diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index e593852e43ad..b19c41056deb 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -25,9 +25,9 @@ #include #include #include +#include #include #include -#include #include diff --git a/src/tirx/analysis/verify_memory.cc b/src/tirx/analysis/verify_memory.cc index 6c4ba1193400..27853fb04c13 100644 --- a/src/tirx/analysis/verify_memory.cc +++ b/src/tirx/analysis/verify_memory.cc @@ -24,12 +24,12 @@ #include #include #include +#include #include #include #include #include #include -#include namespace tvm { namespace tirx { diff --git a/src/tirx/analysis/verify_well_formed.cc b/src/tirx/analysis/verify_well_formed.cc index cc33e59d5690..b3adda7812e4 100644 --- a/src/tirx/analysis/verify_well_formed.cc +++ b/src/tirx/analysis/verify_well_formed.cc @@ -40,7 +40,6 @@ namespace tvm { namespace tirx { - namespace { template diff --git a/src/tirx/ir/stmt.cc b/src/tirx/ir/stmt.cc index 4e2cfd3f1474..99eab3590203 100644 --- a/src/tirx/ir/stmt.cc +++ b/src/tirx/ir/stmt.cc @@ -510,7 +510,8 @@ MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { // Validate shape TVM_FFI_ICHECK(source->region.size() >= buffer->shape.size()) - << "Dimension of source ffi::Array expected to be larger or equal than target buffer shape, but " + << "Dimension of source ffi::Array expected to be larger or equal than target buffer " + "shape, but " "got " << source->region.size() << " vs. " << buffer->shape.size(); size_t offset = source->region.size() - buffer->shape.size(); diff --git a/src/tirx/ir/tir_visitor_with_path.cc b/src/tirx/ir/tir_visitor_with_path.cc index 512225344959..42bae073cbe9 100644 --- a/src/tirx/ir/tir_visitor_with_path.cc +++ b/src/tirx/ir/tir_visitor_with_path.cc @@ -35,7 +35,6 @@ namespace tvm { namespace tirx { - void TIRVisitorWithPath::Visit(const IRModule& mod, ffi::reflection::AccessPath path) { // To ensure deterministic order of visits, sort the GlobalVar first // by visibility (public then private), then alphabetically by name. @@ -333,10 +332,10 @@ void TIRVisitorWithPath::VisitExpr_(const CallNode* op, ffi::reflection::AccessP Visit(op->args, path->Attr("args")); } -#define DEFINE_BINOP_VISIT_(OP) \ +#define DEFINE_BINOP_VISIT_(OP) \ void TIRVisitorWithPath::VisitExpr_(const OP* op, ffi::reflection::AccessPath path) { \ - Visit(op->a, path->Attr("a")); \ - Visit(op->b, path->Attr("b")); \ + Visit(op->a, path->Attr("a")); \ + Visit(op->b, path->Attr("b")); \ } DEFINE_BINOP_VISIT_(AddNode); diff --git a/src/tirx/script/builder/ir.cc b/src/tirx/script/builder/ir.cc index 6dc316590c01..c0d61919d7d9 100644 --- a/src/tirx/script/builder/ir.cc +++ b/src/tirx/script/builder/ir.cc @@ -22,8 +22,8 @@ #include #include #include -#include #include +#include #include "./utils.h" diff --git a/src/tirx/transform/lower_intrin.cc b/src/tirx/transform/lower_intrin.cc index a8c60d33b9d2..981615b0d1d5 100644 --- a/src/tirx/transform/lower_intrin.cc +++ b/src/tirx/transform/lower_intrin.cc @@ -24,12 +24,12 @@ #include #include #include +#include #include #include #include #include #include -#include #include #include diff --git a/src/tirx/transform/lower_tvm_builtin.cc b/src/tirx/transform/lower_tvm_builtin.cc index 3ba72294bb2c..085f62d668c0 100644 --- a/src/tirx/transform/lower_tvm_builtin.cc +++ b/src/tirx/transform/lower_tvm_builtin.cc @@ -25,11 +25,11 @@ #include #include #include +#include #include #include #include #include -#include #include diff --git a/src/tirx/transform/storage_rewrite.cc b/src/tirx/transform/storage_rewrite.cc index f64d262f97c8..24cd5ce4a274 100644 --- a/src/tirx/transform/storage_rewrite.cc +++ b/src/tirx/transform/storage_rewrite.cc @@ -27,13 +27,13 @@ #include #include #include +#include #include #include #include #include #include #include -#include #include #include diff --git a/src/tirx/transform/tvm_ffi_binder.cc b/src/tirx/transform/tvm_ffi_binder.cc index 881d94ba61ab..16b7eab7af2c 100644 --- a/src/tirx/transform/tvm_ffi_binder.cc +++ b/src/tirx/transform/tvm_ffi_binder.cc @@ -25,11 +25,11 @@ #include #include +#include #include #include #include #include -#include #include "ir_utils.h" @@ -226,7 +226,8 @@ bool TVMFFIABIBuilder::BindScalar(const PrimExpr& arg, const PrimExpr& value, // ============================================================ /*! - * \brief Render PrimExpr to string with variable names replaced by ffi::reflection::AccessPath names. + * \brief Render PrimExpr to string with variable names replaced by ffi::reflection::AccessPath + * names. * * Uses ExprFunctor for generic dispatch over all expression types. * The default TIR printer sanitizes Var name_hints (e.g. "B.shape[0]" -> "B_shape_0_") @@ -342,8 +343,8 @@ void TVMFFIABIBuilder::BindArray(const ffi::Array& arg, const ffi::Arr // BindBuffer (buffer-to-buffer bind with ffi::reflection::AccessPath) // ============================================================ -void TVMFFIABIBuilder::BindBuffer(const Buffer& arg, const Buffer& value, ffi::reflection::AccessPath base_path, - bool fuzzy_match) { +void TVMFFIABIBuilder::BindBuffer(const Buffer& arg, const Buffer& value, + ffi::reflection::AccessPath base_path, bool fuzzy_match) { TVM_FFI_ICHECK_EQ(arg.scope(), value.scope()) << "Argument " << arg->name << " Buffer bind scope mismatch"; TVM_FFI_ICHECK_EQ(arg->dtype, value->dtype) @@ -514,7 +515,8 @@ void TVMFFIABIBuilder::DecodeParam(int param_index) { } // Bind scalar param to loaded value (defines vars before buffer binds reference them) - ffi::reflection::AccessPath param_path = ffi::reflection::AccessPath::Root()->Extend(AccessStep::ArrayItem(param_index)); + ffi::reflection::AccessPath param_path = + ffi::reflection::AccessPath::Root()->Extend(AccessStep::ArrayItem(param_index)); BindScalar(param, arg_value, param_path, true); } @@ -536,8 +538,9 @@ void TVMFFIABIBuilder::DecodeAllParams() { Var param = params_[i]; if (buffer_map_.count(param)) { Buffer buffer = buffer_map_[param]; - ffi::reflection::AccessPath param_path = - ffi::reflection::AccessPath::Root()->Extend(AccessStep::ArrayItem(i))->Attr(ffi::String(buffer->name)); + ffi::reflection::AccessPath param_path = ffi::reflection::AccessPath::Root() + ->Extend(AccessStep::ArrayItem(i)) + ->Attr(ffi::String(buffer->name)); DecodeParamDLTensor(buffer, device_type_, device_id_, param, func_name_ + "." + param->name_hint, param_path); decl_buffers_.push_back(DeclBuffer(buffer)); @@ -607,7 +610,8 @@ void TVMFFIABIBuilder::BindAutoBroadcastStrides(const Buffer& buffer, const Var& PrimExpr value = cast(buffer->shape[k].dtype(), LoadInt64ArrayElem(strides_ptr, k)); value = tvm::if_then_else(v_strides_is_null, stride, value); value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); - ffi::reflection::AccessPath strides_k_path = param_path->Attr(ffi::String("strides"))->ArrayItem(k); + ffi::reflection::AccessPath strides_k_path = + param_path->Attr(ffi::String("strides"))->ArrayItem(k); BindScalar(buffer->strides[k], value, strides_k_path, true); stride = analyzer_.Simplify(stride * buffer->shape[k]); } @@ -619,7 +623,8 @@ void TVMFFIABIBuilder::BindRegularStrides(const Buffer& buffer, const Var& strid PrimExpr stride_from_shape = 1; for (int k = buffer->strides.size() - 1; k >= 0; k--) { PrimExpr explicit_stride = cast(buffer->shape[k].dtype(), LoadInt64ArrayElem(strides_ptr, k)); - ffi::reflection::AccessPath strides_k_path = param_path->Attr(ffi::String("strides"))->ArrayItem(k); + ffi::reflection::AccessPath strides_k_path = + param_path->Attr(ffi::String("strides"))->ArrayItem(k); BindScalar(buffer->strides[k], tvm::if_then_else(v_strides_is_null, stride_from_shape, explicit_stride), strides_k_path, true); @@ -633,7 +638,8 @@ void TVMFFIABIBuilder::BindRegularStrides(const Buffer& buffer, const Var& strid void TVMFFIABIBuilder::DecodeParamDLTensor(const Buffer& buffer, const PrimExpr& device_type, const PrimExpr& device_id, const Var& handle, - const std::string& arg_name, ffi::reflection::AccessPath base_path) { + const std::string& arg_name, + ffi::reflection::AccessPath base_path) { const DataType tvm_ndim_type = DataType::Int(32); std::string buf_name = buffer->name; diff --git a/src/tirx/transform/tvm_ffi_binder.h b/src/tirx/transform/tvm_ffi_binder.h index 03ed0b77fede..92af52df6bcb 100644 --- a/src/tirx/transform/tvm_ffi_binder.h +++ b/src/tirx/transform/tvm_ffi_binder.h @@ -85,7 +85,8 @@ namespace tirx { */ class TVMFFIABIBuilder { public: - /*! \brief Variable definition info: bound value and the ffi::reflection::AccessPath where first defined. */ + /*! \brief Variable definition info: bound value and the ffi::reflection::AccessPath where first + * defined. */ struct VarDefInfo { PrimExpr value; ffi::reflection::AccessPath first_def_path; diff --git a/src/tirx/transform/vectorize_loop.cc b/src/tirx/transform/vectorize_loop.cc index 45d6a5e118be..bf1085165ad4 100644 --- a/src/tirx/transform/vectorize_loop.cc +++ b/src/tirx/transform/vectorize_loop.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -33,7 +34,6 @@ #include #include #include -#include #include #include diff --git a/tests/python/contrib/test_example_npu.py b/tests/python/contrib/test_example_npu.py index e152051234b7..217d50d11f10 100644 --- a/tests/python/contrib/test_example_npu.py +++ b/tests/python/contrib/test_example_npu.py @@ -122,9 +122,9 @@ def test_example_npu_patterns_registered(): "example_npu.max_pool2d", } - assert core_patterns.issubset( - pattern_names - ), f"Missing core patterns: {core_patterns - pattern_names}" + assert core_patterns.issubset(pattern_names), ( + f"Missing core patterns: {core_patterns - pattern_names}" + ) # Check that at least some activation patterns are available activation_patterns = {name for name in pattern_names if "relu" in name or "sigmoid" in name} @@ -224,7 +224,7 @@ def test_example_npu_codegen(): @example_npu_enabled def test_example_npu_runtime_execution(): """Test end-to-end execution with the example NPU runtime""" - import tvm.relax.backend.contrib.example_npu # noqa: F401 + import tvm.relax.backend.contrib.example_npu # Create simple test inputs np.random.seed(42) diff --git a/tests/python/ir/test_ir_type.py b/tests/python/ir/test_ir_type.py index 339a92f43524..cfced6f7c84c 100644 --- a/tests/python/ir/test_ir_type.py +++ b/tests/python/ir/test_ir_type.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# ruff: noqa: E711, F401, F821 +# ruff: noqa: F401, F821 """Test type nodes in the IR""" import tvm diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index f3e2e581e1d5..5d032ba5c778 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -7404,10 +7404,10 @@ def main(x: R.Tensor((2, 10), dtype="float32")) -> R.Tuple( def test_index_put_with_tuple_output(): class IndexPutTupleOutput(Module): - def forward(self, x, l, idx): + def forward(self, x, buf, idx): values = x - l[..., idx, idx] = values - return x[..., 1], l + buf[..., idx, idx] = values + return x[..., 1], buf example_args = ( torch.ones(2, 3, 5, dtype=torch.float32), @@ -7425,8 +7425,7 @@ def forward(self, x, l, idx): assert len(tensor_fields) >= 2 assert any( - len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5 - for f in tensor_fields + len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5 for f in tensor_fields ) @@ -7434,14 +7433,14 @@ def test_m4d_diag_index_put_tuple_output_regression(): class M4D(Module): def forward(self, x): b, k, n = 2, 3, 5 - l = x.new_zeros(b, k, n, n) + buf = x.new_zeros(b, k, n, n) idx = torch.arange(n, device=x.device) - diag = l[..., idx, idx] + diag = buf[..., idx, idx] diag = torch.nn.functional.elu(diag) + 1.0 + 1e-8 - l[..., idx, idx] = diag + buf[..., idx, idx] = diag - return x[..., :1], l + return x[..., :1], buf ex_in = torch.zeros(2, 3, 5, dtype=torch.float32) exported_program = export(M4D().eval(), args=(ex_in,)) @@ -7458,10 +7457,9 @@ def forward(self, x): assert len(tensor_fields) >= 2 # x: (2, 3, 5) → x[..., :1]: (2, 3, 1) assert any(len(f.shape) == 3 and int(f.shape[-1]) == 1 for f in tensor_fields) - # l: (2, 3, 5, 5) → 4-D with spatial dims 5×5 + # buf: (2, 3, 5, 5) → 4-D with spatial dims 5x5 assert any( - len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5 - for f in tensor_fields + len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5 for f in tensor_fields ) @@ -9290,9 +9288,7 @@ def false_fn(x): def test_affine_grid(): class AffineGrid(Module): def forward(self, theta): - return torch.nn.functional.affine_grid( - theta, [1, 3, 16, 16], align_corners=True - ) + return torch.nn.functional.affine_grid(theta, [1, 3, 16, 16], align_corners=True) @tvm.script.ir_module class expected: @@ -9321,9 +9317,7 @@ def test_affine_grid_numerically(): class AffineGrid(Module): def forward(self, theta): - return torch.nn.functional.affine_grid( - theta, [2, 3, 8, 12], align_corners=True - ) + return torch.nn.functional.affine_grid(theta, [2, 3, 8, 12], align_corners=True) model = AffineGrid() example_args = (torch.randn(2, 2, 3, dtype=torch.float32),) diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 890c6ef3a1ff..b2fe59b50799 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3672,6 +3672,7 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tensor( verify_model(Interpolate4(), input_info, {}, expected4) input_info_5d = [([1, 3, 4, 10, 10], "float32")] + class Interpolate5(Module): def forward(self, input): return torch.nn.functional.interpolate( @@ -3681,13 +3682,13 @@ def forward(self, input): mode="trilinear", align_corners=False, ) + @tvm.script.ir_module class expected5: @R.function def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor( (1, 3, 8, 20, 20), dtype="float32" ): - with R.dataflow(): lv: R.Tensor((1, 3, 8, 20, 20), dtype="float32") = R.image.resize3d( input_5, @@ -3713,17 +3714,17 @@ def forward(self, input): return torch.nn.functional.interpolate( input, size=None, - scale_factor=(2.0,4.0,4.0), + scale_factor=(2.0, 4.0, 4.0), mode="trilinear", align_corners=False, ) + @tvm.script.ir_module class expected6: @R.function def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor( (1, 3, 8, 40, 40), dtype="float32" ): - with R.dataflow(): lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = R.image.resize3d( input_5, @@ -3748,17 +3749,17 @@ class Interpolate7(Module): def forward(self, input): return torch.nn.functional.interpolate( input, - size=(8,40,40), + size=(8, 40, 40), mode="trilinear", align_corners=False, ) + @tvm.script.ir_module class expected7: @R.function def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor( (1, 3, 8, 40, 40), dtype="float32" ): - with R.dataflow(): lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = R.image.resize3d( input_5, @@ -3783,17 +3784,17 @@ class Interpolate8(Module): def forward(self, input): return torch.nn.functional.interpolate( input, - size=(8,40,40), + size=(8, 40, 40), mode="trilinear", align_corners=True, ) + @tvm.script.ir_module class expected8: @R.function def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor( (1, 3, 8, 40, 40), dtype="float32" ): - with R.dataflow(): lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = R.image.resize3d( input_5, @@ -3936,17 +3937,17 @@ def forward(self, input): return torch.nn.functional.interpolate( input, size=None, - scale_factor=(2.0,4.0,4.0), + scale_factor=(2.0, 4.0, 4.0), mode="trilinear", align_corners=False, ) + @tvm.script.ir_module class expected_nhwc3: @R.function def main(input_5: R.Tensor((1, 4, 10, 10, 3), dtype="float32")) -> R.Tensor( (1, 8, 40, 40, 3), dtype="float32" ): - with R.dataflow(): lv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") = R.image.resize3d( input_5, @@ -3975,17 +3976,17 @@ def forward(self, input): return torch.nn.functional.interpolate( input, size=None, - scale_factor=(2.0,4.0,4.0), + scale_factor=(2.0, 4.0, 4.0), mode="trilinear", align_corners=True, ) + @tvm.script.ir_module class expected_nhwc4: @R.function def main(input_5: R.Tensor((1, 4, 10, 10, 3), dtype="float32")) -> R.Tensor( (1, 8, 40, 40, 3), dtype="float32" ): - with R.dataflow(): lv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") = R.image.resize3d( input_5, @@ -4009,6 +4010,7 @@ def main(input_5: R.Tensor((1, 4, 10, 10, 3), dtype="float32")) -> R.Tensor( mod4 = from_fx(graph_model4, input_info_5d, default_image_layout="NDHWC") tvm.ir.assert_structural_equal(mod4, expected_nhwc4) + def test_addmm(): input_info = [ ([10, 10], "float32"), diff --git a/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py b/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py index 51ea99226895..d252eeb9d740 100644 --- a/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py +++ b/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py @@ -39,7 +39,7 @@ compared on the unpadded positions (padded positions are intentionally free to contain arbitrary garbage). """ -# ruff: noqa: E501 + import math import numpy as np diff --git a/tests/python/relax/test_frontend_nn_parameter_containers.py b/tests/python/relax/test_frontend_nn_parameter_containers.py index d07a21405a61..037925c7d033 100644 --- a/tests/python/relax/test_frontend_nn_parameter_containers.py +++ b/tests/python/relax/test_frontend_nn_parameter_containers.py @@ -171,12 +171,8 @@ def test_load_state_dict(): assert unexpected_keys == [] tvm.testing.assert_allclose(m.list_params[0].data.numpy(), np.full((4,), 1.0, "float32")) tvm.testing.assert_allclose(m.list_params[1].data.numpy(), np.full((4,), 2.0, "float32")) - tvm.testing.assert_allclose( - m.dict_params["foo"].data.numpy(), np.full((4,), 3.0, "float32") - ) - tvm.testing.assert_allclose( - m.dict_params["bar"].data.numpy(), np.full((4,), 4.0, "float32") - ) + tvm.testing.assert_allclose(m.dict_params["foo"].data.numpy(), np.full((4,), 3.0, "float32")) + tvm.testing.assert_allclose(m.dict_params["bar"].data.numpy(), np.full((4,), 4.0, "float32")) def test_export_tvm_parameter_names(): diff --git a/tests/python/relax/test_frontend_nn_subroutines.py b/tests/python/relax/test_frontend_nn_subroutines.py index a06fa05c7723..db4652b2df79 100644 --- a/tests/python/relax/test_frontend_nn_subroutines.py +++ b/tests/python/relax/test_frontend_nn_subroutines.py @@ -137,7 +137,8 @@ def forward(self, x: relax.Expr, y: relax.Expr) -> relax.Var: func for gvar, func in tvm_mod.functions.items() if isinstance(func, relax.Function) - and gvar.name_hint not in ( + and gvar.name_hint + not in ( "forward", "_initialize_effect", ) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index c46709e33de8..7f1cecd1c979 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -549,12 +549,13 @@ def test_concat_with_param_shape_value(): helper.make_node("Reshape", ["x", "new_shape"], ["y"]), ] graph = helper.make_graph( - nodes, "concat_param_shape", [inp], [out], + nodes, + "concat_param_shape", + [inp], + [out], initializer=[twelve, starts, ends], ) - model = helper.make_model( - graph, opset_imports=[helper.make_opsetid("", 13)] - ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) model.ir_version = 8 onnx.checker.check_model(model) # Both modes should succeed; previously True crashed with @@ -880,8 +881,8 @@ def _verify_gather(data_shape, indices, out_shape, axis=0): (0, [-1, 0], [2, 4]), (1, [-1, 0], [3, 2]), ( - 1, - [[-1, 0], [1, -2]], + 1, + [[-1, 0], [1, -2]], [3, 2, 2], ), ], @@ -1995,7 +1996,7 @@ def test_cumsum_axis_shape_validation(): model = helper.make_model(graph, producer_name="cumsum_invalid_axis_shape_graph") with pytest.raises( - ValueError, + ValueError, match="axis input must be a scalar \(0-D\) or a single-element 1-D tensor", ): from_onnx(model, opset=14, keep_params_in_input=True) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index fc509a4d0f49..a53906d2f147 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -184,7 +184,7 @@ class Cumsum(tf.Module): @tf.function( input_signature=[ tf.TensorSpec(shape=(3, 4), dtype=tf.float32), - tf.TensorSpec(shape=(5, 6), dtype=tf.int32) + tf.TensorSpec(shape=(5, 6), dtype=tf.int32), ] ) def func(self, x, y): @@ -567,7 +567,8 @@ def func(self, start, limit, delta): with pytest.raises(tvm.error.OpNotImplemented, match="dynamic scalar inputs"): verify(RangeDynamic) - + + def test_tile_ir(): """TILE conversion with explicit Relax IR structural check.""" @@ -735,18 +736,6 @@ def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="floa verify(TfInput, Expected) -def test_prelu(): - alpha_init = tf.keras.initializers.Constant(np.linspace(0.1, 0.3, 30, dtype=np.float32)) - prelu = tf.keras.layers.PReLU(alpha_initializer=alpha_init) - - class TfInput(tf.Module): - @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) - def func(self, x): - return prelu(x) - - verify(TfInput) - - def test_fill(): class TfInput(tf.Module): @tf.function( @@ -819,9 +808,7 @@ def test_random_standard_normal_dynamic_shape(): class TfRandomStandardNormal(tf.Module): @tf.function(input_signature=[tf.TensorSpec(shape=(2,), dtype=tf.int32)]) def func(self, shape): - return tf.raw_ops.RandomStandardNormal( - shape=shape, dtype=tf.float32, seed=3, seed2=5 - ) + return tf.raw_ops.RandomStandardNormal(shape=shape, dtype=tf.float32, seed=3, seed2=5) cf = TfRandomStandardNormal().func.get_concrete_function() mod = _get_mod_from_cfunc(cf) @@ -959,9 +946,9 @@ def func(self, s0, s1): @I.ir_module class Expected: @R.function - def main( - s0: R.Tensor((3,), dtype="int32"), s1: R.Tensor((3,), dtype="int32") - ) -> R.Tensor((3,), dtype="int32"): + def main(s0: R.Tensor((3,), dtype="int32"), s1: R.Tensor((3,), dtype="int32")) -> R.Tensor( + (3,), dtype="int32" + ): R.func_attr({"num_input": 2}) with R.dataflow(): lv: R.Tensor((0,), dtype="int32") = R.full( @@ -999,9 +986,9 @@ def func(self, s0, s1): @I.ir_module class Expected: @R.function - def main( - s0: R.Tensor((1,), dtype="int32"), s1: R.Tensor((3,), dtype="int32") - ) -> R.Tensor((3,), dtype="int32"): + def main(s0: R.Tensor((1,), dtype="int32"), s1: R.Tensor((3,), dtype="int32")) -> R.Tensor( + (3,), dtype="int32" + ): R.func_attr({"num_input": 2}) with R.dataflow(): lv: R.Tensor((2,), dtype="int32") = R.full( @@ -1631,9 +1618,7 @@ def func(self, data, kernel): def test_conv3d_valid(): - Conv3DModule = _make_conv3d_module( - (1, 8, 8, 8, 3), (3, 3, 3, 3, 16), (1, 1, 1, 1, 1), "VALID" - ) + Conv3DModule = _make_conv3d_module((1, 8, 8, 8, 3), (3, 3, 3, 3, 16), (1, 1, 1, 1, 1), "VALID") @I.ir_module class Expected: @@ -1663,9 +1648,7 @@ def main( def test_conv3d_same(): - Conv3DModule = _make_conv3d_module( - (1, 8, 8, 8, 3), (3, 3, 3, 3, 16), (1, 1, 1, 1, 1), "SAME" - ) + Conv3DModule = _make_conv3d_module((1, 8, 8, 8, 3), (3, 3, 3, 3, 16), (1, 1, 1, 1, 1), "SAME") @I.ir_module class Expected: @@ -1709,7 +1692,7 @@ def _make_conv3d_transpose_module(data_shape, kernel_shape, strides, padding): out_spatial.append((in_size - 1) * s + k_size) else: # SAME out_spatial.append(in_size * s) - computed_output_shape = [batch] + out_spatial + [out_channels] + computed_output_shape = [batch, *out_spatial, out_channels] class Conv3DTransposeModule(tf.Module): @tf.function( @@ -1730,7 +1713,6 @@ def func(self, data, kernel): return Conv3DTransposeModule - def test_conv3d_transpose_valid(): Conv3DTransposeModule = _make_conv3d_transpose_module( (1, 8, 8, 8, 3), (3, 3, 3, 8, 3), (1, 1, 1, 1, 1), "VALID" @@ -2012,6 +1994,7 @@ def func(self, condition, x, y): verify(ModelBroadcasting) + def test_scatter_nd(): class Model(tf.Module): @tf.function( @@ -2047,9 +2030,7 @@ def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2), dtype="flo lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims( R.const([0, 0, 1, 2], "int32"), axis=[1] ) - gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd( - lv, lv1, data, reduction="add" - ) + gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd(lv, lv1, data, reduction="add") R.output(gv) return gv @@ -2080,9 +2061,7 @@ def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2), dtype="flo lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims( R.const([2, 0, 2, 1], "int32"), axis=[1] ) - gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd( - lv, lv1, data, reduction="min" - ) + gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd(lv, lv1, data, reduction="min") R.output(gv) return gv @@ -2113,9 +2092,7 @@ def main(data: R.Tensor((4, 2), dtype="float32")) -> R.Tensor((3, 2), dtype="flo lv1: R.Tensor((4, 1), dtype="int32") = R.expand_dims( R.const([1, 0, 1, 2], "int32"), axis=[1] ) - gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd( - lv, lv1, data, reduction="mul" - ) + gv: R.Tensor((3, 2), dtype="float32") = R.scatter_nd(lv, lv1, data, reduction="mul") R.output(gv) return gv @@ -3477,6 +3454,18 @@ def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 30), dtype="floa verify(ReLU_N1_to_1, Expected) +def test_prelu_basic(): + alpha_init = tf.keras.initializers.Constant(np.linspace(0.1, 0.3, 30, dtype=np.float32)) + prelu = tf.keras.layers.PReLU(alpha_initializer=alpha_init) + + class TfInput(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), dtype=tf.float32)]) + def func(self, x): + return prelu(x) + + verify(TfInput) + + @pytest.mark.parametrize( "shared_axes", [ @@ -3652,6 +3641,7 @@ def main( # Since TensorFlow does not provide an API to create sparse TFLite models, # we manually build them using the flatbuffers API. + # Import schema helpers explicitly. CI's generated tflite package does not # reliably re-export these builder helpers and enums at the package top-level. def _get_tflite_schema_module(module_name): @@ -3784,9 +3774,7 @@ def _build_operator( builtin_options2=None, ): inputs_vec = _tflite_int32_vector(builder, _tfl_operator.OperatorStartInputsVector, inputs) - outputs_vec = _tflite_int32_vector( - builder, _tfl_operator.OperatorStartOutputsVector, outputs - ) + outputs_vec = _tflite_int32_vector(builder, _tfl_operator.OperatorStartOutputsVector, outputs) _tfl_operator.OperatorStart(builder) _tfl_operator.OperatorAddOpcodeIndex(builder, opcode_index) _tfl_operator.OperatorAddInputs(builder, inputs_vec) @@ -3819,9 +3807,7 @@ def _build_subgraph(builder, *, tensors, operators, inputs, outputs): builder, _tfl_subgraph.SubGraphStartOperatorsVector, operators ) inputs_vec = _tflite_int32_vector(builder, _tfl_subgraph.SubGraphStartInputsVector, inputs) - outputs_vec = _tflite_int32_vector( - builder, _tfl_subgraph.SubGraphStartOutputsVector, outputs - ) + outputs_vec = _tflite_int32_vector(builder, _tfl_subgraph.SubGraphStartOutputsVector, outputs) _tfl_subgraph.SubGraphStart(builder) _tfl_subgraph.SubGraphAddTensors(builder, tensors_vec) @@ -3935,9 +3921,7 @@ def _build_stablehlo_typed_binary_model(*, builtin_name, tensor_type): ) def test_stablehlo_unary(builtin_name, relax_op): """TFLite StableHLO unary elementwise operators.""" - mod = _load_model_from_buffer( - _build_stablehlo_model(builtin_name=builtin_name, input_count=1) - ) + mod = _load_model_from_buffer(_build_stablehlo_model(builtin_name=builtin_name, input_count=1)) @I.ir_module class Expected: @@ -3966,9 +3950,7 @@ def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float3 ) def test_stablehlo_binary(builtin_name, relax_op): """TFLite StableHLO binary elementwise operators.""" - mod = _load_model_from_buffer( - _build_stablehlo_model(builtin_name=builtin_name, input_count=2) - ) + mod = _load_model_from_buffer(_build_stablehlo_model(builtin_name=builtin_name, input_count=2)) @I.ir_module class Expected: @@ -3999,9 +3981,7 @@ def main( def test_stablehlo_typed_binary(builtin_name, relax_op, dtype, tensor_type): """TFLite StableHLO binary elementwise operators with non-float dtype requirements.""" mod = _load_model_from_buffer( - _build_stablehlo_typed_binary_model( - builtin_name=builtin_name, tensor_type=tensor_type - ) + _build_stablehlo_typed_binary_model(builtin_name=builtin_name, tensor_type=tensor_type) ) @I.ir_module @@ -4075,12 +4055,9 @@ def main( R.output(gv) return gv - tvm.ir.assert_structural_equal(mod, Expected) - - def _build_stablehlo_convert_model(): """STABLEHLO_CONVERT: float32 input -> int32 output.""" builder = flatbuffers.Builder(1024) @@ -4090,9 +4067,7 @@ def _build_stablehlo_convert_model(): t_out = _build_tensor(builder, 1, shape, tensor_type=_tfl_tensor_type.INT32) tensors = [t_in, t_out] - op_code = _build_operator_code( - builder, _get_stablehlo_builtin_operator("STABLEHLO_CONVERT") - ) + op_code = _build_operator_code(builder, _get_stablehlo_builtin_operator("STABLEHLO_CONVERT")) op = _build_operator(builder, 0, [0], [1]) subgraph = _build_subgraph( builder, @@ -4164,9 +4139,9 @@ def _build_stablehlo_concat_model(dimension, num_inputs): out_shape = [num_inputs * shape[0], shape[1]] else: out_shape = [shape[0], num_inputs * shape[1]] - tensors = [ - _build_tensor(builder, i, shape) for i in range(num_inputs) - ] + [_build_tensor(builder, num_inputs, out_shape)] + tensors = [_build_tensor(builder, i, shape) for i in range(num_inputs)] + [ + _build_tensor(builder, num_inputs, out_shape) + ] op = _build_operator( builder, @@ -4275,9 +4250,7 @@ class Expected: def main(x: R.Tensor((3,), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): R.func_attr({"num_input": 1}) with R.dataflow(): - gv: R.Tensor((2, 3), dtype="float32") = R.broadcast_to( - R.reshape(x, (1, 3)), (2, 3) - ) + gv: R.Tensor((2, 3), dtype="float32") = R.broadcast_to(R.reshape(x, (1, 3)), (2, 3)) R.output(gv) return gv @@ -4381,12 +4354,30 @@ def _build_stablehlo_compare_model(direction): @pytest.mark.parametrize( "direction_enum, relax_op", [ - (_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_EQ, R.equal), - (_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_NE, R.not_equal), - (_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_GE, R.greater_equal), - (_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_GT, R.greater), - (_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LE, R.less_equal), - (_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LT, R.less), + ( + _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_EQ, + R.equal, + ), + ( + _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_NE, + R.not_equal, + ), + ( + _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_GE, + R.greater_equal, + ), + ( + _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_GT, + R.greater, + ), + ( + _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LE, + R.less_equal, + ), + ( + _tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LT, + R.less, + ), ], ) def test_stablehlo_compare(direction_enum, relax_op): @@ -4506,21 +4497,13 @@ def _build_stablehlo_gather_model( ) _tfl_stablehlo_gather_opts.StablehloGatherOptionsStart(builder) - _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddOffsetDims( - builder, offset_dims_vec - ) + _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddOffsetDims(builder, offset_dims_vec) _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddCollapsedSliceDims( builder, collapsed_slice_dims_vec ) - _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddStartIndexMap( - builder, start_index_map_vec - ) - _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddIndexVectorDim( - builder, index_vector_dim - ) - _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddSliceSizes( - builder, slice_sizes_vec - ) + _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddStartIndexMap(builder, start_index_map_vec) + _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddIndexVectorDim(builder, index_vector_dim) + _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddSliceSizes(builder, slice_sizes_vec) gather_opts = _tfl_stablehlo_gather_opts.StablehloGatherOptionsEnd(builder) builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_GATHER") @@ -4613,6 +4596,7 @@ def test_stablehlo_gather_complex_unsupported(): with pytest.raises(tvm.error.OpNotImplemented, match="start_index_map"): from_tflite(tflite_model) + def _pad_vector(builder, start_vector_fn, values): """Build a FlatBuffers int64 vector for pad options.""" start_vector_fn(builder, len(values)) @@ -4657,13 +4641,19 @@ def _build_stablehlo_pad_model(edge_low, edge_high, interior): tensors = [t_in, t_pad_val, t_out] op = _build_operator( - builder, 0, [0, 1], [2], + builder, + 0, + [0, 1], + [2], builtin_options2_type=_tfl_builtin_options2.StablehloPadOptions, builtin_options2=pad_opts, ) subgraph = _build_subgraph( - builder, tensors=tensors, operators=[op], - inputs=[0], outputs=[2], + builder, + tensors=tensors, + operators=[op], + inputs=[0], + outputs=[2], ) buffers = [ _build_buffer(builder), @@ -4733,13 +4723,19 @@ def test_stablehlo_pad_interior_unsupported(): tensors = [t_in, t_pv, t_out] op = _build_operator( - builder, 0, [0, 1], [2], + builder, + 0, + [0, 1], + [2], builtin_options2_type=_tfl_builtin_options2.StablehloPadOptions, builtin_options2=pad_opts, ) subgraph = _build_subgraph( - builder, tensors=tensors, operators=[op], - inputs=[0], outputs=[2], + builder, + tensors=tensors, + operators=[op], + inputs=[0], + outputs=[2], ) buffers = [ _build_buffer(builder), @@ -4792,13 +4788,19 @@ def test_stablehlo_pad_negative_unsupported(): tensors = [t_in, t_pv, t_out] op = _build_operator( - builder, 0, [0, 1], [2], + builder, + 0, + [0, 1], + [2], builtin_options2_type=_tfl_builtin_options2.StablehloPadOptions, builtin_options2=pad_opts, ) subgraph = _build_subgraph( - builder, tensors=tensors, operators=[op], - inputs=[0], outputs=[2], + builder, + tensors=tensors, + operators=[op], + inputs=[0], + outputs=[2], ) buffers = [ _build_buffer(builder), @@ -4822,17 +4824,13 @@ def _build_stablehlo_dynamic_slice_model(slice_sizes, start_vals): ndim = len(slice_sizes) # Build SliceSizes vector - _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStartSliceSizesVector( - builder, ndim - ) + _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStartSliceSizesVector(builder, ndim) for v in reversed(slice_sizes): builder.PrependInt64(v) sizes_vec = builder.EndVector() _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStart(builder) - _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsAddSliceSizes( - builder, sizes_vec - ) + _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsAddSliceSizes(builder, sizes_vec) dyn_opts = _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsEnd(builder) builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DYNAMIC_SLICE") @@ -4845,26 +4843,28 @@ def _build_stablehlo_dynamic_slice_model(slice_sizes, start_vals): start_buffers = [] for i, sv in enumerate(start_vals): bidx = 1 + i - start_tensors.append( - _build_tensor(builder, bidx, [], tensor_type=_tfl_tensor_type.INT32) - ) + start_tensors.append(_build_tensor(builder, bidx, [], tensor_type=_tfl_tensor_type.INT32)) start_inputs.append(bidx) - start_buffers.append( - _build_buffer(builder, np.array([sv], dtype=np.int32).tobytes()) - ) + start_buffers.append(_build_buffer(builder, np.array([sv], dtype=np.int32).tobytes())) out_idx = 1 + ndim t_out = _build_tensor(builder, out_idx, slice_sizes) tensors = [t_in, *start_tensors, t_out] op_inputs = [0, *start_inputs] op = _build_operator( - builder, 0, op_inputs, [out_idx], + builder, + 0, + op_inputs, + [out_idx], builtin_options2_type=_tfl_builtin_options2.StablehloDynamicSliceOptions, builtin_options2=dyn_opts, ) subgraph = _build_subgraph( - builder, tensors=tensors, operators=[op], - inputs=[0], outputs=[out_idx], + builder, + tensors=tensors, + operators=[op], + inputs=[0], + outputs=[out_idx], ) buffers = [_build_buffer(builder), *start_buffers, _build_buffer(builder)] return _finish_tflite_model( @@ -4877,17 +4877,13 @@ def _build_stablehlo_dynamic_slice_with_dynamic_starts_model(slice_sizes): builder = flatbuffers.Builder(1024) ndim = len(slice_sizes) - _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStartSliceSizesVector( - builder, ndim - ) + _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStartSliceSizesVector(builder, ndim) for v in reversed(slice_sizes): builder.PrependInt64(v) sizes_vec = builder.EndVector() _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStart(builder) - _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsAddSliceSizes( - builder, sizes_vec - ) + _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsAddSliceSizes(builder, sizes_vec) dyn_opts = _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsEnd(builder) builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DYNAMIC_SLICE") @@ -4895,8 +4891,7 @@ def _build_stablehlo_dynamic_slice_with_dynamic_starts_model(slice_sizes): t_in = _build_tensor(builder, 0, [3, 3]) start_tensors = [ - _build_tensor(builder, 1 + i, [], tensor_type=_tfl_tensor_type.INT32) - for i in range(ndim) + _build_tensor(builder, 1 + i, [], tensor_type=_tfl_tensor_type.INT32) for i in range(ndim) ] out_idx = 1 + ndim t_out = _build_tensor(builder, out_idx, slice_sizes) @@ -4905,13 +4900,19 @@ def _build_stablehlo_dynamic_slice_with_dynamic_starts_model(slice_sizes): op_inputs = [0, *start_inputs] op = _build_operator( - builder, 0, op_inputs, [out_idx], + builder, + 0, + op_inputs, + [out_idx], builtin_options2_type=_tfl_builtin_options2.StablehloDynamicSliceOptions, builtin_options2=dyn_opts, ) subgraph = _build_subgraph( - builder, tensors=tensors, operators=[op], - inputs=op_inputs, outputs=[out_idx], + builder, + tensors=tensors, + operators=[op], + inputs=op_inputs, + outputs=[out_idx], ) buffers = [_build_buffer(builder) for _ in range(out_idx + 1)] return _finish_tflite_model( @@ -4922,9 +4923,7 @@ def _build_stablehlo_dynamic_slice_with_dynamic_starts_model(slice_sizes): def test_stablehlo_dynamic_slice(): """TFLite StableHLO DYNAMIC_SLICE: start=[0,1], sizes=[2,2] from (3,3).""" mod = _load_model_from_buffer( - _build_stablehlo_dynamic_slice_model( - slice_sizes=[2, 2], start_vals=[0, 1] - ) + _build_stablehlo_dynamic_slice_model(slice_sizes=[2, 2], start_vals=[0, 1]) ) @I.ir_module @@ -5282,6 +5281,7 @@ def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float3 tvm.ir.assert_structural_equal(mod, Expected) + def test_densify_with_conv2d(): """Test DENSIFY followed by CONV2D - a real-world scenario. @@ -5315,6 +5315,7 @@ def main(x: R.Tensor((1, 4, 4, 1), dtype="float32")) -> R.Tensor( tvm.ir.assert_structural_equal(mod, Expected) + def test_densify_with_fully_connected(): """Test DENSIFY followed by FULLY_CONNECTED - a real-world scenario. @@ -5399,9 +5400,7 @@ def test_dilate(): _build_buffer(builder), _build_buffer(builder), _build_buffer(builder, np.asarray(dilations, dtype=np.int32).tobytes()), - _build_buffer( - builder, np.asarray([dilation_value], dtype=np.float32).tobytes() - ), + _build_buffer(builder, np.asarray([dilation_value], dtype=np.float32).tobytes()), _build_buffer(builder), ] @@ -5434,9 +5433,7 @@ def main( lv4: R.Tensor((5, 4), dtype="float32") = R.strided_slice( lv3, [0, 1], [0, 0], [5, 4], [1, 1], assume_inbound=False ) - lv5: R.Tensor((5, 4, 1), dtype="float32") = R.reshape( - lv4, R.shape([5, 4, 1]) - ) + lv5: R.Tensor((5, 4, 1), dtype="float32") = R.reshape(lv4, R.shape([5, 4, 1])) lv6: R.Tensor((5, 4, 1), dtype="float32") = R.full( R.shape([5, 4, 1]), R.const(0.5, "float32"), dtype="float32" ) @@ -5470,9 +5467,7 @@ def test_dilate_dynamic_dilations(): _build_buffer(builder), _build_buffer(builder), _build_buffer(builder), # dilations is a runtime input so empty buffer - _build_buffer( - builder, np.asarray([dilation_value], dtype=np.float32).tobytes() - ), + _build_buffer(builder, np.asarray([dilation_value], dtype=np.float32).tobytes()), _build_buffer(builder), ] @@ -5513,9 +5508,9 @@ def main( R.const(0.5, "float32"), dtype="float32", ) - lv6: R.Tensor( - (3, 1 + (dilate_stride_0 - 1), 4), dtype="float32" - ) = R.concat((lv4, lv5), axis=1) + lv6: R.Tensor((3, 1 + (dilate_stride_0 - 1), 4), dtype="float32") = R.concat( + (lv4, lv5), axis=1 + ) lv7: R.Tensor((3 * dilate_stride_0, 4), dtype="float32") = R.reshape( lv6, R.shape([3 * dilate_stride_0, 4]) ) @@ -5530,9 +5525,9 @@ def main( [1, 1], assume_inbound=False, ) - lv9: R.Tensor( - (2 * dilate_stride_0 + 1, 4, 1), dtype="float32" - ) = R.reshape(lv8, R.shape([2 * dilate_stride_0 + 1, 4, 1])) + lv9: R.Tensor((2 * dilate_stride_0 + 1, 4, 1), dtype="float32") = R.reshape( + lv8, R.shape([2 * dilate_stride_0 + 1, 4, 1]) + ) lv10: R.Tensor( (2 * dilate_stride_0 + 1, 4, dilate_stride_1 - 1), dtype="float32" ) = R.full( @@ -5544,10 +5539,8 @@ def main( (2 * dilate_stride_0 + 1, 4, 1 + (dilate_stride_1 - 1)), dtype="float32", ) = R.concat((lv9, lv10), axis=2) - lv12: R.Tensor( - (2 * dilate_stride_0 + 1, 4 * dilate_stride_1), dtype="float32" - ) = R.reshape( - lv11, R.shape([2 * dilate_stride_0 + 1, 4 * dilate_stride_1]) + lv12: R.Tensor((2 * dilate_stride_0 + 1, 4 * dilate_stride_1), dtype="float32") = ( + R.reshape(lv11, R.shape([2 * dilate_stride_0 + 1, 4 * dilate_stride_1])) ) gv: R.Tensor( ( diff --git a/tests/python/relax/test_meta_schedule_relax_integration.py b/tests/python/relax/test_meta_schedule_relax_integration.py index c28b8c444bef..13d4496fb140 100644 --- a/tests/python/relax/test_meta_schedule_relax_integration.py +++ b/tests/python/relax/test_meta_schedule_relax_integration.py @@ -25,8 +25,8 @@ import tvm import tvm.testing from tvm import relax -from tvm.runtime import tensor as tvm_tensor from tvm.runtime import cpu as tvm_cpu +from tvm.runtime import tensor as tvm_tensor from tvm.runtime.vm import VirtualMachine from tvm.s_tir import meta_schedule as ms from tvm.script import ir as I diff --git a/tests/python/relax/test_op_nn_convolution.py b/tests/python/relax/test_op_nn_convolution.py index bf0abb09b093..43d5dffab70d 100644 --- a/tests/python/relax/test_op_nn_convolution.py +++ b/tests/python/relax/test_op_nn_convolution.py @@ -1669,9 +1669,7 @@ def test_conv3d_transpose_wrong_output_padding(): bb.normalize(relax.op.nn.conv3d_transpose(x0, w0, strides=2, output_padding=2)) with pytest.raises(TVMError): bb.normalize( - relax.op.nn.conv3d_transpose( - x0, w0, strides=(2, 2, 2), output_padding=(2, 2, 2) - ) + relax.op.nn.conv3d_transpose(x0, w0, strides=(2, 2, 2), output_padding=(2, 2, 2)) ) diff --git a/tests/python/relax/test_op_vision.py b/tests/python/relax/test_op_vision.py index ef260cf18858..167ccdf45a4d 100644 --- a/tests/python/relax/test_op_vision.py +++ b/tests/python/relax/test_op_vision.py @@ -278,9 +278,9 @@ def test_nms_op_correctness(): data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) valid_count = relax.Var("valid_count", R.Tensor((2,), "int32")) indices = relax.Var("indices", R.Tensor((2, 10), "int32")) - assert relax.op.vision.non_max_suppression( - data, valid_count, indices - ).op == Op.get("relax.vision.non_max_suppression") + assert relax.op.vision.non_max_suppression(data, valid_count, indices).op == Op.get( + "relax.vision.non_max_suppression" + ) def test_nms_infer_struct_info_return_indices(): @@ -290,9 +290,7 @@ def test_nms_infer_struct_info_return_indices(): indices = relax.Var("indices", R.Tensor((2, 10), "int32")) _check_inference( bb, - relax.op.vision.non_max_suppression( - data, valid_count, indices, return_indices=True - ), + relax.op.vision.non_max_suppression(data, valid_count, indices, return_indices=True), relax.TupleStructInfo( [ relax.TensorStructInfo((2, 10), "int32"), @@ -329,9 +327,7 @@ def test_nms_infer_struct_info_return_data(): indices = relax.Var("indices", R.Tensor((2, 10), "int32")) _check_inference( bb, - relax.op.vision.non_max_suppression( - data, valid_count, indices, return_indices=False - ), + relax.op.vision.non_max_suppression(data, valid_count, indices, return_indices=False), relax.TensorStructInfo((2, 10, 6), "float32"), ) @@ -346,9 +342,7 @@ def test_nms_infer_struct_info_return_data_shape_var(): indices = relax.Var("indices", R.Tensor((batch_size, num_anchors), "int32")) _check_inference( bb, - relax.op.vision.non_max_suppression( - data, valid_count, indices, return_indices=False - ), + relax.op.vision.non_max_suppression(data, valid_count, indices, return_indices=False), relax.TensorStructInfo((batch_size, num_anchors, elem_length), "float32"), ) @@ -402,9 +396,7 @@ def test_nms_wrong_aux_input_shape(): indices_bad_anchors = relax.Var("indices_bad_anchors", R.Tensor((2, 9), "int32")) with pytest.raises(TVMError): bb.normalize( - relax.op.vision.non_max_suppression( - data, valid_count_bad_batch, indices_bad_anchors - ) + relax.op.vision.non_max_suppression(data, valid_count_bad_batch, indices_bad_anchors) ) with pytest.raises(TVMError): bb.normalize(relax.op.vision.non_max_suppression(data, valid_count, indices_bad_batch)) @@ -1264,6 +1256,8 @@ def main( mod["main"].ret_struct_info, relax.TensorStructInfo((2, 2, 3, 2), "float32"), ) + + def test_all_class_non_max_suppression_infer_struct_info(): bb = relax.BlockBuilder() batch_size, num_classes, num_boxes = 10, 8, 5 diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py b/tests/python/relax/test_tvmscript_parser_op_vision.py index d4755ee367f7..ac5fa78b39ec 100644 --- a/tests/python/relax/test_tvmscript_parser_op_vision.py +++ b/tests/python/relax/test_tvmscript_parser_op_vision.py @@ -96,9 +96,7 @@ def foo( bb = relax.BlockBuilder() with bb.function("foo", [data]): gv = bb.emit( - relax.op.vision.get_valid_counts( - data, score_threshold=0.5, id_index=0, score_index=1 - ) + relax.op.vision.get_valid_counts(data, score_threshold=0.5, id_index=0, score_index=1) ) bb.emit_func_output(gv) diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index cf3e28388eb5..65c76675a0bc 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -103,7 +103,9 @@ def test_extern_func_with_struct_info(): { "my_ext": relax.ExternFunc( "my_ext", - relax.FuncStructInfo([], relax.TensorStructInfo(dtype="float32", ndim=2), purity=True), + relax.FuncStructInfo( + [], relax.TensorStructInfo(dtype="float32", ndim=2), purity=True + ), ), } ) @@ -125,7 +127,9 @@ def test_extern_func_with_struct_info_roundtrip(): { "my_ext": relax.ExternFunc( "my_ext", - relax.FuncStructInfo([], relax.TensorStructInfo(dtype="float32", ndim=2), purity=True), + relax.FuncStructInfo( + [], relax.TensorStructInfo(dtype="float32", ndim=2), purity=True + ), ), } ) diff --git a/tests/python/s_tir/dlight/test_gpu_low_batch_gemv.py b/tests/python/s_tir/dlight/test_gpu_low_batch_gemv.py index c290720327a5..bd43cd3679de 100644 --- a/tests/python/s_tir/dlight/test_gpu_low_batch_gemv.py +++ b/tests/python/s_tir/dlight/test_gpu_low_batch_gemv.py @@ -528,6 +528,7 @@ def expected(B0: T.Buffer((512, 6144), "uint32"), B1: T.Buffer((128, 6144), "flo mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod) # pylint: disable=not-callable tvm.ir.assert_structural_equal(mod["main"], expected) + def test_low_batch_gemv_cuda_target_without_max_shared_memory_per_block(): # fmt: off @T.prim_func(private=True) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py index 558d67cadee6..1306386bde38 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_lower_thread_all_reduce.py @@ -503,7 +503,7 @@ def main(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32 After_script = After.script() assert "tvm_warp_shuffle_down" in After_script assert "tvm_storage_sync" in After_script - assert "\"tirx.volatile\": T.bool(True)" in After_script + assert '"tirx.volatile": T.bool(True)' in After_script assert "T.uint32(" not in After_script diff --git a/tests/python/tirx-base/test_tir_constructor.py b/tests/python/tirx-base/test_tir_constructor.py index 358091f8cdfa..16f85f962505 100644 --- a/tests/python/tirx-base/test_tir_constructor.py +++ b/tests/python/tirx-base/test_tir_constructor.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# ruff: noqa: E711 import pytest