diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f7538f0837c6..eeb6480c3491 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -17,7 +17,7 @@ # pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks # pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except # pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda -# pylint: disable=missing-function-docstring +# pylint: disable=missing-function-docstring, redefined-builtin """PT: PyTorch frontend.""" import functools import itertools @@ -45,7 +45,7 @@ from .common import infer_value as _infer_value from .common import infer_value_simulated as _infer_value_simulated from .common import lstm_cell, try_infer_value, unbind -from .pytorch_utils import is_version_greater_than +from .pytorch_utils import is_version_greater_than, getattr_attr_name __all__ = ["from_pytorch"] @@ -1393,10 +1393,19 @@ def log_softmax(self, inputs, input_types): def sigmoid(self, inputs, input_types): data = inputs[0] - return _op.tensor.sigmoid(data) + + def func(x): + return _op.tensor.sigmoid(x) + + if self.is_quantized_tensor(data): + assert len(inputs) == 3, "Input quant param not found in op inputs" + input_scale = _expr.const(inputs[1]) + input_zero_point = _expr.const(inputs[2]) + return qnn_torch.apply_with_fp32_fallback(data, input_scale, input_zero_point, func) + + return func(data) def softplus(self, inputs, input_types): - data = inputs[0] dtype = input_types[0] beta = _expr.const(float(inputs[1]), dtype=dtype) return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta @@ -1583,7 +1592,8 @@ def func(x): assert len(inputs) == 6, "Input quant param not found in op inputs" input_scale = _expr.const(inputs[4]) input_zero_point = _expr.const(inputs[5]) - return qnn_torch.quantized_mean(data, input_scale, input_zero_point, func) + # refer to aten/src/ATen/native/quantized/cpu/qreduction.cpp + return qnn_torch.apply_with_fp32_fallback(data, input_scale, input_zero_point, func) return func(data) @@ -1755,9 +1765,7 @@ def pad(inputs, input_types): return pad - def clamp(self, inputs, input_types): - data = inputs[0] - + def clamp_common(self, data, min=None, max=None): def get_v(v, default_v): if isinstance(v, _expr.Constant): return float(v.data.numpy()) @@ -1769,10 +1777,32 @@ def get_v(v, default_v): return v return default_v - amin = get_v(inputs[1], np.finfo(np.float32).min) - amax = get_v(inputs[2], np.finfo(np.float32).max) + dtype = self.infer_type(data).dtype + + type_info = np.finfo(dtype) if "float" in dtype else np.iinfo(dtype) + + # TODO(masahi): Properly handle inf in a one-way clamp case. + if min is not None and max is not None: + amin = get_v(min, type_info.min) + amax = get_v(max, type_info.max) + elif min is not None: + amin = get_v(min, type_info.min) + amax = type_info.max + else: + amin = type_info.min + amax = get_v(max, type_info.max) + return _op.clip(data, amin, amax) + def clamp(self, inputs, _): + return self.clamp_common(inputs[0], min=inputs[1], max=inputs[2]) + + def clamp_min(self, inputs, input_types): + return self.clamp_common(inputs[0], min=inputs[1]) + + def clamp_max(self, inputs, input_types): + return self.clamp_common(inputs[0], max=inputs[1]) + def to(self, inputs, input_types): data = inputs[0] dtype = inputs[1] if inputs[1] is not None and not isinstance(inputs[1], str) else inputs[2] @@ -1847,7 +1877,8 @@ def func(x): assert isinstance(inputs[-1], int) input_scale = _expr.const(inputs[-2]) input_zero_point = _expr.const(inputs[-1]) - return qnn_torch.quantized_upsample(data, input_scale, input_zero_point, func) + # currently piggy backs to fp32, it gets identical output as torch + return qnn_torch.apply_with_fp32_fallback(data, input_scale, input_zero_point, func) return func(data) @@ -3020,6 +3051,8 @@ def create_convert_map(self): "aten::isinf": self.make_unary("isinf"), "aten::isnan": self.make_unary("isnan"), "aten::clamp": self.clamp, + "aten::clamp_min": self.clamp_min, + "aten::clamp_max": self.clamp_max, "aten::detach": self.identity, "aten::upsample_bilinear2d": self.make_upsample("linear"), "aten::upsample_bicubic2d": self.make_upsample("cubic"), @@ -3531,15 +3564,8 @@ def _get_users(node): return [use.user for use in _get_uses(node)] -def _getattr_attr_name(node): - attribute_names = node.attributeNames() - assert len(attribute_names) == 1 - attr_name = node.s(attribute_names[0]) - return attr_name - - def _getattr_full_name(getattrs, sep="."): - return sep.join([_getattr_attr_name(node) for node in getattrs]) + return sep.join([getattr_attr_name(node) for node in getattrs]) def _get_pytorch_value_type(typ, default_dtype="float32"): @@ -3941,6 +3967,7 @@ def from_pytorch( weight_quant_params = qnn_torch.get_weight_quant_params( script_module, packed_param_map.values() ) + qnn_torch.inline_input_quant_params_for_fx(graph, tensors) input_scales_for_bias = qnn_torch.add_input_quant_params_to_op_inputs(graph) qnn_torch.add_quant_params_to_outputs( outputs, diff --git a/python/tvm/relay/frontend/pytorch_utils.py b/python/tvm/relay/frontend/pytorch_utils.py index b4fe1681937f..c87ea8af6d33 100644 --- a/python/tvm/relay/frontend/pytorch_utils.py +++ b/python/tvm/relay/frontend/pytorch_utils.py @@ -39,6 +39,12 @@ def is_version_greater_than(ver): ) +def getattr_attr_name(node): + attribute_names = node.attributeNames() + assert len(attribute_names) == 1 + return node.s(attribute_names[0]) + + def dyn_strided_slice_pattern(inp, end): """A pattern to detect dynamic strided slice op.""" zero = is_constant() diff --git a/python/tvm/relay/frontend/qnn_torch.py b/python/tvm/relay/frontend/qnn_torch.py index 729b89738445..83af057e8b5d 100644 --- a/python/tvm/relay/frontend/qnn_torch.py +++ b/python/tvm/relay/frontend/qnn_torch.py @@ -25,7 +25,7 @@ from tvm.relay.frontend.common import infer_shape from .common import logger -from .pytorch_utils import is_version_greater_than +from .pytorch_utils import is_version_greater_than, getattr_attr_name class QNNParam: @@ -292,8 +292,8 @@ def dfs(current_node): for arg in current_node.inputs(): return dfs(arg.node()) - # shouldn't happen - assert False, "No producer for %s" % (str(current_node)) + # If input_value is not quantized, we reach here. + return None, None return dfs(input_value.node()) @@ -437,6 +437,7 @@ def add_input_quant_params_to_op_inputs(graph): "quantized::mul": 2, "aten::dequantize": 1, "aten::mean": 1, + "aten::sigmoid": 1, "aten::upsample_nearest2d": 1, "aten::upsample_bilinear2d": 1, "aten::relu_": 1, @@ -473,8 +474,9 @@ def add_input_quant_params_to_op_inputs(graph): else: for i in range(num_quantized_inputs[operator]): scale, zp = _get_quant_param_for_input(node.inputsAt(i)) - input_scales.append(scale) - input_zero_points.append(zp) + if scale is not None and zp is not None: + input_scales.append(scale) + input_zero_points.append(zp) if operator in ["quantized::add_scalar", "quantized::mul_scalar"]: scalar = node.inputsAt(1).node().f("value") @@ -488,9 +490,10 @@ def add_input_quant_params_to_op_inputs(graph): node.addInput(scale) node.addInput(zp) - if "conv" in operator or "linear" in operator: + if "quantized::conv" in operator or "quantized::linear" in operator: # This is required for quantizing the bias - input_scales_for_bias[node.inputsAt(1).debugName()] = scale.node().f("value") + assert len(input_scales) == 1, "One quantized parameter expected for qconv or qlinear." + input_scales_for_bias[node.inputsAt(1).debugName()] = input_scales[0].node().f("value") return input_scales_for_bias @@ -503,26 +506,67 @@ def add_quant_params(params, quant_params): params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias) +def inline_input_quant_params_for_fx(graph, params): + """ + Canonicalize input scale and zero point access for FX-quantized graphs. + We expect input qparams to aten::quantize_per_tensor to be prim::Constant, but that's + not the case for FX-based quantized models as shown below. + We replace prim::GetAttr with prim::Constant so that FX-based quantized models can be + converted in the same way as eager-mode based quantized models. + + Before: + %pan_input_zero_point_1 : Tensor = prim::GetAttr[name="pan_input_zero_point_1"](%backbone) + %pan_input_scale_1 : Tensor = prim::GetAttr[name="pan_input_scale_1"](%backbone) + ... + %quantize_per_tensor_2 ... = aten::quantize_per_tensor(..., + %pan_input_scale_1, %pan_input_zero_point_1, ...) + + After: + %2402 : int = prim::Constant[value=0]() + %2403 : float = prim::Constant[value=1.]() + %quantize_per_tensor_2 ... = aten::quantize_per_tensor(..., %2403, %2402, ...) + """ + import torch + + def get_full_attr_name(current): + current_attr = getattr_attr_name(current) + inputs = list(current.inputs()) + if len(inputs) == 1 and inputs[0].node().kind() == "prim::GetAttr": + return get_full_attr_name(inputs[0].node()) + "." + current_attr + return current_attr + + for node in graph.findAllNodes("prim::GetAttr", recurse=True): + out_name = node.output().debugName() + + if "_input_scale" in out_name or "_input_zero_point" in out_name: + full_attr = get_full_attr_name(node) + assert full_attr in params, "%s not found in param dict." % full_attr + param_np = params[full_attr].numpy() + new_const_node = graph.create("prim::Constant") + new_const_node.insertBefore(node) + + if "_input_scale" in out_name: + new_const_node.f_("value", param_np) + new_const_node.output().setType(torch._C.FloatType.get()) + else: + new_const_node.i_("value", param_np.item()) + new_const_node.output().setType(torch._C.IntType.get()) + + node.replaceAllUsesWith(new_const_node) + + def apply_with_upcast(data, func): inp = _op.cast(data, dtype="int32") out = func(inp) return _op.cast(out, "uint8") -def quantized_mean(data, input_scale, input_zero_point, func_fp32): - # refer to aten/src/ATen/native/quantized/cpu/qreduction.cpp +def apply_with_fp32_fallback(data, input_scale, input_zero_point, func_fp32): dequantized = relay.qnn.op.dequantize(data, input_scale, input_zero_point) out = func_fp32(dequantized) return relay.qnn.op.quantize(out, input_scale, input_zero_point, out_dtype="uint8", axis=1) -def quantized_upsample(data, input_scale, input_zero_point, func_fp32): - # currently piggy backs to fp32, it gets identical output as torch - data = relay.qnn.op.dequantize(data, input_scale, input_zero_point) - out = func_fp32(data) - return relay.qnn.op.quantize(out, input_scale, input_zero_point, out_dtype="uint8", axis=1) - - def quantized_relu(data, input_zero_point): # refer to aten/src/ATen/native/quantized/cpu/qrelu.cpp zp = _op.cast(input_zero_point, dtype="uint8") @@ -531,8 +575,14 @@ def quantized_relu(data, input_zero_point): def _quantize_per_tensor(): def _impl(inputs, _): + dim = len(infer_shape(inputs[0])) + if dim > 1: + axis = 1 + else: + axis = 0 + return relay.qnn.op.quantize( - inputs[0], _expr.const(inputs[1]), _expr.const(inputs[2]), out_dtype="uint8", axis=1 + inputs[0], _expr.const(inputs[1]), _expr.const(inputs[2]), out_dtype="uint8", axis=axis ) return _impl diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 3fbef494f16d..9471bb3c0659 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2790,6 +2790,10 @@ def forward(self, *args): verify_model(Clamp3().float().eval(), input_data=input_data) verify_model(Clamp_MinExpr_MaxConstant().float().eval(), input_data=input_data) + verify_model(lambda inp: torch.clamp_min(inp, 0.5), input_data) + inp_uint8 = torch.randint(low=0, high=256, size=(100, 100), dtype=torch.uint8) + verify_model(lambda inp: torch.clamp_max(inp, 125), inp_uint8) + @tvm.testing.uses_gpu def test_forward_clamp_(): diff --git a/tests/python/frontend/pytorch/test_fx_quant.py b/tests/python/frontend/pytorch/test_fx_quant.py new file mode 100644 index 000000000000..f35094a83137 --- /dev/null +++ b/tests/python/frontend/pytorch/test_fx_quant.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Tests on fx-quantized torch model conversion """ +import torch +import torchvision +import numpy as np +from torch.quantization import get_default_qconfig +from torch.quantization.quantize_fx import prepare_fx, convert_fx +from torchvision.models.efficientnet import efficientnet_b4 +from torchvision.models.resnet import resnet50 +from tvm import relay + + +def quantize(model): + qconfig = get_default_qconfig("fbgemm") + qconfig_dict = {"": qconfig} + return convert_fx(prepare_fx(model, qconfig_dict)) + + +def quantize_and_build(model, in_size): + inp = torch.rand(1, 3, in_size, in_size) + input_name = "inp" + qmodel = quantize(model) + + with torch.no_grad(): + script_module = torch.jit.trace(qmodel, inp) + mod, _ = relay.frontend.from_pytorch(script_module, [(input_name, inp.shape)]) + mod = relay.transform.InferType()(mod) + + # Make sure that the model is quantized + assert "qnn.conv2d" in mod.astext(show_meta_data=False) + + # Skip building since it is slow on CI + # relay.build(mod, params=params, target="llvm") + + +def test_ssd_vgg(): + class TraceWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + features = self.model.backbone(inp) + features = list(features.values()) + out = self.model.head(features) + return out["bbox_regression"], out["cls_logits"] + + model_func = torchvision.models.detection.ssd300_vgg16 + model = TraceWrapper(model_func(num_classes=50, pretrained_backbone=True)).eval() + quantize_and_build(model, 300) + + +def test_deeplab_v3(): + class TraceWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, inp): + out = self.model(inp) + return out["out"] + + deeplabv3 = torchvision.models.segmentation.deeplabv3_mobilenet_v3_large(pretrained=True) + model = TraceWrapper(deeplabv3.eval()).eval() + quantize_and_build(model, 300) + + +def test_imagenet(): + for model_func in [resnet50, efficientnet_b4]: + quantize_and_build(model_func(pretrained=True).eval(), 224)