From f0474514da189b06da9dce194ab41dc03595ca16 Mon Sep 17 00:00:00 2001 From: flashmouse Date: Mon, 15 Jun 2026 16:48:55 +0800 Subject: [PATCH] fix attention equal --- python/tvm/relax/transform/legalize_ops/nn.py | 13 +++---- .../relax/test_transform_legalize_ops_nn.py | 35 +++++++++++++++++++ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 51d23de0f761..35d81f968b37 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -714,10 +714,11 @@ def _te_attention( q = topi.transpose(q, [0, 2, 1, 3]) k = topi.transpose(k, [0, 2, 1, 3]) v = topi.transpose(v, [0, 2, 1, 3]) - q = topi.reshape(q, [batch_size * num_head, seq_len, head_dim]) - k = topi.reshape(k, [batch_size * num_head, seq_len_kv, head_dim]) - v = topi.reshape(v, [batch_size * num_head, seq_len_kv, head_dim_v]) - p = topi.nn.batch_matmul(q, k) + bs = batch_size * num_head + q = topi.reshape(q, [bs, seq_len, head_dim]) + k = topi.reshape(k, [bs, seq_len_kv, head_dim]) + v = topi.reshape(v, [bs, seq_len_kv, head_dim_v]) + p = topi.nn.batch_matmul(q, k, oshape=[bs, seq_len, seq_len_kv]) if scale is not None: p = topi.multiply(p, scale) else: @@ -725,7 +726,7 @@ def _te_attention( if bias is not None: p = topi.reshape(p, [batch_size, num_head, seq_len, seq_len_kv]) p = topi.add(p, bias) - p = topi.reshape(p, [batch_size * num_head, seq_len, seq_len_kv]) + p = topi.reshape(p, [bs, seq_len, seq_len_kv]) if causal_mask is None: s = topi.nn.softmax(p) else: @@ -741,7 +742,7 @@ def _te_attention( ) p_masked_sum = topi.sum(p_masked_exp, axis=-1, keepdims=True) s = topi.divide(p_masked_exp, p_masked_sum) - o = topi.nn.batch_matmul(s, v, transpose_b=False) + o = topi.nn.batch_matmul(s, v, transpose_b=False, oshape=[bs, seq_len, head_dim_v]) o = topi.reshape(o, [batch_size, num_head, seq_len, head_dim_v]) return topi.transpose(o, [0, 2, 1, 3]) diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 4a708b5da1f4..8136997cf66c 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -3727,6 +3727,41 @@ def main( LegalizeOps()(Attention) +def test_dynamic_batch_attention(): + """The batch dimension may be dynamic (symbolic). + + fix https://github.com/apache/tvm/issues/19696 + """ + + @tvm.script.ir_module + class Attention: + @R.function + def main( + q: R.Tensor(("batch_size", 16, 32, 8), "float32"), + k: R.Tensor(("batch_size", 8, 32, 8), "float32"), + v: R.Tensor(("batch_size", 8, 32, 16), "float32"), + ): + gv = R.nn.attention(q, k, v) + return gv + + LegalizeOps()(Attention) + + @tvm.script.ir_module + class AttentionBias: + @R.function + def main( + q: R.Tensor(("batch_size", 16, 32, 8), "float32"), + k: R.Tensor(("batch_size", 8, 32, 8), "float32"), + v: R.Tensor(("batch_size", 8, 32, 16), "float32"), + bias: R.Tensor(("batch_size", 32, 16, 8), "float32"), + ): + scale = T.FloatImm("float32", 0.1) + gv = R.nn.attention(q, k, v, bias, scale=scale, causal_mask="BottomRight") + return gv + + LegalizeOps()(AttentionBias) + + def test_nll_loss(): # fmt: off @tvm.script.ir_module