-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[Relax][Frontend][TFLite] Fix dynamic FILL/SPLIT_V partial implementations #19433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c23665d
34093e8
0152656
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,7 +31,7 @@ | |
| import numpy as np | ||
|
|
||
| import tvm | ||
| from tvm import relax | ||
| from tvm import relax, tirx | ||
| from tvm.relax import op as _op | ||
|
|
||
| from .tflite_flexbuffer import FlexBufferDecoder | ||
|
|
@@ -1770,14 +1770,24 @@ def convert_fill(self, op): | |
| input_tensors = self.get_input_tensors(op) | ||
| assert len(input_tensors) == 2, "input tensors length should be 2" | ||
|
|
||
| if self.has_expr(input_tensors[0].tensor_idx): | ||
| raise tvm.error.OpNotImplemented( | ||
| "For dims parameter of Fill operator, only constant values are supported." | ||
| ) | ||
|
|
||
| in_dims = list(self.get_tensor_value(input_tensors[0])) | ||
| dims_tensor = input_tensors[0] | ||
| in_value_expr = self.get_expr(input_tensors[1].tensor_idx) | ||
| out = relax.op.full(in_dims, in_value_expr) | ||
|
|
||
| if self.has_expr(dims_tensor.tensor_idx): | ||
| dims_expr = self.get_expr(dims_tensor.tensor_idx) | ||
| dims_ndim = int(self.get_tensor_shape(dims_tensor)[0]) | ||
|
|
||
| # Bind runtime dims to fresh symbolic shape vars so the imported | ||
| # module remains well formed before LegalizeOps runs. | ||
| dims_expr = self.bb.match_cast(dims_expr, relax.TensorStructInfo([dims_ndim], "int32")) | ||
| dims_expr = self.bb.normalize(relax.op.astype(dims_expr, "int64")) | ||
| shape_dataflow_var = self.bb.emit(relax.op.tensor_to_shape(dims_expr)) | ||
| shape_vars = [tirx.Var(f"fill_dim_{i}", "int64") for i in range(dims_ndim)] | ||
| self.bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars)) | ||
| out = relax.op.full(relax.ShapeExpr(shape_vars), in_value_expr) | ||
| else: | ||
| in_dims = list(self.get_tensor_value(dims_tensor)) | ||
| out = relax.op.full(in_dims, in_value_expr) | ||
|
|
||
| return out | ||
|
|
||
|
|
@@ -2331,29 +2341,64 @@ def convert_split(self, op): | |
| def convert_split_v(self, op): | ||
| """SPLIT_V implementation.""" | ||
| input_tensors = self.get_input_tensors(op) | ||
| output_tensors = self.get_output_tensors(op) | ||
|
|
||
| assert len(input_tensors) == 3, "input tensors length should be 3" | ||
|
|
||
| input_tensor = input_tensors[0] | ||
| input_tensor_idx = input_tensor.tensor_idx | ||
| in_expr = self.get_expr(input_tensor_idx) | ||
|
|
||
| if self.has_expr(input_tensors[1].tensor_idx): | ||
| raise tvm.error.OpNotImplemented( | ||
| "For size_splits parameter of SPLIT_V operator, only constant values are supported." | ||
| ) | ||
| size_splits = list(self.get_tensor_value(input_tensors[1])) | ||
| size_splits = tuple(np.cumsum(size_splits)[:-1]) | ||
|
|
||
| axis_tensor = input_tensors[2] | ||
| split_axis = self.get_tensor_value(axis_tensor) | ||
| split_axis = int(self.get_tensor_value(axis_tensor)) | ||
|
|
||
| size_splits_tensor = input_tensors[1] | ||
|
|
||
| if self.has_expr(size_splits_tensor.tensor_idx): | ||
| # Dynamic size_splits case: decompose into dynamic strided slices. | ||
| size_splits_expr = self.get_expr(size_splits_tensor.tensor_idx) | ||
| cumsum = relax.op.cumsum(size_splits_expr, axis=0, dtype="int64") | ||
| # Pad a leading zero so that cumsum[i-1] can be read uniformly | ||
| # via strided_slice even for i == 0. | ||
| zero = relax.const(np.array([0], dtype="int64"), "int64") | ||
| padded_cumsum = relax.op.concat([zero, cumsum], axis=0) | ||
| # TFLite fixes the tuple arity in the graph, even when the split | ||
| # sizes themselves are supplied at runtime. | ||
| num_splits = len(output_tensors) | ||
| rank = len(in_expr.struct_info.shape) | ||
|
|
||
| # end_base is the full input shape; only split_axis changes per slice. | ||
| end_base = relax.op.shape_to_tensor(relax.op.shape_of(in_expr)) | ||
| begin_base = relax.const(np.zeros((rank,), dtype="int64"), "int64") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, these are naming-only suggestions. Since the variables are local and the current intent is already documented by comments, I’d prefer to keep the diff minimal. |
||
| strides = relax.const(np.ones((rank,), dtype="int64"), "int64") | ||
| scatter_idx = relax.const([split_axis], "int64") | ||
|
|
||
| outputs = [] | ||
| for i in range(num_splits): | ||
| start_val = relax.op.strided_slice( | ||
| padded_cumsum, axes=[0], begin=[i], end=[i + 1] | ||
| ) | ||
| end_val = relax.op.strided_slice( | ||
| padded_cumsum, axes=[0], begin=[i + 1], end=[i + 2] | ||
| ) | ||
|
|
||
| begin = relax.op.scatter_elements(begin_base, scatter_idx, start_val) | ||
| end = relax.op.scatter_elements(end_base, scatter_idx, end_val) | ||
| slice_i = relax.op.dynamic_strided_slice(in_expr, begin, end, strides) | ||
| outputs.append(slice_i) | ||
|
|
||
| out = relax.Tuple(outputs) | ||
| else: | ||
| # Static size_splits case | ||
| size_splits = list(self.get_tensor_value(size_splits_tensor)) | ||
| size_splits = tuple(np.cumsum(size_splits)[:-1]) | ||
| out = relax.op.split(in_expr, size_splits, axis=split_axis) | ||
|
|
||
| out = relax.op.split(in_expr, size_splits, axis=int(split_axis)) | ||
| # Relay does not like a TupleWrapper of 1 element, further this | ||
| # only shows up with tf1.13 if we use a split with num_splits==1. | ||
| # In tf 1.14 this doesn't appear as it is automatically a reshape | ||
| # operation. | ||
| if isinstance(out, relax.Tuple) and out.size == 1: | ||
| if isinstance(out, relax.Tuple) and len(out.fields) == 1: | ||
| out = out[0] | ||
|
|
||
| return out | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable
num_splitsis used but it is better to uselen(output_tensors)directly if it is only used once, or rename it tonum_outputsto be more descriptive of what it represents in the context of TFLite output tensors.