Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 52 additions & 30 deletions python/tvm/relay/op/contrib/ethosn.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def qnn_mul_pattern():
input_is_right = gen_mul_inputs(is_constant(), wildcard())
return input_is_left | input_is_right

def qnn_add_pattern():
def qnn_add_pattern(has_constant_input=False):
add_op = is_op("qnn.add")
gen_add_inputs = lambda x, y: add_op(
x,
Expand All @@ -227,11 +227,13 @@ def qnn_add_pattern():
is_constant(),
is_constant(),
)
two_inputs = gen_add_inputs(wildcard(), wildcard())
input_is_left = gen_add_inputs(wildcard(), is_constant())
input_is_right = gen_add_inputs(is_constant(), wildcard())

return input_is_left | input_is_right | two_inputs
if has_constant_input:
input_is_left = gen_add_inputs(wildcard(), is_constant())
input_is_right = gen_add_inputs(is_constant(), wildcard())
return input_is_left | input_is_right
else:
return gen_add_inputs(wildcard(), wildcard())
Comment thread
lhutton1 marked this conversation as resolved.

def qnn_conv2d_transpose_pattern():
pattern = is_op("qnn.conv2d_transpose")(
Expand Down Expand Up @@ -299,16 +301,24 @@ def check_leaky_relu(extract):

return _ethosn.leaky_relu(extract)

def check_mul(extract):
"""Check if Mul is supported."""
def check_mul_to_reinterpret_quantize(extract):
"""Check if Mul is supported by converting to reinterpret quantize"""
if not ethosn_available():
return False
# Do not support scalar constants for now
check_scalar = lambda i: isinstance(i, tvm.relay.Constant) and len(i.data.shape) == 0
if check_scalar(extract.args[0]) or check_scalar(extract.args[1]):
Comment thread
lhutton1 marked this conversation as resolved.

converted_extract = _ethosn.ConvertQnnMultiplyToReinterpretQuantize(extract)
if converted_extract:
return _ethosn.reinterpret_quantize(converted_extract)
return False

def check_mul_to_depthwise(extract):
"""Check if Mul is supported by converting to a depthwise operation."""
if not ethosn_available():
return False
extract = _ethosn.ConvertQnnMultiply(extract)
return _ethosn.conv2d(extract)
converted_extract = _ethosn.ConvertQnnMultiplyToDepthwise(extract)
if converted_extract:
return _ethosn.conv2d(converted_extract)
return False

def check_requantize(extract):
"""Check if requantize is supported."""
Expand All @@ -328,19 +338,40 @@ def check_add(extract):
"""Check if an addition is supported by Ethos-N."""
if not ethosn_available():
return False
# Do not support scalar constants for now
check_scalar = lambda i: isinstance(i, tvm.relay.Constant) and len(i.data.shape) == 0
if check_scalar(extract.args[0]) or check_scalar(extract.args[1]):
return False

inputs = extract.args[0:2]
if any([isinstance(i, tvm.relay.Constant) for i in inputs]):
extract = _ethosn.ConvertQnnAdd(extract)
return _ethosn.conv2d(extract)
return _ethosn.addition(extract)

def check_add_to_reinterpret_quantize(extract):
"""Check if addition can be converted to a reinterpret quantize operation."""
if not ethosn_available():
return False
converted_extract = _ethosn.ConvertQnnAddToReinterpretQuantize(extract)
if converted_extract:
return _ethosn.reinterpret_quantize(converted_extract)
return False

def check_add_to_depthwise(extract):
"""Check if addition can be converted to a depthwise operation."""
if not ethosn_available():
return False
converted_extract = _ethosn.ConvertQnnAddToDepthwise(extract)
Comment thread
lhutton1 marked this conversation as resolved.
if converted_extract:
return _ethosn.conv2d(converted_extract)
return False

return [
("ethos-n.qnn_mul", qnn_mul_pattern(), check_mul),
Comment thread
lhutton1 marked this conversation as resolved.
(
"ethos-n.qnn_mul_to_reinterpret_quantize",
qnn_mul_pattern(),
check_mul_to_reinterpret_quantize,
),
("ethos-n.qnn_mul_to_depthwise", qnn_mul_pattern(), check_mul_to_depthwise),
(
"ethos-n.qnn_add_to_reinterpret_quantize",
qnn_add_pattern(True),
check_add_to_reinterpret_quantize,
),
("ethos-n.qnn_add_to_depthwise", qnn_add_pattern(True), check_add_to_depthwise),
("ethos-n.qnn_add", qnn_add_pattern(), check_add),
("ethos-n.qnn_conv2d", qnn_conv_pattern(), check_conv2d),
("ethos-n.qnn_conv2d_transpose", qnn_conv2d_transpose_pattern(), check_conv2d_transpose),
Expand All @@ -355,15 +386,6 @@ def check_add(extract):
]


def _is_ethosn_composite(node):
if isinstance(node, tvm.relay.expr.Call) and isinstance(node.op, tvm.relay.Function):
if "Composite" in node.op.attrs:
comp_name = node.op.attrs["Composite"]
return comp_name.split(".")[0] == "ethos-n"

return False


@tvm.ir.register_op_attr("nn.max_pool2d", "target.ethos-n")
def max_pool2d(expr):
"""Check if a max pool2d is supported by Ethos-N."""
Expand Down
39 changes: 39 additions & 0 deletions src/relay/backend/contrib/ethosn/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ void InferTensorsVisitor::InferCall(const CallNode* cn) {
RequantizeParams params;
err += EthosnAPI::Requantize(cn->op.as<FunctionNode>()->body, &params);
tensor_table_[cn->args[0]] = {params.input_info};
} else if (IsEthosnFunc(call, "ethos-n.qnn_reinterpret_quantize")) {
ReinterpretQuantizationParams params;
err += EthosnAPI::ReinterpretQuantize(cn->op.as<FunctionNode>()->body, &params);
tensor_table_[cn->args[0]] = {params.input_info};
} else if (IsEthosnFunc(call, "ethos-n.qnn_resize")) {
ResizeParams params;
err += EthosnAPI::Resize(cn->op.as<FunctionNode>()->body, &params);
Expand Down Expand Up @@ -333,6 +337,9 @@ sl::TensorsAndId ConstructNetworkVisitor::HandleCall(const CallNode* cn) {
} else if (IsEthosnFunc(call, "ethos-n.qnn_requantize")) {
if ((err = MakeRequantizeLayer(call, &tensor))) ReportFatalError(call, err);
return MakeOps(tensor);
} else if (IsEthosnFunc(call, "ethos-n.qnn_reinterpret_quantize")) {
if ((err = MakeReinterpretQuantizeLayer(call, &tensor))) ReportFatalError(call, err);
return MakeOps(tensor);
} else if (IsEthosnFunc(call, "ethos-n.qnn_resize")) {
if ((err = MakeResizeLayer(call, &tensor))) ReportFatalError(call, err);
return MakeOps(tensor);
Expand Down Expand Up @@ -654,6 +661,24 @@ EthosnError ConstructNetworkVisitor::MakeRequantizeLayer(const Call& call,
return EthosnError();
}

EthosnError ConstructNetworkVisitor::MakeReinterpretQuantizeLayer(
const Call& call, sl::TensorAndId<sl::Operand>* out) {
ReinterpretQuantizationParams params;
params.input_info = GetTensorInfo(tensor_table_, call);
if (auto err = EthosnAPI::ReinterpretQuantize(call->op.as<FunctionNode>()->body, &params)) {
return err;
}

auto input = operand_table_[call->args[0]][0];

try {
*out = AddReinterpretQuantization(network_, *input, params.reinterpret_quantize_info);
} catch (const sl::NotSupportedException& e) {
return EthosnError(e.what());
}
return EthosnError();
}

EthosnError ConstructNetworkVisitor::MakeResizeLayer(const Call& call,
sl::TensorAndId<sl::Operand>* out) {
ResizeParams params;
Expand Down Expand Up @@ -1022,6 +1047,20 @@ TVM_REGISTER_GLOBAL("relay.ethos-n.support.requantize")
err += EthosnError(reason);
});

TVM_REGISTER_GLOBAL("relay.ethos-n.support.reinterpret_quantize")
.set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
Call call = args[0];
ReinterpretQuantizationParams params;
auto err = EthosnAPI::ReinterpretQuantize(call, &params);
err += EthosnCompiler::SupportedSetup();
char reason[kReasonMaxLength];
reason[0] = '\0';
*rv = !err && EthosnCompiler::GetSupported()->IsReinterpretQuantizationSupported(
params.reinterpret_quantize_info, params.input_info, &params.output_info,
reason, sizeof(reason));
err += EthosnError(reason);
});

TVM_REGISTER_GLOBAL("relay.ethos-n.support.resize")
.set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) {
Call call = args[0];
Expand Down
1 change: 1 addition & 0 deletions src/relay/backend/contrib/ethosn/codegen_ethosn.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ class ConstructNetworkVisitor : public MixedModeVisitor, private ErrorReportingP
EthosnError MakeReluLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
EthosnError MakeLeakyReLULayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
EthosnError MakeRequantizeLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
EthosnError MakeReinterpretQuantizeLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);
EthosnError MakeResizeLayer(const Call& call, sl::TensorAndId<sl::Operand>* out);

/*! \brief A look-up table from Expr to layers. */
Expand Down
Loading