Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/tvm/topi/image/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,9 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
height_use_int_div = False
width_use_int_div = False
if method == "nearest_neighbor" and coordinate_transformation_mode == "asymmetric":
height_use_int_div = can_convert_multiply_to_intdiv(image_height, target_height)
width_use_int_div = can_convert_multiply_to_intdiv(image_width, target_width)
if rounding_method == "floor" or rounding_method == "":
height_use_int_div = can_convert_multiply_to_intdiv(image_height, target_height)
width_use_int_div = can_convert_multiply_to_intdiv(image_width, target_width)

n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout)
box_idx = box_indices(n) if box_indices is not None else n
Expand Down
45 changes: 33 additions & 12 deletions python/tvm/topi/testing/resize_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,29 +41,39 @@ def get_inx(x, image_width, target_width, coordinate_transformation_mode):
return in_x


def get_index(x, image_width, target_width, coordinate_transformation_mode):
def get_index(x, image_width, target_width, coordinate_transformation_mode, rounding_method=""):
"""get and round the nearest index for nearest_neighbor"""
in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode)
if coordinate_transformation_mode == "align_corners":
# round prefer ceil
if rounding_method == "" or rounding_method == "floor":
if coordinate_transformation_mode == "align_corners":
out = math.floor(in_x + 0.5)
else:
out = math.floor(in_x)
elif rounding_method == "round":
out = math.floor(in_x + 0.5)
elif rounding_method == "round_prefer_floor":
out = math.ceil(in_x - 0.5)
elif rounding_method == "round_prefer_ceil":
out = math.floor(in_x + 0.5)
elif rounding_method == "ceil":
out = math.ceil(in_x)
else:
out = math.floor(in_x)
out = max(min(out, image_width - 1), 0)
return out
Comment on lines +44 to 63

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of get_index has a few issues that could lead to incorrect test results or mask bugs:

  1. Incorrect Rounding for "round" method: The logic for rounding_method="round" uses math.floor(in_x + 0.5), which is "round half up". The TVM operator te.round implements "round half to even". The reference implementation should use Python's built-in round() to match this behavior. This also affects the default behavior for coordinate_transformation_mode="align_corners".
  2. Incorrect logic for "floor" method: When rounding_method="floor" and coordinate_transformation_mode="align_corners", the current code applies "round half up" instead of flooring. The TVM operator does not alter the rounding method if it's explicitly provided.
  3. Inconsistent Error Handling: Unknown rounding methods silently fall back to floor. This differs from the main operator, which raises a ValueError. The test helper should also raise an error to detect invalid inputs.
  4. Return Type: The function returns a float for an index, which is less precise. It should return an int.

I suggest refactoring the function to align with the TVM operator's logic, which will fix these issues and improve clarity.

def get_index(x, image_width, target_width, coordinate_transformation_mode, rounding_method=""):
    """get and round the nearest index for nearest_neighbor"""
    in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode)

    effective_rounding_method = rounding_method
    if not effective_rounding_method:
        if coordinate_transformation_mode == "align_corners":
            effective_rounding_method = "round"
        else:
            effective_rounding_method = "floor"

    if effective_rounding_method == "floor":
        out = math.floor(in_x)
    elif effective_rounding_method == "round":
        # Use python's built-in round to match te.round (round half to even)
        out = round(in_x)
    elif effective_rounding_method == "round_prefer_floor":
        out = math.ceil(in_x - 0.5)
    elif effective_rounding_method == "round_prefer_ceil":
        out = math.floor(in_x + 0.5)
    elif effective_rounding_method == "ceil":
        out = math.ceil(in_x)
    else:
        raise ValueError(f"Unknown rounding method: {rounding_method!r}")

    out = max(min(out, image_width - 1), 0)
    return int(out)



def resize3d_nearest(arr, scale, coordinate_transformation_mode):
def resize3d_nearest(arr, scale, coordinate_transformation_mode, rounding_method=""):
"""Populate the array by scale factor"""
d, h, w = arr.shape
out_d, out_h, out_w = [round(i * s) for i, s in zip(arr.shape, scale)]
out = np.empty((out_d, out_h, out_w))
for z in range(out_d):
for y in range(out_h):
for x in range(out_w):
in_z = get_index(z, d, out_d, coordinate_transformation_mode)
in_y = get_index(y, h, out_h, coordinate_transformation_mode)
in_x = get_index(x, w, out_w, coordinate_transformation_mode)
in_z = get_index(z, d, out_d, coordinate_transformation_mode, rounding_method)
in_y = get_index(y, h, out_h, coordinate_transformation_mode, rounding_method)
in_x = get_index(x, w, out_w, coordinate_transformation_mode, rounding_method)
out[z, y, x] = arr[in_z, in_y, in_x]
return out

Expand Down Expand Up @@ -170,7 +180,11 @@ def _get_patch(zint, yint, xint):


def resize3d_ncdhw(
data, scale, method="nearest_neighbor", coordinate_transformation_mode="align_corners"
data,
scale,
method="nearest_neighbor",
coordinate_transformation_mode="align_corners",
rounding_method="",
):
"""reference kernel for 3D image resizing"""
ishape = data.shape
Expand All @@ -189,7 +203,7 @@ def resize3d_ncdhw(
for c in range(oshape[1]):
if method == "nearest_neighbor":
output_np[b, c, :, :, :] = resize3d_nearest(
data[b, c, :, :, :], scale, coordinate_transformation_mode
data[b, c, :, :, :], scale, coordinate_transformation_mode, rounding_method
)
elif method == "linear":
output_np[b, c, :, :, :] = resize3d_linear(
Expand All @@ -211,14 +225,17 @@ def resize1d_python(
layout="NCW",
method="nearest_neighbor",
coordinate_transformation_mode="align_corners",
rounding_method="",
):
"""Python version of 3D scaling using nearest neighbour"""

if layout == "NWC":
data = data.transpose([0, 2, 1])

data = np.expand_dims(data, axis=[2, 3])
output_np = resize3d_ncdhw(data, (1, 1) + scale, method, coordinate_transformation_mode)
output_np = resize3d_ncdhw(
data, (1, 1) + scale, method, coordinate_transformation_mode, rounding_method
)
output_np = np.squeeze(output_np, axis=2)
output_np = np.squeeze(output_np, axis=2)

Expand All @@ -234,6 +251,7 @@ def resize2d_python(
layout="NCHW",
method="nearest_neighbor",
coordinate_transformation_mode="align_corners",
rounding_method="",
):
"""Python version of scaling using nearest neighbour"""

Expand All @@ -248,7 +266,9 @@ def resize2d_python(
)

data = np.expand_dims(data, axis=2)
output_np = resize3d_ncdhw(data, (1,) + scale, method, coordinate_transformation_mode)
output_np = resize3d_ncdhw(
data, (1,) + scale, method, coordinate_transformation_mode, rounding_method
)
output_np = np.squeeze(output_np, axis=2)

if layout == "NHWC":
Expand All @@ -266,13 +286,14 @@ def resize3d_python(
layout="NCDHW",
method="nearest_neighbor",
coordinate_transformation_mode="align_corners",
rounding_method="",
):
"""Python version of 3D scaling using nearest neighbour"""

if layout == "NDHWC":
data = data.transpose([0, 4, 1, 2, 3])

output_np = resize3d_ncdhw(data, scale, method, coordinate_transformation_mode)
output_np = resize3d_ncdhw(data, scale, method, coordinate_transformation_mode, rounding_method)

if layout == "NDHWC":
output_np = output_np.transpose([0, 2, 3, 4, 1])
Expand Down
Loading