Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/guiders/frequency_decoupled_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = No
pred_guided_pyramid.append(pred)
else:
# Add the current pred_cond_pyramid level as the "non-FDG" prediction
pred_guided_pyramid.append(pred_cond_freq)
pred_guided_pyramid.append(pred_cond_pyramid[level])

# Convert from frequency space back to data (e.g. pixel) space by applying inverse freq transform
pred = build_image_from_pyramid(pred_guided_pyramid)
Expand Down
46 changes: 27 additions & 19 deletions src/diffusers/hooks/mag_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,29 @@ def reset(self):
self.calibration_ratios = []


def _advance_mag_cache_step(state: MagCacheState, config: MagCacheConfig) -> None:
state.step_index += 1
if state.step_index >= config.num_inference_steps:
if config.calibrate:
print("\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):")
print(f"{state.calibration_ratios}\n")
logger.info(f"MagCache Calibration Results: {state.calibration_ratios}")

state.step_index = 0
state.accumulated_ratio = 1.0
state.accumulated_steps = 0
state.accumulated_err = 0.0
state.previous_residual = None
state.calibration_ratios = []


class MagCacheHeadHook(ModelHook):
_is_stateful = True

def __init__(self, state_manager: StateManager, config: MagCacheConfig):
def __init__(self, state_manager: StateManager, config: MagCacheConfig, advance_on_skip: bool = False):
self.state_manager = state_manager
self.config = config
self.advance_on_skip = advance_on_skip
self._metadata = None

def initialize_hook(self, module):
Expand Down Expand Up @@ -260,6 +277,9 @@ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
"Cannot apply residual safely. Returning input without residual."
)

if self.advance_on_skip:
_advance_mag_cache_step(state, self.config)

if self._metadata.return_encoder_hidden_states_index is not None:
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
"encoder_hidden_states", args, kwargs
Expand Down Expand Up @@ -377,21 +397,7 @@ def _perform_calibration_step(self, state: MagCacheState, current_residual: torc
state.calibration_ratios.append(ratio)

def _advance_step(self, state: MagCacheState):
state.step_index += 1
if state.step_index >= self.config.num_inference_steps:
# End of inference loop
if self.config.calibrate:
print("\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):")
print(f"{state.calibration_ratios}\n")
logger.info(f"MagCache Calibration Results: {state.calibration_ratios}")

# Reset state
state.step_index = 0
state.accumulated_ratio = 1.0
state.accumulated_steps = 0
state.accumulated_err = 0.0
state.previous_residual = None
state.calibration_ratios = []
_advance_mag_cache_step(state, self.config)


def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
Expand Down Expand Up @@ -425,7 +431,7 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
name, block = remaining_blocks[0]
logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'")
_apply_mag_cache_block_hook(block, state_manager, config, is_tail=True)
_apply_mag_cache_head_hook(block, state_manager, config)
_apply_mag_cache_head_hook(block, state_manager, config, advance_on_skip=True)
return

head_block_name, head_block = remaining_blocks.pop(0)
Expand All @@ -441,14 +447,16 @@ def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
_apply_mag_cache_block_hook(tail_block, state_manager, config, is_tail=True)


def _apply_mag_cache_head_hook(block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig) -> None:
def _apply_mag_cache_head_hook(
block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig, advance_on_skip: bool = False
) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)

# Automatically remove existing hook to allow re-application (e.g. switching modes)
if registry.get_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) is not None:
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK)

hook = MagCacheHeadHook(state_manager, config)
hook = MagCacheHeadHook(state_manager, config, advance_on_skip=advance_on_skip)
registry.register_hook(hook, _MAG_CACHE_LEADER_BLOCK_HOOK)


Expand Down
36 changes: 36 additions & 0 deletions tests/guiders/test_frequency_decoupled_guidance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers.guiders.frequency_decoupled_guidance import FrequencyDecoupledGuidance


pytest.importorskip("kornia")


def test_frequency_decoupled_guidance_disabled_level_uses_cond_pyramid():
guider = FrequencyDecoupledGuidance(
guidance_scales=[0.0, 7.5],
use_original_formulation=True,
)
guider.set_state(step=0, num_inference_steps=10, timestep=torch.tensor([500]))
pred_cond = torch.randn(1, 3, 32, 32)
pred_uncond = torch.randn(1, 3, 32, 32)

output = guider.forward(pred_cond, pred_uncond)

assert output.pred is not None
assert output.pred.shape == pred_cond.shape
32 changes: 32 additions & 0 deletions tests/hooks/test_mag_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,35 @@ def test_mag_cache_calibration():
# Let's ensure list is empty after reset (end of step 1)
ratios_after = _get_calibration_data(model)
assert ratios_after == []


class SingleBlockDummyTransformer(ModelMixin):
def __init__(self):
super().__init__()
self.transformer_blocks = torch.nn.ModuleList([DummyBlock()])

def forward(self, hidden_states, encoder_hidden_states=None):
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
return hidden_states


def test_mag_cache_single_block_step_reset():
"""Single-block models must advance step_index on cache skips."""
model = SingleBlockDummyTransformer()
config = MagCacheConfig(
threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0])
)
apply_mag_cache(model, config)
_set_context(model, "test_context")

out0 = model(torch.tensor([[[10.0]]]))
assert torch.allclose(out0, torch.tensor([[[20.0]]]))

out1 = model(torch.tensor([[[11.0]]]))
assert torch.allclose(out1, torch.tensor([[[21.0]]]))

out2 = model(torch.tensor([[[12.0]]]))
assert torch.allclose(out2, torch.tensor([[[24.0]]])), (
f"Expected compute after reset (24.0), got {out2.item()}"
)
3 changes: 2 additions & 1 deletion utils/check_copies.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import re
import subprocess
import sys


# All paths are set with the intent you should run this script from the root of the repo with the command
Expand Down Expand Up @@ -94,7 +95,7 @@ def get_indent(code):


def run_ruff(code):
command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
command = [sys.executable, "-m", "ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
stdout, _ = process.communicate(input=code.encode())
return stdout.decode()
Expand Down