[QUARK-403] Add MiniMax-2.1 support#237
Conversation
0c962ff to
e2a847f
Compare
|
@valarLip can you please take a look on this one? |
e2a847f to
881b16f
Compare
|
Found a big issue with this model: Fix: weights are now sized to match the checkpoint (total_num_heads * head_dim, replicated across TP ranks). The forward pass computes global variance via an all-reduce of per-rank sum-of-squares, then each rank multiplies by its contiguous head-slice of the weight vector. After the fix this is gsm8k result
|
88a25cd to
e11e392
Compare
Thank you for your great work. |
e11e392 to
6bb2ea5
Compare
|
@thpereir ,Hi, I test your code for Minimax-M2.1 and Minimax-M2.1-MXFP4, and the accuracy is lower.
Can you help to check? |
Hi, lirong, could you post the command you used? |
The FP4 model is the same, and both on tp2 and tp 4, and aiter is latest version. |
|
@ZhangLirong-amd If you check our Quark MXFP4 model card here you will see they are not using plain lm-eval to get their 0.9348 score! They are using vllm's own gsm8k script that does extra things. The scores are not directly comparable! I read vllm script and it does extra things like using stop tokens Regarding the original model, I don't see a GSM8K score on its model card. What command did you use to serve this model using vllm and compare to ATOM? |
This is the script I tested on vllm, and I got 0.95 score. I think ATOM and vllm, I use the same lm_eval script. And I also check the gsm8k results for ATOM, and it was indeed a calculation error in the math problem. I believe this issue needs to be addressed. Right now, I can only identify a difference Q,K we got in |
6bb2ea5 to
01e6c1b
Compare
|
@ZhangLirong-amd I fixed the issue for tp=2. Please take a look again. It was still weight loading for the q_norm/k_norm weights (per-layer RMSNorm). The sharding was incorrect I was only able to fix the issue for tp=2. When using tp=4 or tp=8, padding needs to be added to the shards. This padding and weight shuffle are causing problems together. I am still investigating this. Now we have parity with vllm in this case. Server: lm-eval Results: Let's merge this and then I will continue working on tp=4/tp=8 |
|
@thpereir , Now I can got 0.94 accuracy for Minimax-M2.1-MXFP4/Minimax-M2.5 when tp=2/4. And it seems some problems in tp 8. Thanks for the fix. |
05abbae to
4091157
Compare
Introduces atom/models/minimax_m2.py with full support for MiniMax-M2.1
under ATOM's TP-parallel serving stack.
Architecture support:
- MiniMaxM2Attention: GQA with rotary embeddings (partial rotary, 50%)
and optional qk_norm (enabled in M2.1)
- MiniMaxM2SparseMoeBlock: 256-expert sparse MoE, top-8, sigmoid routing
with per-expert routing bias (use_routing_bias)
- MiniMaxM2DecoderLayer / MiniMaxM2Model / MiniMaxM2ForCausalLM
- Packed QKV mapping (q_proj/k_proj/v_proj -> qkv_proj) for weight loading
- Pipeline-parallel (PP) support via PPMissingLayer / IntermediateTensors
- Expert weight mapping: w1/w2/w3 -> gate/down/up proj
qk_norm: correct TP-distributed global RMSNorm
- q_norm and k_norm weights sized to match checkpoint:
q_norm: [total_num_heads * head_dim] = [6144] (replicated across TP ranks)
k_norm: [total_num_kv_heads * head_dim] = [1024] (replicated across TP ranks)
- Forward computes global variance via all-reduce of per-rank sum-of-squares,
then each rank applies its contiguous head-slice of the weight vector
- Handles kv_heads < tp_size by using the full k_norm weight without slicing
- Wrong implementation used per-rank sizes ([768]/[128] with TP=8), causing
weight loading to silently skip the norms (shape mismatch) and leaving
them at all-ones, which reduced GSM8K from ~0.87 to 0.10
Fix MoE routing to match vLLM: use grouped_topk, fp32 gate weights and router
logits, fix FusedMoE has_bias default to False, fix SwiGLU branch condition.
4091157 to
fce8222
Compare
|
Thanks for the support, but now TP 8 MiniMax still has some problems, can you help to support it in the next step? Since we noticed Inference-Max running MiniMax in TP 8 on H200, we want to compare with it. Thanks. cc @haoyangli0109 |
|
Yes @ZhangLirong-amd, I am working on TP 8! |
| ) -> bool: | ||
| """Match the target string or regular expression""" | ||
|
|
||
| # Replace layer name if packed_modules_mapping is offered |
There was a problem hiding this comment.
hi, @thpereir , may I ask why this code piece is deleted? Can I add it back? This will affect qwen3.5's weight loading, and cause serious accuracy regression on qwen3.5
There was a problem hiding this comment.
@ganyi1996ppo let me take a look. It was affecting MiniMax loading. I will work on a fix to cover both models.
Can you share the exact Qwen3.5 model you're using o test?
There was a problem hiding this comment.
Ok, I opened #452 to address this issue. The remap is already in place we just had to add Qwen3.5 to the list
Introduces atom/models/minimax_m2.py with full support for MiniMax-M2.1
under ATOM's TP-parallel serving stack.
Architecture support:
- MiniMaxM2Attention: GQA with rotary embeddings (partial rotary, 50%)
and optional qk_norm (enabled in M2.1)
- MiniMaxM2SparseMoeBlock: 256-expert sparse MoE, top-8, sigmoid routing
with per-expert routing bias (use_routing_bias)
- MiniMaxM2DecoderLayer / MiniMaxM2Model / MiniMaxM2ForCausalLM
- Packed QKV mapping (q_proj/k_proj/v_proj -> qkv_proj) for weight loading
- Pipeline-parallel (PP) support via PPMissingLayer / IntermediateTensors
- Expert weight mapping: w1/w2/w3 -> gate/down/up proj
qk_norm: correct TP-distributed global RMSNorm
- q_norm and k_norm weights sized to match checkpoint:
q_norm: [total_num_heads * head_dim] = [6144] (replicated across TP ranks)
k_norm: [total_num_kv_heads * head_dim] = [1024] (replicated across TP ranks)
- Forward computes global variance via all-reduce of per-rank sum-of-squares,
then each rank applies its contiguous head-slice of the weight vector
- Handles kv_heads < tp_size by using the full k_norm weight without slicing
- Wrong implementation used per-rank sizes ([768]/[128] with TP=8), causing
weight loading to silently skip the norms (shape mismatch) and leaving
them at all-ones, which reduced GSM8K from ~0.87 to 0.10
Fix MoE routing to match vLLM: use grouped_topk, fp32 gate weights and router
logits, fix FusedMoE has_bias default to False, fix SwiGLU branch condition.
Motivation
Add support for MiniMax2.1 into ATOM
Technical Details
MiniMax2.1 uses sigmoid on the Expert selection instead of just topk
Test Plan
Run server:
Run lm-eval:
Test Result
Results are below what we obtained on vllm so we need to debug further
Submission Checklist