Skip to content

Add FLUX-only TeaCache inference cache hook#12

Open
srlynch1 wants to merge 2 commits into
mainfrom
feat/teacache
Open

Add FLUX-only TeaCache inference cache hook#12
srlynch1 wants to merge 2 commits into
mainfrom
feat/teacache

Conversation

@srlynch1

@srlynch1 srlynch1 commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Summary

  • Add FLUX-only TeaCache inference cache via block hooks (closes #12589)
  • TeaCacheConfig + apply_teacache wired through CacheMixin.enable_cache()
  • Polynomial-rescaled modulated-input L1 skip metric from the TeaCache paper
  • FLUX_TEACACHE_COEFFICIENTS vendored from TeaCache4FLUX

Test plan

  • 8 fast hook unit tests (tests/hooks/test_teacache.py)
  • TestFluxTransformerTeaCache mixin tests (CI)
  • check_copies.py
  • check_dummies.py

Note

Medium Risk
Hooks alter the denoising forward path for FLUX inference; wrong skips could affect image quality, though boundary steps always compute and unsupported models raise explicitly.

Overview
Adds TeaCache as a new inference cache for FLUX (FluxTransformer2DModel only in v1), wired like other caches via TeaCacheConfig, apply_teacache, and transformer.enable_cache().

At the first transformer block, the hook compares consecutive steps using a polynomial-rescaled relative L1 distance on the modulated input (FLUX coeffs in FLUX_TEACACHE_COEFFICIENTS). When the accumulated score stays below rel_l1_thresh, middle blocks are bypassed and the last step’s full-stack residual is replayed; first and last denoising steps always run.

Docs cover usage and a comparison with FirstBlockCache / MagCache. Tests include hook unit tests and TeaCacheTesterMixin on the Flux transformer.

Reviewed by Cursor Bugbot for commit 111bcf4. Bugbot is set up for automated code reviews on this repo. Configure here.

srlynch1 and others added 2 commits June 24, 2026 07:39
Co-authored-by: Cursor <cursoragent@cursor.com>
Completes public registration for TeaCache constants and keeps dummy_pt_objects in sync.

Co-authored-by: Cursor <cursoragent@cursor.com>

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 2 potential issues.

Fix All in Cursor

Bugbot Autofix prepared fixes for both issues found in the latest run.

  • ✅ Fixed: Mixin threshold blocks inference skip
    • Changed TeaCacheConfigMixin rel_l1_thresh to float('inf') so polynomial-rescaled L1 from randn perturbations cannot block the required second-pass skip in _test_cache_inference.
  • ✅ Fixed: Single block skip stalls step
    • TeaCacheHeadHook now advances step_index on skip when advance_step_on_skip=True for the single-block apply_teacache path where the tail block hook is bypassed.

Create PR

Or push these changes by commenting:

@cursor push 84b4512cc4
Preview (84b4512cc4)
diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py
--- a/src/diffusers/hooks/teacache.py
+++ b/src/diffusers/hooks/teacache.py
@@ -132,11 +132,13 @@
         config: TeaCacheConfig,
         extract_modulated_input: Callable,
         coefficients: List[float],
+        advance_step_on_skip: bool = False,
     ):
         self.state_manager = state_manager
         self.config = config
         self.extract_modulated_input = extract_modulated_input
         self.coefficients = coefficients
+        self.advance_step_on_skip = advance_step_on_skip
         self._metadata = None
 
     def initialize_hook(self, module):
@@ -180,6 +182,9 @@
         if not should_compute:
             logger.debug(f"TeaCache: Skipping step {state.step_index}")
 
+            if self.advance_step_on_skip:
+                self._advance_step(state)
+
             output = hidden_states
             res = state.previous_residual
 
@@ -230,7 +235,15 @@
         self.state_manager.reset()
         return module
 
+    def _advance_step(self, state: TeaCacheState):
+        state.step_index += 1
+        if state.step_index >= self.config.num_inference_steps:
+            state.step_index = 0
+            state.accumulated_distance = 0.0
+            state.previous_residual = None
+            state.previous_modulated_input = None
 
+
 class TeaCacheBlockHook(ModelHook):
     def __init__(self, state_manager: StateManager, is_tail: bool = False, config: TeaCacheConfig = None):
         super().__init__()
@@ -350,7 +363,9 @@
         name, block = remaining_blocks[0]
         logger.info(f"TeaCache: Applying Head+Tail Hooks to single block '{name}'")
         _apply_teacache_block_hook(block, state_manager, config, is_tail=True)
-        _apply_teacache_head_hook(block, state_manager, config, extract_modulated_input, coefficients)
+        _apply_teacache_head_hook(
+            block, state_manager, config, extract_modulated_input, coefficients, advance_step_on_skip=True
+        )
         return
 
     head_block_name, head_block = remaining_blocks.pop(0)
@@ -372,13 +387,16 @@
     config: TeaCacheConfig,
     extract_modulated_input: Callable,
     coefficients: List[float],
+    advance_step_on_skip: bool = False,
 ) -> None:
     registry = HookRegistry.check_if_exists_or_initialize(block)
 
     if registry.get_hook(_TEACACHE_LEADER_BLOCK_HOOK) is not None:
         registry.remove_hook(_TEACACHE_LEADER_BLOCK_HOOK)
 
-    hook = TeaCacheHeadHook(state_manager, config, extract_modulated_input, coefficients)
+    hook = TeaCacheHeadHook(
+        state_manager, config, extract_modulated_input, coefficients, advance_step_on_skip=advance_step_on_skip
+    )
     registry.register_hook(hook, _TEACACHE_LEADER_BLOCK_HOOK)
 
 

diff --git a/tests/models/testing_utils/cache.py b/tests/models/testing_utils/cache.py
--- a/tests/models/testing_utils/cache.py
+++ b/tests/models/testing_utils/cache.py
@@ -643,10 +643,11 @@
     """
 
     # Default TeaCache config - can be overridden by subclasses.
-    # Uses num_inference_steps=4 so interior steps can be skipped during _test_cache_inference.
+    # Uses num_inference_steps=4 and an infinite rel_l1_thresh so the second
+    # inference step is always skipped, which is required by _test_cache_inference.
     TEA_CACHE_CONFIG = {
         "num_inference_steps": 4,
-        "rel_l1_thresh": 100.0,
+        "rel_l1_thresh": float("inf"),
     }
 
     def _get_cache_config(self):

You can send follow-ups to the cloud agent here.

Reviewed by Cursor Bugbot for commit 111bcf4. Configure here.

TEA_CACHE_CONFIG = {
"num_inference_steps": 4,
"rel_l1_thresh": 100.0,
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mixin threshold blocks inference skip

Medium Severity

TeaCacheConfigMixin uses rel_l1_thresh=100.0 assuming the second _test_cache_inference pass will skip, but _test_cache_inference perturbs hidden_states with randn_like, which often pushes the polynomial-rescaled L1 above 100 so step 1 fully recomputes. Cached and uncached outputs then match and the test fails.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 111bcf4. Configure here.

logger.info(f"TeaCache: Applying Head+Tail Hooks to single block '{name}'")
_apply_teacache_block_hook(block, state_manager, config, is_tail=True)
_apply_teacache_head_hook(block, state_manager, config, extract_modulated_input, coefficients)
return

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Single block skip stalls step

Low Severity

When apply_teacache finds only one transformer block, the head hook is outermost and returns early on skip without invoking the co-located tail block hook, so _advance_step never runs and step_index stays stuck across forwards.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 111bcf4. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant