Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions tests/hooks/test_first_block_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from diffusers.models import ModelMixin


class DummyBlock(torch.nn.Module):
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
return hidden_states * 2.0


class DummyTransformer(ModelMixin):
def __init__(self):
super().__init__()
self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()])

def forward(self, hidden_states, encoder_hidden_states=None):
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
return hidden_states


class TupleOutputBlock(torch.nn.Module):
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
return hidden_states * 2.0, encoder_hidden_states


class TupleTransformer(ModelMixin):
def __init__(self):
super().__init__()
self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock(), TupleOutputBlock()])

def forward(self, hidden_states, encoder_hidden_states=None):
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
return hidden_states, encoder_hidden_states


@pytest.fixture(autouse=True)
def register_dummy_blocks():
TransformerBlockRegistry.register(
DummyBlock,
TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None),
)
TransformerBlockRegistry.register(
TupleOutputBlock,
TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1),
)


def _set_context(model, context_name):
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook._set_context(context_name)


def test_first_block_cache_skipping_logic():
"""
FirstBlockCache skips tail blocks when the first-block residual change is below threshold.
"""
model = DummyTransformer()
apply_first_block_cache(model, FirstBlockCacheConfig(threshold=1.0))
_set_context(model, "test_context")

# Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each). Head residual = 10, tail residual = 20.
input_t0 = torch.tensor([[[10.0]]])
output_t0 = model(input_t0)
assert torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed"

# Step 1: Input 11.0.
# If skipped: head output 22 + tail residual 20 = 42.0
# If computed: 11 * 4 = 44.0
input_t1 = torch.tensor([[[11.0]]])
output_t1 = model(input_t1)
assert torch.allclose(output_t1, torch.tensor([[[42.0]]])), f"Expected skip (42.0), got {output_t1.item()}"


def test_first_block_cache_compute_when_residual_changes():
"""A low threshold forces full recomputation when the first-block residual shifts."""
model = DummyTransformer()
apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.01))
_set_context(model, "test_context")

model(torch.tensor([[[10.0]]]))

input_t1 = torch.tensor([[[11.0]]])
output_t1 = model(input_t1)
assert torch.allclose(output_t1, torch.tensor([[[44.0]]])), (
f"Expected compute (44.0) due to low threshold, got {output_t1.item()}"
)


def test_first_block_cache_tuple_outputs():
"""Test compatibility with models returning (hidden, encoder_hidden) like Flux."""
model = TupleTransformer()
apply_first_block_cache(model, FirstBlockCacheConfig(threshold=1.0))
_set_context(model, "test_context")

input_t0 = torch.tensor([[[10.0]]])
enc_t0 = torch.tensor([[[1.0]]])
out_0, _ = model(input_t0, encoder_hidden_states=enc_t0)
assert torch.allclose(out_0, torch.tensor([[[40.0]]]))

# Step 1: Skip. Input 11.0 -> head output 22 + tail residual 20 = 42.0
input_t1 = torch.tensor([[[11.0]]])
out_1, _ = model(input_t1, encoder_hidden_states=enc_t0)
assert torch.allclose(out_1, torch.tensor([[[42.0]]])), f"Tuple skip failed. Expected 42.0, got {out_1.item()}"


def test_first_block_cache_first_pass_always_computes():
"""The first forward pass must always run all blocks to populate tail residuals."""
model = DummyTransformer()
apply_first_block_cache(model, FirstBlockCacheConfig(threshold=1.0))
_set_context(model, "test_context")

input_t0 = torch.tensor([[[5.0]]])
output_t0 = model(input_t0)
assert torch.allclose(output_t0, torch.tensor([[[20.0]]]))


def test_first_block_cache_large_input_change_recomputes():
"""Large input changes exceed the threshold and trigger full recomputation."""
model = DummyTransformer()
apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.05))
_set_context(model, "test_context")

model(torch.tensor([[[10.0]]]))

input_t1 = torch.tensor([[[20.0]]])
output_t1 = model(input_t1)
assert torch.allclose(output_t1, torch.tensor([[[80.0]]])), (
f"Expected full compute (80.0) for large input change, got {output_t1.item()}"
)
103 changes: 103 additions & 0 deletions tests/others/test_remote_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import json
import unittest
from unittest.mock import Mock

import torch
from PIL import Image

from diffusers.utils.remote_utils import (
check_inputs_decode,
detect_image_type,
postprocess_decode,
prepare_decode,
prepare_encode,
)


class RemoteUtilsTest(unittest.TestCase):
def test_detect_image_type(self):
self.assertEqual(detect_image_type(b"\xff\xd8\xff"), "jpeg")
self.assertEqual(detect_image_type(b"\x89PNG\r\n\x1a\n"), "png")
self.assertEqual(detect_image_type(b"GIF89a"), "gif")
self.assertEqual(detect_image_type(b"BM"), "bmp")
self.assertEqual(detect_image_type(b"unknown"), "unknown")

def test_check_inputs_decode_packed_latents_requires_hw(self):
tensor = torch.randn(4, 8, 8)
with self.assertRaises(ValueError):
check_inputs_decode("http://example.com", tensor)

def test_check_inputs_decode_processor_required(self):
tensor = torch.randn(1, 4, 8, 8)
with self.assertRaises(ValueError):
check_inputs_decode(
"http://example.com",
tensor,
processor=None,
output_type="pt",
return_type="pil",
partial_postprocess=False,
)

def test_prepare_decode_sets_accept_header_for_jpeg(self):
tensor = torch.randn(1, 4, 8, 8, dtype=torch.float16)
payload = prepare_decode(tensor, output_type="pil", image_format="jpg")
self.assertEqual(payload["headers"]["Accept"], "image/jpeg")
self.assertEqual(payload["params"]["output_type"], "pil")
self.assertEqual(payload["params"]["shape"], list(tensor.shape))

def test_prepare_encode_tensor_includes_shape_and_dtype(self):
tensor = torch.randn(1, 3, 8, 8, dtype=torch.float16)
payload = prepare_encode(tensor, scaling_factor=0.18215)
self.assertEqual(payload["params"]["shape"], list(tensor.shape))
self.assertEqual(payload["params"]["dtype"], "float16")
self.assertEqual(payload["params"]["scaling_factor"], 0.18215)

def test_prepare_encode_pil_image(self):
image = Image.new("RGB", (8, 8), color="red")
payload = prepare_encode(image)
self.assertIn(b"PNG", payload["data"][:8])

def test_postprocess_decode_pil_without_processor(self):
buffer = io.BytesIO()
Image.new("RGB", (4, 4), color="blue").save(buffer, format="PNG")
response = Mock()
response.content = buffer.getvalue()

output = postprocess_decode(response, processor=None, output_type="pil", return_type="pil")
self.assertIsInstance(output, Image.Image)
self.assertEqual(output.size, (4, 4))
self.assertEqual(output.format, "png")

def test_postprocess_decode_pt_tensor(self):
tensor = torch.arange(16, dtype=torch.float32).reshape(1, 4, 2, 2)
response = Mock()
response.content = tensor.numpy().tobytes()
response.headers = {
"shape": json.dumps(list(tensor.shape)),
"dtype": "float32",
}

output = postprocess_decode(
response,
processor=None,
output_type="pt",
return_type="pt",
partial_postprocess=False,
)
torch.testing.assert_close(output, tensor)
110 changes: 110 additions & 0 deletions tests/others/test_state_dict_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from diffusers.utils.state_dict_utils import (
StateDictType,
convert_all_state_dict_to_peft,
convert_state_dict,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
state_dict_all_zero,
)


class StateDictUtilsTest(unittest.TestCase):
def test_convert_state_dict_applies_first_matching_pattern(self):
state_dict = {"layer.processor.weight": torch.ones(1)}
converted = convert_state_dict(state_dict, {".processor.": "."})
self.assertIn("layer.weight", converted)
self.assertNotIn("layer.processor.weight", converted)

def test_convert_state_dict_to_peft_auto_infers_diffusers_old(self):
state_dict = {
"unet.down_blocks.0.attentions.0.to_out_lora.down.weight": torch.ones(2, 2),
"unet.down_blocks.0.attentions.0.to_out_lora.up.weight": torch.ones(2, 2),
}
converted = convert_state_dict_to_peft(state_dict)
self.assertIn("unet.down_blocks.0.attentions.0.out_proj.lora_A.weight", converted)

state_dict = {
"unet.down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2),
"unet.down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2),
}
converted = convert_state_dict_to_peft(state_dict, original_type=StateDictType.DIFFUSERS_OLD)
self.assertIn("unet.down_blocks.0.attentions.0.q_proj.lora_A.weight", converted)
self.assertIn("unet.down_blocks.0.attentions.0.q_proj.lora_B.weight", converted)

def test_convert_state_dict_to_peft_diffusers(self):
state_dict = {
"text_encoder.encoder.layers.0.self_attn.q_proj.lora_linear_layer.down.weight": torch.ones(2, 2),
"text_encoder.encoder.layers.0.self_attn.q_proj.lora_linear_layer.up.weight": torch.ones(2, 2),
}
converted = convert_state_dict_to_peft(state_dict, original_type=StateDictType.DIFFUSERS)
self.assertIn("text_encoder.encoder.layers.0.self_attn.q_proj.lora_A.weight", converted)
self.assertIn("text_encoder.encoder.layers.0.self_attn.q_proj.lora_B.weight", converted)

def test_convert_state_dict_to_diffusers_from_peft(self):
state_dict = {
"unet.down_blocks.0.attentions.0.to_q.lora_A.weight": torch.ones(2, 2),
"unet.down_blocks.0.attentions.0.to_q.lora_B.weight": torch.ones(2, 2),
}
converted = convert_state_dict_to_diffusers(state_dict, original_type=StateDictType.PEFT)
self.assertIn("unet.down_blocks.0.attentions.0.to_q.lora.down.weight", converted)
self.assertIn("unet.down_blocks.0.attentions.0.to_q.lora.up.weight", converted)

def test_convert_state_dict_to_diffusers_already_diffusers(self):
state_dict = {
"layer.lora_linear_layer.down.weight": torch.ones(2, 2),
"layer.lora_linear_layer.up.weight": torch.ones(2, 2),
}
converted = convert_state_dict_to_diffusers(state_dict)
self.assertIs(converted, state_dict)

def test_convert_unet_state_dict_to_peft(self):
state_dict = {
"down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2),
"down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2),
}
converted = convert_unet_state_dict_to_peft(state_dict)
self.assertIn("down_blocks.0.attentions.0.to_q.lora_A.weight", converted)
self.assertIn("down_blocks.0.attentions.0.to_q.lora_B.weight", converted)

def test_convert_all_state_dict_to_peft_falls_back_to_unet(self):
state_dict = {
"down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2),
"down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2),
}
converted = convert_all_state_dict_to_peft(state_dict)
self.assertTrue(any("lora_A" in key or "lora_B" in key for key in converted))

def test_state_dict_all_zero(self):
state_dict = {
"a": torch.zeros(2, 2),
"b": torch.zeros(3),
}
self.assertTrue(state_dict_all_zero(state_dict))
state_dict["b"] = torch.ones(3)
self.assertFalse(state_dict_all_zero(state_dict))

def test_state_dict_all_zero_with_filter(self):
state_dict = {
"lora.down": torch.zeros(2, 2),
"bias": torch.ones(3),
}
self.assertTrue(state_dict_all_zero(state_dict, filter_str="lora"))
Loading