Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions tests/others/test_loading_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import os
import tempfile
import unittest
from unittest.mock import Mock, patch

import PIL.Image
import torch
from torch import nn

from diffusers.utils.loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video


class LoadingUtilsTest(unittest.TestCase):
def test_load_image_pil_passthrough_converts_rgb(self):
image = PIL.Image.new("RGBA", (4, 4), color=(255, 0, 0, 128))
loaded = load_image(image)
self.assertEqual(loaded.mode, "RGB")
self.assertEqual(loaded.size, (4, 4))

def test_load_image_local_path(self):
image = PIL.Image.new("RGB", (8, 8), color="green")
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
image.save(tmp.name)
path = tmp.name
try:
loaded = load_image(path)
self.assertEqual(loaded.size, (8, 8))
self.assertEqual(loaded.mode, "RGB")
finally:
os.remove(path)

def test_load_image_invalid_path_raises(self):
with self.assertRaises(ValueError):
load_image("/path/that/does/not/exist.png")

def test_load_image_invalid_scheme_raises(self):
with self.assertRaises(ValueError):
load_image("ftp://example.com/image.png")

def test_load_image_invalid_type_raises(self):
with self.assertRaises(ValueError):
load_image(123)

def test_load_image_custom_convert_method(self):
image = PIL.Image.new("RGB", (4, 4), color="blue")

def to_grayscale(img):
return img.convert("L")

loaded = load_image(image, convert_method=to_grayscale)
self.assertEqual(loaded.mode, "L")

@patch("diffusers.utils.loading_utils.requests.get")
def test_load_image_from_url(self, mock_get):
buffer = io.BytesIO()
PIL.Image.new("RGB", (6, 6), color="red").save(buffer, format="PNG")
buffer.seek(0)
mock_response = Mock()
mock_response.raw = buffer
mock_get.return_value = mock_response

loaded = load_image("https://example.com/image.png")
self.assertEqual(loaded.size, (6, 6))
self.assertEqual(loaded.mode, "RGB")

def test_load_video_invalid_path_raises(self):
with self.assertRaises(ValueError):
load_video("/path/that/does/not/exist.mp4")

def test_load_video_gif_frames(self):
frames = [PIL.Image.new("RGB", (4, 4), color=(i * 40, 0, 0)) for i in range(3)]
with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as tmp:
path = tmp.name
try:
frames[0].save(path, save_all=True, append_images=frames[1:], duration=100, loop=0)
loaded = load_video(path)
self.assertEqual(len(loaded), 3)
self.assertEqual(loaded[0].size, (4, 4))
finally:
os.remove(path)

def test_get_module_from_name_nested(self):
module = nn.Sequential(nn.Linear(4, 4), nn.ReLU())
found, name = get_module_from_name(module, "0.weight")
self.assertIsInstance(found, nn.Linear)
self.assertEqual(name, "weight")

def test_get_module_from_name_missing_attribute_raises(self):
module = nn.Linear(4, 4)
with self.assertRaises(AttributeError):
get_module_from_name(module, "missing.weight")

def test_get_submodule_by_name_modulelist_index(self):
module = nn.ModuleList([nn.Linear(2, 2), nn.Linear(3, 3)])
found = get_submodule_by_name(module, "1")
self.assertIsInstance(found, nn.Linear)
self.assertEqual(found.in_features, 3)

def test_get_submodule_by_name_dotted_path(self):
module = nn.Sequential(
nn.ModuleDict({"block": nn.Linear(4, 4)}),
)
found = get_submodule_by_name(module, "0.block")
self.assertIsInstance(found, nn.Linear)
104 changes: 104 additions & 0 deletions tests/others/test_remote_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import json
import unittest
from unittest.mock import Mock

import torch
from PIL import Image

from diffusers.image_processor import VaeImageProcessor
from diffusers.utils.remote_utils import (
check_inputs_decode,
detect_image_type,
postprocess_decode,
prepare_decode,
prepare_encode,
)


class RemoteUtilsTest(unittest.TestCase):
def test_detect_image_type(self):
self.assertEqual(detect_image_type(b"\xff\xd8\xff"), "jpeg")
self.assertEqual(detect_image_type(b"\x89PNG\r\n\x1a\n"), "png")
self.assertEqual(detect_image_type(b"GIF89a"), "gif")
self.assertEqual(detect_image_type(b"BM"), "bmp")
self.assertEqual(detect_image_type(b"unknown"), "unknown")

def test_check_inputs_decode_packed_latents_requires_hw(self):
tensor = torch.randn(4, 8, 8)
with self.assertRaises(ValueError):
check_inputs_decode("http://example.com", tensor)

def test_check_inputs_decode_processor_required(self):
tensor = torch.randn(1, 4, 8, 8)
with self.assertRaises(ValueError):
check_inputs_decode(
"http://example.com",
tensor,
processor=None,
output_type="pt",
return_type="pil",
partial_postprocess=False,
)

def test_prepare_decode_sets_accept_header_for_jpeg(self):
tensor = torch.randn(1, 4, 8, 8, dtype=torch.float16)
payload = prepare_decode(tensor, output_type="pil", image_format="jpg")
self.assertEqual(payload["headers"]["Accept"], "image/jpeg")
self.assertEqual(payload["params"]["output_type"], "pil")
self.assertEqual(payload["params"]["shape"], list(tensor.shape))

def test_prepare_encode_tensor_includes_shape_and_dtype(self):
tensor = torch.randn(1, 3, 8, 8, dtype=torch.float16)
payload = prepare_encode(tensor, scaling_factor=0.18215)
self.assertEqual(payload["params"]["shape"], list(tensor.shape))
self.assertEqual(payload["params"]["dtype"], "float16")
self.assertEqual(payload["params"]["scaling_factor"], 0.18215)

def test_prepare_encode_pil_image(self):
image = Image.new("RGB", (8, 8), color="red")
payload = prepare_encode(image)
self.assertIn(b"PNG", payload["data"][:8])

def test_postprocess_decode_pil_without_processor(self):
buffer = io.BytesIO()
Image.new("RGB", (4, 4), color="blue").save(buffer, format="PNG")
response = Mock()
response.content = buffer.getvalue()

output = postprocess_decode(response, processor=None, output_type="pil", return_type="pil")
self.assertIsInstance(output, Image.Image)
self.assertEqual(output.size, (4, 4))
self.assertEqual(output.format, "png")

def test_postprocess_decode_pt_tensor(self):
tensor = torch.arange(16, dtype=torch.float32).reshape(1, 4, 2, 2)
response = Mock()
response.content = tensor.numpy().tobytes()
response.headers = {
"shape": json.dumps(list(tensor.shape)),
"dtype": "float32",
}

output = postprocess_decode(
response,
processor=None,
output_type="pt",
return_type="pt",
partial_postprocess=False,
)
torch.testing.assert_close(output, tensor)
Loading