Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1722,6 +1722,9 @@ def _split(self, node: fx.Node) -> relax.Var:
def _squeeze(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
# Support both "dim" and "dims" parameters
if dim is None:
dim = node.kwargs.get("dims", None)
Comment on lines 1724 to +1727

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to get dim can be simplified into a single line by chaining dict.get calls. This makes the code more concise and easier to read.

Suggested change
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
# Support both "dim" and "dims" parameters
if dim is None:
dim = node.kwargs.get("dims", None)
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", node.kwargs.get("dims"))

return self.block_builder.emit(relax.op.squeeze(x, dim))

def _stack(self, node: fx.Node) -> relax.Var:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,7 @@ def create_convert_map(
"split_with_sizes.default": self._split,
"squeeze.default": self._squeeze,
"squeeze.dim": self._squeeze,
"squeeze.dims": self._squeeze,
"stack.default": self._stack,
"take.default": self._take,
"tile.default": self._tile,
Expand Down Expand Up @@ -1075,6 +1076,7 @@ def create_convert_map(
# other
"getitem": self._getitem,
"item.default": self._item,
"_local_scalar_dense.default": self._item,
}

def create_input_vars(
Expand Down
122 changes: 92 additions & 30 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -5823,7 +5823,7 @@ def main(
return gv

example_input = torch.randn(5, 3, dtype=torch.float32)
verify_model(Cumprod(), (example_input,), {}, Expected)
verify_model(Cumprod(), (example_input,), {}, Expected, run_ep_decomposition=True)


def test_where():
Expand All @@ -5849,7 +5849,7 @@ def main(
x = torch.randn(5, 3, dtype=torch.float32)
y = torch.randn(5, 3, dtype=torch.float32)

verify_model(Where(), (condition, x, y), {}, Expected)
verify_model(Where(), (condition, x, y), {}, Expected, run_ep_decomposition=True)


def test_bucketize():
Expand All @@ -5874,7 +5874,7 @@ def main(
input_tensor = torch.arange(0, 20)
boundaries = torch.arange(0, 20, 2)

verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected)
verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected, run_ep_decomposition=True)


def test_argsort():
Expand All @@ -5890,12 +5890,18 @@ def main(x: R.Tensor((5, 3), dtype="float32")) -> R.Tuple(R.Tensor((5, 3), dtype
lv: R.Tensor((5, 3), dtype="int32") = R.argsort(
x, axis=1, descending=True, dtype="int32"
)
gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv,)
lv1: R.Tensor((5, 3), dtype="float32") = R.gather_elements(x, lv, axis=1)
lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")) = (
lv1,
lv,
)
lv3: R.Tensor((5, 3), dtype="int32") = lv2[1]
gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv3,)
Comment on lines +5893 to +5899

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The generated IR for argsort is unnecessarily complex. It computes both sorted values (lv1) and indices (lv), then creates a tuple (lv2), only to extract the indices (lv3). The sorted values are computed via gather_elements but are never used for the final result. This seems to be a result of decomposing torch.argsort into torch.sort and then taking the indices.

While a Dead Code Elimination (DCE) pass might clean this up, it would be more efficient to have a more direct translation for argsort that doesn't compute the sorted values if they are not needed.

R.output(gv)
return gv

example_args = (torch.randn(5, 3, dtype=torch.float32),)
verify_model(Argsort(), example_args, {}, Expected)
verify_model(Argsort(), example_args, {}, Expected, run_ep_decomposition=True)


def test_topk():
Expand Down Expand Up @@ -5923,7 +5929,7 @@ def main(
return gv

example_args = (torch.randn(5, 3, dtype=torch.float32),)
verify_model(Topk(), example_args, {}, Expected)
verify_model(Topk(), example_args, {}, Expected, run_ep_decomposition=True)


def test_dynamic_shape():
Expand Down Expand Up @@ -5972,7 +5978,7 @@ def main(
return gv

example_args = (torch.randn(5, 1, dtype=torch.float32),)
verify_model(BroadcastTo(), example_args, {}, Expected)
verify_model(BroadcastTo(), example_args, {}, Expected, run_ep_decomposition=True)


def test_narrow():
Expand All @@ -5992,6 +5998,7 @@ def main(
(R.prim_value(1),),
(R.prim_value(0),),
(R.prim_value(2),),
(R.prim_value(1),),
assume_inbound=False,
)
gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,)
Expand All @@ -6000,7 +6007,7 @@ def main(
return gv

example_args = (torch.randn(5, 3, dtype=torch.float32),)
verify_model(Narrow(), example_args, {}, Expected)
verify_model(Narrow(), example_args, {}, Expected, run_ep_decomposition=True)


def test_item():
Expand All @@ -6019,7 +6026,7 @@ def main(input: R.Tensor((1,), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="
return gv

example_args = (torch.randn(1, dtype=torch.float32),)
verify_model(Item(), example_args, {}, Expected)
verify_model(Item(), example_args, {}, Expected, run_ep_decomposition=True)


def test_norm():
Expand Down Expand Up @@ -6131,7 +6138,9 @@ def main(
example_args = (torch.randn(1, 3, 5, 3, dtype=torch.float32),)

for (p, dim, keepdim), expected in norms:
verify_model(Norm(p, dim=dim, keepdim=keepdim), example_args, {}, expected)
verify_model(
Norm(p, dim=dim, keepdim=keepdim), example_args, {}, expected, run_ep_decomposition=True
)


def test_eye():
Expand All @@ -6146,8 +6155,20 @@ def main(
input: R.Tensor((3, 5), dtype="float32")
) -> R.Tuple(R.Tensor((3, 5), dtype="float32")):
with R.dataflow():
lv: R.Tensor((3, 5), dtype="float32") = R.eye(3, 5, dtype="float32")
gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv,)
lv: R.Tensor((3,), dtype="int64") = R.arange(
R.prim_value(0), R.prim_value(3), R.prim_value(1), dtype="int64"
)
lv1: R.Tensor((5,), dtype="int64") = R.arange(
R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="int64"
)
lv2: R.Tensor((3, 1), dtype="int64") = R.expand_dims(lv, axis=[-1])
lv3: R.Tensor((3, 5), dtype="bool") = R.equal(lv2, lv1)
lv4: R.Tensor((1,), dtype="float32") = R.full(
R.shape([1]), R.const(1.0, "float32"), dtype="float32"
)
lv5: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
lv6: R.Tensor((3, 5), dtype="float32") = R.where(lv3, lv4, lv5)
gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv6,)
R.output(gv)
return gv

Expand All @@ -6162,16 +6183,28 @@ def main(
input: R.Tensor((5,), dtype="float32")
) -> R.Tuple(R.Tensor((5, 5), dtype="float32")):
with R.dataflow():
lv: R.Tensor((5, 5), dtype="float32") = R.eye(5, dtype="float32")
gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,)
lv: R.Tensor((5,), dtype="int64") = R.arange(
R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="int64"
)
lv1: R.Tensor((5,), dtype="int64") = R.arange(
R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="int64"
)
lv2: R.Tensor((5, 1), dtype="int64") = R.expand_dims(lv, axis=[-1])
lv3: R.Tensor((5, 5), dtype="bool") = R.equal(lv2, lv1)
lv4: R.Tensor((1,), dtype="float32") = R.full(
R.shape([1]), R.const(1.0, "float32"), dtype="float32"
)
lv5: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
lv6: R.Tensor((5, 5), dtype="float32") = R.where(lv3, lv4, lv5)
gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv6,)
R.output(gv)
return gv

example_args1 = (torch.randn(3, 5, dtype=torch.float32),)
verify_model(Eye1(), example_args1, {}, Expected1)
verify_model(Eye1(), example_args1, {}, Expected1, run_ep_decomposition=True)

example_args2 = (torch.randn(5, dtype=torch.float32),)
verify_model(Eye2(), example_args2, {}, Expected2)
verify_model(Eye2(), example_args2, {}, Expected2, run_ep_decomposition=True)


def test_cross_entropy():
Expand All @@ -6187,21 +6220,39 @@ def forward(self, x):
@tvm.script.ir_module
class Expected1:
@R.function
def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")):
def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((4,), dtype="float32")):
with R.dataflow():
lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x, axis=-1)
lv1: R.Tensor((), dtype="float32") = R.nn.nll_loss(
lv,
targets=R.const([0, 1, 2, 1], dtype="int64"),
reduction="mean",
ignore_index=-100,
lv: R.Tensor((4, 3), dtype="float32") = R.astype(x, dtype="float32")
lv1: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(lv, axis=1)
lv2: R.Tensor((4,), dtype="bool") = R.not_equal(
R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64")
)
lv3: R.Tensor((), dtype="int64") = R.const(0, "int64")
lv4: R.Tensor((4,), dtype="int64") = R.where(
lv2, R.const([0, 1, 2, 1], dtype="int64"), lv3
)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
lv5: R.Tensor((4, 1), dtype="int64") = R.expand_dims(lv4, axis=[1])
lv6: R.Tensor((4, 1), dtype="float32") = R.gather_elements(lv1, lv5, axis=1)
lv7: R.Tensor((4,), dtype="float32") = R.squeeze(lv6, axis=[1])
lv8: R.Tensor((4,), dtype="float32") = R.negative(lv7)
lv9: R.Tensor((4,), dtype="bool") = R.not_equal(
R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64")
)
lv10: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
lv11: R.Tensor((4,), dtype="float32") = R.where(lv9, lv8, lv10)
lv12: R.Tensor((4,), dtype="bool") = R.not_equal(
R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64")
)
lv13: R.Tensor((4,), dtype="bool") = R.sum(lv12, axis=[], keepdims=False)
lv14: R.Tensor((4,), dtype="float32") = R.astype(lv13, dtype="float32")
lv15: R.Tensor((4,), dtype="float32") = R.sum(lv11, axis=[], keepdims=False)
lv16: R.Tensor((4,), dtype="float32") = R.divide(lv15, lv14)
gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv16,)
R.output(gv)
return gv

example_args1 = (torch.randn(4, 3, dtype=torch.float32),)
verify_model(CrossEntropyModule(), example_args1, {}, Expected1)
verify_model(CrossEntropyModule(), example_args1, {}, Expected1, run_ep_decomposition=True)


def test_linspace():
Expand All @@ -6216,13 +6267,24 @@ def main(
input: R.Tensor((9, 9), dtype="float32")
) -> R.Tuple(R.Tensor((9,), dtype="float32")):
with R.dataflow():
lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625, 0.125, dtype="float32")
gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,)
lv: R.Tensor((9,), dtype="int64") = R.arange(
R.prim_value(0), R.prim_value(9), R.prim_value(1), dtype="int64"
)
lv1: R.Tensor((9,), dtype="bool") = R.less(lv, R.const(4, "int64"))
lv2: R.Tensor((9,), dtype="float32") = R.astype(lv, dtype="float32")
lv3: R.Tensor((9,), dtype="float32") = R.multiply(lv2, R.const(0.125, "float32"))
lv4: R.Tensor((9,), dtype="float32") = R.add(lv3, R.const(0.0, "float32"))
lv5: R.Tensor((9,), dtype="int64") = R.subtract(R.const(8, "int64"), lv)
lv6: R.Tensor((9,), dtype="float32") = R.astype(lv5, dtype="float32")
lv7: R.Tensor((9,), dtype="float32") = R.multiply(lv6, R.const(0.125, "float32"))
lv8: R.Tensor((9,), dtype="float32") = R.subtract(R.const(1.0, "float32"), lv7)
lv9: R.Tensor((9,), dtype="float32") = R.where(lv1, lv4, lv8)
gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv9,)
Comment on lines 6218 to +6282

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The generated IR for linspace is highly inefficient. It computes two expressions, lv4 and lv8, which are mathematically equivalent for the given inputs (i * 0.125 for i in [0, 8]). Then it uses R.where to select between these identical values. The entire where operation and the computation of lv5 through lv8 are redundant. The IR could be simplified to just compute lv4 and use that as the result. This suggests an issue in the PyTorch decomposition logic for linspace that should be investigated.

R.output(gv)
return gv

example_args = (torch.randn(9, 9, dtype=torch.float32),)
verify_model(Linspace(), example_args, {}, Expected)
verify_model(Linspace(), example_args, {}, Expected, run_ep_decomposition=True)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -6259,7 +6321,7 @@ def main(
R.output(gv)
return gv

verify_model(Model(), example_args, {}, Expected)
verify_model(Model(), example_args, {}, Expected, run_ep_decomposition=True)


def test_mm():
Expand All @@ -6285,7 +6347,7 @@ def main(
R.output(gv)
return gv

verify_model(MatrixMultiply(), example_args, {}, Expected)
verify_model(MatrixMultiply(), example_args, {}, Expected, run_ep_decomposition=True)


def test_lstm():
Expand Down
Loading