Skip to content

[Relax][Frontend][KVCache] Extend masked sequence prefill to causal left-padding#19431

Merged
tlopex merged 3 commits into
apache:mainfrom
xthomaswang:relax/masked-sequence-prefill
Apr 25, 2026
Merged

[Relax][Frontend][KVCache] Extend masked sequence prefill to causal left-padding#19431
tlopex merged 3 commits into
apache:mainfrom
xthomaswang:relax/masked-sequence-prefill

Conversation

@xthomaswang

Copy link
Copy Markdown
Contributor

This PR extends _attention_sequence_prefill_with_mask to support a second mask regime for decoder-style embedding workloads.

Summary

  • Keep the existing right-padded bidirectional behavior as mask_mode="padded".
  • Add mask_mode="causal_padded_left" for left-padded causal sequence prefill.
  • Add a softmax_update_causal_padded_left macro for the online softmax mask.
  • Add tests for causal left-padding with zero, full, mixed, and GQA valid lengths.

Motivation

This is a TVM-side kernel dependency for the first-class embedding serving work tracked in mlc-ai/mlc-llm#3451.

The existing masked sequence prefill kernel supports encoder-style batches where real tokens occupy the valid prefix [0, valid_len) and padding is on the right.

Decoder-style embedding batches, such as the decoder-only embedding path, commonly left-pad variable-length inputs so the final real token / EOS lands at the same final column across the batch. This allows last-token pooling to read output[:, -1, :], while still requiring causal masking within each valid suffix.

For each batch row:

  • mask_mode="padded": real tokens are [0, valid_len).
  • mask_mode="causal_padded_left": real tokens are [seq_len - valid_len, seq_len), with col <= row.

Testing

  • git diff --check
  • Attempted:
    python -m pytest -q tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py -k 'causal_padded_left or valid_len_mixed'

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for causal attention with left-padding in the LLM prefill kernels by adding a new mask_mode option, causal_padded_left. The changes include a new TIR macro softmax_update_causal_padded_left and updates to the _attention_sequence_prefill_with_mask kernel to handle generalized sequence validity predicates. Feedback focuses on generalizing the new macro to handle cases where query and key/value sequence lengths differ by explicitly passing kv_len and adjusting the causal mask logic accordingly.

Comment thread python/tvm/relax/frontend/nn/llm/_kernel_common.py
Comment thread python/tvm/relax/frontend/nn/llm/_prefill_kernels.py Outdated
@xthomaswang

Copy link
Copy Markdown
Contributor Author

/gemini review

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for causal attention with left-padding in prefill kernels by adding the softmax_update_causal_padded_left macro and updating _attention_sequence_prefill_with_mask to support multiple masking modes. New tests verify these changes across various scenarios. Feedback was provided to optimize the new macro by hoisting constant calculations and thread guards out of inner loops for better performance.

Comment on lines +386 to +433
@T.macro
def softmax_update_causal_padded_left(
S_smem: T.Buffer, m_smem: T.Buffer, d_smem: T.Buffer, m_prev_smem: T.Buffer,
m_new: T.Buffer, m_prev: T.Buffer, d_new: T.Buffer,
ty: T.int32, tx: T.int32, LH_start: T.int32, L_kv_start: T.int32,
valid_len: T.int32, qo_len: T.int32, kv_len: T.int32,
):
# Three-phase online softmax with left-padding + causal mask. Real
# queries occupy [qo_len - valid_len, qo_len); real keys occupy
# [kv_len - valid_len, kv_len). Causal keeps
# col <= row + (kv_len - qo_len) within those valid suffixes.
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
if row < tile_x:
with T.sblock("update1"):
m_prev[i] = m_smem[row]
m_new[i] = m_smem[row]
row_: T.int32 = (LH_start + row) // group_size
pad_q: T.int32 = qo_len - valid_len
pad_kv: T.int32 = kv_len - valid_len
for j in T.serial(tile_z):
col_: T.int32 = L_kv_start + j
if tirx.And(tirx.And(row_ < qo_len, row_ >= pad_q), tirx.And(col_ >= pad_kv, col_ < kv_len - qo_len + row_ + 1)):
m_new[i] = T.max(m_new[i], S_smem[row, j])
d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
with T.sblock("update"):
for j in T.serial(tile_z):
if row < tile_x:
row_: T.int32 = (LH_start + row) // group_size
pad_q: T.int32 = qo_len - valid_len
pad_kv: T.int32 = kv_len - valid_len
col_: T.int32 = L_kv_start + j
if tirx.And(tirx.And(row_ < qo_len, row_ >= pad_q), tirx.And(col_ >= pad_kv, col_ < kv_len - qo_len + row_ + 1)):
S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
else:
S_smem[row, j] = T.exp2(-5e4 - m_new[i])
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
if row < tile_x:
with T.sblock("update"):
for j in T.serial(tile_z):
d_new[i] += S_smem[row, j]
m_smem[row] = m_new[i]
d_smem[row] = d_new[i]
m_prev_smem[row] = m_prev[i]
T.tvm_storage_sync("shared")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The softmax_update_causal_padded_left macro can be optimized for efficiency. Currently, several calculations that are constant for the inner loop (over j) are performed repeatedly inside it. Additionally, moving the thread guard if row < tile_x outside the inner loop in Phase 2 can improve performance by avoiding redundant checks for every column in the tile.

Specifically:

  1. In Phase 1, the row_valid check can wrap the j loop.
  2. In Phase 2, the thread guard and constant calculations (row_, pad_q, pad_kv) should be moved outside the j loop.
    @T.macro
    def softmax_update_causal_padded_left(
        S_smem: T.Buffer, m_smem: T.Buffer, d_smem: T.Buffer, m_prev_smem: T.Buffer,
        m_new: T.Buffer, m_prev: T.Buffer, d_new: T.Buffer,
        ty: T.int32, tx: T.int32, LH_start: T.int32, L_kv_start: T.int32,
        valid_len: T.int32, qo_len: T.int32, kv_len: T.int32,
    ):
        # Three-phase online softmax with left-padding + causal mask. Real
        # queries occupy [qo_len - valid_len, qo_len); real keys occupy
        # [kv_len - valid_len, kv_len). Causal keeps
        # col <= row + (kv_len - qo_len) within those valid suffixes.
        for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
            row: T.int32 = i * bdx * num_warps + ty * bdx + tx
            if row < tile_x:
                with T.sblock("update1"):
                    m_prev[i] = m_smem[row]
                    m_new[i] = m_smem[row]
                    row_: T.int32 = (LH_start + row) // group_size
                    pad_q: T.int32 = qo_len - valid_len
                    pad_kv: T.int32 = kv_len - valid_len
                    if tirx.And(row_ < qo_len, row_ >= pad_q):
                        for j in T.serial(tile_z):
                            col_: T.int32 = L_kv_start + j
                            if tirx.And(col_ >= pad_kv, col_ < kv_len - qo_len + row_ + 1):
                                m_new[i] = T.max(m_new[i], S_smem[row, j])
                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
        for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
            row: T.int32 = i * bdx * num_warps + ty * bdx + tx
            if row < tile_x:
                row_: T.int32 = (LH_start + row) // group_size
                pad_q: T.int32 = qo_len - valid_len
                pad_kv: T.int32 = kv_len - valid_len
                row_valid = tirx.And(row_ < qo_len, row_ >= pad_q)
                with T.sblock("update"):
                    for j in T.serial(tile_z):
                        col_: T.int32 = L_kv_start + j
                        if tirx.And(row_valid, tirx.And(col_ >= pad_kv, col_ < kv_len - qo_len + row_ + 1)):
                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
                        else:
                            S_smem[row, j] = T.exp2(-5e4 - m_new[i])
        for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
            row: T.int32 = i * bdx * num_warps + ty * bdx + tx
            if row < tile_x:
                with T.sblock("update"):
                    for j in T.serial(tile_z):
                        d_new[i] += S_smem[row, j]
                    m_smem[row] = m_new[i]
                    d_smem[row] = d_new[i]
                    m_prev_smem[row] = m_prev[i]
        T.tvm_storage_sync("shared")

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping this aligned with the existing softmax update macros, not considering change this right now.

@tlopex tlopex left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tlopex tlopex merged commit 82293c8 into apache:main Apr 25, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants