diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h index f8ef2edc39d1..b0c6ac8f6722 100644 --- a/include/tvm/topi/broadcast.h +++ b/include/tvm/topi/broadcast.h @@ -384,6 +384,19 @@ TOPI_DEFINE_BCAST_OP(minimum, { return tvm::min(a, b); }); */ TOPI_DEFINE_BCAST_OP(power, { return tvm::pow(a, b); }); +/*! + * \fn atan2 + * \brief Compute atan2(y, x) with auto-broadcasting. + * + * \param A The first tensor, or Expr (y-coordinates). + * \param B The second tensor, or Expr (x-coordinates). + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return The result. + */ +TOPI_DEFINE_BCAST_OP(atan2, { return tvm::atan2(a, b); }); + /*! * \fn left_shift * \brief Compute A << B with auto-broadcasting. diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index b7a7e42c488d..155b6301f937 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -121,6 +121,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): "ADD_N": self.convert_add_n, "ARG_MAX": functools.partial(self._convert_arg_min_max, relax_op=_op.argmax), "ARG_MIN": functools.partial(self._convert_arg_min_max, relax_op=_op.argmin), + "ATAN2": functools.partial(self._convert_elemwise, relax_op=_op.atan2), "AVERAGE_POOL_2D": functools.partial(self.convert_pool2d, pool_type="average"), "BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd, "BATCH_MATMUL": self.convert_batch_matmul, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 6f985ef36cac..473e50ed30b5 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -47,6 +47,7 @@ ) from .binary import ( add, + atan2, bitwise_and, bitwise_or, bitwise_xor, diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py index 939ba6927571..9480612e6f52 100644 --- a/python/tvm/relax/op/binary.py +++ b/python/tvm/relax/op/binary.py @@ -141,6 +141,24 @@ def power(x1: Expr, x2: Expr): return _ffi_api.power(x1, x2) # type: ignore +def atan2(x1: Expr, x2: Expr) -> Expr: + """Atan2 with numpy-style broadcasting. + + Parameters + ---------- + x1 : relax.Expr + The first input tensor (y-coordinates). + x2 : relax.Expr + The second input tensor (x-coordinates). + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.atan2(x1, x2) # type: ignore + + def subtract(x1: Expr, x2: Expr) -> Expr: """Subtraction with numpy-style broadcasting. diff --git a/python/tvm/relax/script/builder/ir.py b/python/tvm/relax/script/builder/ir.py index f62164dbd767..84ad485a33bf 100644 --- a/python/tvm/relax/script/builder/ir.py +++ b/python/tvm/relax/script/builder/ir.py @@ -54,6 +54,7 @@ assert_op, astype, atan, + atan2, atanh, bitwise_and, bitwise_not, @@ -813,6 +814,7 @@ def dtype(value: py_str | DataType) -> Expr: "assert_op", "astype", "atan", + "atan2", "atanh", "bitwise_and", "bitwise_not", diff --git a/python/tvm/relax/transform/legalize_ops/binary.py b/python/tvm/relax/transform/legalize_ops/binary.py index 85e3f0644020..355fed86b982 100644 --- a/python/tvm/relax/transform/legalize_ops/binary.py +++ b/python/tvm/relax/transform/legalize_ops/binary.py @@ -49,6 +49,7 @@ def binary_call_te(bb: BlockBuilder, call: Call) -> Expr: register_legalize("relax.log_add_exp", _binary(topi.log_add_exp)) register_legalize("relax.multiply", _binary(topi.multiply)) register_legalize("relax.power", _binary(topi.power)) +register_legalize("relax.atan2", _binary(topi.atan2)) register_legalize("relax.subtract", _binary(topi.subtract)) register_legalize("relax.equal", _binary(topi.equal)) register_legalize("relax.mod", _binary(topi.mod)) diff --git a/python/tvm/topi/broadcast.py b/python/tvm/topi/broadcast.py index c00495b03237..e97730bc8850 100644 --- a/python/tvm/topi/broadcast.py +++ b/python/tvm/topi/broadcast.py @@ -249,6 +249,25 @@ def power(lhs, rhs): return _cpp.power(lhs, rhs) +def atan2(lhs, rhs): + """Atan2 with auto-broadcasting. + + Parameters + ---------- + lhs : tvm.te.Tensor or Expr + The left operand (y-coordinates). + rhs : tvm.te.Tensor or Expr + The right operand (x-coordinates). + + Returns + ------- + ret : tvm.te.Tensor or Expr + Returns Expr if both operands are Expr. + Otherwise returns Tensor. + """ + return _cpp.atan2(lhs, rhs) + + def left_shift(lhs, rhs): """Left shift with auto-broadcasting diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index 71c00e09e42a..07c3364a9f35 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -203,6 +203,7 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(log_add_exp); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(power); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(atan2); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(mod); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_mod); diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index b5650fad2735..a0dfbd66e6f9 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -81,6 +81,9 @@ Expr multiply(Expr x1, Expr x2); /*! \brief Power with numpy-style broadcasting. */ Expr power(Expr x1, Expr x2); +/*! \brief Atan2 with numpy-style broadcasting. */ +Expr atan2(Expr x1, Expr x2); + /*! \brief Subtraction with numpy-style broadcasting. */ Expr subtract(Expr x1, Expr x2); diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc index c90b20877101..cba8e29afab4 100644 --- a/src/topi/broadcast.cc +++ b/src/topi/broadcast.cc @@ -66,6 +66,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { .TOPI_DEF_BCAST_OP("topi.maximum", topi::maximum) .TOPI_DEF_BCAST_OP("topi.minimum", topi::minimum) .TOPI_DEF_BCAST_OP("topi.power", topi::power) + .TOPI_DEF_BCAST_OP("topi.atan2", topi::atan2) .TOPI_DEF_BCAST_OP("topi.left_shift", topi::left_shift) .TOPI_DEF_BCAST_OP("topi.logical_and", topi::logical_and) .TOPI_DEF_BCAST_OP("topi.logical_or", topi::logical_or) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 69aab2d43b93..37211d337a88 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -769,6 +769,7 @@ def func(self, dims, value): (tf.divide, R.divide), (tf.math.floormod, R.floor_mod), (tf.math.floordiv, R.floor_divide), + (tf.math.atan2, R.atan2), ], ) def test_binary(tf_op, relax_op): diff --git a/tests/python/relax/test_op_binary.py b/tests/python/relax/test_op_binary.py index 7049e6aaef87..0ac8cf1e9fd1 100644 --- a/tests/python/relax/test_op_binary.py +++ b/tests/python/relax/test_op_binary.py @@ -34,6 +34,7 @@ def test_op_correctness(): assert relax.op.floor_divide(x, y).op == Op.get("relax.floor_divide") assert relax.op.multiply(x, y).op == Op.get("relax.multiply") assert relax.op.power(x, y).op == Op.get("relax.power") + assert relax.op.atan2(x, y).op == Op.get("relax.atan2") assert relax.op.subtract(x, y).op == Op.get("relax.subtract") assert relax.op.mod(x, y).op == Op.get("relax.mod") assert relax.op.floor_mod(x, y).op == Op.get("relax.floor_mod") @@ -71,6 +72,7 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r (relax.op.floor_divide, tirx.FloorDiv), (relax.op.multiply, tirx.Mul), (relax.op.power, tirx.pow), + (relax.op.atan2, tirx.atan2), (relax.op.subtract, tirx.Sub), (relax.op.maximum, tirx.Max), (relax.op.minimum, tirx.Min), diff --git a/tests/python/relax/test_transform_legalize_ops_binary.py b/tests/python/relax/test_transform_legalize_ops_binary.py index f9b2074eab4d..42355ba757d8 100644 --- a/tests/python/relax/test_transform_legalize_ops_binary.py +++ b/tests/python/relax/test_transform_legalize_ops_binary.py @@ -791,6 +791,123 @@ def power( tvm.ir.assert_structural_equal(Expected, After) +def test_atan2(): + # fmt: off + @tvm.script.ir_module + class Atan2: + @R.function + def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"): + gv: R.Tensor((4, 3, 2, 3), "float32") = R.atan2(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func(private=True) + def atan2(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_atan2: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tirx.noalias": True}) + # with T.sblock("root"): + for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)): + with T.sblock("T_atan2"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)]) + T.writes(T_atan2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_atan2[v_ax0, v_ax1, v_ax2, v_ax3] = T.atan2(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)]) + + @R.function + def main(x: R.Tensor((1, 2, 3), dtype="float32"), y: R.Tensor((4, 3, 2, 1), dtype="float32")) -> R.Tensor((4, 3, 2, 3), dtype="float32"): + gv = R.call_tir(Expected.atan2, (x, y), out_sinfo=R.Tensor((4, 3, 2, 3), dtype="float32")) + return gv + + # fmt: on + + mod = LegalizeOps()(Atan2) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_atan2_symbolic(): + # fmt: off + @tvm.script.ir_module + class Atan2: + @R.function + def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv: R.Tensor((a, b, c, d), "float32") = R.atan2(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @T.prim_func(private=True) + def atan2(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_atan2: T.handle): + T.func_attr({"tirx.noalias": True}) + c = T.int64() + d = T.int64() + rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), c, d)) + a = T.int64() + b = T.int64() + rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (a, b, c, T.int64(1))) + T_atan2 = T.match_buffer(var_T_atan2, (a, b, c, d)) + for ax0, ax1, ax2, ax3 in T.grid(a, b, c, d): + with T.sblock("T_atan2"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)]) + T.writes(T_atan2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_atan2[v_ax0, v_ax1, v_ax2, v_ax3] = T.atan2(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)]) + + @R.function + def main(x: R.Tensor((1, "c", "d"), dtype="float32"), y: R.Tensor(("a", "b", "c", 1), dtype="float32")) -> R.Tensor(("a", "b", "c", "d"), dtype="float32"): + a = T.int64() + b = T.int64() + c = T.int64() + d = T.int64() + gv = R.call_tir(Expected.atan2, (x, y), out_sinfo=R.Tensor((a, b, c, d), dtype="float32")) + return gv + # fmt: on + + mod = LegalizeOps()(Expected) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_atan2_primvalue(): + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + gv = R.atan2(x, y) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([64, 32, 16], "float32"), + y: R.Prim("float32"), + ): + cls = Expected + gv = R.call_tir(cls.atan2, (x, y), R.Tensor([64, 32, 16], dtype="float32")) + return gv + + @T.prim_func(private=True) + def atan2( + lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + rhs: T.float32, + output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"), + ): + T.func_attr({"tirx.noalias": True}) + for i, j, k in T.grid(*lhs.shape): + with T.sblock("T_atan2"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + output[vi, vj, vk] = T.atan2(lhs[vi, vj, vk], rhs) + + After = LegalizeOps()(Before) + tvm.ir.assert_structural_equal(Expected, After) + + def test_subtract(): # fmt: off @tvm.script.ir_module