[WIP][Megatron-LM] feat: reduce extra qkv transpose in attn#625
[WIP][Megatron-LM] feat: reduce extra qkv transpose in attn#625Xiaoming-AMD merged 2 commits intomainfrom
Conversation
RuibinCheung
commented
Mar 24, 2026
- Keep input layout to SBHD layout to reduce extra q,k,v transpose in attention.
There was a problem hiding this comment.
Pull request overview
This PR updates the Megatron Primus-Turbo attention wrapper to keep tensors in an SBHD-oriented flow and adjust the attention call/output reshaping to reduce redundant QKV layout transposes.
Changes:
- Tightens allowed
qkv_formathandling to a single supported layout. - Updates Q/K/V layout conversion before calling
flash_attn_*. - Updates output layout conversion back to SBHD and reshapes to
(S, B, hidden).
| # NOTE(ruibin): The layout of q, k and v is (S, B, H, D). But attn accept the shape of qkv is (B, S, H, D). | ||
| query, key, value = [x.permute(1, 0, 2, 3) for x in (query, key, value)] |
There was a problem hiding this comment.
The NOTE has a couple grammar issues that make it harder to parse quickly (e.g., “attn accept” / “shape of qkv is”). Consider rephrasing to something like: “Q/K/V come in (S, B, H, D), but flash_attn_* expects (B, S, H, D).”
| SUPPORTED_QKV_FORMATS = "sbhd" | ||
|
|
||
| packed_seq_kwargs = ( | ||
| {key: getattr(packed_seq_params, key) for key in self.kept_packed_seq_params} | ||
| if packed_seq_params is not None | ||
| else {} | ||
| ) | ||
|
|
||
| qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format) | ||
| assert qkv_format in ("sbhd", "bhsd"), "qkv_format only support bshd, but got {qkv_format}" | ||
| if qkv_format == "sbhd": | ||
| query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)] | ||
| assert ( | ||
| qkv_format in SUPPORTED_QKV_FORMATS | ||
| ), f"qkv_format only support {SUPPORTED_QKV_FORMATS}, but got {qkv_format}" |
There was a problem hiding this comment.
SUPPORTED_QKV_FORMATS is a string, so qkv_format in SUPPORTED_QKV_FORMATS performs substring matching (e.g., an invalid value like "bhd" would pass because it’s contained in "sbhd"). Make this a container of allowed formats (e.g., a tuple/set) or change to a single SUPPORTED_QKV_FORMAT equality check so validation is correct.
| qkv_format in SUPPORTED_QKV_FORMATS | ||
| ), f"qkv_format only support {SUPPORTED_QKV_FORMATS}, but got {qkv_format}" | ||
| # NOTE(ruibin): The layout of q, k and v is (S, B, H, D). But attn accept the shape of qkv is (B, S, H, D). | ||
| query, key, value = [x.permute(1, 0, 2, 3) for x in (query, key, value)] |
There was a problem hiding this comment.
permute(1, 0, 2, 3) returns a non-contiguous view, but previously this path called .contiguous() after transposing. If pt.ops.flash_attn_* expects contiguous Q/K/V (or incurs an implicit copy), this can cause a runtime error or unexpected perf regressions. Consider using transpose(0, 1).contiguous() (same layout change) or permute(...).contiguous() before calling self.attn.
| query, key, value = [x.permute(1, 0, 2, 3) for x in (query, key, value)] | |
| query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)] |