[Relax][Frontend][KVCache] Extend masked sequence prefill to causal left-padding#19431
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for causal attention with left-padding in the LLM prefill kernels by adding a new mask_mode option, causal_padded_left. The changes include a new TIR macro softmax_update_causal_padded_left and updates to the _attention_sequence_prefill_with_mask kernel to handle generalized sequence validity predicates. Feedback focuses on generalizing the new macro to handle cases where query and key/value sequence lengths differ by explicitly passing kv_len and adjusting the causal mask logic accordingly.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for causal attention with left-padding in prefill kernels by adding the softmax_update_causal_padded_left macro and updating _attention_sequence_prefill_with_mask to support multiple masking modes. New tests verify these changes across various scenarios. Feedback was provided to optimize the new macro by hoisting constant calculations and thread guards out of inner loops for better performance.
| @T.macro | ||
| def softmax_update_causal_padded_left( | ||
| S_smem: T.Buffer, m_smem: T.Buffer, d_smem: T.Buffer, m_prev_smem: T.Buffer, | ||
| m_new: T.Buffer, m_prev: T.Buffer, d_new: T.Buffer, | ||
| ty: T.int32, tx: T.int32, LH_start: T.int32, L_kv_start: T.int32, | ||
| valid_len: T.int32, qo_len: T.int32, kv_len: T.int32, | ||
| ): | ||
| # Three-phase online softmax with left-padding + causal mask. Real | ||
| # queries occupy [qo_len - valid_len, qo_len); real keys occupy | ||
| # [kv_len - valid_len, kv_len). Causal keeps | ||
| # col <= row + (kv_len - qo_len) within those valid suffixes. | ||
| for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): | ||
| row: T.int32 = i * bdx * num_warps + ty * bdx + tx | ||
| if row < tile_x: | ||
| with T.sblock("update1"): | ||
| m_prev[i] = m_smem[row] | ||
| m_new[i] = m_smem[row] | ||
| row_: T.int32 = (LH_start + row) // group_size | ||
| pad_q: T.int32 = qo_len - valid_len | ||
| pad_kv: T.int32 = kv_len - valid_len | ||
| for j in T.serial(tile_z): | ||
| col_: T.int32 = L_kv_start + j | ||
| if tirx.And(tirx.And(row_ < qo_len, row_ >= pad_q), tirx.And(col_ >= pad_kv, col_ < kv_len - qo_len + row_ + 1)): | ||
| m_new[i] = T.max(m_new[i], S_smem[row, j]) | ||
| d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) | ||
| for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): | ||
| row: T.int32 = i * bdx * num_warps + ty * bdx + tx | ||
| with T.sblock("update"): | ||
| for j in T.serial(tile_z): | ||
| if row < tile_x: | ||
| row_: T.int32 = (LH_start + row) // group_size | ||
| pad_q: T.int32 = qo_len - valid_len | ||
| pad_kv: T.int32 = kv_len - valid_len | ||
| col_: T.int32 = L_kv_start + j | ||
| if tirx.And(tirx.And(row_ < qo_len, row_ >= pad_q), tirx.And(col_ >= pad_kv, col_ < kv_len - qo_len + row_ + 1)): | ||
| S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) | ||
| else: | ||
| S_smem[row, j] = T.exp2(-5e4 - m_new[i]) | ||
| for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)): | ||
| row: T.int32 = i * bdx * num_warps + ty * bdx + tx | ||
| if row < tile_x: | ||
| with T.sblock("update"): | ||
| for j in T.serial(tile_z): | ||
| d_new[i] += S_smem[row, j] | ||
| m_smem[row] = m_new[i] | ||
| d_smem[row] = d_new[i] | ||
| m_prev_smem[row] = m_prev[i] | ||
| T.tvm_storage_sync("shared") |
There was a problem hiding this comment.
The softmax_update_causal_padded_left macro can be optimized for efficiency. Currently, several calculations that are constant for the inner loop (over j) are performed repeatedly inside it. Additionally, moving the thread guard if row < tile_x outside the inner loop in Phase 2 can improve performance by avoiding redundant checks for every column in the tile.
Specifically:
- In Phase 1, the
row_validcheck can wrap thejloop. - In Phase 2, the thread guard and constant calculations (
row_,pad_q,pad_kv) should be moved outside thejloop.
@T.macro
def softmax_update_causal_padded_left(
S_smem: T.Buffer, m_smem: T.Buffer, d_smem: T.Buffer, m_prev_smem: T.Buffer,
m_new: T.Buffer, m_prev: T.Buffer, d_new: T.Buffer,
ty: T.int32, tx: T.int32, LH_start: T.int32, L_kv_start: T.int32,
valid_len: T.int32, qo_len: T.int32, kv_len: T.int32,
):
# Three-phase online softmax with left-padding + causal mask. Real
# queries occupy [qo_len - valid_len, qo_len); real keys occupy
# [kv_len - valid_len, kv_len). Causal keeps
# col <= row + (kv_len - qo_len) within those valid suffixes.
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
if row < tile_x:
with T.sblock("update1"):
m_prev[i] = m_smem[row]
m_new[i] = m_smem[row]
row_: T.int32 = (LH_start + row) // group_size
pad_q: T.int32 = qo_len - valid_len
pad_kv: T.int32 = kv_len - valid_len
if tirx.And(row_ < qo_len, row_ >= pad_q):
for j in T.serial(tile_z):
col_: T.int32 = L_kv_start + j
if tirx.And(col_ >= pad_kv, col_ < kv_len - qo_len + row_ + 1):
m_new[i] = T.max(m_new[i], S_smem[row, j])
d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
if row < tile_x:
row_: T.int32 = (LH_start + row) // group_size
pad_q: T.int32 = qo_len - valid_len
pad_kv: T.int32 = kv_len - valid_len
row_valid = tirx.And(row_ < qo_len, row_ >= pad_q)
with T.sblock("update"):
for j in T.serial(tile_z):
col_: T.int32 = L_kv_start + j
if tirx.And(row_valid, tirx.And(col_ >= pad_kv, col_ < kv_len - qo_len + row_ + 1)):
S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
else:
S_smem[row, j] = T.exp2(-5e4 - m_new[i])
for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
row: T.int32 = i * bdx * num_warps + ty * bdx + tx
if row < tile_x:
with T.sblock("update"):
for j in T.serial(tile_z):
d_new[i] += S_smem[row, j]
m_smem[row] = m_new[i]
d_smem[row] = d_new[i]
m_prev_smem[row] = m_prev[i]
T.tvm_storage_sync("shared")There was a problem hiding this comment.
Keeping this aligned with the existing softmax update macros, not considering change this right now.
This PR extends
_attention_sequence_prefill_with_maskto support a second mask regime for decoder-style embedding workloads.Summary
mask_mode="padded".mask_mode="causal_padded_left"for left-padded causal sequence prefill.softmax_update_causal_padded_leftmacro for the online softmax mask.Motivation
This is a TVM-side kernel dependency for the first-class embedding serving work tracked in mlc-ai/mlc-llm#3451.
The existing masked sequence prefill kernel supports encoder-style batches where real tokens occupy the valid prefix
[0, valid_len)and padding is on the right.Decoder-style embedding batches, such as the decoder-only embedding path, commonly left-pad variable-length inputs so the final real token / EOS lands at the same final column across the batch. This allows last-token pooling to read
output[:, -1, :], while still requiring causal masking within each valid suffix.For each batch row:
mask_mode="padded": real tokens are[0, valid_len).mask_mode="causal_padded_left": real tokens are[seq_len - valid_len, seq_len), withcol <= row.Testing
git diff --checkpython -m pytest -q tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py -k 'causal_padded_left or valid_len_mixed'