diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index f5b88b0c6ad5..d70f5d837e0f 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -132,6 +132,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): "CEIL": functools.partial(self._convert_unary_elemwise, relax_op=_op.ceil), "CONCATENATION": self.convert_concatenation, "CONV_2D": functools.partial(self.convert_conv, conv_type="conv2d"), + "CONV_3D": self.convert_conv3d, "COS": functools.partial(self._convert_unary_elemwise, relax_op=_op.cos), "CUMSUM": self.convert_cumsum, "DENSIFY": self.convert_densify, @@ -2449,6 +2450,142 @@ def convert_conv(self, op, conv_type): out = self.convert_fused_activation_function(out, fused_activation_fn) return out + def convert_conv3d(self, op): + """3D convolution implementation.""" + + from tflite.BuiltinOptions import BuiltinOptions + from tflite.Conv3DOptions import Conv3DOptions + from tflite.Padding import Padding + from tflite.TensorType import TensorType + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) >= 2, "input tensors length should be >= 2" + + input_tensor = input_tensors[0] + input_tensor_idx = input_tensor.tensor_idx + weight_tensor = input_tensors[1] + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + output_tensor = output_tensors[0] + + assert op.BuiltinOptionsType() == BuiltinOptions.Conv3DOptions + op_options = op.BuiltinOptions() + conv3d_options = Conv3DOptions() + conv3d_options.Init(op_options.Bytes, op_options.Pos) + + stride_d = conv3d_options.StrideD() + stride_h = conv3d_options.StrideH() + stride_w = conv3d_options.StrideW() + dilation_d = conv3d_options.DilationDFactor() + dilation_h = conv3d_options.DilationHFactor() + dilation_w = conv3d_options.DilationWFactor() + padding = conv3d_options.Padding() + fused_activation_fn = conv3d_options.FusedActivationFunction() + + _, input_d, input_h, input_w, input_c = to_int_list(self.get_tensor_shape(input_tensor)) + # TFLite Conv3D kernel layout is already DHWIO: + # KD KH KW IC OC + kernel_d, kernel_h, kernel_w, in_channels, output_channels = to_int_list( + self.get_tensor_shape(weight_tensor) + ) + + dilated_kernel_d = dilation_d * (kernel_d - 1) + 1 + dilated_kernel_h = dilation_h * (kernel_h - 1) + 1 + dilated_kernel_w = dilation_w * (kernel_w - 1) + 1 + + params = { + "strides": [stride_d, stride_h, stride_w], + "dilation": [dilation_d, dilation_h, dilation_w], + "padding": [0, 0, 0, 0, 0, 0], + "data_layout": "NDHWC", + } + + params["kernel_layout"] = "DHWIO" + if input_c != in_channels: + assert input_c % in_channels == 0, ( + "Input channels is not divisible by kernel in_channels." + ) + params["groups"] = int(input_c / in_channels) + + # weight tensor type should be INT8/UINT8 (quantization) or FLOAT32 + weight_tensor_type = weight_tensor.tensor.Type() + assert weight_tensor_type in ( + TensorType.INT8, + TensorType.UINT8, + TensorType.FLOAT32, + ) + weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type) + + in_expr = self.get_expr(input_tensor_idx) + + # TFLite Conv3D kernel is already in DHWIO layout, no transpose needed. + if self.has_expr(weight_tensor.tensor_idx): + weight_expr = self.get_expr(weight_tensor.tensor_idx) + else: + if self.is_prefetched(weight_tensor.tensor_idx): + weight_value = self.get_prefetched_node(weight_tensor.tensor_idx) + else: + weight_value = self.get_tensor_value(weight_tensor) + + weight_expr = self.exp_tab.new_const( + weight_value, dtype=weight_tensor_type_str, + source_name=weight_tensor.tensor.Name() + ) + + if padding == Padding.VALID: + pass + elif padding == Padding.SAME: + pad_front, pad_back = get_pad_value(input_d, dilated_kernel_d, stride_d) + pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h, stride_h) + pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w, stride_w) + + do_pad = not ( + pad_front == 0 and pad_back == 0 + and pad_top == 0 and pad_bottom == 0 + and pad_left == 0 and pad_right == 0 + ) + if do_pad: + params["padding"] = [pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right] + else: + raise tvm.error.OpAttributeUnImplemented( + f"Padding format {padding} is not supported for operator Conv3D." + ) + + if input_tensor.qnn_params: + raise tvm.error.OpNotImplemented( + "Quantized Conv3D is not yet supported in the Relax frontend." + ) + + out = relax.op.nn.conv3d(in_expr, weight_expr, **params) + + # if we have bias + if len(input_tensors) == 3: + bias_tensor = input_tensors[2] + if bias_tensor.tensor_idx != -1: + bias_tensor_type = bias_tensor.tensor.Type() + # bias tensor type should be INT32 (int8 qnn) or INT64 (int16 qnn) or FLOAT32 + assert bias_tensor_type in (TensorType.INT32, TensorType.INT64, TensorType.FLOAT32) + bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type) + if self.has_expr(bias_tensor.tensor_idx): + bias_expr = self.get_expr(bias_tensor.tensor_idx) + else: + bias_expr = self.exp_tab.new_const( + self.get_tensor_value(bias_tensor), + dtype=bias_tensor_type_str, + source_name=bias_tensor.tensor.Name(), + ) + out = relax.op.add(out, bias_expr) + + # Handle fused activation. + if output_tensor.qnn_params: + raise tvm.error.OpNotImplemented( + "Quantized Conv3D is not yet supported in the Relax frontend." + ) + + out = self.convert_fused_activation_function(out, fused_activation_fn) + return out + def convert_split(self, op): """split implementation.""" diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index e4c237887e6e..d0401e464984 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -1611,6 +1611,89 @@ def main( verify(Conv2DModule, Expected) +def _make_conv3d_module(data_shape, kernel_shape, strides, padding): + class Conv3DModule(tf.Module): + @tf.function( + input_signature=[ + tf.TensorSpec(shape=data_shape, dtype=tf.float32), + tf.TensorSpec(shape=kernel_shape, dtype=tf.float32), + ] + ) + def func(self, data, kernel): + return tf.nn.conv3d( + input=data, + filters=kernel, + strides=strides, + padding=padding, + ) + + return Conv3DModule + + +def test_conv3d_valid(): + Conv3DModule = _make_conv3d_module( + (1, 8, 8, 8, 3), (3, 3, 3, 3, 16), (1, 1, 1, 1, 1), "VALID" + ) + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((1, 8, 8, 8, 3), dtype="float32"), + kernel: R.Tensor((3, 3, 3, 3, 16), dtype="float32"), + ) -> R.Tensor((1, 6, 6, 6, 16), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + gv: R.Tensor((1, 6, 6, 6, 16), dtype="float32") = R.nn.conv3d( + data, + kernel, + strides=[1, 1, 1], + padding=[0, 0, 0, 0, 0, 0], + dilation=[1, 1, 1], + groups=1, + data_layout="NDHWC", + kernel_layout="DHWIO", + out_layout="NDHWC", + out_dtype="void", + ) + R.output(gv) + return gv + + verify(Conv3DModule, Expected) + + +def test_conv3d_same(): + Conv3DModule = _make_conv3d_module( + (1, 8, 8, 8, 3), (3, 3, 3, 3, 16), (1, 1, 1, 1, 1), "SAME" + ) + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((1, 8, 8, 8, 3), dtype="float32"), + kernel: R.Tensor((3, 3, 3, 3, 16), dtype="float32"), + ) -> R.Tensor((1, 8, 8, 8, 16), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + gv: R.Tensor((1, 8, 8, 8, 16), dtype="float32") = R.nn.conv3d( + data, + kernel, + strides=[1, 1, 1], + padding=[1, 1, 1, 1, 1, 1], + dilation=[1, 1, 1], + groups=1, + data_layout="NDHWC", + kernel_layout="DHWIO", + out_layout="NDHWC", + out_dtype="void", + ) + R.output(gv) + return gv + + verify(Conv3DModule, Expected) + + def _make_pool2d_module(pool, data_shape, ksize, data_format, strides, padding): class Pool2DModule(tf.Module): @tf.function(