-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[Relax][PyTorch] Add support for decomposed operators and fix IR of ops tests(3) #18410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5823,7 +5823,7 @@ def main( | |
| return gv | ||
|
|
||
| example_input = torch.randn(5, 3, dtype=torch.float32) | ||
| verify_model(Cumprod(), (example_input,), {}, Expected) | ||
| verify_model(Cumprod(), (example_input,), {}, Expected, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_where(): | ||
|
|
@@ -5849,7 +5849,7 @@ def main( | |
| x = torch.randn(5, 3, dtype=torch.float32) | ||
| y = torch.randn(5, 3, dtype=torch.float32) | ||
|
|
||
| verify_model(Where(), (condition, x, y), {}, Expected) | ||
| verify_model(Where(), (condition, x, y), {}, Expected, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_bucketize(): | ||
|
|
@@ -5874,7 +5874,7 @@ def main( | |
| input_tensor = torch.arange(0, 20) | ||
| boundaries = torch.arange(0, 20, 2) | ||
|
|
||
| verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected) | ||
| verify_model(Bucketize(), (input_tensor, boundaries), {}, Expected, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_argsort(): | ||
|
|
@@ -5890,12 +5890,18 @@ def main(x: R.Tensor((5, 3), dtype="float32")) -> R.Tuple(R.Tensor((5, 3), dtype | |
| lv: R.Tensor((5, 3), dtype="int32") = R.argsort( | ||
| x, axis=1, descending=True, dtype="int32" | ||
| ) | ||
| gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv,) | ||
| lv1: R.Tensor((5, 3), dtype="float32") = R.gather_elements(x, lv, axis=1) | ||
| lv2: R.Tuple(R.Tensor((5, 3), dtype="float32"), R.Tensor((5, 3), dtype="int32")) = ( | ||
| lv1, | ||
| lv, | ||
| ) | ||
| lv3: R.Tensor((5, 3), dtype="int32") = lv2[1] | ||
| gv: R.Tuple(R.Tensor((5, 3), dtype="int32")) = (lv3,) | ||
|
Comment on lines
+5893
to
+5899
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The generated IR for While a Dead Code Elimination (DCE) pass might clean this up, it would be more efficient to have a more direct translation for |
||
| R.output(gv) | ||
| return gv | ||
|
|
||
| example_args = (torch.randn(5, 3, dtype=torch.float32),) | ||
| verify_model(Argsort(), example_args, {}, Expected) | ||
| verify_model(Argsort(), example_args, {}, Expected, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_topk(): | ||
|
|
@@ -5923,7 +5929,7 @@ def main( | |
| return gv | ||
|
|
||
| example_args = (torch.randn(5, 3, dtype=torch.float32),) | ||
| verify_model(Topk(), example_args, {}, Expected) | ||
| verify_model(Topk(), example_args, {}, Expected, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_dynamic_shape(): | ||
|
|
@@ -5972,7 +5978,7 @@ def main( | |
| return gv | ||
|
|
||
| example_args = (torch.randn(5, 1, dtype=torch.float32),) | ||
| verify_model(BroadcastTo(), example_args, {}, Expected) | ||
| verify_model(BroadcastTo(), example_args, {}, Expected, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_narrow(): | ||
|
|
@@ -5992,6 +5998,7 @@ def main( | |
| (R.prim_value(1),), | ||
| (R.prim_value(0),), | ||
| (R.prim_value(2),), | ||
| (R.prim_value(1),), | ||
| assume_inbound=False, | ||
| ) | ||
| gv: R.Tuple(R.Tensor((5, 2), dtype="float32")) = (lv,) | ||
|
|
@@ -6000,7 +6007,7 @@ def main( | |
| return gv | ||
|
|
||
| example_args = (torch.randn(5, 3, dtype=torch.float32),) | ||
| verify_model(Narrow(), example_args, {}, Expected) | ||
| verify_model(Narrow(), example_args, {}, Expected, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_item(): | ||
|
|
@@ -6019,7 +6026,7 @@ def main(input: R.Tensor((1,), dtype="float32")) -> R.Tuple(R.Tensor((), dtype=" | |
| return gv | ||
|
|
||
| example_args = (torch.randn(1, dtype=torch.float32),) | ||
| verify_model(Item(), example_args, {}, Expected) | ||
| verify_model(Item(), example_args, {}, Expected, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_norm(): | ||
|
|
@@ -6131,7 +6138,9 @@ def main( | |
| example_args = (torch.randn(1, 3, 5, 3, dtype=torch.float32),) | ||
|
|
||
| for (p, dim, keepdim), expected in norms: | ||
| verify_model(Norm(p, dim=dim, keepdim=keepdim), example_args, {}, expected) | ||
| verify_model( | ||
| Norm(p, dim=dim, keepdim=keepdim), example_args, {}, expected, run_ep_decomposition=True | ||
| ) | ||
|
|
||
|
|
||
| def test_eye(): | ||
|
|
@@ -6146,8 +6155,20 @@ def main( | |
| input: R.Tensor((3, 5), dtype="float32") | ||
| ) -> R.Tuple(R.Tensor((3, 5), dtype="float32")): | ||
| with R.dataflow(): | ||
| lv: R.Tensor((3, 5), dtype="float32") = R.eye(3, 5, dtype="float32") | ||
| gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv,) | ||
| lv: R.Tensor((3,), dtype="int64") = R.arange( | ||
| R.prim_value(0), R.prim_value(3), R.prim_value(1), dtype="int64" | ||
| ) | ||
| lv1: R.Tensor((5,), dtype="int64") = R.arange( | ||
| R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="int64" | ||
| ) | ||
| lv2: R.Tensor((3, 1), dtype="int64") = R.expand_dims(lv, axis=[-1]) | ||
| lv3: R.Tensor((3, 5), dtype="bool") = R.equal(lv2, lv1) | ||
| lv4: R.Tensor((1,), dtype="float32") = R.full( | ||
| R.shape([1]), R.const(1.0, "float32"), dtype="float32" | ||
| ) | ||
| lv5: R.Tensor((), dtype="float32") = R.const(0.0, "float32") | ||
| lv6: R.Tensor((3, 5), dtype="float32") = R.where(lv3, lv4, lv5) | ||
| gv: R.Tuple(R.Tensor((3, 5), dtype="float32")) = (lv6,) | ||
| R.output(gv) | ||
| return gv | ||
|
|
||
|
|
@@ -6162,16 +6183,28 @@ def main( | |
| input: R.Tensor((5,), dtype="float32") | ||
| ) -> R.Tuple(R.Tensor((5, 5), dtype="float32")): | ||
| with R.dataflow(): | ||
| lv: R.Tensor((5, 5), dtype="float32") = R.eye(5, dtype="float32") | ||
| gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv,) | ||
| lv: R.Tensor((5,), dtype="int64") = R.arange( | ||
| R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="int64" | ||
| ) | ||
| lv1: R.Tensor((5,), dtype="int64") = R.arange( | ||
| R.prim_value(0), R.prim_value(5), R.prim_value(1), dtype="int64" | ||
| ) | ||
| lv2: R.Tensor((5, 1), dtype="int64") = R.expand_dims(lv, axis=[-1]) | ||
| lv3: R.Tensor((5, 5), dtype="bool") = R.equal(lv2, lv1) | ||
| lv4: R.Tensor((1,), dtype="float32") = R.full( | ||
| R.shape([1]), R.const(1.0, "float32"), dtype="float32" | ||
| ) | ||
| lv5: R.Tensor((), dtype="float32") = R.const(0.0, "float32") | ||
| lv6: R.Tensor((5, 5), dtype="float32") = R.where(lv3, lv4, lv5) | ||
| gv: R.Tuple(R.Tensor((5, 5), dtype="float32")) = (lv6,) | ||
| R.output(gv) | ||
| return gv | ||
|
|
||
| example_args1 = (torch.randn(3, 5, dtype=torch.float32),) | ||
| verify_model(Eye1(), example_args1, {}, Expected1) | ||
| verify_model(Eye1(), example_args1, {}, Expected1, run_ep_decomposition=True) | ||
|
|
||
| example_args2 = (torch.randn(5, dtype=torch.float32),) | ||
| verify_model(Eye2(), example_args2, {}, Expected2) | ||
| verify_model(Eye2(), example_args2, {}, Expected2, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_cross_entropy(): | ||
|
|
@@ -6187,21 +6220,39 @@ def forward(self, x): | |
| @tvm.script.ir_module | ||
| class Expected1: | ||
| @R.function | ||
| def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")): | ||
| def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((4,), dtype="float32")): | ||
| with R.dataflow(): | ||
| lv: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(x, axis=-1) | ||
| lv1: R.Tensor((), dtype="float32") = R.nn.nll_loss( | ||
| lv, | ||
| targets=R.const([0, 1, 2, 1], dtype="int64"), | ||
| reduction="mean", | ||
| ignore_index=-100, | ||
| lv: R.Tensor((4, 3), dtype="float32") = R.astype(x, dtype="float32") | ||
| lv1: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(lv, axis=1) | ||
| lv2: R.Tensor((4,), dtype="bool") = R.not_equal( | ||
| R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64") | ||
| ) | ||
| lv3: R.Tensor((), dtype="int64") = R.const(0, "int64") | ||
| lv4: R.Tensor((4,), dtype="int64") = R.where( | ||
| lv2, R.const([0, 1, 2, 1], dtype="int64"), lv3 | ||
| ) | ||
| gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,) | ||
| lv5: R.Tensor((4, 1), dtype="int64") = R.expand_dims(lv4, axis=[1]) | ||
| lv6: R.Tensor((4, 1), dtype="float32") = R.gather_elements(lv1, lv5, axis=1) | ||
| lv7: R.Tensor((4,), dtype="float32") = R.squeeze(lv6, axis=[1]) | ||
| lv8: R.Tensor((4,), dtype="float32") = R.negative(lv7) | ||
| lv9: R.Tensor((4,), dtype="bool") = R.not_equal( | ||
| R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64") | ||
| ) | ||
| lv10: R.Tensor((), dtype="float32") = R.const(0.0, "float32") | ||
| lv11: R.Tensor((4,), dtype="float32") = R.where(lv9, lv8, lv10) | ||
| lv12: R.Tensor((4,), dtype="bool") = R.not_equal( | ||
| R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64") | ||
| ) | ||
| lv13: R.Tensor((4,), dtype="bool") = R.sum(lv12, axis=[], keepdims=False) | ||
| lv14: R.Tensor((4,), dtype="float32") = R.astype(lv13, dtype="float32") | ||
| lv15: R.Tensor((4,), dtype="float32") = R.sum(lv11, axis=[], keepdims=False) | ||
| lv16: R.Tensor((4,), dtype="float32") = R.divide(lv15, lv14) | ||
| gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv16,) | ||
| R.output(gv) | ||
| return gv | ||
|
|
||
| example_args1 = (torch.randn(4, 3, dtype=torch.float32),) | ||
| verify_model(CrossEntropyModule(), example_args1, {}, Expected1) | ||
| verify_model(CrossEntropyModule(), example_args1, {}, Expected1, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_linspace(): | ||
|
|
@@ -6216,13 +6267,24 @@ def main( | |
| input: R.Tensor((9, 9), dtype="float32") | ||
| ) -> R.Tuple(R.Tensor((9,), dtype="float32")): | ||
| with R.dataflow(): | ||
| lv: R.Tensor((9,), dtype="float32") = R.arange(0, 1.0625, 0.125, dtype="float32") | ||
| gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv,) | ||
| lv: R.Tensor((9,), dtype="int64") = R.arange( | ||
| R.prim_value(0), R.prim_value(9), R.prim_value(1), dtype="int64" | ||
| ) | ||
| lv1: R.Tensor((9,), dtype="bool") = R.less(lv, R.const(4, "int64")) | ||
| lv2: R.Tensor((9,), dtype="float32") = R.astype(lv, dtype="float32") | ||
| lv3: R.Tensor((9,), dtype="float32") = R.multiply(lv2, R.const(0.125, "float32")) | ||
| lv4: R.Tensor((9,), dtype="float32") = R.add(lv3, R.const(0.0, "float32")) | ||
| lv5: R.Tensor((9,), dtype="int64") = R.subtract(R.const(8, "int64"), lv) | ||
| lv6: R.Tensor((9,), dtype="float32") = R.astype(lv5, dtype="float32") | ||
| lv7: R.Tensor((9,), dtype="float32") = R.multiply(lv6, R.const(0.125, "float32")) | ||
| lv8: R.Tensor((9,), dtype="float32") = R.subtract(R.const(1.0, "float32"), lv7) | ||
| lv9: R.Tensor((9,), dtype="float32") = R.where(lv1, lv4, lv8) | ||
| gv: R.Tuple(R.Tensor((9,), dtype="float32")) = (lv9,) | ||
|
Comment on lines
6218
to
+6282
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The generated IR for |
||
| R.output(gv) | ||
| return gv | ||
|
|
||
| example_args = (torch.randn(9, 9, dtype=torch.float32),) | ||
| verify_model(Linspace(), example_args, {}, Expected) | ||
| verify_model(Linspace(), example_args, {}, Expected, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
|
|
@@ -6259,7 +6321,7 @@ def main( | |
| R.output(gv) | ||
| return gv | ||
|
|
||
| verify_model(Model(), example_args, {}, Expected) | ||
| verify_model(Model(), example_args, {}, Expected, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_mm(): | ||
|
|
@@ -6285,7 +6347,7 @@ def main( | |
| R.output(gv) | ||
| return gv | ||
|
|
||
| verify_model(MatrixMultiply(), example_args, {}, Expected) | ||
| verify_model(MatrixMultiply(), example_args, {}, Expected, run_ep_decomposition=True) | ||
|
|
||
|
|
||
| def test_lstm(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic to get
dimcan be simplified into a single line by chainingdict.getcalls. This makes the code more concise and easier to read.