From 9d639e152107da2dd10d7721a58ccb3f7408564c Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 24 Jun 2026 22:06:51 +0000 Subject: [PATCH] Fix FasterCache crash and SmoothedEnergyGuidance hook install bugs - Guard FasterCacheBlockHook attention skip until cache is populated; guidance-distilled models previously crashed on the first forward inside the skip timestep range with TypeError on None cache. - Add missing _count_prepared increment in SmoothedEnergyGuidance.prepare_models so SEG hooks are actually installed during inference. - Pass indices=[layer] when building SmoothedEnergyGuidanceConfig from seg_guidance_layers shorthand (same int-vs-list bug as LayerSkipConfig). - Invoke ruff via python -m in check_copies.py so copy checking works when ruff is not on PATH. Add regression tests for both fixes. Co-authored-by: Simon Lynch --- .../guiders/smoothed_energy_guidance.py | 5 +- src/diffusers/hooks/faster_cache.py | 4 +- .../guiders/test_smoothed_energy_guidance.py | 75 +++++++++++++++++++ tests/hooks/test_faster_cache.py | 44 +++++++++++ utils/check_copies.py | 3 +- 5 files changed, 128 insertions(+), 3 deletions(-) create mode 100644 tests/guiders/test_smoothed_energy_guidance.py create mode 100644 tests/hooks/test_faster_cache.py diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index 4767607..c4aabf8 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -126,7 +126,9 @@ def __init__( raise ValueError( f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}." ) - seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers] + seg_guidance_config = [ + SmoothedEnergyGuidanceConfig(indices=[layer], fqn="auto") for layer in seg_guidance_layers + ] if isinstance(seg_guidance_config, dict): seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config) @@ -145,6 +147,7 @@ def __init__( self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))] def prepare_models(self, denoiser: torch.nn.Module) -> None: + self._count_prepared += 1 if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1: for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config): _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name) diff --git a/src/diffusers/hooks/faster_cache.py b/src/diffusers/hooks/faster_cache.py index 682cebe..baef2b7 100644 --- a/src/diffusers/hooks/faster_cache.py +++ b/src/diffusers/hooks/faster_cache.py @@ -428,7 +428,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0 should_skip_attention = not should_compute_attention if should_skip_attention: - should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size + should_skip_attention = self.state.cache is not None and ( + self.is_guidance_distilled or self.state.batch_size != batch_size + ) if should_skip_attention: logger.debug("FasterCache - Skipping attention and using approximation") diff --git a/tests/guiders/test_smoothed_energy_guidance.py b/tests/guiders/test_smoothed_energy_guidance.py new file mode 100644 index 0000000..5fdf03b --- /dev/null +++ b/tests/guiders/test_smoothed_energy_guidance.py @@ -0,0 +1,75 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +import torch + +from diffusers.guiders.smoothed_energy_guidance import SmoothedEnergyGuidance +from diffusers.models.attention import Attention +from diffusers.models.attention_processor import AttnProcessor + + +class DummyBlock(torch.nn.Module): + def __init__(self): + super().__init__() + self.attn1 = Attention( + query_dim=4, + cross_attention_dim=None, + heads=1, + dim_head=4, + processor=AttnProcessor(), + ) + + def forward(self, hidden_states): + return self.attn1(hidden_states) + + +class DummyTransformer(torch.nn.Module): + def __init__(self, num_blocks=10): + super().__init__() + self.transformer_blocks = torch.nn.ModuleList([DummyBlock() for _ in range(num_blocks)]) + + +def test_seg_shorthand_layers_builds_list_indices(): + guider = SmoothedEnergyGuidance( + guidance_scale=7.5, + seg_guidance_scale=2.8, + seg_guidance_layers=[7, 8, 9], + ) + assert guider.seg_guidance_config[0].indices == [7] + assert guider.seg_guidance_config[1].indices == [8] + assert guider.seg_guidance_config[2].indices == [9] + + +def test_seg_prepare_models_increments_count_and_installs_hooks_on_third_pass(): + guider = SmoothedEnergyGuidance( + guidance_scale=7.5, + seg_guidance_scale=2.8, + seg_guidance_layers=[7], + ) + guider.set_state(step=10, num_inference_steps=50, timestep=torch.tensor([500])) + + denoiser = DummyTransformer(num_blocks=10) + + guider.prepare_models(denoiser) + assert guider._count_prepared == 1 + + guider.prepare_models(denoiser) + assert guider._count_prepared == 2 + + with patch("diffusers.guiders.smoothed_energy_guidance._apply_smoothed_energy_guidance_hook") as mock_apply: + guider.prepare_models(denoiser) + assert guider._count_prepared == 3 + mock_apply.assert_called_once() diff --git a/tests/hooks/test_faster_cache.py b/tests/hooks/test_faster_cache.py new file mode 100644 index 0000000..9552cd1 --- /dev/null +++ b/tests/hooks/test_faster_cache.py @@ -0,0 +1,44 @@ +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from diffusers.hooks import HookRegistry +from diffusers.hooks.faster_cache import FasterCacheBlockHook + + +class DummyAttention(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * 2.0 + + +def test_guidance_distilled_first_forward_in_skip_range_does_not_crash(): + module = DummyAttention() + hook = FasterCacheBlockHook( + block_skip_range=2, + timestep_skip_range=(0, 1000), + is_guidance_distilled=True, + weight_callback=lambda _: 0.5, + current_timestep_callback=lambda: 500, + ) + registry = HookRegistry.check_if_exists_or_initialize(module) + registry.register_hook(hook, "faster_cache_block") + hook.initialize_hook(module) + + hidden_states = torch.randn(2, 4) + output = module(hidden_states) + assert output.shape == hidden_states.shape + + output = module(hidden_states) + assert output.shape == hidden_states.shape diff --git a/utils/check_copies.py b/utils/check_copies.py index 001366c..338a25d 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -18,6 +18,7 @@ import os import re import subprocess +import sys # All paths are set with the intent you should run this script from the root of the repo with the command @@ -94,7 +95,7 @@ def get_indent(code): def run_ruff(code): - command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"] + command = [sys.executable, "-m", "ruff", "format", "-", "--config", "pyproject.toml", "--silent"] process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE) stdout, _ = process.communicate(input=code.encode()) return stdout.decode()