diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 732950ca6869..5536c369db8f 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -31,7 +31,7 @@ import numpy as np import tvm -from tvm import relax +from tvm import relax, tirx from tvm.relax import op as _op from .tflite_flexbuffer import FlexBufferDecoder @@ -1770,14 +1770,24 @@ def convert_fill(self, op): input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" - if self.has_expr(input_tensors[0].tensor_idx): - raise tvm.error.OpNotImplemented( - "For dims parameter of Fill operator, only constant values are supported." - ) - - in_dims = list(self.get_tensor_value(input_tensors[0])) + dims_tensor = input_tensors[0] in_value_expr = self.get_expr(input_tensors[1].tensor_idx) - out = relax.op.full(in_dims, in_value_expr) + + if self.has_expr(dims_tensor.tensor_idx): + dims_expr = self.get_expr(dims_tensor.tensor_idx) + dims_ndim = int(self.get_tensor_shape(dims_tensor)[0]) + + # Bind runtime dims to fresh symbolic shape vars so the imported + # module remains well formed before LegalizeOps runs. + dims_expr = self.bb.match_cast(dims_expr, relax.TensorStructInfo([dims_ndim], "int32")) + dims_expr = self.bb.normalize(relax.op.astype(dims_expr, "int64")) + shape_dataflow_var = self.bb.emit(relax.op.tensor_to_shape(dims_expr)) + shape_vars = [tirx.Var(f"fill_dim_{i}", "int64") for i in range(dims_ndim)] + self.bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars)) + out = relax.op.full(relax.ShapeExpr(shape_vars), in_value_expr) + else: + in_dims = list(self.get_tensor_value(dims_tensor)) + out = relax.op.full(in_dims, in_value_expr) return out @@ -2331,6 +2341,7 @@ def convert_split(self, op): def convert_split_v(self, op): """SPLIT_V implementation.""" input_tensors = self.get_input_tensors(op) + output_tensors = self.get_output_tensors(op) assert len(input_tensors) == 3, "input tensors length should be 3" @@ -2338,22 +2349,56 @@ def convert_split_v(self, op): input_tensor_idx = input_tensor.tensor_idx in_expr = self.get_expr(input_tensor_idx) - if self.has_expr(input_tensors[1].tensor_idx): - raise tvm.error.OpNotImplemented( - "For size_splits parameter of SPLIT_V operator, only constant values are supported." - ) - size_splits = list(self.get_tensor_value(input_tensors[1])) - size_splits = tuple(np.cumsum(size_splits)[:-1]) - axis_tensor = input_tensors[2] - split_axis = self.get_tensor_value(axis_tensor) + split_axis = int(self.get_tensor_value(axis_tensor)) + + size_splits_tensor = input_tensors[1] + + if self.has_expr(size_splits_tensor.tensor_idx): + # Dynamic size_splits case: decompose into dynamic strided slices. + size_splits_expr = self.get_expr(size_splits_tensor.tensor_idx) + cumsum = relax.op.cumsum(size_splits_expr, axis=0, dtype="int64") + # Pad a leading zero so that cumsum[i-1] can be read uniformly + # via strided_slice even for i == 0. + zero = relax.const(np.array([0], dtype="int64"), "int64") + padded_cumsum = relax.op.concat([zero, cumsum], axis=0) + # TFLite fixes the tuple arity in the graph, even when the split + # sizes themselves are supplied at runtime. + num_splits = len(output_tensors) + rank = len(in_expr.struct_info.shape) + + # end_base is the full input shape; only split_axis changes per slice. + end_base = relax.op.shape_to_tensor(relax.op.shape_of(in_expr)) + begin_base = relax.const(np.zeros((rank,), dtype="int64"), "int64") + strides = relax.const(np.ones((rank,), dtype="int64"), "int64") + scatter_idx = relax.const([split_axis], "int64") + + outputs = [] + for i in range(num_splits): + start_val = relax.op.strided_slice( + padded_cumsum, axes=[0], begin=[i], end=[i + 1] + ) + end_val = relax.op.strided_slice( + padded_cumsum, axes=[0], begin=[i + 1], end=[i + 2] + ) + + begin = relax.op.scatter_elements(begin_base, scatter_idx, start_val) + end = relax.op.scatter_elements(end_base, scatter_idx, end_val) + slice_i = relax.op.dynamic_strided_slice(in_expr, begin, end, strides) + outputs.append(slice_i) + + out = relax.Tuple(outputs) + else: + # Static size_splits case + size_splits = list(self.get_tensor_value(size_splits_tensor)) + size_splits = tuple(np.cumsum(size_splits)[:-1]) + out = relax.op.split(in_expr, size_splits, axis=split_axis) - out = relax.op.split(in_expr, size_splits, axis=int(split_axis)) # Relay does not like a TupleWrapper of 1 element, further this # only shows up with tf1.13 if we use a split with num_splits==1. # In tf 1.14 this doesn't appear as it is automatically a reshape # operation. - if isinstance(out, relax.Tuple) and out.size == 1: + if isinstance(out, relax.Tuple) and len(out.fields) == 1: out = out[0] return out diff --git a/python/tvm/relax/transform/legalize_ops/create.py b/python/tvm/relax/transform/legalize_ops/create.py index 99b4449ebf79..6708859caaff 100644 --- a/python/tvm/relax/transform/legalize_ops/create.py +++ b/python/tvm/relax/transform/legalize_ops/create.py @@ -23,7 +23,8 @@ from tvm import tirx, topi from ...block_builder import BlockBuilder -from ...expr import Call, Expr, PrimValue, const +from ...expr import Call, Expr, PrimValue, ShapeExpr, const +from ...struct_info import ShapeStructInfo from .common import LegalizeFunc, _try_convert_to_scalar_const, register_legalize @@ -34,10 +35,21 @@ def full_call_te(bb: BlockBuilder, call: Call) -> Expr: if fill_value is None else fill_value ) + shape = call.args[0].struct_info.shape if is_like else call.args[0] + + if isinstance(shape, ShapeExpr): + output_shape = shape.values + else: + assert isinstance(shape.struct_info, ShapeStructInfo) + assert shape.struct_info.ndim >= 0 + + shape = bb.emit(shape) + output_shape = [tirx.Var(f"s{i}", "int64") for i in range(shape.struct_info.ndim)] + bb.match_cast(shape, ShapeStructInfo(output_shape)) return bb.call_te( topi.full, - call.args[0].struct_info.shape if is_like else call.args[0], + output_shape, call.struct_info.dtype, _fill_value, primfunc_name_hint=primfunc_name, diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index 88ea67ab99e1..fb2aa448e31b 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -96,6 +96,8 @@ TVM_REGISTER_OP("relax.full") .add_argument("shape", "Shape", "The shape of the created tensor.") .add_argument("fill_value", "Tensor", "The scalar tensor, denoting the value to fill.") .set_attr("FInferStructInfo", InferStructInfoFull) + .set_attr("RequiresArgumentShapes", Bool(false)) + .set_attr("FDataDependent", Bool(true)) .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 908868faf0c4..23002c866898 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -208,6 +208,26 @@ def main(x: R.Tensor((1, 30), dtype="float32")) -> R.Tensor((1, 3, 10), dtype="f verify(Split, Expected) +def test_split_v_dynamic(): + """SPLIT_V with runtime split sizes imports shape-aware Relax IR.""" + + class TfSplitVDynamic(tf.Module): + @tf.function( + input_signature=[ + tf.TensorSpec(shape=(10,), dtype=tf.float32), + tf.TensorSpec(shape=(3,), dtype=tf.int32), + ] + ) + def func(self, x, size_splits): + return tf.split(x, size_splits, axis=0) + + cf = TfSplitVDynamic().func.get_concrete_function() + mod = _get_mod_from_cfunc(cf) + ir = mod.script() + assert "R.dynamic_strided_slice" in ir + assert "R.scatter_elements" in ir + + def test_pack(): class Pack(tf.Module): @tf.function( @@ -592,6 +612,28 @@ def main( verify(TfInput, Expected) +def test_fill_dynamic_dims(): + """FILL with runtime dims legalizes and compiles.""" + + class TfFillDynamic(tf.Module): + @tf.function( + input_signature=[ + tf.TensorSpec(shape=(2,), dtype=tf.int32), + tf.TensorSpec(shape=(), dtype=tf.float32), + ] + ) + def func(self, dims, value): + return tf.fill(dims, value) + + cf = TfFillDynamic().func.get_concrete_function() + mod = _get_mod_from_cfunc(cf) + ir = mod.script() + assert "R.tensor_to_shape" in ir + assert "R.full" in ir + tvm.compile(mod, tvm.target.Target("llvm")) + verify(cf) + + @pytest.mark.parametrize( "tf_op, relax_op", [