Add FLUX-only TeaCache inference cache hook#12
Conversation
Co-authored-by: Cursor <cursoragent@cursor.com>
Completes public registration for TeaCache constants and keeps dummy_pt_objects in sync. Co-authored-by: Cursor <cursoragent@cursor.com>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 2 potential issues.
Bugbot Autofix prepared fixes for both issues found in the latest run.
- ✅ Fixed: Mixin threshold blocks inference skip
- Changed TeaCacheConfigMixin rel_l1_thresh to float('inf') so polynomial-rescaled L1 from randn perturbations cannot block the required second-pass skip in _test_cache_inference.
- ✅ Fixed: Single block skip stalls step
- TeaCacheHeadHook now advances step_index on skip when advance_step_on_skip=True for the single-block apply_teacache path where the tail block hook is bypassed.
Or push these changes by commenting:
@cursor push 84b4512cc4
Preview (84b4512cc4)
diff --git a/src/diffusers/hooks/teacache.py b/src/diffusers/hooks/teacache.py
--- a/src/diffusers/hooks/teacache.py
+++ b/src/diffusers/hooks/teacache.py
@@ -132,11 +132,13 @@
config: TeaCacheConfig,
extract_modulated_input: Callable,
coefficients: List[float],
+ advance_step_on_skip: bool = False,
):
self.state_manager = state_manager
self.config = config
self.extract_modulated_input = extract_modulated_input
self.coefficients = coefficients
+ self.advance_step_on_skip = advance_step_on_skip
self._metadata = None
def initialize_hook(self, module):
@@ -180,6 +182,9 @@
if not should_compute:
logger.debug(f"TeaCache: Skipping step {state.step_index}")
+ if self.advance_step_on_skip:
+ self._advance_step(state)
+
output = hidden_states
res = state.previous_residual
@@ -230,7 +235,15 @@
self.state_manager.reset()
return module
+ def _advance_step(self, state: TeaCacheState):
+ state.step_index += 1
+ if state.step_index >= self.config.num_inference_steps:
+ state.step_index = 0
+ state.accumulated_distance = 0.0
+ state.previous_residual = None
+ state.previous_modulated_input = None
+
class TeaCacheBlockHook(ModelHook):
def __init__(self, state_manager: StateManager, is_tail: bool = False, config: TeaCacheConfig = None):
super().__init__()
@@ -350,7 +363,9 @@
name, block = remaining_blocks[0]
logger.info(f"TeaCache: Applying Head+Tail Hooks to single block '{name}'")
_apply_teacache_block_hook(block, state_manager, config, is_tail=True)
- _apply_teacache_head_hook(block, state_manager, config, extract_modulated_input, coefficients)
+ _apply_teacache_head_hook(
+ block, state_manager, config, extract_modulated_input, coefficients, advance_step_on_skip=True
+ )
return
head_block_name, head_block = remaining_blocks.pop(0)
@@ -372,13 +387,16 @@
config: TeaCacheConfig,
extract_modulated_input: Callable,
coefficients: List[float],
+ advance_step_on_skip: bool = False,
) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
if registry.get_hook(_TEACACHE_LEADER_BLOCK_HOOK) is not None:
registry.remove_hook(_TEACACHE_LEADER_BLOCK_HOOK)
- hook = TeaCacheHeadHook(state_manager, config, extract_modulated_input, coefficients)
+ hook = TeaCacheHeadHook(
+ state_manager, config, extract_modulated_input, coefficients, advance_step_on_skip=advance_step_on_skip
+ )
registry.register_hook(hook, _TEACACHE_LEADER_BLOCK_HOOK)
diff --git a/tests/models/testing_utils/cache.py b/tests/models/testing_utils/cache.py
--- a/tests/models/testing_utils/cache.py
+++ b/tests/models/testing_utils/cache.py
@@ -643,10 +643,11 @@
"""
# Default TeaCache config - can be overridden by subclasses.
- # Uses num_inference_steps=4 so interior steps can be skipped during _test_cache_inference.
+ # Uses num_inference_steps=4 and an infinite rel_l1_thresh so the second
+ # inference step is always skipped, which is required by _test_cache_inference.
TEA_CACHE_CONFIG = {
"num_inference_steps": 4,
- "rel_l1_thresh": 100.0,
+ "rel_l1_thresh": float("inf"),
}
def _get_cache_config(self):You can send follow-ups to the cloud agent here.
Reviewed by Cursor Bugbot for commit 111bcf4. Configure here.
| TEA_CACHE_CONFIG = { | ||
| "num_inference_steps": 4, | ||
| "rel_l1_thresh": 100.0, | ||
| } |
There was a problem hiding this comment.
Mixin threshold blocks inference skip
Medium Severity
TeaCacheConfigMixin uses rel_l1_thresh=100.0 assuming the second _test_cache_inference pass will skip, but _test_cache_inference perturbs hidden_states with randn_like, which often pushes the polynomial-rescaled L1 above 100 so step 1 fully recomputes. Cached and uncached outputs then match and the test fails.
Reviewed by Cursor Bugbot for commit 111bcf4. Configure here.
| logger.info(f"TeaCache: Applying Head+Tail Hooks to single block '{name}'") | ||
| _apply_teacache_block_hook(block, state_manager, config, is_tail=True) | ||
| _apply_teacache_head_hook(block, state_manager, config, extract_modulated_input, coefficients) | ||
| return |
There was a problem hiding this comment.
Single block skip stalls step
Low Severity
When apply_teacache finds only one transformer block, the head hook is outermost and returns early on skip without invoking the co-located tail block hook, so _advance_step never runs and step_index stays stuck across forwards.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit 111bcf4. Configure here.



Summary
TeaCacheConfig+apply_teacachewired throughCacheMixin.enable_cache()FLUX_TEACACHE_COEFFICIENTSvendored from TeaCache4FLUXTest plan
tests/hooks/test_teacache.py)TestFluxTransformerTeaCachemixin tests (CI)check_copies.pycheck_dummies.pyNote
Medium Risk
Hooks alter the denoising forward path for FLUX inference; wrong skips could affect image quality, though boundary steps always compute and unsupported models raise explicitly.
Overview
Adds TeaCache as a new inference cache for FLUX (
FluxTransformer2DModelonly in v1), wired like other caches viaTeaCacheConfig,apply_teacache, andtransformer.enable_cache().At the first transformer block, the hook compares consecutive steps using a polynomial-rescaled relative L1 distance on the modulated input (FLUX coeffs in
FLUX_TEACACHE_COEFFICIENTS). When the accumulated score stays belowrel_l1_thresh, middle blocks are bypassed and the last step’s full-stack residual is replayed; first and last denoising steps always run.Docs cover usage and a comparison with FirstBlockCache / MagCache. Tests include hook unit tests and
TeaCacheTesterMixinon the Flux transformer.Reviewed by Cursor Bugbot for commit 111bcf4. Bugbot is set up for automated code reviews on this repo. Configure here.