Skip to content

[BYOC][TRT] Add DFPattern support for TRT backend#10759

Merged
mbaret merged 8 commits into
apache:mainfrom
mikepapadim:pattern_trt
Apr 4, 2022
Merged

[BYOC][TRT] Add DFPattern support for TRT backend#10759
mbaret merged 8 commits into
apache:mainfrom
mikepapadim:pattern_trt

Conversation

@mikepapadim

@mikepapadim mikepapadim commented Mar 24, 2022

Copy link
Copy Markdown
Contributor

This PR adds DFPattern support for the TRT backend without removing the existing predicate registry.

Adds and extends the following:

  • In tensorrt.py: Add a pattern_table for all the supported ops and consumes the pre-existing op_registry checks
  • Adds an additional pass as unmerge_composites.cc. This is required for the TRT backend as it expects a single primitive function to work with, while the MergeComposite and PartitionGraph will produce a single function for each Composite pattern.
  • Adds test_inline_composites.py which tests the newly introduced pass.

Both the pattern-based and predicate-based pass sequences produce syntactically equivalent IRModules.
This is to ensure backwards compatibility."

Original Pass orderding:

  seq = tvm.transform.Sequential(
        [
            transform.InferType(),
            RemoveDropoutPass(),
            transform.RemoveUnusedFunctions(),
            transform.ConvertLayout(
                {
                    "nn.conv1d": ["NCW", "default"],
                    "nn.conv2d": ["NCHW", "default"],
                    "nn.conv3d": ["NCDHW", "default"],
                    "nn.conv2d_transpose": ["NCHW", "default"],
                }
            ),
            transform.FoldConstant(),
            transform.AnnotateTarget("tensorrt"),
            transform.MergeCompilerRegions(),
            transform.PartitionGraph(),
            transform.InferType(),
        ]
    )

Pass ordering with MergeComposites and UnmergeComposites:

  seq = tvm.transform.Sequential(
        [
            transform.InferType(),
            RemoveDropoutPass(),
            transform.RemoveUnusedFunctions(),
            transform.ConvertLayout(
                {
                    "nn.conv1d": ["NCW", "default"],
                    "nn.conv2d": ["NCHW", "default"],
                    "nn.conv3d": ["NCDHW", "default"],
                    "nn.conv2d_transpose": ["NCHW", "default"],
                }
            ),
            transform.FoldConstant(),
            transform.MergeComposite(pattern_table()),                                   <-------- Change #1
            transform.AnnotateTarget("tensorrt"),
            transform.MergeCompilerRegions(),
            transform.PartitionGraph(),
            transform.InlineComposites("tensorrt"),                                     <-------- Change #2
            transform.InferType(),
        ]
    )

@mbs-octoml @mbaret @masahi

@mikepapadim mikepapadim force-pushed the pattern_trt branch 2 times, most recently from a4f84ce to 53b0c48 Compare March 24, 2022 12:48

@mbaret mbaret left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this pass needs its own unit tests so it can be tested outside of the TRT partitioning flow.

Comment thread python/tvm/relay/op/contrib/tensorrt.py Outdated

@_register_external_dynamic_check_func("nn.batch_matmul")
def batch_matmul_annotate_fn(expr):
def batch_matmul_annotate_fn(expr): # pylint: disable=unused-variable

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change required?

for tup in pattern_table:
if len(tup) == 2:
pattern_name, pattern = tup
check = lambda extract: True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain this change?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missmatch in black autoformat it


namespace relay {

class Unmerger : ExprMutator {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MixedModeMutator is now preferred where possible.

Comment on lines +48 to +49
Function gv = GetRef<Function>(function_var_node);
const auto* fn = gv.as<FunctionNode>();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed - it looks like we already start with the FunctionNode?


Expr VisitExpr_(const CallNode* call_node) final {
Call vanilla_call = GetAnyCall(call_node);
const auto* global_var_node = vanilla_call->op.as<GlobalVarNode>();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused

Comment on lines +62 to +65
// Attrs need to be empty at this point to avoid propagating Composite and
// PartitionedFromPattern that fiddling TRT code gen for registered ops.
auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, {});
return Bind(func->body, bind_map);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand this, why can't we just do

return Bind(fn->body, bind_map);

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


if (!base_func->GetAttr<String>(attr::kCompiler).defined() &&
base_func->GetAttr<String>(attr::kCompiler) != target) {
return module;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be better to continue; here rather than return, otherwise it seems if any partitioning for a different target has taken place, this will bail out.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

*/

/*!
* \file src/relay/transforms/unmerge_composites.cc

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personal preference for the name here would be either InlineComposite or RemoveComposite, not a huge deal though, if no one else agrees we can keep it as Unmerge.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InlineComposite makes sense, I will rename it


/*!
* \file src/relay/transforms/unmerge_composites.cc
* \brief Undo the partioned graphs originate from merge composite.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 'Inline composite functions for a given target' describes this a bit better.

@mikepapadim

Copy link
Copy Markdown
Contributor Author

PTAL

@mbaret

mbaret commented Mar 25, 2022

Copy link
Copy Markdown
Contributor

I still think this needs a couple of simple unit tests to confirm the behaviour. Also ping @mbs-octoml if you want to take a quick look.

@mikepapadim

Copy link
Copy Markdown
Contributor Author

@mbs-octoml

@mikepapadim

Copy link
Copy Markdown
Contributor Author

@mbaret PTAL. Under test_inline_composites.py are a couple o unit-tests for testing the new pass.

Comment on lines +123 to +125
print("merge composite reusult")
print(result)
print("---------------------")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably omit the prints.

Comment on lines +164 to +174
def expected():
a = relay.var("a", shape=(10, 10))
b = relay.var("b", shape=(10, 10))

# add_relu function
in_1 = relay.var("in_1", shape=(10, 10))
in_2 = relay.var("in_2", shape=(10, 10))
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
return add_relu

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the same as before() (given a and b aren't used). If all we really want to test is that doing InlineComposites undoes MergeComposite, we can probably just test that the result is equal to the input.

"""Utility function to check inline composites results."""
result = run_opt_pass(
graph, relay.transform.MergeComposite(pattern_table), import_prelude=import_prelude
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should put some form of check here just to confirm that a composite function has been created (so we know MergeComposite didn't just skip everything if for instance there was a pattern error).

relu relu

"""
pattern_table = [("add", make_add_pattern()), ("nn.relu", make_relu_pattern())]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to match the description above.

"""


def make_conv_bias_relu_pattern():

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern doesn't seem to be used, I think either add a test for it or remove.

Comment thread tests/python/relay/test_pass_inline_composites.py Outdated

@mbaret mbaret left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@mbaret mbaret merged commit 98580a2 into apache:main Apr 4, 2022
pfk-beta pushed a commit to pfk-beta/tvm that referenced this pull request Apr 11, 2022
This PR adds DFPattern support for the TRT backend without removing the existing predicate registry.

Adds and extends the following:

In tensorrt.py: Add a pattern_table for all the supported ops and consumes the pre-existing op_registry checks
Adds an additional pass as unmerge_composites.cc. This is required for the TRT backend as it expects a single primitive 
function to work with, while the MergeComposite and PartitionGraph will produce a single function for each Composite 
pattern.

Adds test_inline_composites.py which tests the newly introduced pass.
Both the pattern-based and predicate-based pass sequences produce syntactically equivalent IRModules.
This is to ensure backwards compatibility."
mehrdadh pushed a commit to mehrdadh/tvm that referenced this pull request Apr 11, 2022
This PR adds DFPattern support for the TRT backend without removing the existing predicate registry.

Adds and extends the following:

In tensorrt.py: Add a pattern_table for all the supported ops and consumes the pre-existing op_registry checks
Adds an additional pass as unmerge_composites.cc. This is required for the TRT backend as it expects a single primitive 
function to work with, while the MergeComposite and PartitionGraph will produce a single function for each Composite 
pattern.

Adds test_inline_composites.py which tests the newly introduced pass.
Both the pattern-based and predicate-based pass sequences produce syntactically equivalent IRModules.
This is to ensure backwards compatibility."
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants