From 7c3d48bbe794426f550ec61d1635ef9fb4e4219c Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Mon, 1 Dec 2025 19:11:56 +0800 Subject: [PATCH] Fix index_put with broadcast indices --- .../torch/base_fx_graph_translator.py | 3 +- python/tvm/relax/op/manipulate.py | 2 +- python/tvm/topi/index_put.py | 68 +++++++++++++++---- .../test_frontend_from_exported_program.py | 49 +++++++++++++ 4 files changed, 108 insertions(+), 14 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 33a22b34fcc0..692a553f0907 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1812,8 +1812,9 @@ def _index_put(self, node: fx.Node) -> relax.Var: ) ) # Reshape to [dim_size, 1, 1, ...] for broadcasting + # Add an extra dimension so it broadcasts with other indices arange_idx = self.block_builder.emit( - relax.op.reshape(arange_idx, [data_shape[i]] + [1] * (max_ndim - 1)) + relax.op.reshape(arange_idx, [data_shape[i]] + [1] * max_ndim) ) processed_indices.append(arange_idx) else: diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index bb134f114855..ee486b0ab69c 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -642,7 +642,7 @@ def index_put( [0.0, 3.0, 0.0], ] """ - if not isinstance(indices, (list, tuple)): + if isinstance(indices, (list, tuple)): indices = RxTuple(indices) return _ffi_api.index_put(data, indices, values, accumulate) # type: ignore diff --git a/python/tvm/topi/index_put.py b/python/tvm/topi/index_put.py index f51c6718ab99..52406d402cdd 100644 --- a/python/tvm/topi/index_put.py +++ b/python/tvm/topi/index_put.py @@ -1,6 +1,6 @@ # Licensed to the Apache Software Foundation (ASF) under one -# or more contrir_builderutor license agreements. See the NOTICE file -# distrir_builderuted with this work for additional information +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance @@ -9,7 +9,7 @@ # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, -# software distrir_builderuted under the License is distrir_builderuted on an +# software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations @@ -29,7 +29,8 @@ def index_put(data, indices, values, accumulate=False): The source array to be modified. indices : Tuple[tvm.te.Tensor] - Tuple of 1D index tensors (one for each dimension) specifying positions. + Tuple of index tensors (can be multi-dimensional) specifying positions. + Index tensors are broadcast together following NumPy broadcasting rules. values : tvm.te.Tensor The values to place at the specified indices. @@ -60,11 +61,28 @@ def index_put(data, indices, values, accumulate=False): for dim in shape: full_range *= dim - # Check all indices have same length - index_len = len(indices[0]) - for idx in indices[1:]: - if not utils.equal_const_int(len(idx), index_len): - raise ValueError("All index tensors must have same length") + index_shapes = [idx.shape for idx in indices] + broadcast_ndim = max(len(s) for s in index_shapes) + broadcast_shape = [] + + for i in range(broadcast_ndim): + max_dim = 1 + for idx_shape in index_shapes: + # Right-align shapes + dim_idx = len(idx_shape) - broadcast_ndim + i + if dim_idx >= 0: + dim_size = idx_shape[dim_idx] + if not utils.equal_const_int(dim_size, 1): + if utils.equal_const_int(max_dim, 1): + max_dim = dim_size + elif not utils.equal_const_int(dim_size, max_dim): + raise ValueError(f"Cannot broadcast index shapes: {index_shapes}") + broadcast_shape.append(max_dim) + + # Compute total number of elements after broadcasting + index_len = 1 + for dim in broadcast_shape: + index_len *= dim def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func): ir_builder = tir.ir_builder.create() @@ -78,12 +96,38 @@ def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func): out[i] = data[i] with ir_builder.for_range(0, index_len, "k", kind="parallel") as k: - # Calculate multi-dimensional index + # Decompose k into multi-dimensional broadcast index + k_temp = k + broadcast_indices = [] + for i in range(broadcast_ndim - 1, -1, -1): + broadcast_indices.insert(0, k_temp % broadcast_shape[i]) + k_temp = k_temp // broadcast_shape[i] + flat_index = 0 stride = 1 for dim in range(len(shape) - 1, -1, -1): - # Get index and shift to positive if needed - idx_val = indices[dim][k] + # Get the index for this dimension using broadcasting + idx_shape = index_shapes[dim] + idx_ndim = len(idx_shape) + + # Compute the linear index into this index tensor + idx_offset = 0 + idx_stride = 1 + for i in range(broadcast_ndim - 1, -1, -1): + # Right-align the index shape with broadcast shape + dim_idx = idx_ndim - broadcast_ndim + i + if dim_idx >= 0: + dim_size = idx_shape[dim_idx] + # Use broadcasting: if size is 1, use index 0 + # otherwise use broadcast_indices[i] + if utils.equal_const_int(dim_size, 1): + idx_in_dim = 0 + else: + idx_in_dim = broadcast_indices[i] + idx_offset += idx_in_dim * idx_stride + idx_stride *= dim_size + + idx_val = indices[dim][idx_offset] shifted_idx = idx_val + (idx_val < 0) * shape[dim] flat_index += shifted_idx * stride stride *= shape[dim] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 7397b3f21aef..ebaeb4faa1be 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -7133,6 +7133,54 @@ def main( R.output(gv) return gv + # Test case 9: batched indexing with slice (e.g., M[:, rows, cols] = x) + class IndexPutBatchedWithNone(Module): + def forward(self, x): + B = x.size(0) + M = torch.zeros(B, 11, 11) + rows = torch.arange(10) + cols = rows + 1 + M[:, rows, cols] = x # Batched index assignment + return M + + example_args_batched_none = (torch.randn(2, 10, dtype=torch.float32),) + + @I.ir_module + class ExpectedBatchedWithNone: + @R.function + def main( + x: R.Tensor((2, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((2, 11, 11), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 11, 11), dtype="float32") = R.full( + R.shape([2, 11, 11]), R.const(0.0, "float32"), dtype="float32" + ) + lv1: R.Tensor((10,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64" + ) + lv2: R.Tensor((10,), dtype="int64") = R.add(lv1, R.const(1, "int64")) + lv3: R.Tensor((2, 11, 11), dtype="float32") = R.strided_slice( + lv, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(9223372036854775807),), + (R.prim_value(1),), + assume_inbound=False, + ) + lv4: R.Tensor((2,), dtype="int64") = R.arange( + R.prim_value(0), R.prim_value(2), R.prim_value(1), dtype="int64" + ) + lv5: R.Tensor((2, 1), dtype="int64") = R.reshape(lv4, R.shape([2, 1])) + lv6: R.Tensor((2, 11, 11), dtype="float32") = R.index_put( + lv3, (lv5, lv1, lv2), x, accumulate=False + ) + lv7: R.Tensor((2, 11, 11), dtype="float32") = R.slice_scatter( + lv, lv6, R.prim_value(0), R.prim_value(2), R.prim_value(1), axis=0 + ) + gv: R.Tuple(R.Tensor((2, 11, 11), dtype="float32")) = (lv7,) + R.output(gv) + return gv + # Run verification for each case verify_model(IndexPut1D(), example_args_1d, {}, Expected1D) verify_model(IndexPut2D(), example_args_2d, {}, Expected2D) @@ -7142,6 +7190,7 @@ def main( verify_model(IndexPutBroadcast1D(), example_args_broadcast1, {}, ExpectedBroadcast1D) verify_model(IndexPutBroadcast2D(), example_args_broadcast2, {}, ExpectedBroadcast2D) verify_model(IndexPutBroadcast3D(), example_args_broadcast3d, {}, ExpectedBroadcast3D) + verify_model(IndexPutBatchedWithNone(), example_args_batched_none, {}, ExpectedBatchedWithNone) def test_flip():