Skip to content

[WIP][Megatron-LM] feat: reduce extra qkv transpose in attn#625

Merged
Xiaoming-AMD merged 2 commits intomainfrom
dev/zhangrb/refine_turbo_attn
Mar 31, 2026
Merged

[WIP][Megatron-LM] feat: reduce extra qkv transpose in attn#625
Xiaoming-AMD merged 2 commits intomainfrom
dev/zhangrb/refine_turbo_attn

Conversation

@RuibinCheung
Copy link
Copy Markdown
Contributor

  • Keep input layout to SBHD layout to reduce extra q,k,v transpose in attention.

Copilot AI review requested due to automatic review settings March 24, 2026 07:26
@RuibinCheung RuibinCheung changed the title [Megatron-LM] feat: reduce extra qkv transpose in attn [WIP][Megatron-LM] feat: reduce extra qkv transpose in attn Mar 24, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the Megatron Primus-Turbo attention wrapper to keep tensors in an SBHD-oriented flow and adjust the attention call/output reshaping to reduce redundant QKV layout transposes.

Changes:

  • Tightens allowed qkv_format handling to a single supported layout.
  • Updates Q/K/V layout conversion before calling flash_attn_*.
  • Updates output layout conversion back to SBHD and reshapes to (S, B, hidden).

Comment on lines +472 to +473
# NOTE(ruibin): The layout of q, k and v is (S, B, H, D). But attn accept the shape of qkv is (B, S, H, D).
query, key, value = [x.permute(1, 0, 2, 3) for x in (query, key, value)]
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The NOTE has a couple grammar issues that make it harder to parse quickly (e.g., “attn accept” / “shape of qkv is”). Consider rephrasing to something like: “Q/K/V come in (S, B, H, D), but flash_attn_* expects (B, S, H, D).”

Copilot uses AI. Check for mistakes.
Comment on lines +460 to +471
SUPPORTED_QKV_FORMATS = "sbhd"

packed_seq_kwargs = (
{key: getattr(packed_seq_params, key) for key in self.kept_packed_seq_params}
if packed_seq_params is not None
else {}
)

qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format)
assert qkv_format in ("sbhd", "bhsd"), "qkv_format only support bshd, but got {qkv_format}"
if qkv_format == "sbhd":
query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
assert (
qkv_format in SUPPORTED_QKV_FORMATS
), f"qkv_format only support {SUPPORTED_QKV_FORMATS}, but got {qkv_format}"
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SUPPORTED_QKV_FORMATS is a string, so qkv_format in SUPPORTED_QKV_FORMATS performs substring matching (e.g., an invalid value like "bhd" would pass because it’s contained in "sbhd"). Make this a container of allowed formats (e.g., a tuple/set) or change to a single SUPPORTED_QKV_FORMAT equality check so validation is correct.

Copilot uses AI. Check for mistakes.
qkv_format in SUPPORTED_QKV_FORMATS
), f"qkv_format only support {SUPPORTED_QKV_FORMATS}, but got {qkv_format}"
# NOTE(ruibin): The layout of q, k and v is (S, B, H, D). But attn accept the shape of qkv is (B, S, H, D).
query, key, value = [x.permute(1, 0, 2, 3) for x in (query, key, value)]
Copy link

Copilot AI Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

permute(1, 0, 2, 3) returns a non-contiguous view, but previously this path called .contiguous() after transposing. If pt.ops.flash_attn_* expects contiguous Q/K/V (or incurs an implicit copy), this can cause a runtime error or unexpected perf regressions. Consider using transpose(0, 1).contiguous() (same layout change) or permute(...).contiguous() before calling self.attn.

Suggested change
query, key, value = [x.permute(1, 0, 2, 3) for x in (query, key, value)]
query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]

Copilot uses AI. Check for mistakes.
@Xiaoming-AMD Xiaoming-AMD merged commit 7665157 into main Mar 31, 2026
2 checks passed
Xiaoming-AMD added a commit that referenced this pull request Apr 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants