Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/tvm/topi/image/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,9 @@ def _cast_output(value, data_dtype="float32", out_dtype=None):
height_use_int_div = False
width_use_int_div = False
if method == "nearest_neighbor" and coordinate_transformation_mode == "asymmetric":
height_use_int_div = can_convert_multiply_to_intdiv(image_height, target_height)
width_use_int_div = can_convert_multiply_to_intdiv(image_width, target_width)
if rounding_method == "floor" or rounding_method == "":
height_use_int_div = can_convert_multiply_to_intdiv(image_height, target_height)
width_use_int_div = can_convert_multiply_to_intdiv(image_width, target_width)

n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout)
box_idx = box_indices(n) if box_indices is not None else n
Expand Down
55 changes: 41 additions & 14 deletions python/tvm/topi/testing/resize_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,29 +41,45 @@ def get_inx(x, image_width, target_width, coordinate_transformation_mode):
return in_x


def get_index(x, image_width, target_width, coordinate_transformation_mode):
def get_index(x, image_width, target_width, coordinate_transformation_mode, rounding_method=""):
"""get and round the nearest index for nearest_neighbor"""
in_x = get_inx(x, image_width, target_width, coordinate_transformation_mode)
if coordinate_transformation_mode == "align_corners":
# round prefer ceil

effective_rounding_method = rounding_method
if not effective_rounding_method:
if coordinate_transformation_mode == "align_corners":
effective_rounding_method = "round"
else:
effective_rounding_method = "floor"
Comment thread
tqchen marked this conversation as resolved.

if effective_rounding_method == "floor":
out = math.floor(in_x)
elif effective_rounding_method == "round":
out = round(in_x)
elif effective_rounding_method == "round_prefer_floor":
out = math.ceil(in_x - 0.5)
elif effective_rounding_method == "round_prefer_ceil":
out = math.floor(in_x + 0.5)
elif effective_rounding_method == "ceil":
out = math.ceil(in_x)
else:
out = math.floor(in_x)
raise ValueError(f"Unknown rounding method: {rounding_method!r}")

out = max(min(out, image_width - 1), 0)
return out
return int(out)


def resize3d_nearest(arr, scale, coordinate_transformation_mode):
def resize3d_nearest(arr, scale, coordinate_transformation_mode, rounding_method=""):
"""Populate the array by scale factor"""
d, h, w = arr.shape
out_d, out_h, out_w = [round(i * s) for i, s in zip(arr.shape, scale)]
out = np.empty((out_d, out_h, out_w))
for z in range(out_d):
for y in range(out_h):
for x in range(out_w):
in_z = get_index(z, d, out_d, coordinate_transformation_mode)
in_y = get_index(y, h, out_h, coordinate_transformation_mode)
in_x = get_index(x, w, out_w, coordinate_transformation_mode)
in_z = get_index(z, d, out_d, coordinate_transformation_mode, rounding_method)
in_y = get_index(y, h, out_h, coordinate_transformation_mode, rounding_method)
in_x = get_index(x, w, out_w, coordinate_transformation_mode, rounding_method)
out[z, y, x] = arr[in_z, in_y, in_x]
return out

Expand Down Expand Up @@ -170,7 +186,11 @@ def _get_patch(zint, yint, xint):


def resize3d_ncdhw(
data, scale, method="nearest_neighbor", coordinate_transformation_mode="align_corners"
data,
scale,
method="nearest_neighbor",
coordinate_transformation_mode="align_corners",
rounding_method="",
):
"""reference kernel for 3D image resizing"""
ishape = data.shape
Expand All @@ -189,7 +209,7 @@ def resize3d_ncdhw(
for c in range(oshape[1]):
if method == "nearest_neighbor":
output_np[b, c, :, :, :] = resize3d_nearest(
data[b, c, :, :, :], scale, coordinate_transformation_mode
data[b, c, :, :, :], scale, coordinate_transformation_mode, rounding_method
)
elif method == "linear":
output_np[b, c, :, :, :] = resize3d_linear(
Expand All @@ -211,14 +231,17 @@ def resize1d_python(
layout="NCW",
method="nearest_neighbor",
coordinate_transformation_mode="align_corners",
rounding_method="",
):
"""Python version of 3D scaling using nearest neighbour"""

if layout == "NWC":
data = data.transpose([0, 2, 1])

data = np.expand_dims(data, axis=[2, 3])
output_np = resize3d_ncdhw(data, (1, 1) + scale, method, coordinate_transformation_mode)
output_np = resize3d_ncdhw(
data, (1, 1) + scale, method, coordinate_transformation_mode, rounding_method
)
output_np = np.squeeze(output_np, axis=2)
output_np = np.squeeze(output_np, axis=2)

Expand All @@ -234,6 +257,7 @@ def resize2d_python(
layout="NCHW",
method="nearest_neighbor",
coordinate_transformation_mode="align_corners",
rounding_method="",
):
"""Python version of scaling using nearest neighbour"""

Expand All @@ -248,7 +272,9 @@ def resize2d_python(
)

data = np.expand_dims(data, axis=2)
output_np = resize3d_ncdhw(data, (1,) + scale, method, coordinate_transformation_mode)
output_np = resize3d_ncdhw(
data, (1,) + scale, method, coordinate_transformation_mode, rounding_method
)
output_np = np.squeeze(output_np, axis=2)

if layout == "NHWC":
Expand All @@ -266,13 +292,14 @@ def resize3d_python(
layout="NCDHW",
method="nearest_neighbor",
coordinate_transformation_mode="align_corners",
rounding_method="",
):
"""Python version of 3D scaling using nearest neighbour"""

if layout == "NDHWC":
data = data.transpose([0, 4, 1, 2, 3])

output_np = resize3d_ncdhw(data, scale, method, coordinate_transformation_mode)
output_np = resize3d_ncdhw(data, scale, method, coordinate_transformation_mode, rounding_method)

if layout == "NDHWC":
output_np = output_np.transpose([0, 2, 3, 4, 1])
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relax/test_transform_legalize_ops_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def resize2d(rxplaceholder: T.Buffer((T.int64(2), T.int64(8), T.int64(8), T.int6
for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(16), T.int64(16), T.int64(3)):
with T.sblock("resize"):
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder[i0_1, T.max(T.min(T.Div(i1_1, T.int64(2)), T.int64(7)), T.int64(0)), T.max(T.min(T.Div(i2_1, T.int64(2)), T.int64(7)), T.int64(0)), i3_1])
T.reads(rxplaceholder[i0_1, T.int64(0):T.int64(8), T.int64(0):T.int64(8), i3_1])
T.writes(resize[i0_1, i1_1, i2_1, i3_1])
resize[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, T.max(T.min(T.Div(i1_1, T.int64(2)), T.int64(7)), T.int64(0)), T.max(T.min(T.Div(i2_1, T.int64(2)), T.int64(7)), T.int64(0)), i3_1]
resize[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, T.max(T.min(T.Cast("int64", T.round(T.float32(0.5) * T.Cast("float32", i1_1))), T.int64(7)), T.int64(0)), T.max(T.min(T.Cast("int64", T.round(T.float32(0.5) * T.Cast("float32", i2_1))), T.int64(7)), T.int64(0)), i3_1]
# fmt: on

mod = LegalizeOps()(Resize2D)
Expand Down
Loading