diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index dfc0298979e6..0f99193fe525 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1483,6 +1483,23 @@ def _impl_v12(cls, inputs, attr, params): return result +class Squeeze(OnnxOpConverter): + """Operator converter for Squeeze.""" + + @classmethod + def _impl_v1(cls, inputs, attr, params): + axis = attr.get("axes", None) + return _op.squeeze(*inputs, axis) + + @classmethod + def _impl_v13(cls, inputs, attr, params): + axis = inputs[1] + dtype = infer_type(axis).checked_type.dtype + rank = _op.shape_of(_op.shape_of(inputs[0], dtype), dtype) + axis = _op.where(axis < _op.const(0, dtype), axis + rank, axis) + return _op.squeeze(inputs[0], fold_constant(axis)) + + class Split(OnnxOpConverter): """Operator converter for Split.""" @@ -2806,7 +2823,8 @@ def _impl_v12(cls, inputs, attr, params): alpha = _op.const(attr.get("alpha", 1.0), dtype) zero = _op.const(0, dtype) one = _op.const(1, dtype) - return _op.maximum(zero, x) + _op.minimum(zero, alpha * (_op.exp(x / alpha) - one)) + out = _op.maximum(zero, x) + _op.minimum(zero, alpha * (_op.exp(x / alpha) - one)) + return out class MaxRoiPool(OnnxOpConverter): @@ -4136,7 +4154,7 @@ def _get_convert_map(opset): "ScatterElements": Scatter.get_converter(opset), "ScatterND": ScatterND.get_converter(opset), "EyeLike": EyeLike.get_converter(opset), - "Squeeze": AttrCvt("squeeze", {"axes": "axis"}), + "Squeeze": Squeeze.get_converter(opset), "Unsqueeze": Unsqueeze.get_converter(opset), "Pad": Pad.get_converter(opset), "Shape": Shape.get_converter(opset), diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py index c8235ec9375a..c909764319d9 100644 --- a/python/tvm/relay/op/dyn/_transform.py +++ b/python/tvm/relay/op/dyn/_transform.py @@ -26,6 +26,7 @@ _reg.register_broadcast_schedule("dyn.broadcast_to") _reg.register_injective_schedule("dyn.reshape") _reg.register_injective_schedule("dyn.expand_dims") +_reg.register_injective_schedule("dyn.squeeze") _reg.register_broadcast_schedule("dyn.tile") _reg.register_injective_schedule("dyn.one_hot") _reg.register_injective_schedule("dyn.full") @@ -258,3 +259,24 @@ def _sparse_to_dense_shape_func(output_shape, ndim): @_reg.register_shape_func("dyn.sparse_to_dense", True) def sparse_to_dense_shape_func(attrs, inputs, out_ndims): return [_sparse_to_dense_shape_func(inputs[3], out_ndims[0])] + + +@script +def _squeeze_shape_func_input_data(data, axis, ndims): + out = output_tensor((ndims,), "int64") + out_i = 0 + for i in const_range(data.shape[0]): + not_in_axis = True + for j in const_range(axis.shape[0]): + if i == axis[j]: + not_in_axis = False + if not_in_axis: + out[out_i] = int64(data[i]) + out_i += 1 + + return out + + +@_reg.register_shape_func("dyn.squeeze", [False, True]) +def dynamic_squeeze_shape_func(attrs, inputs, out_ndims): + return [_squeeze_shape_func_input_data(inputs[0], inputs[1], out_ndims[0])] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index fe1a73ca231a..234e76b11813 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -149,7 +149,7 @@ def squeeze(data, axis=None): data : tvm.relay.Expr The input data to the operator. - axis : None or List[int] + axis : None or List[int] or Expr The set of axes to remove. If axis = None, remove all axis of dimensions 1. If any specified axis has dimension that does not equal 1, it is an error. @@ -159,6 +159,10 @@ def squeeze(data, axis=None): result : tvm.relay.Expr The squeezed result. """ + if isinstance(axis, Constant): + axis = list(axis.data.numpy()) + if isinstance(axis, Expr): + return _dyn_make.squeeze(data, axis) return _make.squeeze(data, axis) diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index 848d058f0af3..64baa6066522 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -692,6 +692,63 @@ RELAY_REGISTER_OP("dyn.expand_dims") .set_attr("FTVMCompute", ExpandDimsCompute) .set_attr("TOpPattern", kInjective); +bool DynSqueezeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // [data, axes, output] + ICHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) { + return false; + } + const auto* axes = types[1].as(); + if (axes == nullptr) { + return false; + } + + ICHECK_EQ(axes->shape.size(), 1) << "Got" << axes->shape.size() << "expected 1"; + ICHECK(axes->shape[0].as()) << "axes expected to be static rank"; + size_t output_rank = data->shape.size() - axes->shape[0].as()->value; + std::vector result_shape(output_rank, Any()); + reporter->Assign(types[2], TensorType(result_shape, data->dtype)); + return true; +} + +Array SqueezeCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const auto* out_ttype = out_type.as(); + ICHECK(out_ttype != nullptr); + Array newshape; + for (auto val : out_ttype->shape) { + newshape.push_back(val.as()->ToVar()); + } + return {topi::reshape(inputs[0], newshape)}; +} + +Expr MakeDynSqueeze(Expr data, Expr axes) { + auto attrs = make_object(); + static const Op& op = Op::Get("dyn.squeeze"); + return Call(op, {data, axes}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.dyn._make.squeeze").set_body_typed(MakeDynSqueeze); + +RELAY_REGISTER_OP("dyn.squeeze") + .describe(R"code(Remove axes of value 1 in input tensor at the dimensions given by axes + +- **data**: The input data to the operator. +- **axes**: The axes to squeeze. + +)code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("axes", "Tensor", "The axes to squeeze.") + .set_support_level(3) + .add_type_rel("DynSqueeze", DynSqueezeRel) + .set_attr("FTVMCompute", SqueezeCompute) + .set_attr("TOpPattern", kInjective) + .set_attr("TReshapeOp", true); + } // namespace dyn } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 318022fb86f5..751271d2add3 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -45,6 +45,15 @@ class DynamicToStaticMutator : public MixedModeMutator { } return Expr(nullptr); }}, + {Op::Get("dyn.squeeze"), + [this](const CallNode* call_node) { + auto args = PrepareArgs(call_node); + if (const ConstantNode* axis = args[1].as()) { + ICHECK_EQ(axis->data->ndim, 1); + return MakeSqueeze(call_node->args[0], ToVector(axis->data)); + } + return Expr(nullptr); + }}, {Op::Get("dyn.tile"), [this](const CallNode* call_node) { auto args = PrepareArgs(call_node); diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index e049a195dc17..a4ffcc39b52b 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4950,8 +4950,6 @@ def verify_eyelike(indata): "test_split_variable_parts_2d", "test_split_variable_parts_default_axis", "test_split_zero_size_splits", - "test_squeeze", - "test_squeeze_negative_axes", "test_strnormalizer_export_monday_casesensintive_lower", "test_strnormalizer_export_monday_casesensintive_nochangecase", "test_strnormalizer_export_monday_casesensintive_upper", diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index 8c57e1dc4a9f..22583eda4a40 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -92,6 +92,22 @@ def verify_reshape(shape, newshape, oshape): verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) +def test_squeeze(): + def verify_squeeze(shape, dtype, axis): + x = relay.var("x", relay.TensorType(shape, dtype)) + assert axis is not None + np_axis = tuple(axis) + axis = relay.var("axis", relay.TensorType([len(axis)], "int64")) + squeeze = relay.squeeze(x, axis=axis) + func = relay.Function([x, axis], squeeze) + x_data = np.random.random_sample(shape).astype(dtype) + ref_res = np.squeeze(x_data, axis=np_axis) + verify_func(func, [x_data, np.array(np_axis).astype("int64")], ref_res) + + verify_squeeze((1, 3, 1), "float32", [0]) + verify_squeeze((1, 2, 1, 2, 1), "float32", [0, 2]) + + @tvm.testing.uses_gpu def test_dyn_expand_dims(): def verify_expand_dims( diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index a34c4ac6f705..5b61733bbd76 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -72,6 +72,31 @@ def verify_reshape(shape, newshape, oshape): verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) +@tvm.testing.uses_gpu +def test_dynamic_to_static_squeeze(): + def verify_squeeze(shape, axis, oshape): + x = relay.var("x", relay.TensorType(shape, "float32")) + y = relay.var("y", relay.TensorType(axis, "float32")) + z = relay.squeeze(x, relay.shape_of(y)) + func = run_infer_type(relay.Function([x, y], z)) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + + zz = func2.body + assert isinstance(zz, relay.Call) + assert zz.op == relay.op.get("squeeze") + assert "axis=" in zz.astext() + assert zz.checked_type == relay.ty.TensorType(oshape, "float32") + + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + y_data = np.random.uniform(low=-1, high=1, size=axis).astype("float32") + ref_res = np.squeeze(x_data, axis) + verify_func(func2, [x_data, y_data], ref_res) + + verify_squeeze((1, 3, 4, 1), (0,), (3, 4, 1)) + verify_squeeze((1, 3, 4, 1), (3,), (1, 3, 4)) + verify_squeeze((1, 3, 4, 1), (0, 3), (3, 4)) + + @tvm.testing.uses_gpu def test_dynamic_to_static_double_reshape(): def verify_reshape(shape, newshape):