Skip to content

[acceleration] Model-agnostic acceleration plugin layer (torch.compile / attention backend / feature caching)#196

Open
Jayce-Ping wants to merge 14 commits into
mainfrom
feat/acceleration-plugin-layer
Open

[acceleration] Model-agnostic acceleration plugin layer (torch.compile / attention backend / feature caching)#196
Jayce-Ping wants to merge 14 commits into
mainfrom
feat/acceleration-plugin-layer

Conversation

@Jayce-Ping

@Jayce-Ping Jayce-Ping commented Jun 29, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds a model-agnostic, registry-based acceleration plugin layer (src/flow_factory/acceleration/) that speeds up training without touching trainer or model math, gated for train-inference correctness by the trainer's RL paradigm.

  • Two ordered config slots under acceleration:shared (lossless/stage=both, applied to both rollout and the training forward via setup()) and rollout (Stage-3 only via rollout_context(), may be lossy).
  • Accelerators: attention_backend (folds in the removed model.attn_backend knob), torch_compile, diffusers_cache, cache_dit.
  • acceleration/validator.py enforces the safety contract (constraints [Feat] Support Pairwise and Groupwise reward computation #7/[Feat] Extend EMA schedules #26): a lossy rollout accelerator is rejected on coupled trainers; a lossy stage=both accelerator (torch_compile) is allowed everywhere but warns on coupled that the on-policy ratio is ≈1, not bit-exact.
  • model.attn_backend is removed (fails fast with a migration message → use the attention_backend accelerator).

Correctness — GRPO on-policy ratio (zero tolerance: must be exactly 1, atol=0)

Verified on 16×H20 (2-node) under both DeepSpeed ZeRO-2 and FSDP2 HYBRID_SHARD, per-element assert at the first on-policy step on all 16 ranks.

method backend / mode max|ratio-1| (worst of 16) verdict
eager SDPA 0.000e+00 bit-exact
attention_backend _flash_3_hub (FA3) 0.000e+00 (ZeRO-2 and FSDP2) bit-exact
torch_compile full 1.13e-5 (2/16; FSDP2 run 0/16) not bit-exact (intermittent)
attention_backend+torch_compile _flash_3_hub + full 1.31e-5 not bit-exact (compile)
torch_compile (Qwen-Image) regional 1.69e-5 (1/16) not bit-exact
torch_compile (no-CFG) full 7.15e-7 (1/16) not bit-exact

Conclusion: eager and attention_backend are strictly bit-exact. torch_compile (any mode/model, CFG or not) leaves an intermittent ~1e-5 residual on a minority of samples — numerically on-policy (≪ clip_range=1e-4) but not bit-exact. Root cause: Inductor compiles a separate graph for grad vs no-grad mode; the grad-forced rollout fix removes the dominant divergence but rollout and training remain distinct kernel invocations and bf16 accumulation is non-associative. Hence torch_compile is now classified lossy (validator warns on coupled).

Note: regional compile is unavailable on SD3.5 (SD3Transformer2DModel has no _repeated_blocks) → clean fail-fast (works on Qwen-Image).

Speedup — lossless (GRPO / SD3.5 LoRA, 16-GPU ZeRO-2, CFG g=4.5, 8 steps; rank0 steady-state median)

method rollout (s) rollout × train-step (s) train × peak VRAM
eager (baseline) 13.65 1.00× 15.38 1.00× 14.49 GB
attention_backend(_flash_3_hub) 13.03 1.05× 15.40 1.00× 14.49 GB
torch_compile(full) 11.44 1.19× 13.04 1.18× 12.00 GB (−17%)
_flash_3_hub+torch_compile(full) 10.38 1.32× 11.01 1.40× 12.00 GB
  • FA3 is ~neutral on H20 (bandwidth-bound Hopper; attention is not the bottleneck).
  • torch.compile is the real win (~1.2× + −17% VRAM; combined with FA3 ~1.3–1.4×) — at the cost of strict bit-exactness.

Speedup — lossy rollout caching (DiffusionNFT / Qwen-Image LoRA, 16-GPU ZeRO-2, 28 steps; same seed; rank0 steady-state)

Fresh post-removal measurement (rank0, median of n=6 steady-state epochs, 28 inference steps):

method rollout (s) rollout × train-step (s) reward drift vs baseline
cache off (baseline) 27.02 1.00× 19.52
diffusers_cache (first_block, thr 0.08) 22.01 1.23× 19.77 negligible
cache_dit (default / aggressive) 25.5-26.2 ~1.0× 0 (bit-identical — no-op; accelerator removed)
  • Caching is rollout-only (train-step unaffected) and validator-gated to decoupled/distillation.
  • Fixed (commit 247a64a): periodic eval + diffusers_cache previously crashed with a diffusers FirstBlockCache "No context is set" error. Root cause: diffusers' HookRegistry._get_child_registries() caches its child list on first use; during eval (cache disabled) the adapter still opens cache_context(...), caching an empty child list before the block cache-hooks exist, so the later training _set_context skipped them. The accelerator now invalidates that stale _child_registries_cache after enable_cache. Verified on 16xH20 Qwen-Image NFT (eval_freq=1, rc=0, zero errors).
  • Verified: torch_compile (regional) + diffusers_cache coexist — both apply, run completes (rc=0), and per-epoch rollout time is stable (no per-epoch Dynamo recompile churn).
  • diffusers_cache works: ~1.2× rollout (1.23× here on n=6; 1.18-1.26× across runs) with negligible reward drift. At 8 inference steps it never triggers (no gain, bit-identical output) — it needs enough denoising steps to engage.
  • cache_dit accelerator was REMOVED (commit 0170953). diffusers_cache is now the single lossy rollout backend; a config still naming cache_dit fails fast with an actionable error pointing to diffusers_cache. Full investigation that led to removal:
    • Root cause of the no-op: CacheDitAccelerator called cache_dit.enable_cache(adapter.pipeline), which sets persistent_context=False for a real pipeline — caching then only activates inside pipeline.__call__, but Flow-Factory's rollout calls the transformer directly (adapter.inference()), so cache_dit's per-call context is never created and every step logs "Cache context not exist … skip cache" (bit-identical output, default and aggressive configs alike).
    • Engagement fix (validated): passing the bare base transformer instead (cache_dit wraps it in a FakeDiffusionPipelinepersistent_context=True) makes caching actually engage on Qwen-Image (epoch 0 ran with no "skip cache").
    • Remaining blockers (why it's not shipped): (1) cache_dit's enable_cache/disable_cache is not idempotent across the per-epoch rollout scoping — re-enabling on epoch 2 asserts hasattr(pipe, "_context_manager"); (2) it conflicts with the adapter's native diffusers cache_context (_diffusers_hook is not None) — cache_dit is a competing caching system to the diffusers-native one that diffusers_cache already uses; (3) on SD3.5 the bare transformer class SD3Transformer2DModel isn't registered (only StableDiffusion3Pipeline is), so cache_dit's bare-module path is unsupported there.
    • Conclusion: a robust cache_dit integration would need to coexist with / replace the adapters' diffusers cache_context and use cache_dit's context-manager + refresh_context for per-rollout scoping, built on private dev-build internals — substantial work for no gain over diffusers-native caching. So the CacheDitAccelerator was removed; diffusers_cache is the lossy backend (clean per-epoch lifecycle, ~1.2× rollout on Qwen, negligible reward drift).

Test plan

  • GRPO on-policy ratio bit-exact assert (eager/FA3 pass; compile ~1e-5) — ZeRO-2 + FSDP2, 16 GPU.
  • Lossless speedup (FA3 / compile / combo) vs eager — SD3.5, 16 GPU.
  • Lossy NFT caching speedup + reward drift — Qwen-Image, 16 GPU (diffusers_cache ~1.2×; cache_dit found to be a no-op — bug noted).
  • Paradigm-gated fail-fast (lossy-rollout on coupled rejected; lossy-shared on coupled warns).

Jayce-Ping and others added 12 commits June 27, 2026 15:47
… plugin layer (lossless)

Introduce a registry-based acceleration plugin layer that respects the
algorithm/model decoupling, plus the first lossless accelerators.

- acceleration/: BaseAccelerator (safety/stage contract), registry with
  direct-path fallback, paradigm-gated validator, CompileAccelerator
  (torch.compile, regional/full), AttentionBackendAccelerator (exact backends).
- hparams: AccelerationArguments with shared/rollout slots; wired into Arguments
  (field + nested_map) and exported. Off by default (backward compatible).
- trainers: BaseTrainer builds and validates accelerators after prepare, applies
  the shared accelerator via setup() and wraps the Stage-3 rollout loop with the
  rollout accelerator context; per-trainer paradigm tags (coupled/decoupled/
  distillation) drive the lossy-safety gate (constraints.md #7, #20a, #26).
Rollout-only (Stage 3) feature caching, gated by the paradigm validator to
decoupled/distillation trainers so train-inference consistency is preserved.

- DiffusersCacheAccelerator: zero-extra-dep, wraps diffusers CacheMixin
  enable_cache/disable_cache with policies first_block (default) / faster /
  pyramid / taylorseer / magcache; enabled per rollout epoch, torn down on exit.
- CacheDitAccelerator: optional cache-dit backend (richer DBCache/TaylorSeer),
  imported defensively with an install hint when absent (constraints.md #22a).
- registry: register both lossy accelerators.
- pyproject: add optional 'acceleration' extra (cache-dit) and include in 'all'.
- guidance/acceleration.md: safety model, config schema, accelerator table,
  model cache-readiness, and how to add a new accelerator.
- architecture.md: register the acceleration registry (now four), accelerator
  table, validator/paradigm gating note, and the new-accelerator extension point.
- AGENTS.md: link the new guide.
- examples: commented acceleration block in the NFT/Qwen-Image (decoupled) config.
…erator; compile after post_init

Two correctness/cleanup fixes from review:

1. Remove AttentionBackendAccelerator: it duplicated model.attn_backend
   (BaseAdapter._set_attention_backend, run in adapter __init__) with no added
   capability, and re-set the backend a second time after prepare. Attention
   backend stays a single canonical knob (model.attn_backend). Drop the registry
   entry, file, and docs/example references.

2. Apply the shared (lossless) accelerator AFTER adapter.post_init() instead of
   inside _initialization(). post_init() performs state-checkpoint resume,
   _init_ema(), and _init_ref_parameters(); compiling first was fragile. Now
   _init_acceleration() only builds+validates; _apply_shared_acceleration() runs
   torch.compile after weights are final. In-place compile (nn.Module.compile /
   compile_repeated_blocks) keeps state_dict keys and parameter identity stable,
   so save/load checkpoint, LoRA, and copy_-based EMA/ref/named-param swaps remain
   correct. Also fix the regional-mode guard to check _repeated_blocks (the method
   always exists on ModelMixin) and document the LoRA disable_adapter recompile.
…erator the single path

Per review preference, invert the previous fix: instead of removing the
AttentionBackendAccelerator and keeping BaseAdapter._set_attention_backend, remove
the in-adapter call and route attention-backend selection exclusively through the
accelerator, for one consistent transformer-acceleration mechanism.

- Remove BaseAdapter._set_attention_backend() call and method from models/abc.py.
- Restore AttentionBackendAccelerator: reads backend from model.attn_backend (with
  optional 'backend' param override) and forwards it verbatim to diffusers'
  set_attention_backend (incl. approximate backends like 'sage' — matches old
  behavior; no whitelist). Re-register in the accelerator registry.
- BaseTrainer._apply_shared_acceleration now applies attention backend (from
  model.attn_backend) first, then the shared accelerator (e.g. torch.compile), so
  the compiled graph captures the chosen backend. Runs after prepare()/post_init().
- model.attn_backend stays the config knob (used in examples); no config breakage.
  Standalone inference uses raw diffusers pipelines, not the adapter, so dropping
  the constructor call has no effect there.
- Docs (architecture.md, guidance/acceleration.md, acceleration_args help) updated.
…umerical safety

Answers the review questions on attention-backend timing and lossy backends:

- Correctness axis is SYMMETRIC application (stage), not bit-exactness. A
  stage='both' transform runs in both rollout inference() and training forward()
  on the shared module, so they stay train-inference consistent even when the
  transform is numerically approximate (e.g. Sage int8 attention). We therefore
  do NOT reject 'sage' and do not need per-instance 'two safety states'.
- Validator: the shared slot now requires stage=='both' (drop the
  safety=='lossless' requirement); the rollout slot requires stage=='rollout' and
  still gates safety=='lossy' to decoupled/distillation paradigms. safety is only
  consulted for rollout-stage accelerators.
- Clarify BaseAccelerator safety/stage semantics and the validator/guidance docs.

Timing (confirmed, no code change): set_attention_backend only sets
processor._attention_backend attributes + a global active backend (touches no
params), so it is safe AFTER accelerator.prepare() and is applied BEFORE compile
(the dispatch is read at forward time); applying after prepare is also required
for context-parallel backend validation. _apply_shared_acceleration runs after
post_init (final weights), attention backend first, then compile.
…ttn_backend into the acceleration layer + ordered multi-accelerator

Remove the standalone `model.attn_backend` knob and express attention-backend
selection only as an `attention_backend` entry in the acceleration layer. Each
acceleration slot (`shared` / `rollout`) becomes an ordered list of {name, params}
entries (list order = application order), mirroring MultiRewardArguments, so
multiple accelerators can be combined deterministically (e.g. attention_backend
before torch_compile so the compiled graph captures the backend).

- hparams/acceleration_args.py: add AccelerationSpec; AccelerationArguments.shared
  and .rollout are List[AccelerationSpec]; from_dict accepts list or single-entry
  shorthand; to_dict round-trips.
- hparams/model_args.py: drop attn_backend field; __post_init__ fail-fast migration
  error if a stale `model.attn_backend` is present.
- acceleration/attention_backend.py: require an explicit `backend` param; raise when
  no transformer supports set_attention_backend (e.g. Bagel forces fa2 at load).
- trainers/abc.py: build/validate accelerator lists; _apply_shared_acceleration runs
  setup() in order; _rollout_acceleration nests rollout contexts via ExitStack.
- examples: replace commented attn_backend lines with a commented acceleration block
  (coupled => shared-only; decoupled => shared+rollout). Bagel examples drop the inert
  attn_backend line and document the forced flash_attention_2 path.
- docs: guidance/acceleration.md, architecture.md, README.md updated to the list form.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…n-inference consistency for coupled algorithms

torch.compile (Inductor) compiles a separate, numerically non-identical graph for
grad vs no-grad mode (Dynamo guards on grad_mode). Rollout ran under no_grad while
the training forward runs with grad, so the recomputed on-policy log-prob diverged
from the rollout log-prob (~1.5e-5) and broke the coupled PPO ratio==1 invariant
(GRPO/GRPO-Guard/DPPO). This is the ONLY source — CFG is applied identically in both
stages (batch 2N each, verified) and is not a separate cause.

Fix (CompileAccelerator + trainer + base adapter):
- requires_grad_rollout flag: the trainer's _rollout_grad_context runs the rollout
  with grad enabled and sets adapter._rollout_detach when a compile accelerator is
  active (otherwise unchanged: no_grad).
- The compiled transformer's call entry points (forward + _compiled_call_impl) are
  wrapped to force torch.enable_grad(), overriding the @torch.no_grad() on
  inference(), so rollout and training execute the same grad-mode graph.
- The wrapper returns the grad-carrying output directly and does NOT detach it: an
  inner detach lets Inductor pick a divergent inference-optimized kernel and silently
  re-introduces the drift. Instead the per-step latent feedback is detached in
  BaseAdapter.cast_latents (gated by _rollout_detach), which breaks the autograd graph
  chain across denoising steps so rollout memory stays bounded.
- Compile/wrap the BASE transformer under PEFT/LoRA (_peel_peft -> get_base_model),
  not the PeftModel: compiling the wrapper drifts (~2.0) and hides _repeated_blocks
  (so regional compile now also works under LoRA). LoRA submodules stay inside the
  compiled graph and still train.

Result (verified, SD3.5, first on-policy step): max|ratio-1| = 0.000e+00 (bit-exact)
for coupled training with AND without CFG (old_lp == new_lp exactly; rollout and
training both call the transformer at the identical 2N shape under CFG). A
train-vs-train recompute is exactly 0.0, confirming compile is deterministic (no
RNG/autotune nondeterminism). Non-compile runs are unchanged.

Full analysis: .scratch/torch_compile_consistency_report.md; guidance/acceleration.md updated.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…y (not bit-exact across rollout/train) + fail-fast tidy

torch.compile is stage='both' (applied to the shared transformer) but Inductor compiles
a separate graph for grad vs no-grad mode; even with the grad-forced rollout fix an
intermittent ~1e-5 on-policy ratio residual remains on a minority of samples (verified
16xH20, 2-node ZeRO-2/FSDP2, SD3.5 + Qwen-Image, regional/full, CFG and no-CFG). So it is
NOT bit-exact across stages.

- Re-class CompileAccelerator safety='lossy' (lossless now means bit-exact across stages,
  not merely "applied to the shared module"); update abc/validator/torch_compile docstrings.
- validator: WARN (do not reject) when a stage='both' lossy accelerator runs on a coupled
  trainer -- it stays within clip_range so it remains allowed, but the on-policy ratio is
  ~1, not exactly 1. eager + attention_backend stay lossless/bit-exact.
- fail-fast: drop the swallowing try/except in CompileAccelerator._peel_peft.
- init BaseAdapter._rollout_detach=False and read it directly (drop getattr defaults);
  remove the dead acceleration_args None-branch in _init_acceleration.
- correct guidance/acceleration.md + architecture.md tables/notes that overclaimed bit-exact.
…sers_cache is the single lossy backend

cache-dit only caches inside the pipeline `__call__` it monkeypatches, but Flow-Factory's
rollout drives the transformer directly (adapter.inference()), so CacheDitAccelerator was a
silent no-op (every step logged "Cache context not exist, skip cache"). Its transformer-only
path (FakeDiffusionPipeline -> persistent_context) does engage but assumes enable-once: the
class-level _is_cached flag is not reset on disable, so the per-epoch rollout_context
enable/disable cycle crashes on epoch 2; it also conflicts with the cache-ready adapters'
own diffusers `cache_context` (two caching systems on the same blocks). It adds nothing over
the diffusers-native `diffusers_cache` (FBCache/TaylorSeer/FasterCache, ~1.2x rollout, clean
per-epoch lifecycle), so it is removed rather than force-fit on private dev-build internals.

- Delete acceleration/cache_dit.py; drop it from the registry.
- Add _REMOVED_ACCELERATORS guard: a config still naming `cache_dit` fails fast with an
  actionable message pointing to `diffusers_cache`.
- Drop the `acceleration`/cache-dit optional-dependency group (diffusers caching needs none).
- Update guidance/acceleration.md, architecture.md, and example-config comments.
…t is set")

A decoupled trainer with a lossy `diffusers_cache` rollout accelerator AND periodic eval
(eval_freq>0) crashed on the first cached training rollout after an eval:
`ValueError: No context is set` from diffusers FirstBlockCache.

Root cause (diffusers HookRegistry): `_get_child_registries()` caches the child-registry
list on first use. During eval the cache is disabled, but the adapter still opens
`transformer.cache_context(...)`, which creates the top HookRegistry and caches an EMPTY
child list (no block cache-hooks exist yet). When training then enables the cache, the FBC
block hooks are added, but the stale empty cache makes `_set_context` skip them, so the
block forward reads an unset context and raises.

Fix: after enable_cache, invalidate the stale `_child_registries_cache` on the unwrapped
transformer (the exact object the adapter's `cache_context` targets) so the next context
build rediscovers the freshly added block hooks. No-op on the common no-eval path.

Verified on 16xH20 (2-node) Qwen-Image NFT + diffusers_cache + eval_freq=1: eval->train
interleave runs clean (rc=0, zero "No context" errors). Also smoke-verified torch_compile
(regional) + diffusers_cache coexist with stable per-epoch rollout time (no recompile churn).
@Jayce-Ping Jayce-Ping force-pushed the feat/acceleration-plugin-layer branch from bde04c7 to 247a64a Compare June 30, 2026 03:33
acceleration_plugin_layer.plan.md is an internal design/roadmap note, not part of the
shipped feature; remove it from the PR and gitignore .cursor/plans/ so scratch plans
don't re-enter the tree. The tracked .cursor/rules/*.mdc project rules are unaffected.
@Jayce-Ping Jayce-Ping marked this pull request as ready for review June 30, 2026 03:41
…compile

Route BaseTrainer.evaluate() through _rollout_grad_context() instead of a
bare torch.no_grad(). The compiled transformer's forward unconditionally
forces enable_grad (for rollout/train graph consistency), so a no_grad eval
still built an autograd graph that chained across the whole denoising loop
(memory grew with num_inference_steps -> OOM). The shared context enables the
per-step cast_latents detach so eval memory stays bounded. With no
grad-forcing accelerator active it is exactly torch.no_grad() (no behavior
change). evaluate()/generate_samples() are base-class-only, so this covers
all trainers.

Also fix the accelerator table in architecture.md: add a blank line so the
trailing "Configured via the acceleration: block" paragraph renders below the
table instead of being absorbed as a malformed single-column row.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant