[Relax][PyTorch] Add support for decomposed operators and fix IR of ops tests(3)#18410
Conversation
Summary of ChangesHello @tlopex, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the PyTorch frontend for Relax by introducing robust support for decomposed operators and refining the Intermediate Representation (IR) for a range of common operations. It ensures that complex PyTorch operators like Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for decomposed operators from PyTorch's ExportedProgram and updates several operator tests to use this feature. My review focuses on the correctness and efficiency of the newly generated Intermediate Representations (IRs). I've identified a few issues: a critical problem in the test_cross_entropy where the expected IR has incorrect shape annotations, making the test invalid. Additionally, the decomposed IR for test_linspace and test_argsort is inefficient due to redundant computations. I've also included a suggestion to simplify the implementation of the _squeeze operator.
| def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((4,), dtype="float32")): | ||
| with R.dataflow(): | ||
| lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x, axis=-1) | ||
| lv1: R.Tensor((), dtype="float32") = R.nn.nll_loss( | ||
| lv, | ||
| targets=R.const([0, 1, 2, 1], dtype="int64"), | ||
| reduction="mean", | ||
| ignore_index=-100, | ||
| lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x, axis=1) | ||
| lv1: R.Tensor((4,), dtype="bool") = R.not_equal( | ||
| R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64") | ||
| ) | ||
| lv2: R.Tensor((), dtype="int64") = R.const(0, "int64") | ||
| lv3: R.Tensor((4,), dtype="int64") = R.where( | ||
| lv1, R.const([0, 1, 2, 1], dtype="int64"), lv2 | ||
| ) | ||
| gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,) | ||
| lv4: R.Tensor((4, 1), dtype="int64") = R.expand_dims(lv3, axis=[1]) | ||
| lv5: R.Tensor((4, 1), dtype="float32") = R.gather_elements(lv, lv4, axis=1) | ||
| lv6: R.Tensor((4,), dtype="float32") = R.squeeze(lv5, axis=[1]) | ||
| lv7: R.Tensor((4,), dtype="float32") = R.negative(lv6) | ||
| lv8: R.Tensor((4,), dtype="bool") = R.not_equal( | ||
| R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64") | ||
| ) | ||
| lv9: R.Tensor((), dtype="float32") = R.const(0.0, "float32") | ||
| lv10: R.Tensor((4,), dtype="float32") = R.where(lv8, lv7, lv9) | ||
| lv11: R.Tensor((4,), dtype="bool") = R.not_equal( | ||
| R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64") | ||
| ) | ||
| lv12: R.Tensor((4,), dtype="bool") = R.sum(lv11, axis=[], keepdims=False) | ||
| lv13: R.Tensor((4,), dtype="float32") = R.astype(lv12, dtype="float32") | ||
| lv14: R.Tensor((4,), dtype="float32") = R.sum(lv10, axis=[], keepdims=False) | ||
| lv15: R.Tensor((4,), dtype="float32") = R.divide(lv14, lv13) | ||
| gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv15,) | ||
| R.output(gv) | ||
| return gv |
There was a problem hiding this comment.
The expected IR for test_cross_entropy has incorrect shape annotations. torch.nn.CrossEntropyLoss with the default reduction='mean' should return a scalar tensor. However, the return type of the main function is annotated as R.Tuple(R.Tensor((4,), dtype="float32")).
Looking at the IR, lv12 and lv14 are results of R.sum with axis=[], which should produce scalar tensors (shape ()), but they are annotated with shape (4,). Consequently, lv15 (the final result) is also annotated with shape (4,) instead of (). The function signature and intermediate type annotations should be corrected to reflect that a scalar is being computed.
| with R.dataflow(): | ||
| lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625, 0.125, dtype="float32") | ||
| gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,) | ||
| lv: R.Tensor((9,), dtype="int64") = R.arange( | ||
| R.prim_value(0), R.prim_value(9), R.prim_value(1), dtype="int64" | ||
| ) | ||
| lv1: R.Tensor((9,), dtype="bool") = R.less(lv, R.const(4, "int64")) | ||
| lv2: R.Tensor((9,), dtype="float32") = R.astype(lv, dtype="float32") | ||
| lv3: R.Tensor((9,), dtype="float32") = R.multiply(lv2, R.const(0.125, "float32")) | ||
| lv4: R.Tensor((9,), dtype="float32") = R.add(lv3, R.const(0.0, "float32")) | ||
| lv5: R.Tensor((9,), dtype="int64") = R.subtract(R.const(8, "int64"), lv) | ||
| lv6: R.Tensor((9,), dtype="float32") = R.astype(lv5, dtype="float32") | ||
| lv7: R.Tensor((9,), dtype="float32") = R.multiply(lv6, R.const(0.125, "float32")) | ||
| lv8: R.Tensor((9,), dtype="float32") = R.subtract(R.const(1.0, "float32"), lv7) | ||
| lv9: R.Tensor((9,), dtype="float32") = R.where(lv1, lv4, lv8) | ||
| gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv9,) |
There was a problem hiding this comment.
The generated IR for linspace is highly inefficient. It computes two expressions, lv4 and lv8, which are mathematically equivalent for the given inputs (i * 0.125 for i in [0, 8]). Then it uses R.where to select between these identical values. The entire where operation and the computation of lv5 through lv8 are redundant. The IR could be simplified to just compute lv4 and use that as the result. This suggests an issue in the PyTorch decomposition logic for linspace that should be investigated.
| dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) | ||
| # Support both "dim" and "dims" parameters | ||
| if dim is None: | ||
| dim = node.kwargs.get("dims", None) |
There was a problem hiding this comment.
The logic to get dim can be simplified into a single line by chaining dict.get calls. This makes the code more concise and easier to read.
| dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) | |
| # Support both "dim" and "dims" parameters | |
| if dim is None: | |
| dim = node.kwargs.get("dims", None) | |
| dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", node.kwargs.get("dims")) |
| lv1: R.Tensor((5, 3), dtype="float32") = R.gather_elements(x, lv, axis=1) | ||
| lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")) = ( | ||
| lv1, | ||
| lv, | ||
| ) | ||
| lv3: R.Tensor((5, 3), dtype="int32") = lv2[1] | ||
| gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv3,) |
There was a problem hiding this comment.
The generated IR for argsort is unnecessarily complex. It computes both sorted values (lv1) and indices (lv), then creates a tuple (lv2), only to extract the indices (lv3). The sorted values are computed via gather_elements but are never used for the final result. This seems to be a result of decomposing torch.argsort into torch.sort and then taking the indices.
While a Dead Code Elimination (DCE) pass might clean this up, it would be more efficient to have a more direct translation for argsort that doesn't compute the sorted values if they are not needed.
|
cc @mshr-h |
This pr fixes
cumprod,where,bucketize,argsort,topk,narrow,item,norm,eye,cross_entropy,linspaceandmm