Skip to content

[Kernel] feat: Add MXFP6-E2M3 activation support to mixed_moe_gemm_2stage#709

Open
amd-satre wants to merge 2 commits into
ROCm:mainfrom
amd-satre:feat/fp6-moe-gemm
Open

[Kernel] feat: Add MXFP6-E2M3 activation support to mixed_moe_gemm_2stage#709
amd-satre wants to merge 2 commits into
ROCm:mainfrom
amd-satre:feat/fp6-moe-gemm

Conversation

@amd-satre

Copy link
Copy Markdown
Contributor

Motivation

Support 6 bits for A operand/activations in moe gemm

Technical Details

Adds a_dtype="fp6" (MXFP6-E2M3) activation support to both stage-1 (gate+up) and stage-2 (down) MoE grouped GEMMs in mixed_moe_gemm_2stage.py, enabling W_MXFP4_A_MXFP6 inference on gfx950 (CDNA4).

Test Plan

  • test_moe_stage2_standalone[a6w4-*]

Test Result

Passed

Submission Checklist

Signed-off-by: Shreyas Atre satre@amd.com
Co-Authored-By: Claude Sonnet 4.6 (1M context) noreply@anthropic.com

@amd-satre amd-satre changed the title [Kernel] feat: fp6 support (A operand) moe gemm [Kernel] feat: Add MXFP6-E2M3 activation support to mixed_moe_gemm_2stage Jun 19, 2026
@amd-satre amd-satre force-pushed the feat/fp6-moe-gemm branch 2 times, most recently from 055c64e to 74651d6 Compare June 19, 2026 18:39
…tage

Extends the MoE stage-1 (gate+up) and stage-2 (down) GEMM kernels to
accept MXFP6-E2M3 (6-bit, E2M3, block-32 E8M0 scale) activations paired
with MXFP4-E2M1 weights, exposed as a_dtype="fp6" in both
compile_mixed_moe_gemm1 and compile_mixed_moe_gemm2.

Kernel changes (kernels/mixed_moe_gemm_2stage.py):
- is_f6_a / is_f4_or_f6_a flags; a_dtype validation extended to "fp6"
- cbsz=2 for MXFP6-E2M3 A (vs cbsz=4 for MXFP4, cbsz=0 for FP8)
- a_per_lane_kpack_bytes=32 for fp6: cbsz=2 MFMA reads A in FP8-padded
  layout — 24 B of packed FP6 codes + 8 B zero pad per K=32 block
- Three LDS loads per K-block to fill the 32-byte A register slot;
  4th slot zero-filled (cbsz=2 MFMA discards it)
- a_elem_vec_pack stays 1 for fp6 (1 stored byte per logical element)

Test infrastructure (tests/):
- fp4_utils.py: fp6_e2m3_to_f32 (LUT-based E2M3 decoder) and
  per_1x32_f6_quant (returns a_pad, scale, a_unpacked)
- test_ref.py: _dequant_mxfp6_per_1x32; a2_kind override on
  torch_moe_gemm2 to select mxfp6 dequant without dtype-shape ambiguity
- test_moe_gemm.py: test_moe_stage2_standalone parametrized with a6w4
  (gfx950+); reference comparison enabled via a_unpacked (no skip_ref);
  realworld shapes: Mixtral-8x7B, Qwen3-30B-A3B (T=128/512)

Signed-off-by: Shreyas Atre <satre@amd.com>
Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
@amd-satre amd-satre force-pushed the feat/fp6-moe-gemm branch from 74651d6 to 780f08d Compare June 19, 2026 18:40
@coderfeli

Copy link
Copy Markdown
Collaborator

@amd-satre why added fp6 moe here? We never heard of such model configs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants