Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/diffusers/guiders/smoothed_energy_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def __init__(
raise ValueError(
f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
)
seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
seg_guidance_config = [
SmoothedEnergyGuidanceConfig(indices=[layer], fqn="auto") for layer in seg_guidance_layers
]

if isinstance(seg_guidance_config, dict):
seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config)
Expand All @@ -145,6 +147,7 @@ def __init__(
self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]

def prepare_models(self, denoiser: torch.nn.Module) -> None:
self._count_prepared += 1
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/hooks/faster_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0
should_skip_attention = not should_compute_attention
if should_skip_attention:
should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size
should_skip_attention = self.state.cache is not None and (
self.is_guidance_distilled or self.state.batch_size != batch_size
)

if should_skip_attention:
logger.debug("FasterCache - Skipping attention and using approximation")
Expand Down
75 changes: 75 additions & 0 deletions tests/guiders/test_smoothed_energy_guidance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import patch

import torch

from diffusers.guiders.smoothed_energy_guidance import SmoothedEnergyGuidance
from diffusers.models.attention import Attention
from diffusers.models.attention_processor import AttnProcessor


class DummyBlock(torch.nn.Module):
def __init__(self):
super().__init__()
self.attn1 = Attention(
query_dim=4,
cross_attention_dim=None,
heads=1,
dim_head=4,
processor=AttnProcessor(),
)

def forward(self, hidden_states):
return self.attn1(hidden_states)


class DummyTransformer(torch.nn.Module):
def __init__(self, num_blocks=10):
super().__init__()
self.transformer_blocks = torch.nn.ModuleList([DummyBlock() for _ in range(num_blocks)])


def test_seg_shorthand_layers_builds_list_indices():
guider = SmoothedEnergyGuidance(
guidance_scale=7.5,
seg_guidance_scale=2.8,
seg_guidance_layers=[7, 8, 9],
)
assert guider.seg_guidance_config[0].indices == [7]
assert guider.seg_guidance_config[1].indices == [8]
assert guider.seg_guidance_config[2].indices == [9]


def test_seg_prepare_models_increments_count_and_installs_hooks_on_third_pass():
guider = SmoothedEnergyGuidance(
guidance_scale=7.5,
seg_guidance_scale=2.8,
seg_guidance_layers=[7],
)
guider.set_state(step=10, num_inference_steps=50, timestep=torch.tensor([500]))

denoiser = DummyTransformer(num_blocks=10)

guider.prepare_models(denoiser)
assert guider._count_prepared == 1

guider.prepare_models(denoiser)
assert guider._count_prepared == 2

with patch("diffusers.guiders.smoothed_energy_guidance._apply_smoothed_energy_guidance_hook") as mock_apply:
guider.prepare_models(denoiser)
assert guider._count_prepared == 3
mock_apply.assert_called_once()
44 changes: 44 additions & 0 deletions tests/hooks/test_faster_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from diffusers.hooks import HookRegistry
from diffusers.hooks.faster_cache import FasterCacheBlockHook


class DummyAttention(torch.nn.Module):
def forward(self, hidden_states):
return hidden_states * 2.0


def test_guidance_distilled_first_forward_in_skip_range_does_not_crash():
module = DummyAttention()
hook = FasterCacheBlockHook(
block_skip_range=2,
timestep_skip_range=(0, 1000),
is_guidance_distilled=True,
weight_callback=lambda _: 0.5,
current_timestep_callback=lambda: 500,
)
registry = HookRegistry.check_if_exists_or_initialize(module)
registry.register_hook(hook, "faster_cache_block")
hook.initialize_hook(module)

hidden_states = torch.randn(2, 4)
output = module(hidden_states)
assert output.shape == hidden_states.shape

output = module(hidden_states)
assert output.shape == hidden_states.shape
3 changes: 2 additions & 1 deletion utils/check_copies.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import re
import subprocess
import sys


# All paths are set with the intent you should run this script from the root of the repo with the command
Expand Down Expand Up @@ -94,7 +95,7 @@ def get_indent(code):


def run_ruff(code):
command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
command = [sys.executable, "-m", "ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
stdout, _ = process.communicate(input=code.encode())
return stdout.decode()
Expand Down