@@ -132,11 +132,13 @@ def __init__(
132132 config : TeaCacheConfig ,
133133 extract_modulated_input : Callable ,
134134 coefficients : List [float ],
135+ advance_step_on_skip : bool = False ,
135136 ):
136137 self .state_manager = state_manager
137138 self .config = config
138139 self .extract_modulated_input = extract_modulated_input
139140 self .coefficients = coefficients
141+ self .advance_step_on_skip = advance_step_on_skip
140142 self ._metadata = None
141143
142144 def initialize_hook (self , module ):
@@ -180,6 +182,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
180182 if not should_compute :
181183 logger .debug (f"TeaCache: Skipping step { state .step_index } " )
182184
185+ if self .advance_step_on_skip :
186+ self ._advance_step (state )
187+
183188 output = hidden_states
184189 res = state .previous_residual
185190
@@ -230,6 +235,14 @@ def reset_state(self, module):
230235 self .state_manager .reset ()
231236 return module
232237
238+ def _advance_step (self , state : TeaCacheState ):
239+ state .step_index += 1
240+ if state .step_index >= self .config .num_inference_steps :
241+ state .step_index = 0
242+ state .accumulated_distance = 0.0
243+ state .previous_residual = None
244+ state .previous_modulated_input = None
245+
233246
234247class TeaCacheBlockHook (ModelHook ):
235248 def __init__ (self , state_manager : StateManager , is_tail : bool = False , config : TeaCacheConfig = None ):
@@ -350,7 +363,9 @@ def apply_teacache(module: torch.nn.Module, config: TeaCacheConfig) -> None:
350363 name , block = remaining_blocks [0 ]
351364 logger .info (f"TeaCache: Applying Head+Tail Hooks to single block '{ name } '" )
352365 _apply_teacache_block_hook (block , state_manager , config , is_tail = True )
353- _apply_teacache_head_hook (block , state_manager , config , extract_modulated_input , coefficients )
366+ _apply_teacache_head_hook (
367+ block , state_manager , config , extract_modulated_input , coefficients , advance_step_on_skip = True
368+ )
354369 return
355370
356371 head_block_name , head_block = remaining_blocks .pop (0 )
@@ -372,13 +387,16 @@ def _apply_teacache_head_hook(
372387 config : TeaCacheConfig ,
373388 extract_modulated_input : Callable ,
374389 coefficients : List [float ],
390+ advance_step_on_skip : bool = False ,
375391) -> None :
376392 registry = HookRegistry .check_if_exists_or_initialize (block )
377393
378394 if registry .get_hook (_TEACACHE_LEADER_BLOCK_HOOK ) is not None :
379395 registry .remove_hook (_TEACACHE_LEADER_BLOCK_HOOK )
380396
381- hook = TeaCacheHeadHook (state_manager , config , extract_modulated_input , coefficients )
397+ hook = TeaCacheHeadHook (
398+ state_manager , config , extract_modulated_input , coefficients , advance_step_on_skip = advance_step_on_skip
399+ )
382400 registry .register_hook (hook , _TEACACHE_LEADER_BLOCK_HOOK )
383401
384402
0 commit comments