From bf2cd9e6e713494dee413e2e4ae76dedf410036f Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Sat, 27 Dec 2025 13:12:34 +0800 Subject: [PATCH] Add edge padding mode --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 8 ++++---- tests/python/relax/test_frontend_onnx.py | 4 ++++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 24a4014f840a..b41f56bfd42f 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1881,8 +1881,8 @@ def _impl_v2(cls, bb, inputs, attr, params): elif pad_mode == "reflect": return bb.emit_te(topi.nn.mirror_pad, inputs[0], pad_before, pad_after, "REFLECT") else: - # TODO(gigiblender) Support edge mode. - raise NotImplementedError("Pad mode {} not implemented".format(pad_mode)) + # edge mode - replicate border values + return bb.emit_te(topi.nn.replicate_pad, inputs[0], pad_before, pad_after) @classmethod def _impl_v11(cls, bb, inputs, attr, params): @@ -1911,8 +1911,8 @@ def _impl_v11(cls, bb, inputs, attr, params): elif pad_mode == "reflect": return bb.emit_te(topi.nn.mirror_pad, inputs[0], pad_before, pad_after, "REFLECT") else: - # TODO(gigiblender) Support edge mode. - raise NotImplementedError("Pad mode {} not implemented".format(pad_mode)) + # edge mode - replicate border values + return bb.emit_te(topi.nn.replicate_pad, inputs[0], pad_before, pad_after) class Tile(OnnxOpConverter): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 23348cf84757..447e1ac99d63 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -2440,6 +2440,8 @@ def verify_pad(input_shape, pads, mode="constant", value=0.0): verify_pad((2, 3), [1, 0, 0, 1], "constant", 0.0) verify_pad((3, 2), [0, 0, 1, 0], "constant", 5.0) verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "reflect") + verify_pad((2, 3), [1, 1, 1, 1], "edge") + verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "edge") @pytest.mark.parametrize("dynamic", [True, False]) @@ -2496,6 +2498,8 @@ def verify_pad(input_shape, pads, mode="constant", value=0.0): verify_pad((2, 3), [1, 0, 0, 1], "constant", 0.0) verify_pad((3, 2), [0, 0, 1, 0], "constant", 5.0) verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "reflect") + verify_pad((2, 3), [1, 1, 1, 1], "edge") + verify_pad((1, 3, 4, 5), [0, 1, 1, 1, 0, 0, 1, 1], "edge") @pytest.mark.parametrize("fp_arith", [np.float16, np.float32])