[Relax][PyTorch] Support advanced range constraints (addition)#18452
Conversation
9e283e1 to
aed2500
Compare
|
cc @mshr-h |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request adds support for advanced range constraints with addition expressions from PyTorch's dynamic shapes. The changes involve parsing sympy expressions into TIR PrimExprs and storing them as a new function attribute. A new test case is added to verify this functionality. My review includes a high-severity fix for a potential TypeError in the sympy expression parser and a medium-severity comment on a potential inconsistency in the new test case's expected output.
| class Expected: | ||
| @R.function | ||
| def main( | ||
| x: R.Tensor(("s0", 4), dtype="float32"), y: R.Tensor(("s0___1", 4), dtype="float32") | ||
| ) -> R.Tuple(R.Tensor(("s0 + s0___1", 4), dtype="float32")): | ||
| s0 = T.int64(is_size_var=True) | ||
| s0___1 = T.int64(is_size_var=True) | ||
| R.func_attr( | ||
| { | ||
| "tir_var_expr": {"s0 + 1": 1 + s0}, | ||
| "tir_var_lower_bound": {"s0": 1, "s0 + 1": 2}, | ||
| "tir_var_upper_bound": {"s0": 64, "s0 + 1": 65}, | ||
| } | ||
| ) | ||
| with R.dataflow(): | ||
| lv: R.Tensor((s0 + s0___1, 4), dtype="float32") = R.concat((x, y), axis=0) | ||
| gv: R.Tuple(R.Tensor((s0 + s0___1, 4), dtype="float32")) = (lv,) | ||
| R.output(gv) | ||
| return gv |
There was a problem hiding this comment.
The Expected IRModule seems to have some inconsistencies with what the translator is expected to generate. Given dynamic_shapes = {"x": {0: batch}, "y": {0: batch + 1}}, the translator should generate a SizeVar named "s0 + 1" for the dynamic dimension of y. Therefore, the y parameter in main should probably have the shape R.Tensor(("s0 + 1", 4), ...) instead of R.Tensor(("s0___1", 4), ...). Consequently, the output tensor shape would be R.Tensor(("s0 + s0 + 1", 4), ...) and the free variable s0___1 would not be needed. Could you please double-check the Expected module definition?
3b3e36a to
a061f94
Compare
a061f94 to
a232be6
Compare
Related Issue
Why
How
SymPyaddition expressions from PyTorch's range_constraints