Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def _interpolate(self, node: fx.Node) -> relax.Var:
# Determine spatial dimension indices based on layout
# NCHW: spatial dims are [2, 3, ...] (skip batch and channel)
# NHWC: spatial dims are [1, 2, ...] (skip batch, before channel)
if self.default_image_layout == "NHWC":
if self.default_image_layout in ("NHWC", "NDHWC"):
spatial_start = 1
spatial_end = len(shape) - 1
else: # NCHW or other layouts
Expand All @@ -547,25 +547,43 @@ def _interpolate(self, node: fx.Node) -> relax.Var:

if method.startswith("nearest"):
method = "nearest_neighbor"
elif method[0:2] == "bi":
elif method.startswith("bi"):
method = method[2:]
elif method.startswith("tri"):
method = method[3:]

if method == "nearest_neighbor":
coord_trans = "asymmetric"
elif align_corners:
elif align_corners is True:
coord_trans = "align_corners"
else:
coord_trans = "half_pixel"

return self.block_builder.emit(
relax.op.image.resize2d(
data,
size,
layout=self.default_image_layout,
method=method,
coordinate_transformation_mode=coord_trans,
if data.struct_info.ndim == 5:
if self.default_image_layout == "NDHWC":
layout_3d = "NDHWC"
else:
layout_3d = "NCDHW"

return self.block_builder.emit(
relax.op.image.resize3d(
data,
size,
layout=layout_3d,
method=method,
coordinate_transformation_mode=coord_trans,
)
)
else:
return self.block_builder.emit(
relax.op.image.resize2d(
data,
size,
layout=self.default_image_layout,
method=method,
coordinate_transformation_mode=coord_trans,
)
)
)

def _linear_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
Expand Down
222 changes: 222 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3671,6 +3671,149 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tensor(

verify_model(Interpolate4(), input_info, {}, expected4)

input_info_5d = [([1, 3, 4, 10, 10], "float32")]
class Interpolate5(Module):
def forward(self, input):
return torch.nn.functional.interpolate(
input,
size=None,
scale_factor=(2.0, 2.0, 2.0),
mode="trilinear",
align_corners=False,
)
@tvm.script.ir_module
class expected5:
@R.function
def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor(
(1, 3, 8, 20, 20), dtype="float32"
):

with R.dataflow():
lv: R.Tensor((1, 3, 8, 20, 20), dtype="float32") = R.image.resize3d(
input_5,
(8, 20, 20),
roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
layout="NCDHW",
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="",
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
)
gv: R.Tensor((1, 3, 8, 20, 20), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Interpolate5(), input_info_5d, {}, expected5)

class Interpolate6(Module):
def forward(self, input):
return torch.nn.functional.interpolate(
input,
size=None,
scale_factor=(2.0,4.0,4.0),
mode="trilinear",
align_corners=False,
)
@tvm.script.ir_module
class expected6:
@R.function
def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor(
(1, 3, 8, 40, 40), dtype="float32"
):

with R.dataflow():
lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = R.image.resize3d(
input_5,
(8, 40, 40),
roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
layout="NCDHW",
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="",
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
)
gv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Interpolate6(), input_info_5d, {}, expected6)

class Interpolate7(Module):
def forward(self, input):
return torch.nn.functional.interpolate(
input,
size=(8,40,40),
mode="trilinear",
align_corners=False,
)
@tvm.script.ir_module
class expected7:
@R.function
def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor(
(1, 3, 8, 40, 40), dtype="float32"
):

with R.dataflow():
lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = R.image.resize3d(
input_5,
(8, 40, 40),
roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
layout="NCDHW",
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="",
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
)
gv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Interpolate7(), input_info_5d, {}, expected7)

class Interpolate8(Module):
def forward(self, input):
return torch.nn.functional.interpolate(
input,
size=(8,40,40),
mode="trilinear",
align_corners=True,
)
@tvm.script.ir_module
class expected8:
@R.function
def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> R.Tensor(
(1, 3, 8, 40, 40), dtype="float32"
):

with R.dataflow():
lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = R.image.resize3d(
input_5,
(8, 40, 40),
roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
layout="NCDHW",
method="linear",
coordinate_transformation_mode="align_corners",
rounding_method="",
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
)
gv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = lv
R.output(gv)
return gv

verify_model(Interpolate8(), input_info_5d, {}, expected8)


def test_interpolate_nhwc_layout():
# First verify backward compatibility - default should still be NCHW
Expand Down Expand Up @@ -3786,6 +3929,85 @@ def main(input_1: R.Tensor((1, 10, 10, 3), dtype="float32")) -> R.Tensor(
mod2 = from_fx(graph_model2, input_info, default_image_layout="NHWC")
tvm.ir.assert_structural_equal(mod2, expected_nhwc2)

input_info_5d = [([1, 4, 10, 10, 3], "float32")]

class InterpolateNHWC3(Module):
def forward(self, input):
return torch.nn.functional.interpolate(
input,
size=None,
scale_factor=(2.0,4.0,4.0),
mode="trilinear",
align_corners=False,
)
@tvm.script.ir_module
class expected_nhwc3:
@R.function
def main(input_5: R.Tensor((1, 4, 10, 10, 3), dtype="float32")) -> R.Tensor(
(1, 8, 40, 40, 3), dtype="float32"
):

with R.dataflow():
lv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") = R.image.resize3d(
input_5,
(8, 40, 40),
roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
layout="NDHWC",
method="linear",
coordinate_transformation_mode="half_pixel",
rounding_method="",
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
)
gv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") = lv
R.output(gv)
return gv

graph_model3 = fx.symbolic_trace(InterpolateNHWC3())
with torch.no_grad():
mod3 = from_fx(graph_model3, input_info_5d, default_image_layout="NDHWC")
tvm.ir.assert_structural_equal(mod3, expected_nhwc3)

class InterpolateNHWC4(Module):
def forward(self, input):
return torch.nn.functional.interpolate(
input,
size=None,
scale_factor=(2.0,4.0,4.0),
mode="trilinear",
align_corners=True,
)
@tvm.script.ir_module
class expected_nhwc4:
@R.function
def main(input_5: R.Tensor((1, 4, 10, 10, 3), dtype="float32")) -> R.Tensor(
(1, 8, 40, 40, 3), dtype="float32"
):

with R.dataflow():
lv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") = R.image.resize3d(
input_5,
(8, 40, 40),
roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
layout="NDHWC",
method="linear",
coordinate_transformation_mode="align_corners",
rounding_method="",
cubic_alpha=-0.75,
cubic_exclude=0,
extrapolation_value=0,
out_dtype="",
)
gv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") = lv
R.output(gv)
return gv

graph_model4 = fx.symbolic_trace(InterpolateNHWC4())
with torch.no_grad():
mod4 = from_fx(graph_model4, input_info_5d, default_image_layout="NDHWC")
tvm.ir.assert_structural_equal(mod4, expected_nhwc4)

def test_addmm():
input_info = [
Expand Down