From 7e227302cbc08bc98d3d334a62a1ac476bafed3a Mon Sep 17 00:00:00 2001 From: fnhirwa Date: Sun, 14 Jun 2026 15:41:46 +0200 Subject: [PATCH 1/3] complex operators --- .../relax/frontend/tflite/tflite_frontend.py | 79 ++++++++++++++++++- tests/python/relax/test_frontend_tflite.py | 76 ++++++++++++++++++ 2 files changed, 154 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index d14643d75c60..409ddd3cfafc 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -207,6 +207,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "BROADCAST_ARGS": self.convert_broadcast_args, "CALL": self.convert_call, "CALL_ONCE": self.convert_call_once, + "COMPLEX_ABS": self.convert_complex_abs, "CAST": self.convert_cast, "CEIL": functools.partial(self._convert_unary_elemwise, relax_op=_op.ceil), "CONCATENATION": self.convert_concatenation, @@ -252,6 +253,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "HASHTABLE_LOOKUP": self.convert_hashtable_lookup, "HASHTABLE_SIZE": self.convert_hashtable_size, "IF": self.convert_if, + "IMAG": self.convert_imag, "L2_NORMALIZATION": self.convert_l2_normalization, "L2_POOL_2D": functools.partial(self.convert_pool2d, pool_type="l2"), "LEAKY_RELU": self.convert_leaky_relu, @@ -295,6 +297,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "RANDOM_STANDARD_NORMAL": self.convert_random_standard_normal, "RANDOM_UNIFORM": self.convert_random_uniform, "READ_VARIABLE": self.convert_read_variable, + "REAL": self.convert_real, "REDUCE_ALL": functools.partial(self._convert_reduce_bool, relax_op=_op.min), "REDUCE_ANY": functools.partial(self._convert_reduce_bool, relax_op=_op.max), "REDUCE_MAX": functools.partial(self._convert_reduce, relax_op=_op.max), @@ -303,6 +306,7 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "RELU": self.convert_relu, "RELU6": self.convert_relu6, "RELU_N1_TO_1": self.convert_relu_n1_to_1, + "RFFT2D": self.convert_rfft2d, "RESHAPE": self.convert_reshape, "RESIZE_BILINEAR": self.convert_resize_bilinear, "RESIZE_NEAREST_NEIGHBOR": self.convert_resize_nearest_neighbor, @@ -7580,6 +7584,69 @@ def convert_fake_quant(self, op): rounded = relax.op.floor(_op.add(_op.multiply(clamped_shifted, inv_scale), half)) return relax.op.add(_op.multiply(rounded, scale_expr), nudged_min_expr) + def convert_real(self, op): + """Convert TFLite REAL op. + + TFLite complex64 tensors are represented as float32[..., 2] in Relax, + where index 0 = real part, index 1 = imaginary part along the last axis + """ + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = self.get_expr(input_tensors[0].tensor_idx) + last_axis = int(input_tensor.struct_info.ndim) - 1 + # slice last axis at index 0, and squeeze to remove the last axis + real = _op.strided_slice(input_tensor, begin=[0], end=[1], strides=[1], axes=[last_axis]) + return _op.squeeze(real, axis=[last_axis]) + + def convert_imag(self, op): + """Convert TFLite IMAG op. + + See convert_real for representation of complex64 tensors in Relax. + """ + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = self.get_expr(input_tensors[0].tensor_idx) + last_axis = int(input_tensor.struct_info.ndim) - 1 + # slice last axis at index 1, and squeeze to remove the last axis + imag = _op.strided_slice(input_tensor, begin=[1], end=[2], strides=[1], axes=[last_axis]) + return _op.squeeze(imag, axis=[last_axis]) + + def convert_complex_abs(self, op): + """Convert TFLite COMPLEX_ABS op: sqrt(real^2 + imag^2) + + See convert_real for the float32[..., 2] complex representation convention. + """ + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = self.get_expr(input_tensors[0].tensor_idx) + last_axis = int(input_tensor.struct_info.ndim) - 1 + real = self.bb.emit( + _op.strided_slice(input_tensor, begin=[0], end=[1], strides=[1], axes=[last_axis]) + ) + real = self.bb.emit(_op.squeeze(real, axis=[last_axis])) + imag = self.bb.emit( + _op.strided_slice(input_tensor, begin=[1], end=[2], strides=[1], axes=[last_axis]) + ) + imag = self.bb.emit(_op.squeeze(imag, axis=[last_axis])) + real_sq = self.bb.emit(_op.multiply(real, real)) + imag_sq = self.bb.emit(_op.multiply(imag, imag)) + sum_expr = self.bb.emit(_op.add(real_sq, imag_sq)) + return _op.sqrt(sum_expr) + + def convert_rfft2d(self, op): + """Convert TFLite RFFT2D op. + + Not implemented: Relax has no native FFT operator and topi.signal.dft + has no C++ registered backend (tvm.get_global_func returns None). + Implement relax.op.signal.rfft2d first, then route here. + """ + raise tvm.error.OpNotImplemented( + "RFFT2D is not supported in the Relax TFLite frontend. " + "topi.signal.dft is pure Python TE with no TVM_REGISTER_GLOBAL entry " + "and cannot be called via call_dps_packed. " + "A native relax.op.signal.rfft2d op is required." + ) + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) @@ -8044,8 +8111,14 @@ def _input_type(model): input_shape = tuple(tensor.ShapeAsNumpy()) tensor_type = tensor.Type() input_name = get_tensor_name(subgraph, input_) + input_dtype = _decode_type(tensor_type) + # Relax models complex64 tensors as float32[..., 2] where the trailing + # dimension stores real/imag parts. + if input_dtype == "complex64": + input_shape = input_shape + (2,) + input_dtype = "float32" shape_dict[input_name] = input_shape - dtype_dict[input_name] = _decode_type(tensor_type) + dtype_dict[input_name] = input_dtype return shape_dict, dtype_dict @@ -8183,6 +8256,10 @@ def func(self, data): dtype = ( _dtype_dict[model_input_name] if model_input_name in _dtype_dict else "float32" ) + if dtype == "complex64": + dtype = "float32" + if shape is not None: + shape = tuple(shape) + (2,) input_var = relax.Var( name_hint=model_input_name, struct_info=relax.TensorStructInfo(shape=shape, dtype=dtype), diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index e4483b9d41cc..5492eb8edd32 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -13020,5 +13020,81 @@ def test_unidirectional_sequence_rnn_time_major(): assert tuple(int(d) for d in out_shape) == (batch, time, num_units) +def test_real(): + class Real(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.complex64)]) + def func(self, x): + return tf.math.real(x) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + # slice real part (index 0 along last axis) + lv: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( + x, axes=[2], begin=[0], end=[1], strides=[1] + ) + gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[2]) + R.output(gv) + return gv + + verify(Real, Expected) + + +def test_imag(): + class Imag(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.complex64)]) + def func(self, x): + return tf.math.imag(x) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + # slice imaginary part (index 1 along last axis) + lv: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( + x, axes=[2], begin=[1], end=[2], strides=[1] + ) + gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[2]) + R.output(gv) + return gv + + verify(Imag, Expected) + + +def test_complex_abs(): + class ComplexAbs(tf.Module): + @tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), dtype=tf.complex64)]) + def func(self, x): + return tf.math.abs(x) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + lv0: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( + x, axes=[2], begin=[0], end=[1], strides=[1] + ) + real: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv0, axis=[2]) + lv1: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( + x, axes=[2], begin=[1], end=[2], strides=[1] + ) + imag: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv1, axis=[2]) + lv2: R.Tensor((2, 4), dtype="float32") = R.multiply(real, real) + lv3: R.Tensor((2, 4), dtype="float32") = R.multiply(imag, imag) + lv4: R.Tensor((2, 4), dtype="float32") = R.add(lv2, lv3) + gv: R.Tensor((2, 4), dtype="float32") = R.sqrt(lv4) + R.output(gv) + return gv + + verify(ComplexAbs, Expected) + + if __name__ == "__main__": pytest.main(["-s", __file__]) From 9c464770ec19b0212731f2a6d5c62458f8979091 Mon Sep 17 00:00:00 2001 From: fnhirwa Date: Sun, 14 Jun 2026 16:18:08 +0200 Subject: [PATCH 2/3] apply gemini suggestions --- .../relax/frontend/tflite/tflite_frontend.py | 19 ++++++++----------- tests/python/relax/test_frontend_tflite.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 409ddd3cfafc..e2a7eed8f1e7 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -7593,10 +7593,9 @@ def convert_real(self, op): input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = self.get_expr(input_tensors[0].tensor_idx) - last_axis = int(input_tensor.struct_info.ndim) - 1 # slice last axis at index 0, and squeeze to remove the last axis - real = _op.strided_slice(input_tensor, begin=[0], end=[1], strides=[1], axes=[last_axis]) - return _op.squeeze(real, axis=[last_axis]) + real = _op.strided_slice(input_tensor, begin=[0], end=[1], strides=[1], axes=[-1]) + return _op.squeeze(real, axis=[-1]) def convert_imag(self, op): """Convert TFLite IMAG op. @@ -7606,10 +7605,9 @@ def convert_imag(self, op): input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = self.get_expr(input_tensors[0].tensor_idx) - last_axis = int(input_tensor.struct_info.ndim) - 1 # slice last axis at index 1, and squeeze to remove the last axis - imag = _op.strided_slice(input_tensor, begin=[1], end=[2], strides=[1], axes=[last_axis]) - return _op.squeeze(imag, axis=[last_axis]) + imag = _op.strided_slice(input_tensor, begin=[1], end=[2], strides=[1], axes=[-1]) + return _op.squeeze(imag, axis=[-1]) def convert_complex_abs(self, op): """Convert TFLite COMPLEX_ABS op: sqrt(real^2 + imag^2) @@ -7619,15 +7617,14 @@ def convert_complex_abs(self, op): input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 1, "input tensors length should be 1" input_tensor = self.get_expr(input_tensors[0].tensor_idx) - last_axis = int(input_tensor.struct_info.ndim) - 1 real = self.bb.emit( - _op.strided_slice(input_tensor, begin=[0], end=[1], strides=[1], axes=[last_axis]) + _op.strided_slice(input_tensor, begin=[0], end=[1], strides=[1], axes=[-1]) ) - real = self.bb.emit(_op.squeeze(real, axis=[last_axis])) + real = self.bb.emit(_op.squeeze(real, axis=[-1])) imag = self.bb.emit( - _op.strided_slice(input_tensor, begin=[1], end=[2], strides=[1], axes=[last_axis]) + _op.strided_slice(input_tensor, begin=[1], end=[2], strides=[1], axes=[-1]) ) - imag = self.bb.emit(_op.squeeze(imag, axis=[last_axis])) + imag = self.bb.emit(_op.squeeze(imag, axis=[-1])) real_sq = self.bb.emit(_op.multiply(real, real)) imag_sq = self.bb.emit(_op.multiply(imag, imag)) sum_expr = self.bb.emit(_op.add(real_sq, imag_sq)) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 5492eb8edd32..dc572e1edd75 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -13034,9 +13034,9 @@ def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="flo with R.dataflow(): # slice real part (index 0 along last axis) lv: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( - x, axes=[2], begin=[0], end=[1], strides=[1] + x, axes=[-1], begin=[0], end=[1], strides=[1] ) - gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[2]) + gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[-1]) R.output(gv) return gv @@ -13057,9 +13057,9 @@ def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="flo with R.dataflow(): # slice imaginary part (index 1 along last axis) lv: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( - x, axes=[2], begin=[1], end=[2], strides=[1] + x, axes=[-1], begin=[1], end=[2], strides=[1] ) - gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[2]) + gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[-1]) R.output(gv) return gv @@ -13079,13 +13079,13 @@ def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="flo R.func_attr({"num_input": 1}) with R.dataflow(): lv0: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( - x, axes=[2], begin=[0], end=[1], strides=[1] + x, axes=[-1], begin=[0], end=[1], strides=[1] ) - real: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv0, axis=[2]) + real: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv0, axis=[-1]) lv1: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( - x, axes=[2], begin=[1], end=[2], strides=[1] + x, axes=[-1], begin=[1], end=[2], strides=[1] ) - imag: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv1, axis=[2]) + imag: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv1, axis=[-1]) lv2: R.Tensor((2, 4), dtype="float32") = R.multiply(real, real) lv3: R.Tensor((2, 4), dtype="float32") = R.multiply(imag, imag) lv4: R.Tensor((2, 4), dtype="float32") = R.add(lv2, lv3) From ae35bcb0e915f5287e70549c532f1cd1d614cd06 Mon Sep 17 00:00:00 2001 From: fnhirwa Date: Tue, 16 Jun 2026 20:08:54 +0200 Subject: [PATCH 3/3] proper complex64 lowering --- .../relax/frontend/tflite/tflite_frontend.py | 29 ++++-------- tests/python/relax/test_frontend_tflite.py | 46 +++++++++++++------ 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index e2a7eed8f1e7..3bd87d0af414 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -306,7 +306,6 @@ def __init__(self, model, subgraph, exp_tab, ctx, conversion_state=None): "RELU": self.convert_relu, "RELU6": self.convert_relu6, "RELU_N1_TO_1": self.convert_relu_n1_to_1, - "RFFT2D": self.convert_rfft2d, "RESHAPE": self.convert_reshape, "RESIZE_BILINEAR": self.convert_resize_bilinear, "RESIZE_NEAREST_NEIGHBOR": self.convert_resize_nearest_neighbor, @@ -1010,6 +1009,7 @@ def get_tensor_type_as_numpy(self, tensor_wrapper): TensorType.UINT32: np.uint32, TensorType.UINT64: np.uint64, TensorType.BOOL: np.bool_, + TensorType.COMPLEX64: np.complex64, }[tensor_wrapper.tensor.Type()] # pylint: disable=no-else-return @@ -1055,6 +1055,8 @@ def get_tensor_type_str(self, tensor_type): return "uint64" if tensor_type == TensorType.BOOL: return "bool" + if tensor_type == TensorType.COMPLEX64: + return "complex64" raise NotImplementedError(f"Tensor type {tensor_type!s} is currently not supported") def _get_shape_expr_from_tensor(self, shape_tensor, prefix): @@ -7630,20 +7632,6 @@ def convert_complex_abs(self, op): sum_expr = self.bb.emit(_op.add(real_sq, imag_sq)) return _op.sqrt(sum_expr) - def convert_rfft2d(self, op): - """Convert TFLite RFFT2D op. - - Not implemented: Relax has no native FFT operator and topi.signal.dft - has no C++ registered backend (tvm.get_global_func returns None). - Implement relax.op.signal.rfft2d first, then route here. - """ - raise tvm.error.OpNotImplemented( - "RFFT2D is not supported in the Relax TFLite frontend. " - "topi.signal.dft is pure Python TE with no TVM_REGISTER_GLOBAL entry " - "and cannot be called via call_dps_packed. " - "A native relax.op.signal.rfft2d op is required." - ) - def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) @@ -7673,6 +7661,12 @@ def get_tensor_expr(self, tensor, is_sparse=False): type_str = self.get_tensor_type_str(tensor.tensor.Type()) value = self.get_tensor_value_or_prefetched(tensor, is_sparse) + # complex64 constants have no native Relax dtype. Reinterpret the + # interleaved float32 storage as float32[..., 2] to match the + # convention used for input tensors. + if type_str == "complex64": + value = value.view(np.float32).reshape(value.shape + (2,)) + type_str = "float32" return self.exp_tab.new_const(value, dtype=type_str, source_name=tensor.tensor.Name()) def get_tensor_shape(self, tensor_wrapper): @@ -8109,11 +8103,6 @@ def _input_type(model): tensor_type = tensor.Type() input_name = get_tensor_name(subgraph, input_) input_dtype = _decode_type(tensor_type) - # Relax models complex64 tensors as float32[..., 2] where the trailing - # dimension stores real/imag parts. - if input_dtype == "complex64": - input_shape = input_shape + (2,) - input_dtype = "float32" shape_dict[input_name] = input_shape dtype_dict[input_name] = input_dtype diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index dc572e1edd75..e9a842b8dfe8 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -13032,9 +13032,13 @@ class Expected: def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): R.func_attr({"num_input": 1}) with R.dataflow(): - # slice real part (index 0 along last axis) lv: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( - x, axes=[-1], begin=[0], end=[1], strides=[1] + x, + (R.prim_value(-1),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, ) gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[-1]) R.output(gv) @@ -13055,9 +13059,13 @@ class Expected: def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): R.func_attr({"num_input": 1}) with R.dataflow(): - # slice imaginary part (index 1 along last axis) lv: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( - x, axes=[-1], begin=[1], end=[2], strides=[1] + x, + (R.prim_value(-1),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, ) gv: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[-1]) R.output(gv) @@ -13078,18 +13086,28 @@ class Expected: def main(x: R.Tensor((2, 4, 2), dtype="float32")) -> R.Tensor((2, 4), dtype="float32"): R.func_attr({"num_input": 1}) with R.dataflow(): - lv0: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( - x, axes=[-1], begin=[0], end=[1], strides=[1] + lv: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( + x, + (R.prim_value(-1),), + (R.prim_value(0),), + (R.prim_value(1),), + (R.prim_value(1),), + assume_inbound=False, ) - real: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv0, axis=[-1]) - lv1: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( - x, axes=[-1], begin=[1], end=[2], strides=[1] + lv1: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv, axis=[-1]) + lv2: R.Tensor((2, 4, 1), dtype="float32") = R.strided_slice( + x, + (R.prim_value(-1),), + (R.prim_value(1),), + (R.prim_value(2),), + (R.prim_value(1),), + assume_inbound=False, ) - imag: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv1, axis=[-1]) - lv2: R.Tensor((2, 4), dtype="float32") = R.multiply(real, real) - lv3: R.Tensor((2, 4), dtype="float32") = R.multiply(imag, imag) - lv4: R.Tensor((2, 4), dtype="float32") = R.add(lv2, lv3) - gv: R.Tensor((2, 4), dtype="float32") = R.sqrt(lv4) + lv3: R.Tensor((2, 4), dtype="float32") = R.squeeze(lv2, axis=[-1]) + lv4: R.Tensor((2, 4), dtype="float32") = R.multiply(lv1, lv1) + lv5: R.Tensor((2, 4), dtype="float32") = R.multiply(lv3, lv3) + lv6: R.Tensor((2, 4), dtype="float32") = R.add(lv4, lv5) + gv: R.Tensor((2, 4), dtype="float32") = R.sqrt(lv6) R.output(gv) return gv