[acceleration] Model-agnostic acceleration plugin layer (torch.compile / attention backend / feature caching)#196
Open
Jayce-Ping wants to merge 14 commits into
Open
[acceleration] Model-agnostic acceleration plugin layer (torch.compile / attention backend / feature caching)#196Jayce-Ping wants to merge 14 commits into
Jayce-Ping wants to merge 14 commits into
Conversation
… plugin layer (lossless) Introduce a registry-based acceleration plugin layer that respects the algorithm/model decoupling, plus the first lossless accelerators. - acceleration/: BaseAccelerator (safety/stage contract), registry with direct-path fallback, paradigm-gated validator, CompileAccelerator (torch.compile, regional/full), AttentionBackendAccelerator (exact backends). - hparams: AccelerationArguments with shared/rollout slots; wired into Arguments (field + nested_map) and exported. Off by default (backward compatible). - trainers: BaseTrainer builds and validates accelerators after prepare, applies the shared accelerator via setup() and wraps the Stage-3 rollout loop with the rollout accelerator context; per-trainer paradigm tags (coupled/decoupled/ distillation) drive the lossy-safety gate (constraints.md #7, #20a, #26).
Rollout-only (Stage 3) feature caching, gated by the paradigm validator to decoupled/distillation trainers so train-inference consistency is preserved. - DiffusersCacheAccelerator: zero-extra-dep, wraps diffusers CacheMixin enable_cache/disable_cache with policies first_block (default) / faster / pyramid / taylorseer / magcache; enabled per rollout epoch, torn down on exit. - CacheDitAccelerator: optional cache-dit backend (richer DBCache/TaylorSeer), imported defensively with an install hint when absent (constraints.md #22a). - registry: register both lossy accelerators. - pyproject: add optional 'acceleration' extra (cache-dit) and include in 'all'.
- guidance/acceleration.md: safety model, config schema, accelerator table, model cache-readiness, and how to add a new accelerator. - architecture.md: register the acceleration registry (now four), accelerator table, validator/paradigm gating note, and the new-accelerator extension point. - AGENTS.md: link the new guide. - examples: commented acceleration block in the NFT/Qwen-Image (decoupled) config.
…erator; compile after post_init Two correctness/cleanup fixes from review: 1. Remove AttentionBackendAccelerator: it duplicated model.attn_backend (BaseAdapter._set_attention_backend, run in adapter __init__) with no added capability, and re-set the backend a second time after prepare. Attention backend stays a single canonical knob (model.attn_backend). Drop the registry entry, file, and docs/example references. 2. Apply the shared (lossless) accelerator AFTER adapter.post_init() instead of inside _initialization(). post_init() performs state-checkpoint resume, _init_ema(), and _init_ref_parameters(); compiling first was fragile. Now _init_acceleration() only builds+validates; _apply_shared_acceleration() runs torch.compile after weights are final. In-place compile (nn.Module.compile / compile_repeated_blocks) keeps state_dict keys and parameter identity stable, so save/load checkpoint, LoRA, and copy_-based EMA/ref/named-param swaps remain correct. Also fix the regional-mode guard to check _repeated_blocks (the method always exists on ModelMixin) and document the LoRA disable_adapter recompile.
…erator the single path Per review preference, invert the previous fix: instead of removing the AttentionBackendAccelerator and keeping BaseAdapter._set_attention_backend, remove the in-adapter call and route attention-backend selection exclusively through the accelerator, for one consistent transformer-acceleration mechanism. - Remove BaseAdapter._set_attention_backend() call and method from models/abc.py. - Restore AttentionBackendAccelerator: reads backend from model.attn_backend (with optional 'backend' param override) and forwards it verbatim to diffusers' set_attention_backend (incl. approximate backends like 'sage' — matches old behavior; no whitelist). Re-register in the accelerator registry. - BaseTrainer._apply_shared_acceleration now applies attention backend (from model.attn_backend) first, then the shared accelerator (e.g. torch.compile), so the compiled graph captures the chosen backend. Runs after prepare()/post_init(). - model.attn_backend stays the config knob (used in examples); no config breakage. Standalone inference uses raw diffusers pipelines, not the adapter, so dropping the constructor call has no effect there. - Docs (architecture.md, guidance/acceleration.md, acceleration_args help) updated.
…umerical safety Answers the review questions on attention-backend timing and lossy backends: - Correctness axis is SYMMETRIC application (stage), not bit-exactness. A stage='both' transform runs in both rollout inference() and training forward() on the shared module, so they stay train-inference consistent even when the transform is numerically approximate (e.g. Sage int8 attention). We therefore do NOT reject 'sage' and do not need per-instance 'two safety states'. - Validator: the shared slot now requires stage=='both' (drop the safety=='lossless' requirement); the rollout slot requires stage=='rollout' and still gates safety=='lossy' to decoupled/distillation paradigms. safety is only consulted for rollout-stage accelerators. - Clarify BaseAccelerator safety/stage semantics and the validator/guidance docs. Timing (confirmed, no code change): set_attention_backend only sets processor._attention_backend attributes + a global active backend (touches no params), so it is safe AFTER accelerator.prepare() and is applied BEFORE compile (the dispatch is read at forward time); applying after prepare is also required for context-parallel backend validation. _apply_shared_acceleration runs after post_init (final weights), attention backend first, then compile.
…ttn_backend into the acceleration layer + ordered multi-accelerator
Remove the standalone `model.attn_backend` knob and express attention-backend
selection only as an `attention_backend` entry in the acceleration layer. Each
acceleration slot (`shared` / `rollout`) becomes an ordered list of {name, params}
entries (list order = application order), mirroring MultiRewardArguments, so
multiple accelerators can be combined deterministically (e.g. attention_backend
before torch_compile so the compiled graph captures the backend).
- hparams/acceleration_args.py: add AccelerationSpec; AccelerationArguments.shared
and .rollout are List[AccelerationSpec]; from_dict accepts list or single-entry
shorthand; to_dict round-trips.
- hparams/model_args.py: drop attn_backend field; __post_init__ fail-fast migration
error if a stale `model.attn_backend` is present.
- acceleration/attention_backend.py: require an explicit `backend` param; raise when
no transformer supports set_attention_backend (e.g. Bagel forces fa2 at load).
- trainers/abc.py: build/validate accelerator lists; _apply_shared_acceleration runs
setup() in order; _rollout_acceleration nests rollout contexts via ExitStack.
- examples: replace commented attn_backend lines with a commented acceleration block
(coupled => shared-only; decoupled => shared+rollout). Bagel examples drop the inert
attn_backend line and document the forced flash_attention_2 path.
- docs: guidance/acceleration.md, architecture.md, README.md updated to the list form.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…n-inference consistency for coupled algorithms torch.compile (Inductor) compiles a separate, numerically non-identical graph for grad vs no-grad mode (Dynamo guards on grad_mode). Rollout ran under no_grad while the training forward runs with grad, so the recomputed on-policy log-prob diverged from the rollout log-prob (~1.5e-5) and broke the coupled PPO ratio==1 invariant (GRPO/GRPO-Guard/DPPO). This is the ONLY source — CFG is applied identically in both stages (batch 2N each, verified) and is not a separate cause. Fix (CompileAccelerator + trainer + base adapter): - requires_grad_rollout flag: the trainer's _rollout_grad_context runs the rollout with grad enabled and sets adapter._rollout_detach when a compile accelerator is active (otherwise unchanged: no_grad). - The compiled transformer's call entry points (forward + _compiled_call_impl) are wrapped to force torch.enable_grad(), overriding the @torch.no_grad() on inference(), so rollout and training execute the same grad-mode graph. - The wrapper returns the grad-carrying output directly and does NOT detach it: an inner detach lets Inductor pick a divergent inference-optimized kernel and silently re-introduces the drift. Instead the per-step latent feedback is detached in BaseAdapter.cast_latents (gated by _rollout_detach), which breaks the autograd graph chain across denoising steps so rollout memory stays bounded. - Compile/wrap the BASE transformer under PEFT/LoRA (_peel_peft -> get_base_model), not the PeftModel: compiling the wrapper drifts (~2.0) and hides _repeated_blocks (so regional compile now also works under LoRA). LoRA submodules stay inside the compiled graph and still train. Result (verified, SD3.5, first on-policy step): max|ratio-1| = 0.000e+00 (bit-exact) for coupled training with AND without CFG (old_lp == new_lp exactly; rollout and training both call the transformer at the identical 2N shape under CFG). A train-vs-train recompute is exactly 0.0, confirming compile is deterministic (no RNG/autotune nondeterminism). Non-compile runs are unchanged. Full analysis: .scratch/torch_compile_consistency_report.md; guidance/acceleration.md updated. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…y (not bit-exact across rollout/train) + fail-fast tidy torch.compile is stage='both' (applied to the shared transformer) but Inductor compiles a separate graph for grad vs no-grad mode; even with the grad-forced rollout fix an intermittent ~1e-5 on-policy ratio residual remains on a minority of samples (verified 16xH20, 2-node ZeRO-2/FSDP2, SD3.5 + Qwen-Image, regional/full, CFG and no-CFG). So it is NOT bit-exact across stages. - Re-class CompileAccelerator safety='lossy' (lossless now means bit-exact across stages, not merely "applied to the shared module"); update abc/validator/torch_compile docstrings. - validator: WARN (do not reject) when a stage='both' lossy accelerator runs on a coupled trainer -- it stays within clip_range so it remains allowed, but the on-policy ratio is ~1, not exactly 1. eager + attention_backend stay lossless/bit-exact. - fail-fast: drop the swallowing try/except in CompileAccelerator._peel_peft. - init BaseAdapter._rollout_detach=False and read it directly (drop getattr defaults); remove the dead acceleration_args None-branch in _init_acceleration. - correct guidance/acceleration.md + architecture.md tables/notes that overclaimed bit-exact.
…sers_cache is the single lossy backend cache-dit only caches inside the pipeline `__call__` it monkeypatches, but Flow-Factory's rollout drives the transformer directly (adapter.inference()), so CacheDitAccelerator was a silent no-op (every step logged "Cache context not exist, skip cache"). Its transformer-only path (FakeDiffusionPipeline -> persistent_context) does engage but assumes enable-once: the class-level _is_cached flag is not reset on disable, so the per-epoch rollout_context enable/disable cycle crashes on epoch 2; it also conflicts with the cache-ready adapters' own diffusers `cache_context` (two caching systems on the same blocks). It adds nothing over the diffusers-native `diffusers_cache` (FBCache/TaylorSeer/FasterCache, ~1.2x rollout, clean per-epoch lifecycle), so it is removed rather than force-fit on private dev-build internals. - Delete acceleration/cache_dit.py; drop it from the registry. - Add _REMOVED_ACCELERATORS guard: a config still naming `cache_dit` fails fast with an actionable message pointing to `diffusers_cache`. - Drop the `acceleration`/cache-dit optional-dependency group (diffusers caching needs none). - Update guidance/acceleration.md, architecture.md, and example-config comments.
…t is set") A decoupled trainer with a lossy `diffusers_cache` rollout accelerator AND periodic eval (eval_freq>0) crashed on the first cached training rollout after an eval: `ValueError: No context is set` from diffusers FirstBlockCache. Root cause (diffusers HookRegistry): `_get_child_registries()` caches the child-registry list on first use. During eval the cache is disabled, but the adapter still opens `transformer.cache_context(...)`, which creates the top HookRegistry and caches an EMPTY child list (no block cache-hooks exist yet). When training then enables the cache, the FBC block hooks are added, but the stale empty cache makes `_set_context` skip them, so the block forward reads an unset context and raises. Fix: after enable_cache, invalidate the stale `_child_registries_cache` on the unwrapped transformer (the exact object the adapter's `cache_context` targets) so the next context build rediscovers the freshly added block hooks. No-op on the common no-eval path. Verified on 16xH20 (2-node) Qwen-Image NFT + diffusers_cache + eval_freq=1: eval->train interleave runs clean (rc=0, zero "No context" errors). Also smoke-verified torch_compile (regional) + diffusers_cache coexist with stable per-epoch rollout time (no recompile churn).
bde04c7 to
247a64a
Compare
acceleration_plugin_layer.plan.md is an internal design/roadmap note, not part of the shipped feature; remove it from the PR and gitignore .cursor/plans/ so scratch plans don't re-enter the tree. The tracked .cursor/rules/*.mdc project rules are unaffected.
…compile Route BaseTrainer.evaluate() through _rollout_grad_context() instead of a bare torch.no_grad(). The compiled transformer's forward unconditionally forces enable_grad (for rollout/train graph consistency), so a no_grad eval still built an autograd graph that chained across the whole denoising loop (memory grew with num_inference_steps -> OOM). The shared context enables the per-step cast_latents detach so eval memory stays bounded. With no grad-forcing accelerator active it is exactly torch.no_grad() (no behavior change). evaluate()/generate_samples() are base-class-only, so this covers all trainers. Also fix the accelerator table in architecture.md: add a blank line so the trailing "Configured via the acceleration: block" paragraph renders below the table instead of being absorbed as a malformed single-column row.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds a model-agnostic, registry-based acceleration plugin layer (
src/flow_factory/acceleration/) that speeds up training without touching trainer or model math, gated for train-inference correctness by the trainer's RLparadigm.acceleration:—shared(lossless/stage=both, applied to both rollout and the training forward viasetup()) androllout(Stage-3 only viarollout_context(), may be lossy).attention_backend(folds in the removedmodel.attn_backendknob),torch_compile,diffusers_cache,cache_dit.acceleration/validator.pyenforces the safety contract (constraints [Feat] Support Pairwise and Groupwise reward computation #7/[Feat] Extend EMA schedules #26): a lossyrolloutaccelerator is rejected on coupled trainers; a lossystage=bothaccelerator (torch_compile) is allowed everywhere but warns on coupled that the on-policy ratio is ≈1, not bit-exact.model.attn_backendis removed (fails fast with a migration message → use theattention_backendaccelerator).Correctness — GRPO on-policy ratio (zero tolerance: must be exactly 1, atol=0)
Verified on 16×H20 (2-node) under both DeepSpeed ZeRO-2 and FSDP2 HYBRID_SHARD, per-element assert at the first on-policy step on all 16 ranks.
attention_backend_flash_3_hub(FA3)torch_compileattention_backend+torch_compile_flash_3_hub+ fulltorch_compile(Qwen-Image)torch_compile(no-CFG)Conclusion: eager and
attention_backendare strictly bit-exact.torch_compile(any mode/model, CFG or not) leaves an intermittent ~1e-5 residual on a minority of samples — numerically on-policy (≪clip_range=1e-4) but not bit-exact. Root cause: Inductor compiles a separate graph for grad vs no-grad mode; the grad-forced rollout fix removes the dominant divergence but rollout and training remain distinct kernel invocations and bf16 accumulation is non-associative. Hencetorch_compileis now classified lossy (validator warns on coupled).Note: regional compile is unavailable on SD3.5 (
SD3Transformer2DModelhas no_repeated_blocks) → clean fail-fast (works on Qwen-Image).Speedup — lossless (GRPO / SD3.5 LoRA, 16-GPU ZeRO-2, CFG g=4.5, 8 steps; rank0 steady-state median)
attention_backend(_flash_3_hub)torch_compile(full)_flash_3_hub+torch_compile(full)torch.compileis the real win (~1.2× + −17% VRAM; combined with FA3 ~1.3–1.4×) — at the cost of strict bit-exactness.Speedup — lossy rollout caching (DiffusionNFT / Qwen-Image LoRA, 16-GPU ZeRO-2, 28 steps; same seed; rank0 steady-state)
Fresh post-removal measurement (rank0, median of n=6 steady-state epochs, 28 inference steps):
diffusers_cache(first_block, thr 0.08)cache_dit(default / aggressive)247a64a): periodic eval + diffusers_cache previously crashed with a diffusers FirstBlockCache "No context is set" error. Root cause: diffusers'HookRegistry._get_child_registries()caches its child list on first use; during eval (cache disabled) the adapter still openscache_context(...), caching an empty child list before the block cache-hooks exist, so the later training_set_contextskipped them. The accelerator now invalidates that stale_child_registries_cacheafterenable_cache. Verified on 16xH20 Qwen-Image NFT (eval_freq=1, rc=0, zero errors).diffusers_cacheworks: ~1.2× rollout (1.23× here on n=6; 1.18-1.26× across runs) with negligible reward drift. At 8 inference steps it never triggers (no gain, bit-identical output) — it needs enough denoising steps to engage.cache_ditaccelerator was REMOVED (commit0170953).diffusers_cacheis now the single lossy rollout backend; a config still namingcache_ditfails fast with an actionable error pointing todiffusers_cache. Full investigation that led to removal:CacheDitAcceleratorcalledcache_dit.enable_cache(adapter.pipeline), which setspersistent_context=Falsefor a real pipeline — caching then only activates insidepipeline.__call__, but Flow-Factory's rollout calls the transformer directly (adapter.inference()), so cache_dit's per-call context is never created and every step logs "Cache context not exist … skip cache" (bit-identical output, default and aggressive configs alike).FakeDiffusionPipeline→persistent_context=True) makes caching actually engage on Qwen-Image (epoch 0 ran with no "skip cache").enable_cache/disable_cacheis not idempotent across the per-epoch rollout scoping — re-enabling on epoch 2 assertshasattr(pipe, "_context_manager"); (2) it conflicts with the adapter's native diffuserscache_context(_diffusers_hook is not None) — cache_dit is a competing caching system to the diffusers-native one thatdiffusers_cachealready uses; (3) on SD3.5 the bare transformer classSD3Transformer2DModelisn't registered (onlyStableDiffusion3Pipelineis), so cache_dit's bare-module path is unsupported there.cache_contextand use cache_dit's context-manager +refresh_contextfor per-rollout scoping, built on private dev-build internals — substantial work for no gain over diffusers-native caching. So theCacheDitAcceleratorwas removed;diffusers_cacheis the lossy backend (clean per-epoch lifecycle, ~1.2× rollout on Qwen, negligible reward drift).Test plan
diffusers_cache~1.2×;cache_ditfound to be a no-op — bug noted).