diff --git a/tests/hooks/test_first_block_cache.py b/tests/hooks/test_first_block_cache.py new file mode 100644 index 0000000..aa21afe --- /dev/null +++ b/tests/hooks/test_first_block_cache.py @@ -0,0 +1,125 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry +from diffusers.hooks.first_block_cache import FirstBlockCacheConfig, apply_first_block_cache +from diffusers.models import ModelMixin + + +class DummyBlock(torch.nn.Module): + def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): + return hidden_states * 2.0 + + +class DummyTransformer(ModelMixin): + def __init__(self): + super().__init__() + self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()]) + + def forward(self, hidden_states, encoder_hidden_states=None): + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states) + return hidden_states + + +class TupleOutputBlock(torch.nn.Module): + def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): + return hidden_states * 2.0, encoder_hidden_states + + +class TupleTransformer(ModelMixin): + def __init__(self): + super().__init__() + self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock(), TupleOutputBlock()]) + + def forward(self, hidden_states, encoder_hidden_states=None): + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states) + return hidden_states, encoder_hidden_states + + +def _set_context(model, context_name): + for module in model.modules(): + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook._set_context(context_name) + + +@pytest.fixture(autouse=True) +def register_dummy_blocks(): + TransformerBlockRegistry.register( + DummyBlock, + TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None), + ) + TransformerBlockRegistry.register( + TupleOutputBlock, + TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1), + ) + + +def test_first_block_cache_skips_when_residual_is_stable(): + """When head-block residuals are similar, tail blocks should be skipped.""" + model = DummyTransformer() + apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.05)) + _set_context(model, "test_context") + + input_t0 = torch.tensor([[[10.0]]]) + output_t0 = model(input_t0) + assert torch.allclose(output_t0, torch.tensor([[[40.0]]])) + + # Identical input -> residual diff is 0 -> skip tail block (40.0, not 44.0). + output_t1 = model(input_t0) + assert torch.allclose(output_t1, torch.tensor([[[40.0]]])) + + +def test_first_block_cache_recomputes_when_residual_changes(): + """When residuals exceed the threshold, the full block stack must run.""" + model = DummyTransformer() + apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.05)) + _set_context(model, "test_context") + + model(torch.tensor([[[10.0]]])) + + output_t1 = model(torch.tensor([[[11.0]]])) + assert torch.allclose(output_t1, torch.tensor([[[44.0]]])) + + +def test_first_block_cache_tuple_outputs(): + """First Block Cache must support tuple block outputs (Flux-style).""" + model = TupleTransformer() + apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.05)) + _set_context(model, "test_context") + + input_t0 = torch.tensor([[[10.0]]]) + enc_t0 = torch.tensor([[[1.0]]]) + out_0, _ = model(input_t0, encoder_hidden_states=enc_t0) + assert torch.allclose(out_0, torch.tensor([[[40.0]]])) + + out_1, _ = model(input_t0, encoder_hidden_states=enc_t0) + assert torch.allclose(out_1, torch.tensor([[[40.0]]])) + + +def test_first_block_cache_recomputes_after_skip_when_input_changes(): + """A large input change after a cached step must trigger full recomputation.""" + model = DummyTransformer() + apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.05)) + _set_context(model, "test_context") + + model(torch.tensor([[[10.0]]])) + model(torch.tensor([[[10.0]]])) + + output = model(torch.tensor([[[12.0]]])) + assert torch.allclose(output, torch.tensor([[[48.0]]])) diff --git a/tests/others/test_loading_utils.py b/tests/others/test_loading_utils.py new file mode 100644 index 0000000..1f8b815 --- /dev/null +++ b/tests/others/test_loading_utils.py @@ -0,0 +1,119 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os +import tempfile +import unittest +from unittest.mock import Mock, patch + +import PIL.Image +import torch +from torch import nn + +from diffusers.utils.loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video + + +class LoadingUtilsTest(unittest.TestCase): + def test_load_image_pil_passthrough_converts_rgb(self): + image = PIL.Image.new("RGBA", (4, 4), color=(255, 0, 0, 128)) + loaded = load_image(image) + self.assertEqual(loaded.mode, "RGB") + self.assertEqual(loaded.size, (4, 4)) + + def test_load_image_local_path(self): + image = PIL.Image.new("RGB", (8, 8), color="green") + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + image.save(tmp.name) + path = tmp.name + try: + loaded = load_image(path) + self.assertEqual(loaded.size, (8, 8)) + self.assertEqual(loaded.mode, "RGB") + finally: + os.remove(path) + + def test_load_image_invalid_path_raises(self): + with self.assertRaises(ValueError): + load_image("/path/that/does/not/exist.png") + + def test_load_image_invalid_scheme_raises(self): + with self.assertRaises(ValueError): + load_image("ftp://example.com/image.png") + + def test_load_image_invalid_type_raises(self): + with self.assertRaises(ValueError): + load_image(123) + + def test_load_image_custom_convert_method(self): + image = PIL.Image.new("RGB", (4, 4), color="blue") + + def to_grayscale(img): + return img.convert("L") + + loaded = load_image(image, convert_method=to_grayscale) + self.assertEqual(loaded.mode, "L") + + @patch("diffusers.utils.loading_utils.requests.get") + def test_load_image_from_url(self, mock_get): + buffer = io.BytesIO() + PIL.Image.new("RGB", (6, 6), color="red").save(buffer, format="PNG") + buffer.seek(0) + mock_response = Mock() + mock_response.raw = buffer + mock_get.return_value = mock_response + + loaded = load_image("https://example.com/image.png") + self.assertEqual(loaded.size, (6, 6)) + self.assertEqual(loaded.mode, "RGB") + + def test_load_video_invalid_path_raises(self): + with self.assertRaises(ValueError): + load_video("/path/that/does/not/exist.mp4") + + def test_load_video_gif_frames(self): + frames = [PIL.Image.new("RGB", (4, 4), color=(i * 40, 0, 0)) for i in range(3)] + with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as tmp: + path = tmp.name + try: + frames[0].save(path, save_all=True, append_images=frames[1:], duration=100, loop=0) + loaded = load_video(path) + self.assertEqual(len(loaded), 3) + self.assertEqual(loaded[0].size, (4, 4)) + finally: + os.remove(path) + + def test_get_module_from_name_nested(self): + module = nn.Sequential(nn.Linear(4, 4), nn.ReLU()) + found, name = get_module_from_name(module, "0.weight") + self.assertIsInstance(found, nn.Linear) + self.assertEqual(name, "weight") + + def test_get_module_from_name_missing_attribute_raises(self): + module = nn.Linear(4, 4) + with self.assertRaises(AttributeError): + get_module_from_name(module, "missing.weight") + + def test_get_submodule_by_name_modulelist_index(self): + module = nn.ModuleList([nn.Linear(2, 2), nn.Linear(3, 3)]) + found = get_submodule_by_name(module, "1") + self.assertIsInstance(found, nn.Linear) + self.assertEqual(found.in_features, 3) + + def test_get_submodule_by_name_dotted_path(self): + module = nn.Sequential( + nn.ModuleDict({"block": nn.Linear(4, 4)}), + ) + found = get_submodule_by_name(module, "0.block") + self.assertIsInstance(found, nn.Linear) diff --git a/tests/others/test_remote_utils.py b/tests/others/test_remote_utils.py new file mode 100644 index 0000000..501b94b --- /dev/null +++ b/tests/others/test_remote_utils.py @@ -0,0 +1,104 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import json +import unittest +from unittest.mock import Mock + +import torch +from PIL import Image + +from diffusers.image_processor import VaeImageProcessor +from diffusers.utils.remote_utils import ( + check_inputs_decode, + detect_image_type, + postprocess_decode, + prepare_decode, + prepare_encode, +) + + +class RemoteUtilsTest(unittest.TestCase): + def test_detect_image_type(self): + self.assertEqual(detect_image_type(b"\xff\xd8\xff"), "jpeg") + self.assertEqual(detect_image_type(b"\x89PNG\r\n\x1a\n"), "png") + self.assertEqual(detect_image_type(b"GIF89a"), "gif") + self.assertEqual(detect_image_type(b"BM"), "bmp") + self.assertEqual(detect_image_type(b"unknown"), "unknown") + + def test_check_inputs_decode_packed_latents_requires_hw(self): + tensor = torch.randn(4, 8, 8) + with self.assertRaises(ValueError): + check_inputs_decode("http://example.com", tensor) + + def test_check_inputs_decode_processor_required(self): + tensor = torch.randn(1, 4, 8, 8) + with self.assertRaises(ValueError): + check_inputs_decode( + "http://example.com", + tensor, + processor=None, + output_type="pt", + return_type="pil", + partial_postprocess=False, + ) + + def test_prepare_decode_sets_accept_header_for_jpeg(self): + tensor = torch.randn(1, 4, 8, 8, dtype=torch.float16) + payload = prepare_decode(tensor, output_type="pil", image_format="jpg") + self.assertEqual(payload["headers"]["Accept"], "image/jpeg") + self.assertEqual(payload["params"]["output_type"], "pil") + self.assertEqual(payload["params"]["shape"], list(tensor.shape)) + + def test_prepare_encode_tensor_includes_shape_and_dtype(self): + tensor = torch.randn(1, 3, 8, 8, dtype=torch.float16) + payload = prepare_encode(tensor, scaling_factor=0.18215) + self.assertEqual(payload["params"]["shape"], list(tensor.shape)) + self.assertEqual(payload["params"]["dtype"], "float16") + self.assertEqual(payload["params"]["scaling_factor"], 0.18215) + + def test_prepare_encode_pil_image(self): + image = Image.new("RGB", (8, 8), color="red") + payload = prepare_encode(image) + self.assertIn(b"PNG", payload["data"][:8]) + + def test_postprocess_decode_pil_without_processor(self): + buffer = io.BytesIO() + Image.new("RGB", (4, 4), color="blue").save(buffer, format="PNG") + response = Mock() + response.content = buffer.getvalue() + + output = postprocess_decode(response, processor=None, output_type="pil", return_type="pil") + self.assertIsInstance(output, Image.Image) + self.assertEqual(output.size, (4, 4)) + self.assertEqual(output.format, "png") + + def test_postprocess_decode_pt_tensor(self): + tensor = torch.arange(16, dtype=torch.float32).reshape(1, 4, 2, 2) + response = Mock() + response.content = tensor.numpy().tobytes() + response.headers = { + "shape": json.dumps(list(tensor.shape)), + "dtype": "float32", + } + + output = postprocess_decode( + response, + processor=None, + output_type="pt", + return_type="pt", + partial_postprocess=False, + ) + torch.testing.assert_close(output, tensor) diff --git a/tests/others/test_state_dict_utils.py b/tests/others/test_state_dict_utils.py new file mode 100644 index 0000000..ac399d9 --- /dev/null +++ b/tests/others/test_state_dict_utils.py @@ -0,0 +1,137 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers.utils.state_dict_utils import ( + StateDictType, + convert_all_state_dict_to_peft, + convert_state_dict, + convert_state_dict_to_diffusers, + convert_state_dict_to_kohya, + convert_state_dict_to_peft, + convert_unet_state_dict_to_peft, + state_dict_all_zero, +) + + +class StateDictUtilsTest(unittest.TestCase): + def test_convert_state_dict_applies_first_matching_pattern(self): + state_dict = {"layer.processor.weight": torch.ones(1)} + converted = convert_state_dict(state_dict, {".processor.": "."}) + self.assertIn("layer.weight", converted) + self.assertNotIn("layer.processor.weight", converted) + + def test_convert_state_dict_to_peft_auto_infers_diffusers_old(self): + state_dict = { + "unet.down_blocks.0.attentions.0.to_out_lora.down.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_out_lora.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_peft(state_dict) + self.assertIn("unet.down_blocks.0.attentions.0.out_proj.lora_A.weight", converted) + + state_dict = { + "unet.down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_peft(state_dict, original_type=StateDictType.DIFFUSERS_OLD) + self.assertIn("unet.down_blocks.0.attentions.0.q_proj.lora_A.weight", converted) + self.assertIn("unet.down_blocks.0.attentions.0.q_proj.lora_B.weight", converted) + + def test_convert_state_dict_to_peft_diffusers(self): + state_dict = { + "text_encoder.encoder.layers.0.self_attn.q_proj.lora_linear_layer.down.weight": torch.ones(2, 2), + "text_encoder.encoder.layers.0.self_attn.q_proj.lora_linear_layer.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_peft(state_dict, original_type=StateDictType.DIFFUSERS) + self.assertIn("text_encoder.encoder.layers.0.self_attn.q_proj.lora_A.weight", converted) + self.assertIn("text_encoder.encoder.layers.0.self_attn.q_proj.lora_B.weight", converted) + + def test_convert_state_dict_to_diffusers_from_peft(self): + state_dict = { + "unet.down_blocks.0.attentions.0.to_q.lora_A.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_q.lora_B.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_diffusers(state_dict, original_type=StateDictType.PEFT) + self.assertIn("unet.down_blocks.0.attentions.0.to_q.lora.down.weight", converted) + self.assertIn("unet.down_blocks.0.attentions.0.to_q.lora.up.weight", converted) + + def test_convert_state_dict_to_diffusers_already_diffusers(self): + state_dict = { + "layer.lora_linear_layer.down.weight": torch.ones(2, 2), + "layer.lora_linear_layer.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_diffusers(state_dict) + self.assertIs(converted, state_dict) + + def test_convert_unet_state_dict_to_peft(self): + state_dict = { + "down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2), + "down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2), + } + converted = convert_unet_state_dict_to_peft(state_dict) + self.assertIn("down_blocks.0.attentions.0.to_q.lora_A.weight", converted) + self.assertIn("down_blocks.0.attentions.0.to_q.lora_B.weight", converted) + + def test_convert_all_state_dict_to_peft_falls_back_to_unet(self): + state_dict = { + "down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2), + "down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2), + } + converted = convert_all_state_dict_to_peft(state_dict) + self.assertTrue(any("lora_A" in key or "lora_B" in key for key in converted)) + + def test_state_dict_all_zero(self): + state_dict = { + "a": torch.zeros(2, 2), + "b": torch.zeros(3), + } + self.assertTrue(state_dict_all_zero(state_dict)) + state_dict["b"] = torch.ones(3) + self.assertFalse(state_dict_all_zero(state_dict)) + + def test_state_dict_all_zero_with_filter(self): + state_dict = { + "lora.down": torch.zeros(2, 2), + "bias": torch.ones(3), + } + self.assertTrue(state_dict_all_zero(state_dict, filter_str="lora")) + + def test_convert_state_dict_to_kohya_remaps_component_prefixes(self): + state_dict = { + "text_encoder.encoder.layers.0.self_attn.q_proj.lora_A.weight": torch.ones(2, 2), + "text_encoder.encoder.layers.0.self_attn.q_proj.lora_B.weight": torch.ones(2, 2), + "text_encoder_2.encoder.layers.0.self_attn.q_proj.lora_A.weight": torch.ones(2, 2), + "text_encoder_2.encoder.layers.0.self_attn.q_proj.lora_B.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_q.lora_A.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_q.lora_B.weight": torch.ones(2, 2), + } + kohya = convert_state_dict_to_kohya(state_dict) + self.assertTrue(any(key.startswith("lora_te1_") for key in kohya)) + self.assertTrue(any(key.startswith("lora_te2_") for key in kohya)) + self.assertTrue(any("lora_unet" in key for key in kohya)) + self.assertTrue(any("alpha" in key for key in kohya)) + + def test_convert_state_dict_to_kohya_adds_alpha_for_down_weights(self): + down_weight = torch.ones(4, 2) + state_dict = { + "unet.down_blocks.0.attentions.0.to_q.lora_A.weight": down_weight, + "unet.down_blocks.0.attentions.0.to_q.lora_B.weight": torch.ones(2, 4), + } + kohya = convert_state_dict_to_kohya(state_dict) + alpha_keys = [key for key in kohya if key.endswith(".alpha")] + self.assertEqual(len(alpha_keys), 1) + self.assertEqual(kohya[alpha_keys[0]].item(), len(down_weight)) diff --git a/tests/schedulers/test_scheduler_amused.py b/tests/schedulers/test_scheduler_amused.py new file mode 100644 index 0000000..1e402a3 --- /dev/null +++ b/tests/schedulers/test_scheduler_amused.py @@ -0,0 +1,136 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AmusedScheduler +from diffusers.schedulers.scheduling_amused import ( + AmusedSchedulerOutput, + gumbel_noise, + mask_by_random_topk, +) + + +class AmusedSchedulerTest(unittest.TestCase): + """ + Unit tests for AmusedScheduler masked-token denoising. This scheduler has a bespoke + step API (discrete token IDs, Gumbel masking) and is not covered by SchedulerCommonTest. + """ + + mask_token_id = 0 + + def get_default_config(self, **kwargs): + config = { + "mask_token_id": self.mask_token_id, + "masking_schedule": "cosine", + } + config.update(**kwargs) + return config + + def test_gumbel_noise_matches_input_shape(self): + probs = torch.ones(2, 4) + noise = gumbel_noise(probs) + self.assertEqual(noise.shape, probs.shape) + + def test_gumbel_noise_is_reproducible_with_generator(self): + probs = torch.ones(2, 4) + generator = torch.Generator().manual_seed(0) + noise_a = gumbel_noise(probs, generator=generator) + generator = torch.Generator().manual_seed(0) + noise_b = gumbel_noise(probs, generator=generator) + self.assertTrue(torch.allclose(noise_a, noise_b)) + + def test_mask_by_random_topk_masks_requested_count(self): + probs = torch.tensor([[0.9, 0.8, 0.1, 0.05]]) + mask_len = torch.tensor([[2]]) + masking = mask_by_random_topk(mask_len, probs, temperature=0.0) + self.assertEqual(masking.sum().item(), 2) + + def test_set_timesteps_builds_temperature_schedule(self): + scheduler = AmusedScheduler(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4, temperature=(2, 0)) + self.assertEqual(scheduler.timesteps.shape, (4,)) + self.assertEqual(scheduler.temperatures.shape, (4,)) + self.assertAlmostEqual(scheduler.temperatures[0].item(), 2.0, places=5) + self.assertAlmostEqual(scheduler.temperatures[-1].item(), 0.0, places=5) + + def test_step_final_timestep_unmasks_all_tokens(self): + scheduler = AmusedScheduler(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=2) + + sample = torch.tensor([[self.mask_token_id, self.mask_token_id, 5, 6]]) + model_output = torch.zeros(1, 4, 8) + model_output[0, :, 3] = 10.0 # token id 3 wins for every position + + output = scheduler.step(model_output, timestep=0, sample=sample, generator=torch.Generator().manual_seed(0)) + self.assertIsInstance(output, AmusedSchedulerOutput) + self.assertFalse((output.prev_sample == self.mask_token_id).any()) + + def test_step_intermediate_timestep_keeps_some_masks(self): + scheduler = AmusedScheduler(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4, temperature=0.0) + + sample = torch.full((1, 6), self.mask_token_id) + model_output = torch.randn(1, 6, 4) + + timestep = scheduler.timesteps[1].item() + output = scheduler.step( + model_output, + timestep=timestep, + sample=sample, + generator=torch.Generator().manual_seed(0), + ) + self.assertTrue((output.prev_sample == self.mask_token_id).any()) + self.assertFalse((output.prev_sample == self.mask_token_id).all()) + + def test_step_2d_input_is_reshaped_and_restored(self): + scheduler = AmusedScheduler(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=2) + + sample = torch.full((1, 2, 2), self.mask_token_id) + model_output = torch.zeros(1, 4, 2, 2) + model_output[:, :, 0, 0] = 10.0 + + output = scheduler.step(model_output, timestep=0, sample=sample, generator=torch.Generator().manual_seed(0)) + self.assertEqual(output.prev_sample.shape, (1, 2, 2)) + + def test_add_noise_respects_mask_ratio(self): + scheduler = AmusedScheduler(**self.get_default_config(masking_schedule="linear")) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.arange(1, 17).reshape(1, 4, 4) + # Early timesteps mask most tokens; use the second step for a partial mask. + timestep = scheduler.timesteps[1].item() + masked = scheduler.add_noise(sample, timesteps=timestep, generator=torch.Generator().manual_seed(0)) + self.assertTrue((masked == self.mask_token_id).any()) + self.assertFalse((masked == self.mask_token_id).all()) + + def test_unknown_masking_schedule_raises_in_step(self): + scheduler = AmusedScheduler(**self.get_default_config(masking_schedule="invalid")) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.tensor([[self.mask_token_id, 1, 2, 3]]) + model_output = torch.randn(1, 4, 4) + with self.assertRaises(ValueError): + scheduler.step(model_output, timestep=scheduler.timesteps[1].item(), sample=sample) + + def test_unknown_masking_schedule_raises_in_add_noise(self): + scheduler = AmusedScheduler(**self.get_default_config(masking_schedule="invalid")) + scheduler.set_timesteps(num_inference_steps=2) + + sample = torch.tensor([[1, 2, 3, 4]]) + with self.assertRaises(ValueError): + scheduler.add_noise(sample, timesteps=scheduler.timesteps[0].item()) diff --git a/tests/schedulers/test_scheduler_cogvideox.py b/tests/schedulers/test_scheduler_cogvideox.py new file mode 100644 index 0000000..e7b503f --- /dev/null +++ b/tests/schedulers/test_scheduler_cogvideox.py @@ -0,0 +1,199 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.schedulers.scheduling_ddim_cogvideox import DDIMSchedulerOutput + + +class CogVideoXDDIMSchedulerTest(unittest.TestCase): + """ + Contract tests for CogVideoXDDIMScheduler — used by CogView3+ and shares the SNR-shifted + schedule with CogVideoX pipelines. No dedicated scheduler test file existed previously. + """ + + scheduler_class = CogVideoXDDIMScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.00085, + "beta_end": 0.0120, + "beta_schedule": "scaled_linear", + "timestep_spacing": "leading", + } + config.update(**kwargs) + return config + + def test_snr_shift_modifies_alphas_cumprod(self): + shifted = self.scheduler_class(**self.get_default_config(snr_shift_scale=3.0)) + unshifted = self.scheduler_class(**self.get_default_config(snr_shift_scale=1.0)) + self.assertFalse(torch.allclose(shifted.alphas_cumprod, unshifted.alphas_cumprod)) + + def test_set_timesteps_num_inference_steps_exceeds_train_timesteps_raises(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(scheduler.config.num_train_timesteps + 1) + + def test_set_timesteps_produces_expected_count(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 4, 10, 50]: + scheduler.set_timesteps(nfe) + self.assertEqual(scheduler.num_inference_steps, nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe,)) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 4, 8, 8) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0] + + output = scheduler.step(model_output, timestep, sample, eta=0.0) + self.assertIsInstance(output, DDIMSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertEqual(output.prev_sample.dtype, sample.dtype) + + def test_step_deterministic_when_eta_zero(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + generator = torch.Generator().manual_seed(0) + timestep = scheduler.timesteps[0] + + first = scheduler.step(model_output, timestep, sample, eta=0.0, generator=generator).prev_sample + generator = torch.Generator().manual_seed(0) + second = scheduler.step(model_output, timestep, sample, eta=0.0, generator=generator).prev_sample + torch.testing.assert_close(first, second) + + def test_full_denoising_loop(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + generator = torch.Generator().manual_seed(0) + for t in scheduler.timesteps: + sample = scheduler.step( + torch.randn_like(sample), t, sample, eta=0.0, generator=generator + ).prev_sample + self.assertEqual(sample.shape, (1, 4, 8, 8)) + + +class CogVideoXDPMSchedulerTest(unittest.TestCase): + """ + Contract tests for CogVideoXDPMScheduler — CogVideoX pipelines branch on its multi-step API. + """ + + scheduler_class = CogVideoXDPMScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.00085, + "beta_end": 0.0120, + "beta_schedule": "scaled_linear", + "timestep_spacing": "trailing", + } + config.update(**kwargs) + return config + + def test_step_requires_set_timesteps(self): + scheduler = self.scheduler_class(**self.get_default_config()) + sample = torch.randn(1, 4, 8, 8) + with self.assertRaises(ValueError): + scheduler.step( + torch.randn_like(sample), + None, + scheduler.timesteps[0], + None, + sample, + return_dict=False, + ) + + def test_set_timesteps_num_inference_steps_exceeds_train_timesteps_raises(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(scheduler.config.num_train_timesteps + 1) + + def test_first_step_returns_pred_original_sample(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0] + + prev_sample, pred_original_sample = scheduler.step( + model_output, + None, + timestep, + None, + sample, + return_dict=False, + ) + self.assertEqual(prev_sample.shape, sample.shape) + self.assertEqual(pred_original_sample.shape, sample.shape) + + def test_second_step_uses_old_pred_original_sample(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + generator = torch.Generator().manual_seed(0) + + _, old_pred = scheduler.step( + model_output, + None, + scheduler.timesteps[0], + None, + sample, + return_dict=False, + generator=generator, + ) + + prev_sample, _ = scheduler.step( + model_output, + old_pred, + scheduler.timesteps[1], + scheduler.timesteps[0], + sample, + return_dict=False, + generator=generator, + ) + self.assertEqual(prev_sample.shape, sample.shape) + + def test_get_variables_and_get_mult(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + t = scheduler.timesteps[1] + t_prev = t - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + t_back = scheduler.timesteps[0] + + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t_prev] + alpha_prod_t_back = scheduler.alphas_cumprod[t_back] + + h, r, lamb, lamb_next = scheduler.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back) + mult = scheduler.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back) + self.assertEqual(len(mult), 4) + self.assertTrue(torch.isfinite(h)) + self.assertTrue(torch.isfinite(lamb)) diff --git a/tests/schedulers/test_scheduler_ddim.py b/tests/schedulers/test_scheduler_ddim.py index 13b353a..dc6e688 100644 --- a/tests/schedulers/test_scheduler_ddim.py +++ b/tests/schedulers/test_scheduler_ddim.py @@ -73,6 +73,20 @@ def test_timestep_spacing(self): for timestep_spacing in ["trailing", "leading"]: self.check_over_configs(timestep_spacing=timestep_spacing) + def test_set_timesteps_num_inference_steps_exceeds_train_timesteps_raises(self): + # Guard against inverted comparison (num_inference_steps < num_train_timesteps) which would + # reject all valid inference schedules and accept invalid ones. + scheduler_class = self.scheduler_classes[0] + scheduler = scheduler_class(**self.get_scheduler_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(scheduler.config.num_train_timesteps + 1) + + def test_set_timesteps_num_inference_steps_at_limit_succeeds(self): + scheduler_class = self.scheduler_classes[0] + scheduler = scheduler_class(**self.get_scheduler_config()) + scheduler.set_timesteps(scheduler.config.num_train_timesteps) + self.assertEqual(scheduler.num_inference_steps, scheduler.config.num_train_timesteps) + def test_rescale_betas_zero_snr(self): for rescale_betas_zero_snr in [True, False]: self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) diff --git a/tests/schedulers/test_scheduler_ddpm.py b/tests/schedulers/test_scheduler_ddpm.py index 056b5d8..b42b2a5 100644 --- a/tests/schedulers/test_scheduler_ddpm.py +++ b/tests/schedulers/test_scheduler_ddpm.py @@ -72,6 +72,18 @@ def test_rescale_betas_zero_snr(self): for rescale_betas_zero_snr in [True, False]: self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) + def test_set_timesteps_num_inference_steps_exceeds_train_timesteps_raises(self): + scheduler_class = self.scheduler_classes[0] + scheduler = scheduler_class(**self.get_scheduler_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(scheduler.config.num_train_timesteps + 1) + + def test_set_timesteps_num_inference_steps_at_limit_succeeds(self): + scheduler_class = self.scheduler_classes[0] + scheduler = scheduler_class(**self.get_scheduler_config()) + scheduler.set_timesteps(scheduler.config.num_train_timesteps) + self.assertEqual(scheduler.num_inference_steps, scheduler.config.num_train_timesteps) + def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() diff --git a/tests/schedulers/test_scheduler_ddpm_wuerstchen.py b/tests/schedulers/test_scheduler_ddpm_wuerstchen.py new file mode 100644 index 0000000..36ea793 --- /dev/null +++ b/tests/schedulers/test_scheduler_ddpm_wuerstchen.py @@ -0,0 +1,110 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import DDPMWuerstchenScheduler +from diffusers.schedulers.scheduling_ddpm_wuerstchen import DDPMWuerstchenSchedulerOutput + + +class DDPMWuerstchenSchedulerTest(unittest.TestCase): + """ + Contract tests for DDPMWuerstchenScheduler — Stable Cascade prior/decoder use float timesteps + in [1, 0] rather than integer indices. Pipeline tests only exercised it indirectly. + """ + + scheduler_class = DDPMWuerstchenScheduler + + def get_default_config(self, **kwargs): + config = {"scaler": 1.0, "s": 0.008} + config.update(**kwargs) + return config + + def test_set_timesteps_float_schedule(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 4, 10]: + scheduler.set_timesteps(num_inference_steps=nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe + 1,)) + self.assertAlmostEqual(scheduler.timesteps[0].item(), 1.0, places=5) + self.assertAlmostEqual(scheduler.timesteps[-1].item(), 0.0, places=5) + + def test_scaler_modifies_alpha_cumprod(self): + default = self.scheduler_class(**self.get_default_config()) + scaled = self.scheduler_class(**self.get_default_config(scaler=2.0)) + t = torch.tensor([0.5]) + default_alpha = default._alpha_cumprod(t, device="cpu") + scaled_alpha = scaled._alpha_cumprod(t, device="cpu") + self.assertFalse(torch.allclose(default_alpha, scaled_alpha)) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 8, 4, 4) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1].expand(sample.shape[0]) + + output = scheduler.step(model_output, timestep, sample) + self.assertIsInstance(output, DDPMWuerstchenSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertEqual(output.prev_sample.dtype, sample.dtype) + + def test_step_deterministic_with_generator(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1] + generator = torch.Generator().manual_seed(0) + + first = scheduler.step(model_output, timestep, sample, generator=generator).prev_sample + generator = torch.Generator().manual_seed(0) + second = scheduler.step(model_output, timestep, sample, generator=generator).prev_sample + torch.testing.assert_close(first, second) + + def test_previous_timestep_advances_schedule(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + current = scheduler.timesteps[0:1] + prev = scheduler.previous_timestep(current) + self.assertAlmostEqual(prev[0].item(), scheduler.timesteps[1].item(), places=5) + + def test_add_noise_interpolates_sample_and_noise(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.zeros(1, 4, 4, 4) + noise = torch.ones_like(sample) + timesteps = scheduler.timesteps[0:1] + mixed = scheduler.add_noise(sample, noise, timesteps) + self.assertTrue(mixed.min() >= 0.0) + self.assertTrue(mixed.max() <= 1.0) + self.assertFalse(torch.allclose(mixed, sample)) + self.assertFalse(torch.allclose(mixed, noise)) + + def test_full_denoising_loop(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + generator = torch.Generator().manual_seed(0) + + sample = torch.randn(1, 4, 8, 8) + for t in scheduler.timesteps[:-1]: + batch_t = t.expand(sample.shape[0]) + sample = scheduler.step( + torch.randn_like(sample), batch_t, sample, generator=generator + ).prev_sample + self.assertEqual(sample.shape, (1, 4, 8, 8)) diff --git a/tests/schedulers/test_scheduler_flow_match_euler_discrete.py b/tests/schedulers/test_scheduler_flow_match_euler_discrete.py new file mode 100644 index 0000000..56f7410 --- /dev/null +++ b/tests/schedulers/test_scheduler_flow_match_euler_discrete.py @@ -0,0 +1,108 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteSchedulerOutput + + +class FlowMatchEulerDiscreteSchedulerTest(unittest.TestCase): + """ + Contract tests for FlowMatchEulerDiscreteScheduler — shared by SD3, Flux, Wan, and many + flow-matching pipelines. No dedicated scheduler test file existed previously. + """ + + scheduler_class = FlowMatchEulerDiscreteScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "shift": 1.0, + } + config.update(**kwargs) + return config + + def test_set_timesteps_endpoints(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 2, 4, 8, 16]: + scheduler.set_timesteps(num_inference_steps=nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe,)) + self.assertEqual(scheduler.sigmas.shape, (nfe + 1,)) + self.assertAlmostEqual(scheduler.sigmas[-1].item(), 0.0, places=6) + + def test_set_timesteps_dynamic_shifting_requires_mu(self): + scheduler = self.scheduler_class(**self.get_default_config(use_dynamic_shifting=True)) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=4) + + def test_set_timesteps_custom_sigmas_and_timesteps_length_mismatch(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(sigmas=[1.0, 0.5, 0.0], timesteps=[900.0, 500.0]) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 4, 8, 8) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1] + + output = scheduler.step(model_output, timestep, sample) + self.assertIsInstance(output, FlowMatchEulerDiscreteSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertEqual(output.prev_sample.dtype, model_output.dtype) + + def test_step_rejects_integer_timestep(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + sample = torch.randn(1, 4, 4, 4) + with self.assertRaises(ValueError): + scheduler.step(torch.randn_like(sample), 0, sample) + + def test_scale_noise_interpolates_sample_and_noise(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(timesteps=[500.0, 250.0]) + sample = torch.zeros(1, 4, 4, 4) + noise = torch.ones_like(sample) + mid_t = scheduler.timesteps[0:1] + mixed = scheduler.scale_noise(sample, mid_t, noise) + # sigma=0.5 at t=500 with default 1000 training steps + torch.testing.assert_close(mixed, 0.5 * noise, atol=1e-4, rtol=1e-4) + + def test_index_for_timestep_duplicate_handling(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(timesteps=[900.0, 500.0, 500.0, 100.0]) + duplicate = scheduler.timesteps[1] + self.assertEqual(scheduler.index_for_timestep(duplicate), 2) + + def test_set_begin_index_anchors_step_index(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + scheduler.set_begin_index(2) + sample = torch.randn(1, 4, 4, 4) + scheduler.step(torch.randn_like(sample), scheduler.timesteps[0], sample) + self.assertEqual(scheduler.step_index, 3) + + def test_full_denoising_loop(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + for t in scheduler.timesteps: + sample = scheduler.step(torch.randn_like(sample), t, sample).prev_sample + self.assertEqual(sample.shape, (1, 4, 8, 8)) diff --git a/tests/schedulers/test_scheduler_helios.py b/tests/schedulers/test_scheduler_helios.py new file mode 100644 index 0000000..6fac77b --- /dev/null +++ b/tests/schedulers/test_scheduler_helios.py @@ -0,0 +1,107 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch + +import torch + +from diffusers import HeliosScheduler +from diffusers.schedulers.scheduling_helios import HeliosSchedulerOutput + + +class HeliosSchedulerTest(unittest.TestCase): + """ + Unit tests for HeliosScheduler multi-stage scheduling. Pipeline tests only cover stages=1; + these tests lock in per-stage sigma ranges and step behavior used by the default 3-stage config. + """ + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "shift": 1.0, + "stages": 3, + "stage_range": [0, 1 / 3, 2 / 3, 1], + "gamma": 1 / 3, + "prediction_type": "flow_prediction", + "scheduler_type": "euler", + "use_dynamic_shifting": False, + } + config.update(**kwargs) + return config + + def test_multi_stage_init_populates_per_stage_buffers(self): + scheduler = HeliosScheduler(**self.get_default_config()) + self.assertEqual(len(scheduler.timesteps_per_stage), 3) + self.assertEqual(len(scheduler.sigmas_per_stage), 3) + self.assertEqual(len(scheduler.start_sigmas), 3) + self.assertEqual(len(scheduler.end_sigmas), 3) + for stage in range(3): + self.assertGreater(scheduler.start_sigmas[stage], scheduler.end_sigmas[stage]) + + def test_set_timesteps_multi_stage_per_stage_index(self): + scheduler = HeliosScheduler(**self.get_default_config()) + for stage_index in range(3): + scheduler.set_timesteps(num_inference_steps=8, stage_index=stage_index) + self.assertEqual(scheduler.timesteps.shape, (8,)) + self.assertEqual(scheduler.sigmas.shape, (9,)) + self.assertAlmostEqual(scheduler.sigmas[-1].item(), 0.0, places=6) + self.assertGreater(scheduler.timesteps[0].item(), scheduler.timesteps[-1].item()) + + def test_set_timesteps_single_stage(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1])) + scheduler.set_timesteps(num_inference_steps=4) + self.assertEqual(scheduler.timesteps.shape, (4,)) + self.assertEqual(scheduler.sigmas.shape, (5,)) + + def test_step_euler_updates_sample(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1])) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 4, 8, 8) + model_output = torch.randn_like(sample) + output = scheduler.step(model_output, scheduler.timesteps[0], sample) + self.assertIsInstance(output, HeliosSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertFalse(torch.allclose(output.prev_sample, sample)) + + def test_convert_model_output_flow_prediction(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1], scheduler_type="unipc")) + scheduler.set_timesteps(num_inference_steps=4) + scheduler._step_index = 0 + + sample = torch.randn(1, 4, 4, 4) + model_output = torch.randn_like(sample) + x0 = scheduler.convert_model_output(model_output, sample=sample) + expected = sample - scheduler.sigmas[0] * model_output + torch.testing.assert_close(x0, expected) + + def test_dynamic_shifting_rescales_timesteps(self): + scheduler = HeliosScheduler( + **self.get_default_config(stages=1, stage_range=[0, 1], use_dynamic_shifting=True) + ) + scheduler.set_timesteps(num_inference_steps=4, mu=0.5) + self.assertEqual(scheduler.timesteps.shape, (4,)) + + def test_step_unipc_invokes_corrector_on_second_step(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1], scheduler_type="unipc")) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 4, 4) + model_output = torch.randn_like(sample) + scheduler.step(model_output, scheduler.timesteps[0], sample) + + with patch.object(scheduler, "multistep_uni_c_bh_update", wraps=scheduler.multistep_uni_c_bh_update) as corrector: + scheduler.step(model_output, scheduler.timesteps[1], sample) + corrector.assert_called_once() diff --git a/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py b/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py new file mode 100644 index 0000000..f898d54 --- /dev/null +++ b/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py @@ -0,0 +1,124 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import LTXEulerAncestralRFScheduler +from diffusers.schedulers.scheduling_ltx_euler_ancestral_rf import LTXEulerAncestralRFSchedulerOutput + + +class LTXEulerAncestralRFSchedulerTest(unittest.TestCase): + """ + Contract tests for the LTX RF ancestral scheduler used by LTX-Video long-form pipelines. + Mirrors the style of `test_scheduler_flow_map_euler_discrete.py` because this scheduler + has a non-standard step API and cannot reuse `SchedulerCommonTest`. + """ + + scheduler_class = LTXEulerAncestralRFScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "eta": 1.0, + "s_noise": 1.0, + } + config.update(**kwargs) + return config + + def test_set_timesteps_auto_generates_schedule(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 2, 4, 8]: + scheduler.set_timesteps(num_inference_steps=nfe) + self.assertEqual(scheduler.num_inference_steps, nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe,)) + self.assertEqual(scheduler.sigmas.shape, (nfe + 1,)) + self.assertAlmostEqual(scheduler.sigmas[-1].item(), 0.0, places=6) + + def test_set_timesteps_requires_args(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps() + + def test_set_timesteps_explicit_sigmas(self): + scheduler = self.scheduler_class(**self.get_default_config()) + custom_sigmas = [1.0, 0.75, 0.5, 0.25, 0.0] + scheduler.set_timesteps(sigmas=custom_sigmas) + self.assertEqual(scheduler.num_inference_steps, 4) + for i, sigma in enumerate(custom_sigmas): + self.assertAlmostEqual(scheduler.sigmas[i].item(), sigma, places=5) + self.assertAlmostEqual(scheduler.timesteps[0].item(), 1000.0, places=4) + + def test_set_timesteps_rejects_non_1d_sigmas(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(sigmas=[[1.0, 0.5], [0.5, 0.0]]) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 8, 4, 4) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1] + + output = scheduler.step(model_output, timestep, sample) + self.assertIsInstance(output, LTXEulerAncestralRFSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertEqual(output.prev_sample.dtype, sample.dtype) + + def test_step_deterministic_when_eta_zero(self): + scheduler = self.scheduler_class(**self.get_default_config(eta=0.0)) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + generator = torch.Generator().manual_seed(0) + + first = scheduler.step(model_output, scheduler.timesteps[0:1], sample, generator=generator).prev_sample + + scheduler.set_timesteps(num_inference_steps=4) + second = scheduler.step(model_output, scheduler.timesteps[0:1], sample, generator=generator).prev_sample + torch.testing.assert_close(first, second) + + def test_step_rejects_integer_timestep(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + sample = torch.randn(1, 4, 4, 4) + with self.assertRaises(ValueError): + scheduler.step(torch.randn_like(sample), 0, sample) + + def test_index_for_timestep_duplicate_handling(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(sigmas=[1.0, 0.5, 0.5, 0.0]) + duplicate = scheduler.timesteps[1] + self.assertEqual(scheduler.index_for_timestep(duplicate), 2) + + def test_set_begin_index_anchors_step_index(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + scheduler.set_begin_index(2) + sample = torch.randn(1, 4, 4, 4) + scheduler.step(torch.randn_like(sample), scheduler.timesteps[0], sample) + self.assertEqual(scheduler.step_index, 3) + + def test_full_denoising_loop(self): + scheduler = self.scheduler_class(**self.get_default_config(eta=0.0)) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + for t in scheduler.timesteps: + sample = scheduler.step(torch.randn_like(sample), t, sample).prev_sample + self.assertEqual(sample.shape, (1, 4, 8, 8)) diff --git a/utils/check_copies.py b/utils/check_copies.py index 001366c..711cd8d 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -18,6 +18,7 @@ import os import re import subprocess +import sys # All paths are set with the intent you should run this script from the root of the repo with the command @@ -94,12 +95,24 @@ def get_indent(code): def run_ruff(code): - command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"] + command = [sys.executable, "-m", "ruff", "format", "-", "--config", "pyproject.toml", "--silent"] process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE) stdout, _ = process.communicate(input=code.encode()) return stdout.decode() +def _get_definition_header(lines, start_index): + idx = start_index + while idx < len(lines): + stripped = lines[idx].strip() + if stripped.startswith(("def ", "class ", "@dataclass")): + return lines[idx] + if stripped and not stripped.startswith("#") and not stripped.startswith("@"): + break + idx += 1 + return lines[start_index - 1] if start_index > 0 else "" + + def stylify(code: str) -> str: """ Applies the ruff part of our `make style` command to some code. This formats the code using `ruff format`. @@ -176,16 +189,24 @@ def is_copy_consistent(filename, overwrite=False): theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code) # stylify after replacement. To be able to do that, we need the header (class or function definition) - # from the previous line - theoretical_code = stylify(lines[start_index - 1] + theoretical_code) - theoretical_code = theoretical_code[len(lines[start_index - 1]) :] + # from the copied block, not the `# Copied from` comment line. + header_line = _get_definition_header(lines, start_index) + theoretical_code = stylify(header_line + theoretical_code) + theoretical_code = theoretical_code[len(header_line) :] + observed_code = stylify(header_line + observed_code) + observed_code = observed_code[len(header_line) :] # Test for a diff and act accordingly. if observed_code != theoretical_code: diffs.append([object_name, start_index]) if overwrite: - lines = lines[:start_index] + [theoretical_code] + lines[line_index:] - line_index = start_index + 1 + replacement_lines = theoretical_code.splitlines(keepends=True) + if replacement_lines and not replacement_lines[-1].endswith("\n"): + replacement_lines[-1] += "\n" + if lines[start_index].strip().startswith(("def ", "class ")): + replacement_lines = [lines[start_index]] + replacement_lines + lines = lines[:start_index] + replacement_lines + lines[line_index:] + line_index = start_index + len(replacement_lines) if overwrite and len(diffs) > 0: # Warn the user a file has been modified.