Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions tests/hooks/test_first_block_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from diffusers.models import ModelMixin


class DummyBlock(torch.nn.Module):
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
return hidden_states * 2.0


class DummyTransformer(ModelMixin):
def __init__(self):
super().__init__()
self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()])

def forward(self, hidden_states, encoder_hidden_states=None):
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
return hidden_states


class TupleOutputBlock(torch.nn.Module):
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
return hidden_states * 2.0, encoder_hidden_states


class TupleTransformer(ModelMixin):
def __init__(self):
super().__init__()
self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock(), TupleOutputBlock()])

def forward(self, hidden_states, encoder_hidden_states=None):
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
return hidden_states, encoder_hidden_states


def _set_context(model, context_name):
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook._set_context(context_name)


@pytest.fixture(autouse=True)
def register_dummy_blocks():
TransformerBlockRegistry.register(
DummyBlock,
TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None),
)
TransformerBlockRegistry.register(
TupleOutputBlock,
TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1),
)


def test_first_block_cache_skips_when_residual_is_stable():
"""When head-block residuals are similar, tail blocks should be skipped."""
model = DummyTransformer()
apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.05))
_set_context(model, "test_context")

input_t0 = torch.tensor([[[10.0]]])
output_t0 = model(input_t0)
assert torch.allclose(output_t0, torch.tensor([[[40.0]]]))

# Identical input -> residual diff is 0 -> skip tail block (40.0, not 44.0).
output_t1 = model(input_t0)
assert torch.allclose(output_t1, torch.tensor([[[40.0]]]))


def test_first_block_cache_recomputes_when_residual_changes():
"""When residuals exceed the threshold, the full block stack must run."""
model = DummyTransformer()
apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.05))
_set_context(model, "test_context")

model(torch.tensor([[[10.0]]]))

output_t1 = model(torch.tensor([[[11.0]]]))
assert torch.allclose(output_t1, torch.tensor([[[44.0]]]))


def test_first_block_cache_tuple_outputs():
"""First Block Cache must support tuple block outputs (Flux-style)."""
model = TupleTransformer()
apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.05))
_set_context(model, "test_context")

input_t0 = torch.tensor([[[10.0]]])
enc_t0 = torch.tensor([[[1.0]]])
out_0, _ = model(input_t0, encoder_hidden_states=enc_t0)
assert torch.allclose(out_0, torch.tensor([[[40.0]]]))

out_1, _ = model(input_t0, encoder_hidden_states=enc_t0)
assert torch.allclose(out_1, torch.tensor([[[40.0]]]))


def test_first_block_cache_recomputes_after_skip_when_input_changes():
"""A large input change after a cached step must trigger full recomputation."""
model = DummyTransformer()
apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.05))
_set_context(model, "test_context")

model(torch.tensor([[[10.0]]]))
model(torch.tensor([[[10.0]]]))

output = model(torch.tensor([[[12.0]]]))
assert torch.allclose(output, torch.tensor([[[48.0]]]))
119 changes: 119 additions & 0 deletions tests/others/test_loading_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import os
import tempfile
import unittest
from unittest.mock import Mock, patch

import PIL.Image
import torch
from torch import nn

from diffusers.utils.loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video


class LoadingUtilsTest(unittest.TestCase):
def test_load_image_pil_passthrough_converts_rgb(self):
image = PIL.Image.new("RGBA", (4, 4), color=(255, 0, 0, 128))
loaded = load_image(image)
self.assertEqual(loaded.mode, "RGB")
self.assertEqual(loaded.size, (4, 4))

def test_load_image_local_path(self):
image = PIL.Image.new("RGB", (8, 8), color="green")
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
image.save(tmp.name)
path = tmp.name
try:
loaded = load_image(path)
self.assertEqual(loaded.size, (8, 8))
self.assertEqual(loaded.mode, "RGB")
finally:
os.remove(path)

def test_load_image_invalid_path_raises(self):
with self.assertRaises(ValueError):
load_image("/path/that/does/not/exist.png")

def test_load_image_invalid_scheme_raises(self):
with self.assertRaises(ValueError):
load_image("ftp://example.com/image.png")

def test_load_image_invalid_type_raises(self):
with self.assertRaises(ValueError):
load_image(123)

def test_load_image_custom_convert_method(self):
image = PIL.Image.new("RGB", (4, 4), color="blue")

def to_grayscale(img):
return img.convert("L")

loaded = load_image(image, convert_method=to_grayscale)
self.assertEqual(loaded.mode, "L")

@patch("diffusers.utils.loading_utils.requests.get")
def test_load_image_from_url(self, mock_get):
buffer = io.BytesIO()
PIL.Image.new("RGB", (6, 6), color="red").save(buffer, format="PNG")
buffer.seek(0)
mock_response = Mock()
mock_response.raw = buffer
mock_get.return_value = mock_response

loaded = load_image("https://example.com/image.png")
self.assertEqual(loaded.size, (6, 6))
self.assertEqual(loaded.mode, "RGB")

def test_load_video_invalid_path_raises(self):
with self.assertRaises(ValueError):
load_video("/path/that/does/not/exist.mp4")

def test_load_video_gif_frames(self):
frames = [PIL.Image.new("RGB", (4, 4), color=(i * 40, 0, 0)) for i in range(3)]
with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as tmp:
path = tmp.name
try:
frames[0].save(path, save_all=True, append_images=frames[1:], duration=100, loop=0)
loaded = load_video(path)
self.assertEqual(len(loaded), 3)
self.assertEqual(loaded[0].size, (4, 4))
finally:
os.remove(path)

def test_get_module_from_name_nested(self):
module = nn.Sequential(nn.Linear(4, 4), nn.ReLU())
found, name = get_module_from_name(module, "0.weight")
self.assertIsInstance(found, nn.Linear)
self.assertEqual(name, "weight")

def test_get_module_from_name_missing_attribute_raises(self):
module = nn.Linear(4, 4)
with self.assertRaises(AttributeError):
get_module_from_name(module, "missing.weight")

def test_get_submodule_by_name_modulelist_index(self):
module = nn.ModuleList([nn.Linear(2, 2), nn.Linear(3, 3)])
found = get_submodule_by_name(module, "1")
self.assertIsInstance(found, nn.Linear)
self.assertEqual(found.in_features, 3)

def test_get_submodule_by_name_dotted_path(self):
module = nn.Sequential(
nn.ModuleDict({"block": nn.Linear(4, 4)}),
)
found = get_submodule_by_name(module, "0.block")
self.assertIsInstance(found, nn.Linear)
104 changes: 104 additions & 0 deletions tests/others/test_remote_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import json
import unittest
from unittest.mock import Mock

import torch
from PIL import Image

from diffusers.image_processor import VaeImageProcessor
from diffusers.utils.remote_utils import (
check_inputs_decode,
detect_image_type,
postprocess_decode,
prepare_decode,
prepare_encode,
)


class RemoteUtilsTest(unittest.TestCase):
def test_detect_image_type(self):
self.assertEqual(detect_image_type(b"\xff\xd8\xff"), "jpeg")
self.assertEqual(detect_image_type(b"\x89PNG\r\n\x1a\n"), "png")
self.assertEqual(detect_image_type(b"GIF89a"), "gif")
self.assertEqual(detect_image_type(b"BM"), "bmp")
self.assertEqual(detect_image_type(b"unknown"), "unknown")

def test_check_inputs_decode_packed_latents_requires_hw(self):
tensor = torch.randn(4, 8, 8)
with self.assertRaises(ValueError):
check_inputs_decode("http://example.com", tensor)

def test_check_inputs_decode_processor_required(self):
tensor = torch.randn(1, 4, 8, 8)
with self.assertRaises(ValueError):
check_inputs_decode(
"http://example.com",
tensor,
processor=None,
output_type="pt",
return_type="pil",
partial_postprocess=False,
)

def test_prepare_decode_sets_accept_header_for_jpeg(self):
tensor = torch.randn(1, 4, 8, 8, dtype=torch.float16)
payload = prepare_decode(tensor, output_type="pil", image_format="jpg")
self.assertEqual(payload["headers"]["Accept"], "image/jpeg")
self.assertEqual(payload["params"]["output_type"], "pil")
self.assertEqual(payload["params"]["shape"], list(tensor.shape))

def test_prepare_encode_tensor_includes_shape_and_dtype(self):
tensor = torch.randn(1, 3, 8, 8, dtype=torch.float16)
payload = prepare_encode(tensor, scaling_factor=0.18215)
self.assertEqual(payload["params"]["shape"], list(tensor.shape))
self.assertEqual(payload["params"]["dtype"], "float16")
self.assertEqual(payload["params"]["scaling_factor"], 0.18215)

def test_prepare_encode_pil_image(self):
image = Image.new("RGB", (8, 8), color="red")
payload = prepare_encode(image)
self.assertIn(b"PNG", payload["data"][:8])

def test_postprocess_decode_pil_without_processor(self):
buffer = io.BytesIO()
Image.new("RGB", (4, 4), color="blue").save(buffer, format="PNG")
response = Mock()
response.content = buffer.getvalue()

output = postprocess_decode(response, processor=None, output_type="pil", return_type="pil")
self.assertIsInstance(output, Image.Image)
self.assertEqual(output.size, (4, 4))
self.assertEqual(output.format, "png")

def test_postprocess_decode_pt_tensor(self):
tensor = torch.arange(16, dtype=torch.float32).reshape(1, 4, 2, 2)
response = Mock()
response.content = tensor.numpy().tobytes()
response.headers = {
"shape": json.dumps(list(tensor.shape)),
"dtype": "float32",
}

output = postprocess_decode(
response,
processor=None,
output_type="pt",
return_type="pt",
partial_postprocess=False,
)
torch.testing.assert_close(output, tensor)
Loading