diff --git a/python/tvm/topi/cuda/scatter.py b/python/tvm/topi/cuda/scatter.py index cee13d7e01a2..c697b648786e 100644 --- a/python/tvm/topi/cuda/scatter.py +++ b/python/tvm/topi/cuda/scatter.py @@ -801,9 +801,11 @@ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr): bdim = ceil_div(fused_updates_dimension, tdim) ib.scope_attr(bx, "thread_extent", bdim) - with ib.for_range(0, ceil_div(fused_shape, bdim)) as i: + # Copy data into the output. This loop writes to the same portions of + # memory as the following loop, so we do not need a memory sync. + with ib.for_range(0, ceil_div(fused_shape, fused_updates_dimension), name="i") as i: index = i * fused_updates_dimension + bx * tdim + tx - with ib.if_scope(index < fused_shape): + with ib.if_scope(bx * tdim + tx < fused_updates_dimension): out[index] = data[index] with ib.for_range(0, fused_indices_dimension) as i: