Skip to content

refactor: delegate ATOM KV cache subsystem to attention builders#659

Merged
valarLip merged 2 commits into
mainfrom
refactor/per-req-cache-abstraction
Apr 29, 2026
Merged

refactor: delegate ATOM KV cache subsystem to attention builders#659
valarLip merged 2 commits into
mainfrom
refactor/per-req-cache-abstraction

Conversation

@valarLip

Copy link
Copy Markdown
Collaborator

Summary

Generalize the GDN per-request state decoupling (#602) into a complete model-agnostic KV cache abstraction owned by the AttentionMetadataBuilder hierarchy. ModelRunner becomes blind to attention type — it walks modules and dispatches via builder hooks; per-attention-type tensor layouts (MLA 576-dim packed, GDN-hybrid full-attn-only rows, MiMo-V2 per-module deferred, V3.2 indexer cache, GDN per-req mamba state) all live next to their respective builder.

ModelRunner net: -526 LOC. The if/elif chains over use_mla / is_qwen_next / is_mimo_v2 / is_deepseek_v32 in _compute_block_bytes, allocate_kv_cache, and the binding loop are all gone. Future stateful attentions (e.g. DeepseekV4 ring buffer + compressor state) plug in by subclassing AttentionMetadataBuilder without touching Scheduler / BlockManager / ModelRunner.

New AttentionMetadataBuilder hooks (defaults are no-ops)

Hook Purpose
compute_per_req_cache_bytes() / slots_per_req() bytes/slot for the per-request state pool
allocate_per_req_cache(num_slots) dict of named per-request state tensors
compute_block_bytes() per-block bytes for the KV pool budget
allocate_kv_cache_tensors(num_kv_heads, num_draft_layers) dict of named primary KV tensors (kv_cache, kv_scale, index_cache, aligned_index_dim, _kv_layer_cache_store)
build_kv_cache_tensor(layer_id, module) vLLM-style KVCacheTensor for one module, or None if foreign type; owns module setattr (k_cache/v_cache/k_scale/v_scale/kv_cache)

Builder overrides

  • AiterAttentionMetadataBuilder — split-K/V MHA + MiMo-V2 per-module
  • AiterMLAMetadataBuilder — 576-dim MLA + V3.2 indexer
  • GDNAttentionMetadataBuilder — hybrid full-attn rows + GDN mamba slot pool; chains super() for MHA modules in hybrid models. Absorbs the formerly-runner-owned gated_delta_net_state_shape/_dtypes helpers and the side-effect init of full_attention_interval / num_full_attn / num_gdn_attn_state.

Naming: group vs. slot

Distinguishes group (per-request unit) from slot (raw tensor index). One group occupies slots_per_req() contiguous slots in the underlying tensor.

Old New
Sequence.mamba_state_slot .per_req_cache_group
seq.mamba_enabled .has_per_req_cache
batch.mamba_state_slots .per_req_cache_groups
BlockManager.mamba_* .per_req_cache_* (free pool, accounting)
config.mamba_equiv_per_req .per_req_cache_equiv_blocks
config.num_mamba_groups .num_per_req_cache_groups
ModelRunner.max_mamba_slots .max_per_req_cache_slots (tensor dim)

Removed

  • ModelRunner._compute_mamba_per_slot_bytes (moved to GDNAttentionMetadataBuilder.compute_per_req_cache_bytes)
  • ModelRunner.gated_delta_net_state_shape / _dtypes (moved to GDNAttentionMetadataBuilder._state_shape / _state_dtypes)
  • The 4-way if/elif dispatch in _compute_block_bytes, allocate_kv_cache, and binding loop

Sanity check

ModelRunner.__init__ now asserts that any builder returning compute_per_req_cache_bytes() > 0 has its model_type registered in InputOutputProcessor._per_req_cache_model_types(), catching the silent-corruption misconfiguration where a stateful attention is added but Sequence-construction never gets the has_per_req_cache=True flag.

Test plan

  • tests/test_per_req_cache_decoupling.py: 24/24 pass
  • Core suite (block_manager, sequence, scheduler, request, io_processor_fanout, prefix_cache_accuracy): 118/118 pass
  • Qwen3.5-397B-A17B-FP8 tp=4 simple_inference: 4-prompt completion quality unchanged
  • Qwen3.5-397B-A17B-FP8 tp=4 GSM8K (5-shot, NUM_CONCURRENT=64):
  • Docs synced (scheduling_kv_cache_guide.md, architecture_guide.md, configuration_guide.md, model_support_guide.md)

Repro

# Unit tests
python -m pytest tests/test_per_req_cache_decoupling.py tests/test_block_manager.py \
  tests/test_sequence.py tests/test_scheduler.py tests/test_request.py \
  tests/test_io_processor_fanout.py tests/test_prefix_cache_accuracy.py -q

# Qwen3.5 smoke (4 prompts, deterministic)
bash /app/logs_claude/run_simple_inference_streamed.sh /data/Qwen3.5-397B-A17B-FP8 4 --temperature 0.0

# Qwen3.5 GSM8K
bash /app/logs_claude/start_atom_server.sh /data/Qwen3.5-397B-A17B-FP8 4 8001
NUM_CONCURRENT=64 bash /app/logs_claude/run_gsm8k_eval.sh /data/Qwen3.5-397B-A17B-FP8 8001 5

Generalize the GDN per-request state decoupling (#602) into a complete
model-agnostic KV abstraction owned by the AttentionMetadataBuilder
hierarchy. ModelRunner is now blind to attention type — it walks modules
and dispatches; per-attention-type tensor layouts (MLA 576-dim packed,
GDN-hybrid full-attn-only rows, MiMo-V2 per-module deferred, V3.2
indexer cache, GDN per-req mamba state) all live next to their
respective builder.

ModelRunner net: -526 LOC. The if/elif chains over use_mla /
is_qwen_next / is_mimo_v2 / is_deepseek_v32 in _compute_block_bytes,
allocate_kv_cache, and the binding loop are all gone. Future stateful
attentions (DeepseekV4 ring buffer + compressor state) plug in by
subclassing AttentionMetadataBuilder without touching scheduler /
block_manager / ModelRunner.

New AttentionMetadataBuilder hooks (defaults are no-ops):
  - compute_per_req_cache_bytes() / slots_per_req()
      bytes/slot for the per-request state pool
  - allocate_per_req_cache(num_slots)
      dict of named per-request state tensors
  - compute_block_bytes()
      per-block bytes for the KV pool budget
  - allocate_kv_cache_tensors(num_kv_heads, num_draft_layers)
      dict of named primary KV cache tensors (kv_cache, kv_scale,
      index_cache, aligned_index_dim, _kv_layer_cache_store)
  - build_kv_cache_tensor(layer_id, module)
      vLLM-style KVCacheTensor for one module, or None if foreign type;
      owns module setattr (k_cache/v_cache/k_scale/v_scale/kv_cache)

Builder overrides:
  - AiterAttentionMetadataBuilder: split-K/V MHA + MiMo-V2 per-module
  - AiterMLAMetadataBuilder: 576-dim MLA + V3.2 indexer
  - GDNAttentionMetadataBuilder: hybrid full-attn rows + GDN mamba slot
    pool; chains super() for MHA modules in hybrid models. Absorbs the
    formerly-runner-owned gated_delta_net_state_shape/dtypes helpers
    and the side-effect init of full_attention_interval / num_full_attn
    / num_gdn_attn_state.

Naming distinguishes group (per-request unit) from slot (raw tensor
index). One group occupies `slots_per_req()` contiguous slots in the
underlying tensor:
  Sequence.mamba_state_slot     -> .per_req_cache_group
  seq.mamba_enabled             -> .has_per_req_cache
  batch.mamba_state_slots       -> .per_req_cache_groups
  BlockManager.mamba_*          -> .per_req_cache_*  (free pool, accounting)
  config.mamba_equiv_per_req    -> .per_req_cache_equiv_blocks
  config.num_mamba_groups       -> .num_per_req_cache_groups
  ModelRunner.max_mamba_slots   -> .max_per_req_cache_slots  (tensor dim)

Removed (moved to builders):
  ModelRunner._compute_mamba_per_slot_bytes
  ModelRunner.gated_delta_net_state_shape / _dtypes

Sanity check: ModelRunner.__init__ now asserts that any builder
returning compute_per_req_cache_bytes() > 0 has its model_type
registered in InputOutputProcessor._per_req_cache_model_types(),
catching the silent-corruption misconfiguration where a stateful
attention is added but Sequence-construction never gets the
has_per_req_cache=True flag.

Verified:
  - tests/test_per_req_cache_decoupling.py: 24/24 pass
  - core suite (block_manager, sequence, scheduler, request,
    io_processor_fanout, prefix_cache_accuracy): 118/118 pass
  - Qwen3.5-397B-A17B-FP8 tp=4 simple_inference: 4-prompt completion
    quality unchanged
  - Qwen3.5-397B-A17B-FP8 tp=4 GSM8K (5-shot, 64 concurrent):
      flexible-extract = 0.8757 +/- 0.0091  (baseline 0.8711 from #602)
      strict-match     = 0.8605 +/- 0.0095
Copilot AI review requested due to automatic review settings April 28, 2026 16:11
@valarLip valarLip merged commit 99f0990 into main Apr 29, 2026
52 of 54 checks passed
@valarLip valarLip deleted the refactor/per-req-cache-abstraction branch April 29, 2026 08:37
yhl-amd added a commit to yhl-amd/ATOM that referenced this pull request May 8, 2026
Adds Eagle3 spec decode for Kimi K2.5 (MLA target + standard MHA draft):

- Eagle3LlamaModel: 1-layer Llama draft (dual-norm input, wide QKV,
  independent embed/lm_head) matching the lightseekorg/kimi-k2.5-eagle3
  checkpoint
- Eagle3DraftBuilder: implements the post-ROCm#659 builder protocol
  (compute_block_bytes / allocate_kv_cache_tensors / build_kv_cache_tensor)
  for the draft's independent non-MLA KV cache, attached to the runner from
  EagleProposer.__init__ via runner.eagle3_draft_builder. ModelRunner
  delegates KV pool sizing, allocation, and per-module binding through this
  hook with no eagle3-specific code in the runner KV path
- Aux hidden state pipeline: target forward returns
  (hidden, aux_hidden_states), captured through CUDAGraph via
  graph_aux_hidden and fed to the draft's combine_hidden_states (fc) as
  input
- SpeculativeConfig: --method eagle3 + --draft-model CLI; eagle3 vs MTP
  branching at construction time; fail-fast if draft is MLA
- Scheduler: spec_stats only updated when speculation actually ran
  (matches vLLM's gating)
- propose: draft-perspective predicate `draft_uses_mha = hasattr(runner,
  "eagle3_draft_builder")` drives both the metadata-flow special-cases
  (slot_mapping re-slice, context_lens += 1, tuple-unpack of the draft
  return value); is_eagle3 string comparison is gone from the hot path

Result on Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, 8x MI355X, gsm8k 5-shot:
acceptance 67.85%, accuracy 93.78%.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
yhl-amd added a commit to yhl-amd/ATOM that referenced this pull request May 8, 2026
Adds Eagle3 spec decode for Kimi K2.5 (MLA target + standard MHA draft):

- Eagle3LlamaModel: 1-layer Llama draft (dual-norm input, wide QKV,
  independent embed/lm_head) matching the lightseekorg/kimi-k2.5-eagle3
  checkpoint
- Eagle3DraftBuilder: implements the post-ROCm#659 builder protocol
  (compute_block_bytes / allocate_kv_cache_tensors / build_kv_cache_tensor)
  for the draft's independent non-MLA KV cache, attached to the runner from
  EagleProposer.__init__ via runner.eagle3_draft_builder. ModelRunner
  delegates KV pool sizing, allocation, and per-module binding through this
  hook with no eagle3-specific code in the runner KV path
- Aux hidden state pipeline: target forward returns
  (hidden, aux_hidden_states), captured through CUDAGraph via
  graph_aux_hidden and fed to the draft's combine_hidden_states (fc) as
  input
- SpeculativeConfig: --method eagle3 + --draft-model CLI; eagle3 vs MTP
  branching at construction time; fail-fast if draft is MLA
- Scheduler: spec_stats only updated when speculation actually ran
  (matches vLLM's gating)
- propose: draft-perspective predicate `draft_uses_mha = hasattr(runner,
  "eagle3_draft_builder")` drives both the metadata-flow special-cases
  (slot_mapping re-slice, context_lens += 1, tuple-unpack of the draft
  return value); is_eagle3 string comparison is gone from the hot path

Result on Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, 8x MI355X, gsm8k 5-shot:
acceptance 67.85%, accuracy 93.78%.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
yhl-amd added a commit to yhl-amd/ATOM that referenced this pull request May 8, 2026
Adds Eagle3 spec decode for Kimi K2.5 (MLA target + standard MHA draft):

- Eagle3LlamaModel: 1-layer Llama draft (dual-norm input, wide QKV,
  independent embed/lm_head) matching the lightseekorg/kimi-k2.5-eagle3
  checkpoint
- Eagle3DraftBuilder: implements the post-ROCm#659 builder protocol
  (compute_block_bytes / allocate_kv_cache_tensors / build_kv_cache_tensor)
  for the draft's independent non-MLA KV cache, attached to the runner from
  EagleProposer.__init__ via runner.eagle3_draft_builder. ModelRunner
  delegates KV pool sizing, allocation, and per-module binding through this
  hook with no eagle3-specific code in the runner KV path
- Aux hidden state pipeline: target forward returns
  (hidden, aux_hidden_states), captured through CUDAGraph via
  graph_aux_hidden and fed to the draft's combine_hidden_states (fc) as
  input
- SpeculativeConfig: --method eagle3 + --draft-model CLI; eagle3 vs MTP
  branching at construction time; fail-fast if draft is MLA
- Scheduler: spec_stats only updated when speculation actually ran
  (matches vLLM's gating)
- propose: draft-perspective predicate `draft_uses_mha = hasattr(runner,
  "eagle3_draft_builder")` drives both the metadata-flow special-cases
  (slot_mapping re-slice, context_lens += 1, tuple-unpack of the draft
  return value); is_eagle3 string comparison is gone from the hot path

Result on Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, 8x MI355X, gsm8k 5-shot:
acceptance 67.85%, accuracy 93.78%.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
yhl-amd added a commit to yhl-amd/ATOM that referenced this pull request May 8, 2026
Adds Eagle3 spec decode for Kimi K2.5 (MLA target + standard MHA draft):

- Eagle3LlamaModel: 1-layer Llama draft (dual-norm input, wide QKV,
  independent embed/lm_head) matching the lightseekorg/kimi-k2.5-eagle3
  checkpoint
- Eagle3DraftBuilder: implements the post-ROCm#659 builder protocol
  (compute_block_bytes / allocate_kv_cache_tensors / build_kv_cache_tensor)
  for the draft's independent non-MLA KV cache, attached to the runner from
  EagleProposer.__init__ via runner.eagle3_draft_builder. ModelRunner
  delegates KV pool sizing, allocation, and per-module binding through this
  hook with no eagle3-specific code in the runner KV path
- Aux hidden state pipeline: target forward returns
  (hidden, aux_hidden_states), captured through CUDAGraph via
  graph_aux_hidden and fed to the draft's combine_hidden_states (fc) as
  input
- SpeculativeConfig: --method eagle3 + --draft-model CLI; eagle3 vs MTP
  branching at construction time; fail-fast if draft is MLA
- Scheduler: spec_stats only updated when speculation actually ran
  (matches vLLM's gating)
- propose: draft-perspective predicate `draft_uses_mha = hasattr(runner,
  "eagle3_draft_builder")` drives both the metadata-flow special-cases
  (slot_mapping re-slice, context_lens += 1, tuple-unpack of the draft
  return value); is_eagle3 string comparison is gone from the hot path

Result on Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, 8x MI355X, gsm8k 5-shot:
acceptance 67.85%, accuracy 93.78%.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
valarLip added a commit that referenced this pull request May 10, 2026
* [Kimi] support Eagle3 speculative decoding for Kimi K2.5

Adds Eagle3 spec decode for Kimi K2.5 (MLA target + standard MHA draft):

- Eagle3LlamaModel: 1-layer Llama draft (dual-norm input, wide QKV,
  independent embed/lm_head) matching the lightseekorg/kimi-k2.5-eagle3
  checkpoint
- Eagle3DraftBuilder: implements the post-#659 builder protocol
  (compute_block_bytes / allocate_kv_cache_tensors / build_kv_cache_tensor)
  for the draft's independent non-MLA KV cache, attached to the runner from
  EagleProposer.__init__ via runner.eagle3_draft_builder. ModelRunner
  delegates KV pool sizing, allocation, and per-module binding through this
  hook with no eagle3-specific code in the runner KV path
- Aux hidden state pipeline: target forward returns
  (hidden, aux_hidden_states), captured through CUDAGraph via
  graph_aux_hidden and fed to the draft's combine_hidden_states (fc) as
  input
- SpeculativeConfig: --method eagle3 + --draft-model CLI; eagle3 vs MTP
  branching at construction time; fail-fast if draft is MLA
- Scheduler: spec_stats only updated when speculation actually ran
  (matches vLLM's gating)
- propose: draft-perspective predicate `draft_uses_mha = hasattr(runner,
  "eagle3_draft_builder")` drives both the metadata-flow special-cases
  (slot_mapping re-slice, context_lens += 1, tuple-unpack of the draft
  return value); is_eagle3 string comparison is gone from the hot path

Result on Kimi-K2.5-MXFP4 + kimi-k2.5-eagle3, 8x MI355X, gsm8k 5-shot:
acceptance 67.85%, accuracy 93.78%.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* ci: add nightly Eagle3 spec-decode accuracy test for Kimi-K2.5

Reuses the base Kimi-K2.5-MXFP4 model + lightseekorg/kimi-k2.5-eagle3
draft, runs at TP=8 (Eagle3 draft KV needs full 8-rank sharding) under
nightly schedule. Local case_verify_v9_gluon measured GSM8K 5-shot
flexible-extract = 0.9257 (vLLM = 0.9280); threshold set to 0.91 with
~1.5pp noise headroom.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Co-authored-by: Lingpeng Jin <103567126+valarLip@users.noreply.github.com>
ClementLinCF added a commit that referenced this pull request May 11, 2026
origin/main #659 moved the per-attention-type KV cache subsystem
(compute_block_bytes / allocate_kv_cache_tensors / build_kv_cache_tensor)
from ModelRunner into the attention builder hierarchy. The initial
gemma4-dev rebase landed Gemma 4 standalone work as inline overrides in
model_runner.py with a TODO to migrate later; this commit delivers that
migration.

In atom/model_ops/attentions/aiter_attention.py three new branches handle
Gemma 4 alongside the existing MiMo-V2-Flash and standard MHA paths:

  - compute_block_bytes() accounts for two split kv_cache / kv_scale
    tensors (sliding + full-attention layers) using their respective
    head_dim and num_kv_heads.
  - allocate_kv_cache_tensors() allocates the two [2, L_kind, ...]
    tensors and a layer_id -> (kind, slot) dispatch map for O(1) routing
    in the binding pass.
  - _build_gemma4_kv_cache_tensor() looks up the slot map, views the
    appropriate slice, and sets the standard module attrs (k_cache,
    v_cache, k_scale, v_scale, max_model_len) just like the other paths.

In atom/model_engine/model_runner.py:

  - Add is_gemma4() (structural detection via hf_config attrs) and
    gemma4_layer_partition() next to is_mimo_v2().
  - Restore _compute_block_bytes() and the KV binding loop to the
    single-line builder delegations origin/main #659 introduced. The
    inline _per_layer_kv_cache override and _get_per_layer_kv_dims
    helper added by the original gemma4-dev standalone commit are
    removed.
  - Add kv_cache_full / kv_scale_full to the exit() cleanup attribute
    list so the new tensors are released on shutdown.

No functional change for non-Gemma-4 models. Gemma 4's KV layout is
preserved (same [2, L_kind, blocks, block_size, kv_heads, head_dim]
per-tensor shape, just allocated via the builder).

Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant