Pass that removes reshapes post LowerTE#12215
Conversation
manupak
left a comment
There was a problem hiding this comment.
Thanks @ashutosh-arm. looking good!
I think we need unit tests for the pass as well.
(E.g. https://github.com/apache/tvm/blob/main/tests/python/relay/test_pass_partition_graph.py)
lhutton1
left a comment
There was a problem hiding this comment.
Great work @ashutosh-arm! Just some small things I picked up on..
areusch
left a comment
There was a problem hiding this comment.
don't have tons of context here, but left a couple suggestions
| return WithFields(GetRef<Let>(let), var, value, body); | ||
| } | ||
|
|
||
| /*! * \brief Returns preceding CallLowered when call is a CallLowered(Reshape) */ |
There was a problem hiding this comment.
i'm probably missing some context here, but what about just returning the args to reshape()?
There was a problem hiding this comment.
Graph contains let nodes in between the call_lowered(). I've included the following piece as part of the Rewrite_() as well.
/*
%1 = call_lowered(@tvmgen_default_non_reshape_function, %input, ...);
let %x: = on_device(%1, ...);
%2 = (%x,);
%3 = call_lowered(@tvmgen_default_fused_reshape, %2, ...,
"relay_attrs"=__dict__="relay.reshape_only"=1, ...);
*/
Change-Id: Iaf5a5f44776080b0b842af4b563d596134508de1
Change-Id: I1f45ee3b15fbe290fdce69832a850d7d85ea1681
Change-Id: I81462a552f467d88cf1288acef2f9cbacc3ff532
Change-Id: I8502bc74eb0914cfcaa86cb809d7c4a9c6e86c70
ca579b2 to
389cadb
Compare
|
Thanks @ashutosh-arm @manupa-arm @areusch! |
Introduces a Pass for removing intermediate reshapes post
LowerTE() in AOT compiler. This commit adds pass specific
tests and updates usmp generated workspace pools due to
reduction in number of allocations post reshape removals.
Note: this pass at present does not support first reshape
appearing in the graph. If seen as a useful case, it can be
added in the future.
cc: @manupa-arm @grant-arm