Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/tvm/topi/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,19 @@ TOPI_DEFINE_BCAST_OP(minimum, { return tvm::min(a, b); });
*/
TOPI_DEFINE_BCAST_OP(power, { return tvm::pow(a, b); });

/*!
* \fn atan2
* \brief Compute atan2(y, x) with auto-broadcasting.
*
* \param A The first tensor, or Expr (y-coordinates).
* \param B The second tensor, or Expr (x-coordinates).
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(atan2, { return tvm::atan2(a, b); });

/*!
* \fn left_shift
* \brief Compute A << B with auto-broadcasting.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(self, model, subgraph, exp_tab, ctx):
"ADD_N": self.convert_add_n,
"ARG_MAX": functools.partial(self._convert_arg_min_max, relax_op=_op.argmax),
"ARG_MIN": functools.partial(self._convert_arg_min_max, relax_op=_op.argmin),
"ATAN2": functools.partial(self._convert_elemwise, relax_op=_op.atan2),
"AVERAGE_POOL_2D": functools.partial(self.convert_pool2d, pool_type="average"),
"BATCH_TO_SPACE_ND": self.convert_batch_to_space_nd,
"BATCH_MATMUL": self.convert_batch_matmul,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
)
from .binary import (
add,
atan2,
bitwise_and,
bitwise_or,
bitwise_xor,
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/relax/op/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,24 @@ def power(x1: Expr, x2: Expr):
return _ffi_api.power(x1, x2) # type: ignore


def atan2(x1: Expr, x2: Expr) -> Expr:
"""Atan2 with numpy-style broadcasting.

Parameters
----------
x1 : relax.Expr
The first input tensor (y-coordinates).
x2 : relax.Expr
The second input tensor (x-coordinates).

Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.atan2(x1, x2) # type: ignore


def subtract(x1: Expr, x2: Expr) -> Expr:
"""Subtraction with numpy-style broadcasting.

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/script/builder/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
assert_op,
astype,
atan,
atan2,
atanh,
bitwise_and,
bitwise_not,
Expand Down Expand Up @@ -813,6 +814,7 @@ def dtype(value: py_str | DataType) -> Expr:
"assert_op",
"astype",
"atan",
"atan2",
"atanh",
"bitwise_and",
"bitwise_not",
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/transform/legalize_ops/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def binary_call_te(bb: BlockBuilder, call: Call) -> Expr:
register_legalize("relax.log_add_exp", _binary(topi.log_add_exp))
register_legalize("relax.multiply", _binary(topi.multiply))
register_legalize("relax.power", _binary(topi.power))
register_legalize("relax.atan2", _binary(topi.atan2))
register_legalize("relax.subtract", _binary(topi.subtract))
register_legalize("relax.equal", _binary(topi.equal))
register_legalize("relax.mod", _binary(topi.mod))
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/topi/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,25 @@ def power(lhs, rhs):
return _cpp.power(lhs, rhs)


def atan2(lhs, rhs):
"""Atan2 with auto-broadcasting.

Parameters
----------
lhs : tvm.te.Tensor or Expr
The left operand (y-coordinates).
rhs : tvm.te.Tensor or Expr
The right operand (x-coordinates).

Returns
-------
ret : tvm.te.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.atan2(lhs, rhs)


def left_shift(lhs, rhs):
"""Left shift with auto-broadcasting

Expand Down
1 change: 1 addition & 0 deletions src/relax/op/tensor/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(log_add_exp);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(power);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(atan2);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(mod);
RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_mod);
Expand Down
3 changes: 3 additions & 0 deletions src/relax/op/tensor/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ Expr multiply(Expr x1, Expr x2);
/*! \brief Power with numpy-style broadcasting. */
Expr power(Expr x1, Expr x2);

/*! \brief Atan2 with numpy-style broadcasting. */
Expr atan2(Expr x1, Expr x2);

/*! \brief Subtraction with numpy-style broadcasting. */
Expr subtract(Expr x1, Expr x2);

Expand Down
1 change: 1 addition & 0 deletions src/topi/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.TOPI_DEF_BCAST_OP("topi.maximum", topi::maximum)
.TOPI_DEF_BCAST_OP("topi.minimum", topi::minimum)
.TOPI_DEF_BCAST_OP("topi.power", topi::power)
.TOPI_DEF_BCAST_OP("topi.atan2", topi::atan2)
.TOPI_DEF_BCAST_OP("topi.left_shift", topi::left_shift)
.TOPI_DEF_BCAST_OP("topi.logical_and", topi::logical_and)
.TOPI_DEF_BCAST_OP("topi.logical_or", topi::logical_or)
Expand Down
1 change: 1 addition & 0 deletions tests/python/relax/test_frontend_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,7 @@ def func(self, dims, value):
(tf.divide, R.divide),
(tf.math.floormod, R.floor_mod),
(tf.math.floordiv, R.floor_divide),
(tf.math.atan2, R.atan2),
],
)
def test_binary(tf_op, relax_op):
Expand Down
2 changes: 2 additions & 0 deletions tests/python/relax/test_op_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_op_correctness():
assert relax.op.floor_divide(x, y).op == Op.get("relax.floor_divide")
assert relax.op.multiply(x, y).op == Op.get("relax.multiply")
assert relax.op.power(x, y).op == Op.get("relax.power")
assert relax.op.atan2(x, y).op == Op.get("relax.atan2")
assert relax.op.subtract(x, y).op == Op.get("relax.subtract")
assert relax.op.mod(x, y).op == Op.get("relax.mod")
assert relax.op.floor_mod(x, y).op == Op.get("relax.floor_mod")
Expand Down Expand Up @@ -71,6 +72,7 @@ def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: r
(relax.op.floor_divide, tirx.FloorDiv),
(relax.op.multiply, tirx.Mul),
(relax.op.power, tirx.pow),
(relax.op.atan2, tirx.atan2),
(relax.op.subtract, tirx.Sub),
(relax.op.maximum, tirx.Max),
(relax.op.minimum, tirx.Min),
Expand Down
117 changes: 117 additions & 0 deletions tests/python/relax/test_transform_legalize_ops_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,123 @@ def power(
tvm.ir.assert_structural_equal(Expected, After)


def test_atan2():
# fmt: off
@tvm.script.ir_module
class Atan2:
@R.function
def main(x: R.Tensor((1, 2, 3), "float32"), y: R.Tensor((4, 3, 2, 1), "float32")) -> R.Tensor((4, 3, 2, 3), "float32"):
gv: R.Tensor((4, 3, 2, 3), "float32") = R.atan2(x, y)
return gv

@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
def atan2(rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(1)), "float32"), T_atan2: T.Buffer((T.int64(4), T.int64(3), T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tirx.noalias": True})
# with T.sblock("root"):
for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(3), T.int64(2), T.int64(3)):
with T.sblock("T_atan2"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)])
T.writes(T_atan2[v_ax0, v_ax1, v_ax2, v_ax3])
T_atan2[v_ax0, v_ax1, v_ax2, v_ax3] = T.atan2(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)])

@R.function
def main(x: R.Tensor((1, 2, 3), dtype="float32"), y: R.Tensor((4, 3, 2, 1), dtype="float32")) -> R.Tensor((4, 3, 2, 3), dtype="float32"):
gv = R.call_tir(Expected.atan2, (x, y), out_sinfo=R.Tensor((4, 3, 2, 3), dtype="float32"))
return gv

# fmt: on

mod = LegalizeOps()(Atan2)
tvm.ir.assert_structural_equal(mod, Expected)


def test_atan2_symbolic():
# fmt: off
@tvm.script.ir_module
class Atan2:
@R.function
def main(x: R.Tensor((1, "c", "d"), "float32"), y: R.Tensor(("a", "b", "c", 1), "float32")) -> R.Tensor(("a", "b", "c", "d"), "float32"):
a = T.int64()
b = T.int64()
c = T.int64()
d = T.int64()
gv: R.Tensor((a, b, c, d), "float32") = R.atan2(x, y)
return gv

@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
def atan2(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_T_atan2: T.handle):
T.func_attr({"tirx.noalias": True})
c = T.int64()
d = T.int64()
rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(1), c, d))
a = T.int64()
b = T.int64()
rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (a, b, c, T.int64(1)))
T_atan2 = T.match_buffer(var_T_atan2, (a, b, c, d))
for ax0, ax1, ax2, ax3 in T.grid(a, b, c, d):
with T.sblock("T_atan2"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)])
T.writes(T_atan2[v_ax0, v_ax1, v_ax2, v_ax3])
T_atan2[v_ax0, v_ax1, v_ax2, v_ax3] = T.atan2(rxplaceholder[T.int64(0), v_ax2, v_ax3], rxplaceholder_1[v_ax0, v_ax1, v_ax2, T.int64(0)])

@R.function
def main(x: R.Tensor((1, "c", "d"), dtype="float32"), y: R.Tensor(("a", "b", "c", 1), dtype="float32")) -> R.Tensor(("a", "b", "c", "d"), dtype="float32"):
a = T.int64()
b = T.int64()
c = T.int64()
d = T.int64()
gv = R.call_tir(Expected.atan2, (x, y), out_sinfo=R.Tensor((a, b, c, d), dtype="float32"))
return gv
# fmt: on

mod = LegalizeOps()(Expected)
tvm.ir.assert_structural_equal(mod, Expected)


def test_atan2_primvalue():
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor([64, 32, 16], "float32"),
y: R.Prim("float32"),
):
gv = R.atan2(x, y)
return gv

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor([64, 32, 16], "float32"),
y: R.Prim("float32"),
):
cls = Expected
gv = R.call_tir(cls.atan2, (x, y), R.Tensor([64, 32, 16], dtype="float32"))
return gv

@T.prim_func(private=True)
def atan2(
lhs: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
rhs: T.float32,
output: T.Buffer([T.int64(64), T.int64(32), T.int64(16)], "float32"),
):
T.func_attr({"tirx.noalias": True})
for i, j, k in T.grid(*lhs.shape):
with T.sblock("T_atan2"):
vi, vj, vk = T.axis.remap("SSS", [i, j, k])
output[vi, vj, vk] = T.atan2(lhs[vi, vj, vk], rhs)

After = LegalizeOps()(Before)
tvm.ir.assert_structural_equal(Expected, After)


def test_subtract():
# fmt: off
@tvm.script.ir_module
Expand Down
Loading