[Relax][TFLite] Fix MIRROR_PAD/ONE_HOT converters and add tests for PAD, PADV2, MIRROR_PAD, TOPK_V2, ONE_HOT#19373
Merged
Conversation
MIRROR_PAD and ONE_HOT converters and add unit tests
…HOT converters Two bugs in the TFLite Relax frontend converters are fixed: 1. convert_mirror_pad called relax.op.nn.mirror_pad which does not exist in the Relax op namespace. Replace with relax.op.nn.pad using pad_mode="reflect" for REFLECT mode, and raise OpAttributeUnImplemented for SYMMETRIC mode (no equivalent in Relax). 2. convert_one_hot passed on_value and off_value as Expr (constant tensor nodes) where relax.op.one_hot requires PrimValue, and included an extra dtype positional argument that the function signature does not accept. Fix by extracting the scalar from the tensor buffer and wrapping it in relax.PrimValue with the correct dtype. Unit tests are added for PAD, PADV2, MIRROR_PAD, TOPK_V2, and ONE_HOT following the verify() + tf.Module pattern. Each test includes an explicit expected IRModule verified with tvm.ir.assert_structural_equal. Partially closes apache#18971.
MIRROR_PAD and ONE_HOT converters and add unit testsc42e8a9 to
44b64b8
Compare
MIRROR_PAD/ONE_HOT converters and add tests for PAD, PADV2, MIRROR_PAD, TOPK_V2, ONE_HOT
Contributor
There was a problem hiding this comment.
Code Review
This pull request updates the TFLite frontend in TVM Relax by refactoring the MIRROR_PAD and ONE_HOT operators. MIRROR_PAD now utilizes relax.op.nn.pad for reflection padding, while ONE_HOT has been updated to wrap its on/off values as PrimValue. Additionally, several new tests were added to cover padding and one-hot operations. Review feedback suggests casting padding values to Python integers for better compatibility and simplifying the extraction of scalar values from numpy arrays in the one-hot conversion.
MIRROR_PAD/ONE_HOT converters and add tests for PAD, PADV2, MIRROR_PAD, TOPK_V2, ONE_HOTMIRROR_PAD/ONE_HOT converters and add tests for PAD, PADV2, MIRROR_PAD, TOPK_V2, ONE_HOT
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Part of #18971
Two bugs in the TFLite Relax frontend converters are fixed, and unit tests
are added for the Padding / Sparse / Other category operators claimed in
that tracking issue.
Bug fixes
convert_mirror_padCalled
relax.op.nn.mirror_padwhich does not exist in the Relax opnamespace. Replaced with
relax.op.nn.padusingpad_mode="reflect"forREFLECT mode (the modes are semantically equivalent). SYMMETRIC mode raises
OpAttributeUnImplementedas there is no direct Relax equivalent.convert_one_hoton_valueandoff_valuewere passed asExpr(constant tensor nodes),but
relax.op.one_hotrequiresPrimValuearguments.dtypepositional argument was passed, which the functionsignature does not accept.
Fixed by extracting the scalar from each tensor buffer and wrapping it in
relax.PrimValuewith the correct dtype viatvm.tirx.FloatImm/tvm.tirx.IntImm.Tests added
Each test uses the
verify()+tf.Modulepattern and includes an explicitexpected IRModule verified with
tvm.ir.assert_structural_equal.test_padPADtest_pad_v2PADV2constant_values=5.0test_mirror_padMIRROR_PADtest_topk_v2TOPK_V2test_one_hotONE_HOT