diff --git a/python/tvm/relay/op/contrib/tensorrt.py b/python/tvm/relay/op/contrib/tensorrt.py index cbe6a22f4a4d..cec7c4d141cb 100644 --- a/python/tvm/relay/op/contrib/tensorrt.py +++ b/python/tvm/relay/op/contrib/tensorrt.py @@ -23,6 +23,7 @@ from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem +from tvm.ir import Op from tvm.relay.expr_functor import ExprMutator, ExprVisitor logger = logging.getLogger("TensorRT") @@ -1035,6 +1036,7 @@ def visit_tuple_getitem(self, op): return visit if ( isinstance(visit.tuple_value, Call) + and isinstance(visit.tuple_value.op, Op) and visit.tuple_value.op.name == "nn.dropout" and visit.index == 0 ):