Skip to content

[Bug.Relay.InferType] Type Inference Report Mismatch after One Operator Is Removed #8432

Description

@Johnson9009

Standard Ouput and Error Message

[00:32:22] /home/zhaqia01/workspaces/tvm/src/ir/transform.cc:616: PrintIR():
#[version = "0.0.5"]
def @main(%x1: Tensor[(1, 224, 224, 3), int8], %x2: Tensor[(7, 7, 3, 64), int8]) -> Tensor[(1, 224, 224, 64), int32] {
  %0 = nn.pad(%x1, 0 /* ty=int32 */, pad_width=[[0, 0], [3, 3], [3, 3], [0, 0]]) /* ty=Tensor[(1, 230, 230, 3), int8] */;
  nn.conv2d(%0, %x2, padding=[0, 0, 0, 0], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32") /* ty=Tensor[(1, 224, 224, 64), int32] */
}

The Relay type checker is unable to show the following types match.
In particular dimension 1 conflicts: 218 does not match 224.dimension 2 conflicts: 218 does not match 224.
The Relay type checker is unable to show the following types match.
In particular `Tensor[(1, 224, 224, 64), int32]` does not match `Tensor[(1, 218, 218, 64), int32]`
note: run with `TVM_BACKTRACE=1` environment variable to display a backtrace.

Reproduce Test Case

import tvm
from tvm import relay


class PadSimplifier(relay.ExprMutator):
    def __init__(self):
        super().__init__()

    def visit_call(self, call):
        call = super().visit_call(call)
        if ((call.op != relay.op.get("nn.conv2d")) or
            (not isinstance(call.args[0], relay.Call)) or
            (call.args[0].op != relay.op.get("nn.pad"))):
            return call

        conv2d = call
        pad, weight = conv2d.args
        data, pad_value = pad.args
        if ((pad.attrs.pad_mode != "constant") or
            (not isinstance(pad_value, relay.Constant)) or
            (pad_value.data.numpy() != 0)):
            return conv2d

        # For reproduce issue, so just return the conv2d operator.
        return relay.Call(conv2d.op, [data, weight], conv2d.attrs, conv2d.type_args, conv2d.span)


@relay.transform.function_pass(opt_level=0)
class SimplifyPad:
    def transform_function(self, func, ir_mod, pass_ctx):
        return PadSimplifier().visit(func)


dtype = "int8"
dshape = (1, 224, 224, 3)
kshape = (7, 7, 3, 64)

x1 = relay.var("x1", shape=dshape, dtype=dtype)
x2 = relay.var("x2", shape=kshape, dtype=dtype)
expr = relay.nn.pad(x1, [[0, 0], [3, 3], [3, 3], [0, 0]])
expr = relay.nn.conv2d(expr, x2, data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32")
nn_mod = tvm.IRModule.from_expr(expr)

passes = [
    relay.transform.InferType(),
    tvm.transform.PrintIR(),
    SimplifyPad(),
]

with tvm.transform.PassContext(opt_level=3):
    nn_mod = tvm.transform.Sequential(passes)(nn_mod)
print(nn_mod)

Current Clue

We can see the first InferType pass works well, so before the pass "SimplifyPad" the infered output shape of "nn.conv2d" is (1, 224, 224, 64), after the pass "SimplifyPad" the operator "nn.pad" is removed, because "SimplifyPad" is a function pass so the InferType pass will be executed automatically, the error happened in this 2nd InferType pass.

pass_ctx->diag_ctx.value().Render();
pass_ctx->diag_ctx = previous;
// TODO(@jroesch): move away from eager type checking for performance reasons
// make issue.
return transform::InferType()(updated_mod);
}

When the 2nd InferType call function Conv2DRel, the value of parameter "types" is "[TensorType([1, 224, 224, 3], int8), TensorType([7, 7, 3, 64], int8), TensorType([1, 224, 224, 64], int32)]", the last item of parameter "types" maybe wrong, because the value of this parameter during the 1st InferType is "[TensorType([1, 230, 230, 3], int8), TensorType([7, 7, 3, 64], int8), IncompleteTypeNode(0, 0x5d6e270)]".

bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
ICHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();

The type solver relevant code is hard to understand, so I want to know is this a bug? or the pass I write missing something?
Thanks a lot.

Metadata

Metadata

Assignees

No one assigned

    Labels

    flow:relayThe overall lowering flow for tvm.relay.build, including BYOC core, excluding tvm.driver.build.relay:opsrc/relay/op

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions