Skip to content

[kernel] add fused_qk_rmsnorm_per_token_quant kernel#2958

Merged
gbyu-amd merged 14 commits into
mainfrom
guanbao/fuse_qknorm_per_token_quant
May 9, 2026
Merged

[kernel] add fused_qk_rmsnorm_per_token_quant kernel#2958
gbyu-amd merged 14 commits into
mainfrom
guanbao/fuse_qknorm_per_token_quant

Conversation

@gbyu-amd

Copy link
Copy Markdown
Contributor

Motivation

Some quark models, e.g., amd/DeepSeek-R1-0528-MXFP4-MTP-MoEFP4 and amd/Kimi-K2-Thinking-MXFP4-AttnFP8 have fp8 weight linear layers in attn and adopt ptpc quant recipe, thus add fused_qk_rmsnorm_per_token_quant kernel in this pr which will be used in ATOM/vLLM-ATOM.

Technical Details

Test Plan

Test Result

Submission Checklist

@gbyu-amd gbyu-amd requested a review from a team April 29, 2026 09:39
@github-actions

Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2958 --add-label <label>

@gbyu-amd gbyu-amd marked this pull request as draft April 29, 2026 11:46
@gbyu-amd gbyu-amd marked this pull request as ready for review April 29, 2026 13:24
@gbyu-amd gbyu-amd force-pushed the guanbao/fuse_qknorm_per_token_quant branch from e77d654 to ec8257b Compare May 6, 2026 02:22
zejunchen-zejun
zejunchen-zejun previously approved these changes May 7, 2026
xytpai
xytpai previously approved these changes May 7, 2026
@xytpai xytpai dismissed stale reviews from zejunchen-zejun and themself via e5cd7e7 May 7, 2026 04:31
@valarLip

valarLip commented May 8, 2026

Copy link
Copy Markdown
Collaborator

you can merge it once atom test passed

@gbyu-amd

gbyu-amd commented May 9, 2026

Copy link
Copy Markdown
Contributor Author

ATOM test passed as well. Merge it now.

@gbyu-amd gbyu-amd merged commit 594d1a9 into main May 9, 2026
42 of 43 checks passed
@gbyu-amd gbyu-amd deleted the guanbao/fuse_qknorm_per_token_quant branch May 9, 2026 00:34
@bingxche

bingxche commented May 9, 2026

Copy link
Copy Markdown
Contributor

Hi, this PR breaks SGLang. Could you please revert first? @valarLip @gbyu-amd
image

The PR renames the public fused_qk_rmsnorm in aiter/ops/fused_qk_norm_rope_cache_quant.py to a private _fused_qk_rmsnorm (with a different signature), without keeping a backward-compatible alias. SGLang imports this name at module load time in python/sglang/srt/models/deepseek_common/attention_forward_methods/forward_mla.py:

from aiter.ops.fused_qk_norm_rope_cache_quant import (
    fused_qk_rmsnorm as fused_qk_rmsnorm_bf16,
)

After your PR, this import raises ImportError: cannot import name 'fused_qk_rmsnorm'. Because it's a top-level import, ~28 SGLang model modules that transitively depend on it (deepseek_v2, deepseek_nextn, deepseek_v4, kimi_k25, glm4_moe, longcat_flash, mistral_large_3, etc.) all fail to register. SGLang then thinks DeepseekV3ForCausalLM has no native implementation, falls back to the Transformers backend, and ultimately crashes with KeyError: 'sglang'.

image

https://github.com/ROCm/aiter/actions/runs/25586397205/job/75123761210#step:10:317

@gbyu-amd

gbyu-amd commented May 9, 2026

Copy link
Copy Markdown
Contributor Author

hi @bingxche , this pr has unified the api to fused_qk_rmsnorm (

def fused_qk_rmsnorm(
q_out_quantized: Optional[Tensor] = None,
q_out_scale: Optional[Tensor] = None,
q: Optional[Tensor] = None,
q_weight: Optional[Tensor] = None,
q_epsilon: float = 1e-6,
q_out_unquantized: Optional[Tensor] = None,
k_out: Optional[Tensor] = None,
q_res_out: Optional[Tensor] = None,
k: Optional[Tensor] = None,
k_weight: Optional[Tensor] = None,
k_epsilon: Optional[float] = None,
q_residual: Optional[Tensor] = None,
gemma_norm: bool = False,
quant_type: Optional[QuantType] = QuantType.No,
group_size: Optional[int] = None,
transpose_scale: bool = False,
) -> None:
# Centralized interface
if quant_type == QuantType.No:
_fused_qk_rmsnorm(
q_out_quantized, q, q_weight, q_epsilon, k_out, k, k_weight, k_epsilon
), which fuses qk_rmsnorm with or without q quant. Putting the dispatch logic inside aiter kernel here should make the code in framework side more cleaner, could you update the sglang code to align with this api?
cc @valarLip

sunway513 added a commit that referenced this pull request May 14, 2026
Both pyproject.toml (build-system) and requirements.txt (runtime) were
inconsistent on this branch — pyproject was at 0.1.4 (stale, not on
PyPI for manylinux_2_28), requirements at 0.1.6. Main is at 0.1.7
since #2958-era kernels need flydsl 0.1.7 IR API.

Wheels rebuilt from this HEAD will declare Requires-Dist: flydsl ==0.1.7,
matching what main publishes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants