Skip to content

Commit 84b4512

Browse files
committed
Fix TeaCache test threshold and single-block step advancement on skip
- Use float('inf') rel_l1_thresh in TeaCacheConfigMixin so _test_cache_inference reliably skips on the second pass despite large randn perturbations - Advance step_index from TeaCacheHeadHook when skipping on single-block models where the co-located tail block hook is bypassed
1 parent 111bcf4 commit 84b4512

2 files changed

Lines changed: 23 additions & 4 deletions

File tree

src/diffusers/hooks/teacache.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,13 @@ def __init__(
132132
config: TeaCacheConfig,
133133
extract_modulated_input: Callable,
134134
coefficients: List[float],
135+
advance_step_on_skip: bool = False,
135136
):
136137
self.state_manager = state_manager
137138
self.config = config
138139
self.extract_modulated_input = extract_modulated_input
139140
self.coefficients = coefficients
141+
self.advance_step_on_skip = advance_step_on_skip
140142
self._metadata = None
141143

142144
def initialize_hook(self, module):
@@ -180,6 +182,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
180182
if not should_compute:
181183
logger.debug(f"TeaCache: Skipping step {state.step_index}")
182184

185+
if self.advance_step_on_skip:
186+
self._advance_step(state)
187+
183188
output = hidden_states
184189
res = state.previous_residual
185190

@@ -230,6 +235,14 @@ def reset_state(self, module):
230235
self.state_manager.reset()
231236
return module
232237

238+
def _advance_step(self, state: TeaCacheState):
239+
state.step_index += 1
240+
if state.step_index >= self.config.num_inference_steps:
241+
state.step_index = 0
242+
state.accumulated_distance = 0.0
243+
state.previous_residual = None
244+
state.previous_modulated_input = None
245+
233246

234247
class TeaCacheBlockHook(ModelHook):
235248
def __init__(self, state_manager: StateManager, is_tail: bool = False, config: TeaCacheConfig = None):
@@ -350,7 +363,9 @@ def apply_teacache(module: torch.nn.Module, config: TeaCacheConfig) -> None:
350363
name, block = remaining_blocks[0]
351364
logger.info(f"TeaCache: Applying Head+Tail Hooks to single block '{name}'")
352365
_apply_teacache_block_hook(block, state_manager, config, is_tail=True)
353-
_apply_teacache_head_hook(block, state_manager, config, extract_modulated_input, coefficients)
366+
_apply_teacache_head_hook(
367+
block, state_manager, config, extract_modulated_input, coefficients, advance_step_on_skip=True
368+
)
354369
return
355370

356371
head_block_name, head_block = remaining_blocks.pop(0)
@@ -372,13 +387,16 @@ def _apply_teacache_head_hook(
372387
config: TeaCacheConfig,
373388
extract_modulated_input: Callable,
374389
coefficients: List[float],
390+
advance_step_on_skip: bool = False,
375391
) -> None:
376392
registry = HookRegistry.check_if_exists_or_initialize(block)
377393

378394
if registry.get_hook(_TEACACHE_LEADER_BLOCK_HOOK) is not None:
379395
registry.remove_hook(_TEACACHE_LEADER_BLOCK_HOOK)
380396

381-
hook = TeaCacheHeadHook(state_manager, config, extract_modulated_input, coefficients)
397+
hook = TeaCacheHeadHook(
398+
state_manager, config, extract_modulated_input, coefficients, advance_step_on_skip=advance_step_on_skip
399+
)
382400
registry.register_hook(hook, _TEACACHE_LEADER_BLOCK_HOOK)
383401

384402

tests/models/testing_utils/cache.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,10 +643,11 @@ class TeaCacheConfigMixin:
643643
"""
644644

645645
# Default TeaCache config - can be overridden by subclasses.
646-
# Uses num_inference_steps=4 so interior steps can be skipped during _test_cache_inference.
646+
# Uses num_inference_steps=4 and an infinite rel_l1_thresh so the second
647+
# inference step is always skipped, which is required by _test_cache_inference.
647648
TEA_CACHE_CONFIG = {
648649
"num_inference_steps": 4,
649-
"rel_l1_thresh": 100.0,
650+
"rel_l1_thresh": float("inf"),
650651
}
651652

652653
def _get_cache_config(self):

0 commit comments

Comments
 (0)