Skip to content

[Relax][PyTorch] Add support for decomposed operators and fix IR of ops tests(3)#18410

Merged
mshr-h merged 3 commits into
apache:mainfrom
tlopex:dcf6
Nov 1, 2025
Merged

[Relax][PyTorch] Add support for decomposed operators and fix IR of ops tests(3)#18410
mshr-h merged 3 commits into
apache:mainfrom
tlopex:dcf6

Conversation

@tlopex

@tlopex tlopex commented Oct 31, 2025

Copy link
Copy Markdown
Member

This pr fixes cumprod, where, bucketize, argsort, topk, narrow, item, norm, eye, cross_entropy, linspace and mm

@gemini-code-assist

Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @tlopex, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the PyTorch frontend for Relax by introducing robust support for decomposed operators and refining the Intermediate Representation (IR) for a range of common operations. It ensures that complex PyTorch operators like argsort, eye, cross_entropy, and linspace are correctly broken down into fundamental Relax primitives, improving the accuracy and robustness of model translation and compilation. Additionally, it extends the compatibility of squeeze and item operators to handle more variations, contributing to a more comprehensive and reliable frontend.

Highlights

  • Decomposed Operator Support: This PR introduces enhanced support for decomposed operators within the PyTorch frontend for Relax, ensuring that complex operations are correctly broken down into simpler primitives for better compilation.
  • Operator IR Fixes: Specific fixes and improved Intermediate Representation (IR) handling have been implemented for a wide range of operators including cumprod, where, bucketize, argsort, topk, narrow, item, norm, eye, cross_entropy, linspace, and mm.
  • Squeeze Operator Compatibility: The _squeeze handler now correctly recognizes and processes both 'dim' and 'dims' parameters, improving compatibility with various PyTorch squeeze invocations.
  • Item Operator Extension: Support for the _local_scalar_dense.default operation has been added, mapping it to the _item handler to broaden the coverage of scalar item extraction.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for decomposed operators from PyTorch's ExportedProgram and updates several operator tests to use this feature. My review focuses on the correctness and efficiency of the newly generated Intermediate Representations (IRs). I've identified a few issues: a critical problem in the test_cross_entropy where the expected IR has incorrect shape annotations, making the test invalid. Additionally, the decomposed IR for test_linspace and test_argsort is inefficient due to redundant computations. I've also included a suggestion to simplify the implementation of the _squeeze operator.

Comment on lines +6223 to 6251
def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((4,), dtype="float32")):
with R.dataflow():
lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x, axis=-1)
lv1: R.Tensor((), dtype="float32") = R.nn.nll_loss(
lv,
targets=R.const([0, 1, 2, 1], dtype="int64"),
reduction="mean",
ignore_index=-100,
lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x, axis=1)
lv1: R.Tensor((4,), dtype="bool") = R.not_equal(
R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64")
)
lv2: R.Tensor((), dtype="int64") = R.const(0, "int64")
lv3: R.Tensor((4,), dtype="int64") = R.where(
lv1, R.const([0, 1, 2, 1], dtype="int64"), lv2
)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,)
lv4: R.Tensor((4, 1), dtype="int64") = R.expand_dims(lv3, axis=[1])
lv5: R.Tensor((4, 1), dtype="float32") = R.gather_elements(lv, lv4, axis=1)
lv6: R.Tensor((4,), dtype="float32") = R.squeeze(lv5, axis=[1])
lv7: R.Tensor((4,), dtype="float32") = R.negative(lv6)
lv8: R.Tensor((4,), dtype="bool") = R.not_equal(
R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64")
)
lv9: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
lv10: R.Tensor((4,), dtype="float32") = R.where(lv8, lv7, lv9)
lv11: R.Tensor((4,), dtype="bool") = R.not_equal(
R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64")
)
lv12: R.Tensor((4,), dtype="bool") = R.sum(lv11, axis=[], keepdims=False)
lv13: R.Tensor((4,), dtype="float32") = R.astype(lv12, dtype="float32")
lv14: R.Tensor((4,), dtype="float32") = R.sum(lv10, axis=[], keepdims=False)
lv15: R.Tensor((4,), dtype="float32") = R.divide(lv14, lv13)
gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv15,)
R.output(gv)
return gv

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The expected IR for test_cross_entropy has incorrect shape annotations. torch.nn.CrossEntropyLoss with the default reduction='mean' should return a scalar tensor. However, the return type of the main function is annotated as R.Tuple(R.Tensor((4,), dtype="float32")).

Looking at the IR, lv12 and lv14 are results of R.sum with axis=[], which should produce scalar tensors (shape ()), but they are annotated with shape (4,). Consequently, lv15 (the final result) is also annotated with shape (4,) instead of (). The function signature and intermediate type annotations should be corrected to reflect that a scalar is being computed.

Comment on lines 6218 to +6281
with R.dataflow():
lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625, 0.125, dtype="float32")
gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,)
lv: R.Tensor((9,), dtype="int64") = R.arange(
R.prim_value(0), R.prim_value(9), R.prim_value(1), dtype="int64"
)
lv1: R.Tensor((9,), dtype="bool") = R.less(lv, R.const(4, "int64"))
lv2: R.Tensor((9,), dtype="float32") = R.astype(lv, dtype="float32")
lv3: R.Tensor((9,), dtype="float32") = R.multiply(lv2, R.const(0.125, "float32"))
lv4: R.Tensor((9,), dtype="float32") = R.add(lv3, R.const(0.0, "float32"))
lv5: R.Tensor((9,), dtype="int64") = R.subtract(R.const(8, "int64"), lv)
lv6: R.Tensor((9,), dtype="float32") = R.astype(lv5, dtype="float32")
lv7: R.Tensor((9,), dtype="float32") = R.multiply(lv6, R.const(0.125, "float32"))
lv8: R.Tensor((9,), dtype="float32") = R.subtract(R.const(1.0, "float32"), lv7)
lv9: R.Tensor((9,), dtype="float32") = R.where(lv1, lv4, lv8)
gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv9,)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The generated IR for linspace is highly inefficient. It computes two expressions, lv4 and lv8, which are mathematically equivalent for the given inputs (i * 0.125 for i in [0, 8]). Then it uses R.where to select between these identical values. The entire where operation and the computation of lv5 through lv8 are redundant. The IR could be simplified to just compute lv4 and use that as the result. This suggests an issue in the PyTorch decomposition logic for linspace that should be investigated.

Comment on lines 1724 to +1727
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
# Support both "dim" and "dims" parameters
if dim is None:
dim = node.kwargs.get("dims", None)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to get dim can be simplified into a single line by chaining dict.get calls. This makes the code more concise and easier to read.

Suggested change
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
# Support both "dim" and "dims" parameters
if dim is None:
dim = node.kwargs.get("dims", None)
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", node.kwargs.get("dims"))

Comment on lines +5893 to +5899
lv1: R.Tensor((5, 3), dtype="float32") = R.gather_elements(x, lv, axis=1)
lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")) = (
lv1,
lv,
)
lv3: R.Tensor((5, 3), dtype="int32") = lv2[1]
gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv3,)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The generated IR for argsort is unnecessarily complex. It computes both sorted values (lv1) and indices (lv), then creates a tuple (lv2), only to extract the indices (lv3). The sorted values are computed via gather_elements but are never used for the final result. This seems to be a result of decomposing torch.argsort into torch.sort and then taking the indices.

While a Dead Code Elimination (DCE) pass might clean this up, it would be more efficient to have a more direct translation for argsort that doesn't compute the sorted values if they are not needed.

@tlopex

tlopex commented Nov 1, 2025

Copy link
Copy Markdown
Member Author

cc @mshr-h

@mshr-h mshr-h merged commit 9249061 into apache:main Nov 1, 2025
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants