Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1812,8 +1812,9 @@ def _index_put(self, node: fx.Node) -> relax.Var:
)
)
# Reshape to [dim_size, 1, 1, ...] for broadcasting
# Add an extra dimension so it broadcasts with other indices
arange_idx = self.block_builder.emit(
relax.op.reshape(arange_idx, [data_shape[i]] + [1] * (max_ndim - 1))
relax.op.reshape(arange_idx, [data_shape[i]] + [1] * max_ndim)
)
processed_indices.append(arange_idx)
else:
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def index_put(
[0.0, 3.0, 0.0],
]
"""
if not isinstance(indices, (list, tuple)):
if isinstance(indices, (list, tuple)):
indices = RxTuple(indices)
return _ffi_api.index_put(data, indices, values, accumulate) # type: ignore

Expand Down
68 changes: 56 additions & 12 deletions python/tvm/topi/index_put.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contrir_builderutor license agreements. See the NOTICE file
# distrir_builderuted with this work for additional information
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
Expand All @@ -9,7 +9,7 @@
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distrir_builderuted under the License is distrir_builderuted on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
Expand All @@ -29,7 +29,8 @@ def index_put(data, indices, values, accumulate=False):
The source array to be modified.

indices : Tuple[tvm.te.Tensor]
Tuple of 1D index tensors (one for each dimension) specifying positions.
Tuple of index tensors (can be multi-dimensional) specifying positions.
Index tensors are broadcast together following NumPy broadcasting rules.

values : tvm.te.Tensor
The values to place at the specified indices.
Expand Down Expand Up @@ -60,11 +61,28 @@ def index_put(data, indices, values, accumulate=False):
for dim in shape:
full_range *= dim

# Check all indices have same length
index_len = len(indices[0])
for idx in indices[1:]:
if not utils.equal_const_int(len(idx), index_len):
raise ValueError("All index tensors must have same length")
index_shapes = [idx.shape for idx in indices]
broadcast_ndim = max(len(s) for s in index_shapes)
broadcast_shape = []

for i in range(broadcast_ndim):
max_dim = 1
for idx_shape in index_shapes:
# Right-align shapes
dim_idx = len(idx_shape) - broadcast_ndim + i
if dim_idx >= 0:
dim_size = idx_shape[dim_idx]
if not utils.equal_const_int(dim_size, 1):
if utils.equal_const_int(max_dim, 1):
max_dim = dim_size
elif not utils.equal_const_int(dim_size, max_dim):
raise ValueError(f"Cannot broadcast index shapes: {index_shapes}")
broadcast_shape.append(max_dim)

# Compute total number of elements after broadcasting
index_len = 1
for dim in broadcast_shape:
index_len *= dim

def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func):
ir_builder = tir.ir_builder.create()
Expand All @@ -78,12 +96,38 @@ def gen_ir(data_ptr, index_ptrs, values_ptr, out_ptr, reduce_func):
out[i] = data[i]

with ir_builder.for_range(0, index_len, "k", kind="parallel") as k:
# Calculate multi-dimensional index
# Decompose k into multi-dimensional broadcast index
k_temp = k
broadcast_indices = []
for i in range(broadcast_ndim - 1, -1, -1):
broadcast_indices.insert(0, k_temp % broadcast_shape[i])
k_temp = k_temp // broadcast_shape[i]

flat_index = 0
stride = 1
for dim in range(len(shape) - 1, -1, -1):
# Get index and shift to positive if needed
idx_val = indices[dim][k]
# Get the index for this dimension using broadcasting
idx_shape = index_shapes[dim]
idx_ndim = len(idx_shape)

# Compute the linear index into this index tensor
idx_offset = 0
idx_stride = 1
for i in range(broadcast_ndim - 1, -1, -1):
# Right-align the index shape with broadcast shape
dim_idx = idx_ndim - broadcast_ndim + i
if dim_idx >= 0:
dim_size = idx_shape[dim_idx]
# Use broadcasting: if size is 1, use index 0
# otherwise use broadcast_indices[i]
if utils.equal_const_int(dim_size, 1):
idx_in_dim = 0
else:
idx_in_dim = broadcast_indices[i]
idx_offset += idx_in_dim * idx_stride
idx_stride *= dim_size

idx_val = indices[dim][idx_offset]
shifted_idx = idx_val + (idx_val < 0) * shape[dim]
flat_index += shifted_idx * stride
stride *= shape[dim]
Expand Down
49 changes: 49 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -7133,6 +7133,54 @@ def main(
R.output(gv)
return gv

# Test case 9: batched indexing with slice (e.g., M[:, rows, cols] = x)
class IndexPutBatchedWithNone(Module):
def forward(self, x):
B = x.size(0)
M = torch.zeros(B, 11, 11)
rows = torch.arange(10)
cols = rows + 1
M[:, rows, cols] = x # Batched index assignment
return M

example_args_batched_none = (torch.randn(2, 10, dtype=torch.float32),)

@I.ir_module
class ExpectedBatchedWithNone:
@R.function
def main(
x: R.Tensor((2, 10), dtype="float32")
) -> R.Tuple(R.Tensor((2, 11, 11), dtype="float32")):
with R.dataflow():
lv: R.Tensor((2, 11, 11), dtype="float32") = R.full(
R.shape([2, 11, 11]), R.const(0.0, "float32"), dtype="float32"
)
lv1: R.Tensor((10,), dtype="int64") = R.arange(
R.prim_value(0), R.prim_value(10), R.prim_value(1), dtype="int64"
)
lv2: R.Tensor((10,), dtype="int64") = R.add(lv1, R.const(1, "int64"))
lv3: R.Tensor((2, 11, 11), dtype="float32") = R.strided_slice(
lv,
(R.prim_value(0),),
(R.prim_value(0),),
(R.prim_value(9223372036854775807),),
(R.prim_value(1),),
assume_inbound=False,
)
lv4: R.Tensor((2,), dtype="int64") = R.arange(
R.prim_value(0), R.prim_value(2), R.prim_value(1), dtype="int64"
)
lv5: R.Tensor((2, 1), dtype="int64") = R.reshape(lv4, R.shape([2, 1]))
lv6: R.Tensor((2, 11, 11), dtype="float32") = R.index_put(
lv3, (lv5, lv1, lv2), x, accumulate=False
)
lv7: R.Tensor((2, 11, 11), dtype="float32") = R.slice_scatter(
lv, lv6, R.prim_value(0), R.prim_value(2), R.prim_value(1), axis=0
)
gv: R.Tuple(R.Tensor((2, 11, 11), dtype="float32")) = (lv7,)
R.output(gv)
return gv

# Run verification for each case
verify_model(IndexPut1D(), example_args_1d, {}, Expected1D)
verify_model(IndexPut2D(), example_args_2d, {}, Expected2D)
Expand All @@ -7142,6 +7190,7 @@ def main(
verify_model(IndexPutBroadcast1D(), example_args_broadcast1, {}, ExpectedBroadcast1D)
verify_model(IndexPutBroadcast2D(), example_args_broadcast2, {}, ExpectedBroadcast2D)
verify_model(IndexPutBroadcast3D(), example_args_broadcast3d, {}, ExpectedBroadcast3D)
verify_model(IndexPutBatchedWithNone(), example_args_batched_none, {}, ExpectedBatchedWithNone)


def test_flip():
Expand Down
Loading