Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/diffusers/hooks/faster_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,11 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No
"https://github.com/huggingface/diffusers/issues."
)

if config.current_timestep_callback is None:
raise ValueError(
"The `current_timestep_callback` function must be provided in the configuration to apply FasterCache."
)

if config.attention_weight_callback is None:
# If the user has not provided a weight callback, we default to 0.5 for all timesteps.
# In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
Expand Down
11 changes: 11 additions & 0 deletions src/diffusers/hooks/first_block_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConf
for index, block in enumerate(submodule):
remaining_blocks.append((f"{name}.{index}", block))

if not remaining_blocks:
logger.warning("FirstBlockCache: No transformer blocks found to apply hooks.")
return

if len(remaining_blocks) == 1:
head_block_name, head_block = remaining_blocks[0]
logger.debug(f"Applying FirstBlockCache head+tail hooks to single block '{head_block_name}'")
_apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
_apply_fbc_block_hook(head_block, state_manager, is_tail=True)
return

head_block_name, head_block = remaining_blocks.pop(0)
tail_block_name, tail_block = remaining_blocks.pop(-1)

Expand Down
8 changes: 5 additions & 3 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1810,9 +1810,11 @@ def _load_pretrained_model(
save_offload_index(offload_index, offload_folder)
offload_index = None

if offload_state_dict:
load_offloaded_weights(model, state_dict_index, state_dict_folder)
shutil.rmtree(state_dict_folder)
if offload_state_dict and state_dict_index is not None and len(state_dict_index) > 0:
load_offloaded_weights(model, state_dict_index, state_dict_folder)

if offload_state_dict and state_dict_folder is not None:
shutil.rmtree(state_dict_folder)

if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
Expand Down
36 changes: 36 additions & 0 deletions tests/hooks/test_faster_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers.hooks.faster_cache import FasterCacheConfig, apply_faster_cache
from diffusers.models import ModelMixin


class DummyTransformer(ModelMixin):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)

def forward(self, hidden_states):
return self.linear(hidden_states)


def test_apply_faster_cache_requires_timestep_callback():
model = DummyTransformer()
config = FasterCacheConfig(spatial_attention_block_skip_range=2)

with pytest.raises(ValueError, match="current_timestep_callback"):
apply_faster_cache(model, config)
45 changes: 45 additions & 0 deletions tests/hooks/test_first_block_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from diffusers.models import ModelMixin


class DummyBlock(torch.nn.Module):
def forward(self, hidden_states, **kwargs):
return hidden_states * 2.0


class SingleBlockTransformer(ModelMixin):
def __init__(self):
super().__init__()
self.transformer_blocks = torch.nn.ModuleList([DummyBlock()])

def forward(self, hidden_states):
for block in self.transformer_blocks:
hidden_states = block(hidden_states)
return hidden_states


def test_apply_first_block_cache_single_block():
TransformerBlockRegistry.register(
DummyBlock,
TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None),
)

model = SingleBlockTransformer()
apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.05))
48 changes: 48 additions & 0 deletions tests/models/test_offload_state_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from diffusers import UNet2DModel


def test_offload_state_dict_loads_cpu_staged_weights(tmp_path):
torch.manual_seed(0)
model = UNet2DModel(
block_out_channels=(4, 8),
layers_per_block=1,
norm_num_groups=4,
sample_size=8,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
sample = torch.randn(1, 3, 8, 8)
timestep = torch.tensor([0])
reference_output = model(sample, timestep).sample
reference_weight = model.conv_in.weight.detach().clone()

model.save_pretrained(tmp_path)

loaded = UNet2DModel.from_pretrained(
tmp_path,
device_map={"": "cpu"},
offload_state_dict=True,
low_cpu_mem_usage=True,
)

assert torch.allclose(loaded.conv_in.weight, reference_weight)
loaded_output = loaded(sample, timestep).sample
assert torch.allclose(loaded_output, reference_output)
3 changes: 2 additions & 1 deletion utils/check_copies.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import re
import subprocess
import sys


# All paths are set with the intent you should run this script from the root of the repo with the command
Expand Down Expand Up @@ -94,7 +95,7 @@ def get_indent(code):


def run_ruff(code):
command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
command = [sys.executable, "-m", "ruff", "format", "-", "--config", "pyproject.toml", "--silent"]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
stdout, _ = process.communicate(input=code.encode())
return stdout.decode()
Expand Down