diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 1479d6f23913..c5eb0420a3c4 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3161,11 +3161,10 @@ def _impl_v1(cls, bb, inputs, attr, params): if pos_ids is None: pos_ids = relax.const([list(range(seq_len))] * batch_size, dtype="int64") - # TODO(jwfromm) Replace with relax ops once take has better support. - word_vec = bb.emit_te(topi.take, word_emb, input_ids, 0) + word_vec = relax.op.take(word_emb, input_ids, axis=0) if segment_ids: - segment_vec = bb.emit_te(topi.take, segment_emb, segment_ids, 0) - pos_vec = bb.emit_te(topi.take, pos_emb, pos_ids, 0) + segment_vec = relax.op.take(segment_emb, segment_ids, axis=0) + pos_vec = relax.op.take(pos_emb, pos_ids, axis=0) vec_sum = relax.op.add(word_vec, pos_vec) if segment_ids: