Skip to content
65 changes: 46 additions & 19 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks
# pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda
# pylint: disable=missing-function-docstring
# pylint: disable=missing-function-docstring, redefined-builtin
"""PT: PyTorch frontend."""
import functools
import itertools
Expand Down Expand Up @@ -45,7 +45,7 @@
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
from .common import lstm_cell, try_infer_value, unbind
from .pytorch_utils import is_version_greater_than
from .pytorch_utils import is_version_greater_than, getattr_attr_name

__all__ = ["from_pytorch"]

Expand Down Expand Up @@ -1393,10 +1393,19 @@ def log_softmax(self, inputs, input_types):

def sigmoid(self, inputs, input_types):
data = inputs[0]
return _op.tensor.sigmoid(data)

def func(x):
return _op.tensor.sigmoid(x)

if self.is_quantized_tensor(data):
assert len(inputs) == 3, "Input quant param not found in op inputs"
input_scale = _expr.const(inputs[1])
input_zero_point = _expr.const(inputs[2])
return qnn_torch.apply_with_fp32_fallback(data, input_scale, input_zero_point, func)

return func(data)

def softplus(self, inputs, input_types):
data = inputs[0]
dtype = input_types[0]
beta = _expr.const(float(inputs[1]), dtype=dtype)
return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.0, dtype=dtype)) / beta
Expand Down Expand Up @@ -1583,7 +1592,8 @@ def func(x):
assert len(inputs) == 6, "Input quant param not found in op inputs"
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])
return qnn_torch.quantized_mean(data, input_scale, input_zero_point, func)
# refer to aten/src/ATen/native/quantized/cpu/qreduction.cpp
return qnn_torch.apply_with_fp32_fallback(data, input_scale, input_zero_point, func)

return func(data)

Expand Down Expand Up @@ -1755,9 +1765,7 @@ def pad(inputs, input_types):

return pad

def clamp(self, inputs, input_types):
data = inputs[0]

def clamp_common(self, data, min=None, max=None):
def get_v(v, default_v):
if isinstance(v, _expr.Constant):
return float(v.data.numpy())
Expand All @@ -1769,10 +1777,32 @@ def get_v(v, default_v):
return v
return default_v

amin = get_v(inputs[1], np.finfo(np.float32).min)
amax = get_v(inputs[2], np.finfo(np.float32).max)
dtype = self.infer_type(data).dtype

type_info = np.finfo(dtype) if "float" in dtype else np.iinfo(dtype)

# TODO(masahi): Properly handle inf in a one-way clamp case.
if min is not None and max is not None:
amin = get_v(min, type_info.min)
amax = get_v(max, type_info.max)
elif min is not None:
amin = get_v(min, type_info.min)
amax = type_info.max
else:
amin = type_info.min
amax = get_v(max, type_info.max)

return _op.clip(data, amin, amax)

def clamp(self, inputs, _):
return self.clamp_common(inputs[0], min=inputs[1], max=inputs[2])

def clamp_min(self, inputs, input_types):
return self.clamp_common(inputs[0], min=inputs[1])

def clamp_max(self, inputs, input_types):
return self.clamp_common(inputs[0], max=inputs[1])

def to(self, inputs, input_types):
data = inputs[0]
dtype = inputs[1] if inputs[1] is not None and not isinstance(inputs[1], str) else inputs[2]
Expand Down Expand Up @@ -1847,7 +1877,8 @@ def func(x):
assert isinstance(inputs[-1], int)
input_scale = _expr.const(inputs[-2])
input_zero_point = _expr.const(inputs[-1])
return qnn_torch.quantized_upsample(data, input_scale, input_zero_point, func)
# currently piggy backs to fp32, it gets identical output as torch
return qnn_torch.apply_with_fp32_fallback(data, input_scale, input_zero_point, func)

return func(data)

Expand Down Expand Up @@ -3020,6 +3051,8 @@ def create_convert_map(self):
"aten::isinf": self.make_unary("isinf"),
"aten::isnan": self.make_unary("isnan"),
"aten::clamp": self.clamp,
"aten::clamp_min": self.clamp_min,
"aten::clamp_max": self.clamp_max,
"aten::detach": self.identity,
"aten::upsample_bilinear2d": self.make_upsample("linear"),
"aten::upsample_bicubic2d": self.make_upsample("cubic"),
Expand Down Expand Up @@ -3531,15 +3564,8 @@ def _get_users(node):
return [use.user for use in _get_uses(node)]


def _getattr_attr_name(node):
attribute_names = node.attributeNames()
assert len(attribute_names) == 1
attr_name = node.s(attribute_names[0])
return attr_name


def _getattr_full_name(getattrs, sep="."):
return sep.join([_getattr_attr_name(node) for node in getattrs])
return sep.join([getattr_attr_name(node) for node in getattrs])


def _get_pytorch_value_type(typ, default_dtype="float32"):
Expand Down Expand Up @@ -3941,6 +3967,7 @@ def from_pytorch(
weight_quant_params = qnn_torch.get_weight_quant_params(
script_module, packed_param_map.values()
)
qnn_torch.inline_input_quant_params_for_fx(graph, tensors)
input_scales_for_bias = qnn_torch.add_input_quant_params_to_op_inputs(graph)
qnn_torch.add_quant_params_to_outputs(
outputs,
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relay/frontend/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ def is_version_greater_than(ver):
)


def getattr_attr_name(node):
attribute_names = node.attributeNames()
assert len(attribute_names) == 1
return node.s(attribute_names[0])


def dyn_strided_slice_pattern(inp, end):
"""A pattern to detect dynamic strided slice op."""
zero = is_constant()
Expand Down
84 changes: 67 additions & 17 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tvm.relay.frontend.common import infer_shape

from .common import logger
from .pytorch_utils import is_version_greater_than
from .pytorch_utils import is_version_greater_than, getattr_attr_name


class QNNParam:
Expand Down Expand Up @@ -292,8 +292,8 @@ def dfs(current_node):
for arg in current_node.inputs():
return dfs(arg.node())

# shouldn't happen
assert False, "No producer for %s" % (str(current_node))
# If input_value is not quantized, we reach here.
return None, None

return dfs(input_value.node())

Expand Down Expand Up @@ -437,6 +437,7 @@ def add_input_quant_params_to_op_inputs(graph):
"quantized::mul": 2,
"aten::dequantize": 1,
"aten::mean": 1,
"aten::sigmoid": 1,
"aten::upsample_nearest2d": 1,
"aten::upsample_bilinear2d": 1,
"aten::relu_": 1,
Expand Down Expand Up @@ -473,8 +474,9 @@ def add_input_quant_params_to_op_inputs(graph):
else:
for i in range(num_quantized_inputs[operator]):
scale, zp = _get_quant_param_for_input(node.inputsAt(i))
input_scales.append(scale)
input_zero_points.append(zp)
if scale is not None and zp is not None:
input_scales.append(scale)
input_zero_points.append(zp)

if operator in ["quantized::add_scalar", "quantized::mul_scalar"]:
scalar = node.inputsAt(1).node().f("value")
Expand All @@ -488,9 +490,10 @@ def add_input_quant_params_to_op_inputs(graph):
node.addInput(scale)
node.addInput(zp)

if "conv" in operator or "linear" in operator:
if "quantized::conv" in operator or "quantized::linear" in operator:
# This is required for quantizing the bias
input_scales_for_bias[node.inputsAt(1).debugName()] = scale.node().f("value")
assert len(input_scales) == 1, "One quantized parameter expected for qconv or qlinear."
input_scales_for_bias[node.inputsAt(1).debugName()] = input_scales[0].node().f("value")

return input_scales_for_bias

Expand All @@ -503,26 +506,67 @@ def add_quant_params(params, quant_params):
params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias)


def inline_input_quant_params_for_fx(graph, params):
"""
Canonicalize input scale and zero point access for FX-quantized graphs.
We expect input qparams to aten::quantize_per_tensor to be prim::Constant, but that's
not the case for FX-based quantized models as shown below.
We replace prim::GetAttr with prim::Constant so that FX-based quantized models can be
converted in the same way as eager-mode based quantized models.

Before:
%pan_input_zero_point_1 : Tensor = prim::GetAttr[name="pan_input_zero_point_1"](%backbone)
%pan_input_scale_1 : Tensor = prim::GetAttr[name="pan_input_scale_1"](%backbone)
...
%quantize_per_tensor_2 ... = aten::quantize_per_tensor(...,
%pan_input_scale_1, %pan_input_zero_point_1, ...)

After:
%2402 : int = prim::Constant[value=0]()
%2403 : float = prim::Constant[value=1.]()
%quantize_per_tensor_2 ... = aten::quantize_per_tensor(..., %2403, %2402, ...)
"""
import torch

def get_full_attr_name(current):
current_attr = getattr_attr_name(current)
inputs = list(current.inputs())
if len(inputs) == 1 and inputs[0].node().kind() == "prim::GetAttr":
return get_full_attr_name(inputs[0].node()) + "." + current_attr
return current_attr

for node in graph.findAllNodes("prim::GetAttr", recurse=True):
out_name = node.output().debugName()

if "_input_scale" in out_name or "_input_zero_point" in out_name:
full_attr = get_full_attr_name(node)
assert full_attr in params, "%s not found in param dict." % full_attr
param_np = params[full_attr].numpy()
new_const_node = graph.create("prim::Constant")
new_const_node.insertBefore(node)

if "_input_scale" in out_name:
new_const_node.f_("value", param_np)
new_const_node.output().setType(torch._C.FloatType.get())
else:
new_const_node.i_("value", param_np.item())
new_const_node.output().setType(torch._C.IntType.get())

node.replaceAllUsesWith(new_const_node)


def apply_with_upcast(data, func):
inp = _op.cast(data, dtype="int32")
out = func(inp)
return _op.cast(out, "uint8")


def quantized_mean(data, input_scale, input_zero_point, func_fp32):
# refer to aten/src/ATen/native/quantized/cpu/qreduction.cpp
def apply_with_fp32_fallback(data, input_scale, input_zero_point, func_fp32):
dequantized = relay.qnn.op.dequantize(data, input_scale, input_zero_point)
out = func_fp32(dequantized)
return relay.qnn.op.quantize(out, input_scale, input_zero_point, out_dtype="uint8", axis=1)


def quantized_upsample(data, input_scale, input_zero_point, func_fp32):
# currently piggy backs to fp32, it gets identical output as torch
data = relay.qnn.op.dequantize(data, input_scale, input_zero_point)
out = func_fp32(data)
return relay.qnn.op.quantize(out, input_scale, input_zero_point, out_dtype="uint8", axis=1)


def quantized_relu(data, input_zero_point):
# refer to aten/src/ATen/native/quantized/cpu/qrelu.cpp
zp = _op.cast(input_zero_point, dtype="uint8")
Expand All @@ -531,8 +575,14 @@ def quantized_relu(data, input_zero_point):

def _quantize_per_tensor():
def _impl(inputs, _):
dim = len(infer_shape(inputs[0]))
if dim > 1:
axis = 1
else:
axis = 0

return relay.qnn.op.quantize(
inputs[0], _expr.const(inputs[1]), _expr.const(inputs[2]), out_dtype="uint8", axis=1
inputs[0], _expr.const(inputs[1]), _expr.const(inputs[2]), out_dtype="uint8", axis=axis
)

return _impl
Expand Down
4 changes: 4 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2790,6 +2790,10 @@ def forward(self, *args):
verify_model(Clamp3().float().eval(), input_data=input_data)
verify_model(Clamp_MinExpr_MaxConstant().float().eval(), input_data=input_data)

verify_model(lambda inp: torch.clamp_min(inp, 0.5), input_data)
inp_uint8 = torch.randint(low=0, high=256, size=(100, 100), dtype=torch.uint8)
verify_model(lambda inp: torch.clamp_max(inp, 125), inp_uint8)


@tvm.testing.uses_gpu
def test_forward_clamp_():
Expand Down
85 changes: 85 additions & 0 deletions tests/python/frontend/pytorch/test_fx_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Tests on fx-quantized torch model conversion """
import torch
import torchvision
import numpy as np
from torch.quantization import get_default_qconfig
from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torchvision.models.efficientnet import efficientnet_b4
from torchvision.models.resnet import resnet50
from tvm import relay


def quantize(model):
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}
return convert_fx(prepare_fx(model, qconfig_dict))


def quantize_and_build(model, in_size):
inp = torch.rand(1, 3, in_size, in_size)
input_name = "inp"
qmodel = quantize(model)

with torch.no_grad():
script_module = torch.jit.trace(qmodel, inp)
mod, _ = relay.frontend.from_pytorch(script_module, [(input_name, inp.shape)])
mod = relay.transform.InferType()(mod)

# Make sure that the model is quantized
assert "qnn.conv2d" in mod.astext(show_meta_data=False)

# Skip building since it is slow on CI
# relay.build(mod, params=params, target="llvm")


def test_ssd_vgg():
class TraceWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, inp):
features = self.model.backbone(inp)
features = list(features.values())
out = self.model.head(features)
return out["bbox_regression"], out["cls_logits"]

model_func = torchvision.models.detection.ssd300_vgg16
model = TraceWrapper(model_func(num_classes=50, pretrained_backbone=True)).eval()
quantize_and_build(model, 300)


def test_deeplab_v3():
class TraceWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, inp):
out = self.model(inp)
return out["out"]

deeplabv3 = torchvision.models.segmentation.deeplabv3_mobilenet_v3_large(pretrained=True)
model = TraceWrapper(deeplabv3.eval()).eval()
quantize_and_build(model, 300)


def test_imagenet():
for model_func in [resnet50, efficientnet_b4]:
quantize_and_build(model_func(pretrained=True).eval(), 224)