diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index 1bd3084c210b..dce687b436d1 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -122,10 +122,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { * \sa MatchShapeCode */ void MatchShape(ffi::PackedArgs args, ffi::Any* rv) { - // input shape the first argument can take in tensor or shape. + // input shape the first argument can take in tensor, DLTensor* or shape. ffi::Shape input_shape; - if (auto opt_nd = args[0].as()) { - input_shape = opt_nd.value().Shape(); + if (auto opt_tensor = args[0].as()) { + input_shape = opt_tensor.value().Shape(); + } else if (auto opt_dltensor = args[0].try_cast()) { + DLTensor* ptr = opt_dltensor.value(); + input_shape = ffi::Shape(ptr->shape, ptr->shape + ptr->ndim); } else { input_shape = args[0].cast(); }