From 6aba0427c3c204d049321de3d341b8eac8f34abf Mon Sep 17 00:00:00 2001 From: Rohan Date: Fri, 16 Jul 2021 19:40:52 +0000 Subject: [PATCH 1/2] [TensorRT, BYOC] Handling a corner case in TRT RemoveDropout pass --- python/tvm/relay/op/contrib/tensorrt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index cbe6a22f4a4d..5710de5bf73f 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -22,7 +22,7 @@ from tvm import relay from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name -from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem +from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem, Let from tvm.relay.expr_functor import ExprMutator, ExprVisitor logger = logging.getLogger("TensorRT") @@ -1033,6 +1033,8 @@ def visit_tuple_getitem(self, op): visit = super().visit_tuple_getitem(op) if visit.index != 0: return visit + if isinstance(visit.tuple_value, Call) and isinstance(visit.tuple_value.op, Let): + return visit if ( isinstance(visit.tuple_value, Call) and visit.tuple_value.op.name == "nn.dropout" From 4908819adb5c60316a1bcae7f70354933f9bbdd5 Mon Sep 17 00:00:00 2001 From: Rohan Date: Tue, 20 Jul 2021 18:46:46 +0000 Subject: [PATCH 2/2] changing visit logic --- python/tvm/relay/op/contrib/tensorrt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index 5710de5bf73f..cec7c4d141cb 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -22,7 +22,8 @@ from tvm import relay from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name -from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem, Let +from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem +from tvm.ir import Op from tvm.relay.expr_functor import ExprMutator, ExprVisitor logger = logging.getLogger("TensorRT") @@ -1033,10 +1034,9 @@ def visit_tuple_getitem(self, op): visit = super().visit_tuple_getitem(op) if visit.index != 0: return visit - if isinstance(visit.tuple_value, Call) and isinstance(visit.tuple_value.op, Let): - return visit if ( isinstance(visit.tuple_value, Call) + and isinstance(visit.tuple_value.op, Op) and visit.tuple_value.op.name == "nn.dropout" and visit.index == 0 ):