Skip to content

[Relay] Flexible shape dispatch transformation#11199

Merged
jwfromm merged 14 commits into
apache:mainfrom
jwfromm:flexible_shape_pass
May 6, 2022
Merged

[Relay] Flexible shape dispatch transformation#11199
jwfromm merged 14 commits into
apache:mainfrom
jwfromm:flexible_shape_pass

Conversation

@jwfromm

@jwfromm jwfromm commented May 3, 2022

Copy link
Copy Markdown
Contributor

This PR adds a new pass to relay.transform that creates a dispatcher around an input module to handle multiple input shapes. For example consider a case where I'd like to optimize my model to handle both batch_size=1 and batch_size=4. I can now do so elegantly as follows:

shape_dict = {'data': [1, 3, 224, 224]}
model_bs1 = tvmc.load('my_model.onnx', shape_dict=shape_dict)
tvmc.tune(model, log_file='batch_1.logs')

shape_dict = {'data': [4, 3, 224, 224]}
model_bs4 = tvmc.load('my_model.onnx', shape_dict=shape_dict)
tvmc.tune(model, log_file='batch_4.logs')

# Create dispatcher for multiple batch sizes.
flex_mod = relay.transform.FlexibleShapeDispatch(buckets=[1, 4])(model_bs1.mod)

with ApplyHistoryBest(['batch_1.logs', 'batch_4.logs']):
    exe = relay.vm.compile(flex_mod, "llvm")

# Now we can run inputs with either batch 1 or batch 4 and get the tuned performance!
batch_1 = np.random.rand(1, 3, 224, 224).astype("float32")
vm.benchmark(tvm.cpu(), batch_2, func_name="main")

batch_4 = np.random.rand(4, 3, 224, 224).astype("float32")
vm.benchmark(tvm.cpu(), batch_4, func_name="main")

As seen above FlexibleShapeDispatch is a simple halfway point between fully static and fully dynamic graphs that allows us to leverage TVM tuning. If an input shape is not provided in buckets, it will either run fully dynamically using relay.Any, or if the auto_pad argument is set for FlexibleShapeDispatch, padding will be applied to match the closest bucket.

There are a few special cases that this pass handles. Multiple dynamic inputs (like those you might see in BERT) can be handled by setting input_indices to indicate which inputs have a dynamic axis. affects_output can be set to False for cases where the output shape is not dependent on input dynamism which could occur in dynamic resolution cases or something.

To make applying tuning logs more convenient, I also added the ability to load and merge multiple files to both autotvm and autoscheduler.

Thanks @jroesch for providing the backbone of this implementation.

@jwfromm jwfromm requested review from AndrewZhaoLuo and jroesch May 3, 2022 02:29
Comment thread python/tvm/relay/transform/flexible_shape.py Outdated
Comment thread python/tvm/relay/transform/flexible_shape.py Outdated
Comment thread python/tvm/relay/transform/flexible_shape.py Outdated
Comment thread python/tvm/relay/transform/flexible_shape.py Outdated
Comment thread python/tvm/relay/transform/flexible_shape.py Outdated

@jroesch jroesch left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo feedback

@jroesch

jroesch commented May 3, 2022

Copy link
Copy Markdown
Member

cc @mbs-octoml

@jwfromm

jwfromm commented May 3, 2022

Copy link
Copy Markdown
Contributor Author

@jroesch I think the documentation should now be substantially improved based on your input. Let me know what you think.

@mbs-octoml mbs-octoml left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some optional nits:

  • consider asserting dim is Any in override_shape to avoid surprises
  • rename dim -> axis, value -> dim
  • should we force a TypeInfer afterwards to be sure the Any->dim subst is sound?

@jwfromm

jwfromm commented May 4, 2022

Copy link
Copy Markdown
Contributor Author

All feedback has been addressed and tests are green. @jroesch do you want to give this one more pass?

@jwfromm jwfromm merged commit 98aa41e into apache:main May 6, 2022
shtinsa pushed a commit to Deelvin/tvm that referenced this pull request May 17, 2022
* Added pass that creates a semi-dynamic dispatcher around a relay module.

* Added automatic padding feature.

* Output slicing working.

* Multiple input support working i think.

* Added test file.

* Improve comments.

* Fix lint.

* Allow default values.

* Fix docstring.

* Improved documentation based on feedback.

* Add extra check for record loading.

* Improve variable names.

* Add type inference to make sure things worked.

* Added support for multiple outputs.
SebastianBoblest pushed a commit to SebastianBoblest/tvm that referenced this pull request May 27, 2022
* Added pass that creates a semi-dynamic dispatcher around a relay module.

* Added automatic padding feature.

* Output slicing working.

* Multiple input support working i think.

* Added test file.

* Improve comments.

* Fix lint.

* Allow default values.

* Fix docstring.

* Improved documentation based on feedback.

* Add extra check for record loading.

* Improve variable names.

* Add type inference to make sure things worked.

* Added support for multiple outputs.
@jwfromm jwfromm deleted the flexible_shape_pass branch April 12, 2023 15:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants