diff --git a/tests/others/test_loading_utils.py b/tests/others/test_loading_utils.py new file mode 100644 index 0000000..1f8b815 --- /dev/null +++ b/tests/others/test_loading_utils.py @@ -0,0 +1,119 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os +import tempfile +import unittest +from unittest.mock import Mock, patch + +import PIL.Image +import torch +from torch import nn + +from diffusers.utils.loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video + + +class LoadingUtilsTest(unittest.TestCase): + def test_load_image_pil_passthrough_converts_rgb(self): + image = PIL.Image.new("RGBA", (4, 4), color=(255, 0, 0, 128)) + loaded = load_image(image) + self.assertEqual(loaded.mode, "RGB") + self.assertEqual(loaded.size, (4, 4)) + + def test_load_image_local_path(self): + image = PIL.Image.new("RGB", (8, 8), color="green") + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + image.save(tmp.name) + path = tmp.name + try: + loaded = load_image(path) + self.assertEqual(loaded.size, (8, 8)) + self.assertEqual(loaded.mode, "RGB") + finally: + os.remove(path) + + def test_load_image_invalid_path_raises(self): + with self.assertRaises(ValueError): + load_image("/path/that/does/not/exist.png") + + def test_load_image_invalid_scheme_raises(self): + with self.assertRaises(ValueError): + load_image("ftp://example.com/image.png") + + def test_load_image_invalid_type_raises(self): + with self.assertRaises(ValueError): + load_image(123) + + def test_load_image_custom_convert_method(self): + image = PIL.Image.new("RGB", (4, 4), color="blue") + + def to_grayscale(img): + return img.convert("L") + + loaded = load_image(image, convert_method=to_grayscale) + self.assertEqual(loaded.mode, "L") + + @patch("diffusers.utils.loading_utils.requests.get") + def test_load_image_from_url(self, mock_get): + buffer = io.BytesIO() + PIL.Image.new("RGB", (6, 6), color="red").save(buffer, format="PNG") + buffer.seek(0) + mock_response = Mock() + mock_response.raw = buffer + mock_get.return_value = mock_response + + loaded = load_image("https://example.com/image.png") + self.assertEqual(loaded.size, (6, 6)) + self.assertEqual(loaded.mode, "RGB") + + def test_load_video_invalid_path_raises(self): + with self.assertRaises(ValueError): + load_video("/path/that/does/not/exist.mp4") + + def test_load_video_gif_frames(self): + frames = [PIL.Image.new("RGB", (4, 4), color=(i * 40, 0, 0)) for i in range(3)] + with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as tmp: + path = tmp.name + try: + frames[0].save(path, save_all=True, append_images=frames[1:], duration=100, loop=0) + loaded = load_video(path) + self.assertEqual(len(loaded), 3) + self.assertEqual(loaded[0].size, (4, 4)) + finally: + os.remove(path) + + def test_get_module_from_name_nested(self): + module = nn.Sequential(nn.Linear(4, 4), nn.ReLU()) + found, name = get_module_from_name(module, "0.weight") + self.assertIsInstance(found, nn.Linear) + self.assertEqual(name, "weight") + + def test_get_module_from_name_missing_attribute_raises(self): + module = nn.Linear(4, 4) + with self.assertRaises(AttributeError): + get_module_from_name(module, "missing.weight") + + def test_get_submodule_by_name_modulelist_index(self): + module = nn.ModuleList([nn.Linear(2, 2), nn.Linear(3, 3)]) + found = get_submodule_by_name(module, "1") + self.assertIsInstance(found, nn.Linear) + self.assertEqual(found.in_features, 3) + + def test_get_submodule_by_name_dotted_path(self): + module = nn.Sequential( + nn.ModuleDict({"block": nn.Linear(4, 4)}), + ) + found = get_submodule_by_name(module, "0.block") + self.assertIsInstance(found, nn.Linear) diff --git a/tests/others/test_remote_utils.py b/tests/others/test_remote_utils.py new file mode 100644 index 0000000..501b94b --- /dev/null +++ b/tests/others/test_remote_utils.py @@ -0,0 +1,104 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import json +import unittest +from unittest.mock import Mock + +import torch +from PIL import Image + +from diffusers.image_processor import VaeImageProcessor +from diffusers.utils.remote_utils import ( + check_inputs_decode, + detect_image_type, + postprocess_decode, + prepare_decode, + prepare_encode, +) + + +class RemoteUtilsTest(unittest.TestCase): + def test_detect_image_type(self): + self.assertEqual(detect_image_type(b"\xff\xd8\xff"), "jpeg") + self.assertEqual(detect_image_type(b"\x89PNG\r\n\x1a\n"), "png") + self.assertEqual(detect_image_type(b"GIF89a"), "gif") + self.assertEqual(detect_image_type(b"BM"), "bmp") + self.assertEqual(detect_image_type(b"unknown"), "unknown") + + def test_check_inputs_decode_packed_latents_requires_hw(self): + tensor = torch.randn(4, 8, 8) + with self.assertRaises(ValueError): + check_inputs_decode("http://example.com", tensor) + + def test_check_inputs_decode_processor_required(self): + tensor = torch.randn(1, 4, 8, 8) + with self.assertRaises(ValueError): + check_inputs_decode( + "http://example.com", + tensor, + processor=None, + output_type="pt", + return_type="pil", + partial_postprocess=False, + ) + + def test_prepare_decode_sets_accept_header_for_jpeg(self): + tensor = torch.randn(1, 4, 8, 8, dtype=torch.float16) + payload = prepare_decode(tensor, output_type="pil", image_format="jpg") + self.assertEqual(payload["headers"]["Accept"], "image/jpeg") + self.assertEqual(payload["params"]["output_type"], "pil") + self.assertEqual(payload["params"]["shape"], list(tensor.shape)) + + def test_prepare_encode_tensor_includes_shape_and_dtype(self): + tensor = torch.randn(1, 3, 8, 8, dtype=torch.float16) + payload = prepare_encode(tensor, scaling_factor=0.18215) + self.assertEqual(payload["params"]["shape"], list(tensor.shape)) + self.assertEqual(payload["params"]["dtype"], "float16") + self.assertEqual(payload["params"]["scaling_factor"], 0.18215) + + def test_prepare_encode_pil_image(self): + image = Image.new("RGB", (8, 8), color="red") + payload = prepare_encode(image) + self.assertIn(b"PNG", payload["data"][:8]) + + def test_postprocess_decode_pil_without_processor(self): + buffer = io.BytesIO() + Image.new("RGB", (4, 4), color="blue").save(buffer, format="PNG") + response = Mock() + response.content = buffer.getvalue() + + output = postprocess_decode(response, processor=None, output_type="pil", return_type="pil") + self.assertIsInstance(output, Image.Image) + self.assertEqual(output.size, (4, 4)) + self.assertEqual(output.format, "png") + + def test_postprocess_decode_pt_tensor(self): + tensor = torch.arange(16, dtype=torch.float32).reshape(1, 4, 2, 2) + response = Mock() + response.content = tensor.numpy().tobytes() + response.headers = { + "shape": json.dumps(list(tensor.shape)), + "dtype": "float32", + } + + output = postprocess_decode( + response, + processor=None, + output_type="pt", + return_type="pt", + partial_postprocess=False, + ) + torch.testing.assert_close(output, tensor) diff --git a/tests/others/test_state_dict_utils.py b/tests/others/test_state_dict_utils.py new file mode 100644 index 0000000..216c101 --- /dev/null +++ b/tests/others/test_state_dict_utils.py @@ -0,0 +1,176 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch + +import torch + +from diffusers.utils.state_dict_utils import ( + StateDictType, + convert_all_state_dict_to_peft, + convert_state_dict, + convert_state_dict_to_diffusers, + convert_state_dict_to_kohya, + convert_state_dict_to_peft, + convert_unet_state_dict_to_peft, + state_dict_all_zero, +) + + +class StateDictUtilsTest(unittest.TestCase): + def test_convert_state_dict_applies_first_matching_pattern(self): + state_dict = {"layer.processor.weight": torch.ones(1)} + converted = convert_state_dict(state_dict, {".processor.": "."}) + self.assertIn("layer.weight", converted) + self.assertNotIn("layer.processor.weight", converted) + + def test_convert_state_dict_to_peft_auto_infers_diffusers_old(self): + state_dict = { + "unet.down_blocks.0.attentions.0.to_out_lora.down.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_out_lora.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_peft(state_dict) + self.assertIn("unet.down_blocks.0.attentions.0.out_proj.lora_A.weight", converted) + + state_dict = { + "unet.down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_peft(state_dict, original_type=StateDictType.DIFFUSERS_OLD) + self.assertIn("unet.down_blocks.0.attentions.0.q_proj.lora_A.weight", converted) + self.assertIn("unet.down_blocks.0.attentions.0.q_proj.lora_B.weight", converted) + + def test_convert_state_dict_to_peft_diffusers(self): + state_dict = { + "text_encoder.encoder.layers.0.self_attn.q_proj.lora_linear_layer.down.weight": torch.ones(2, 2), + "text_encoder.encoder.layers.0.self_attn.q_proj.lora_linear_layer.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_peft(state_dict, original_type=StateDictType.DIFFUSERS) + self.assertIn("text_encoder.encoder.layers.0.self_attn.q_proj.lora_A.weight", converted) + self.assertIn("text_encoder.encoder.layers.0.self_attn.q_proj.lora_B.weight", converted) + + def test_convert_state_dict_to_diffusers_from_peft(self): + state_dict = { + "unet.down_blocks.0.attentions.0.to_q.lora_A.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_q.lora_B.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_diffusers(state_dict, original_type=StateDictType.PEFT) + self.assertIn("unet.down_blocks.0.attentions.0.to_q.lora.down.weight", converted) + self.assertIn("unet.down_blocks.0.attentions.0.to_q.lora.up.weight", converted) + + def test_convert_state_dict_to_diffusers_already_diffusers(self): + state_dict = { + "layer.lora_linear_layer.down.weight": torch.ones(2, 2), + "layer.lora_linear_layer.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_diffusers(state_dict) + self.assertIs(converted, state_dict) + + def test_convert_unet_state_dict_to_peft(self): + state_dict = { + "down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2), + "down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2), + } + converted = convert_unet_state_dict_to_peft(state_dict) + self.assertIn("down_blocks.0.attentions.0.to_q.lora_A.weight", converted) + self.assertIn("down_blocks.0.attentions.0.to_q.lora_B.weight", converted) + + def test_convert_all_state_dict_to_peft_falls_back_to_unet(self): + state_dict = { + "down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2), + "down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2), + } + converted = convert_all_state_dict_to_peft(state_dict) + self.assertTrue(any("lora_A" in key or "lora_B" in key for key in converted)) + + def test_convert_all_state_dict_to_peft_raises_when_conversion_produces_no_lora_keys(self): + state_dict = {"layer.weight": torch.ones(2, 2)} + with self.assertRaisesRegex(ValueError, "Your LoRA was not converted to PEFT"): + convert_all_state_dict_to_peft(state_dict) + + @patch("diffusers.utils.state_dict_utils.convert_state_dict_to_peft") + def test_convert_all_state_dict_to_peft_reraises_non_infer_errors(self, mock_convert): + mock_convert.side_effect = ValueError("Some other error") + with self.assertRaisesRegex(ValueError, "Some other error"): + convert_all_state_dict_to_peft({"layer.weight": torch.ones(2, 2)}) + + def test_state_dict_all_zero(self): + state_dict = { + "a": torch.zeros(2, 2), + "b": torch.zeros(3), + } + self.assertTrue(state_dict_all_zero(state_dict)) + state_dict["b"] = torch.ones(3) + self.assertFalse(state_dict_all_zero(state_dict)) + + def test_state_dict_all_zero_with_filter(self): + state_dict = { + "lora.down": torch.zeros(2, 2), + "bias": torch.ones(3), + } + self.assertTrue(state_dict_all_zero(state_dict, filter_str="lora")) + + def test_convert_state_dict_to_kohya_remaps_component_prefixes(self): + state_dict = { + "text_encoder.encoder.layers.0.self_attn.q_proj.lora_A.weight": torch.ones(2, 2), + "text_encoder.encoder.layers.0.self_attn.q_proj.lora_B.weight": torch.ones(2, 2), + "text_encoder_2.encoder.layers.0.self_attn.q_proj.lora_A.weight": torch.ones(2, 2), + "text_encoder_2.encoder.layers.0.self_attn.q_proj.lora_B.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_q.lora_A.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_q.lora_B.weight": torch.ones(2, 2), + } + kohya = convert_state_dict_to_kohya(state_dict) + self.assertTrue(any(key.startswith("lora_te1_") for key in kohya)) + self.assertTrue(any(key.startswith("lora_te2_") for key in kohya)) + self.assertTrue(any("lora_unet" in key for key in kohya)) + self.assertTrue(any("alpha" in key for key in kohya)) + + def test_convert_state_dict_to_kohya_adds_alpha_for_down_weights(self): + down_weight = torch.ones(4, 2) + state_dict = { + "unet.down_blocks.0.attentions.0.to_q.lora_A.weight": down_weight, + "unet.down_blocks.0.attentions.0.to_q.lora_B.weight": torch.ones(2, 4), + } + kohya = convert_state_dict_to_kohya(state_dict) + alpha_keys = [key for key in kohya if key.endswith(".alpha")] + self.assertEqual(len(alpha_keys), 1) + self.assertEqual(kohya[alpha_keys[0]].item(), len(down_weight)) + + def test_convert_state_dict_to_peft_unsupported_type_raises(self): + state_dict = {"layer.weight": torch.ones(2, 2)} + with self.assertRaises(ValueError): + convert_state_dict_to_peft(state_dict, original_type=StateDictType.PEFT) + + def test_peft_unet_round_trip_preserves_weight_values(self): + weight_a = torch.randn(2, 4) + weight_b = torch.randn(4, 2) + peft_state_dict = { + "unet.down_blocks.0.attentions.0.to_q.lora_A.weight": weight_a.clone(), + "unet.down_blocks.0.attentions.0.to_q.lora_B.weight": weight_b.clone(), + } + diffusers_state_dict = convert_state_dict_to_diffusers(peft_state_dict, original_type=StateDictType.PEFT) + self.assertIn("unet.down_blocks.0.attentions.0.to_q.lora.down.weight", diffusers_state_dict) + self.assertTrue( + torch.allclose( + diffusers_state_dict["unet.down_blocks.0.attentions.0.to_q.lora.down.weight"], + weight_a, + ) + ) + self.assertTrue( + torch.allclose( + diffusers_state_dict["unet.down_blocks.0.attentions.0.to_q.lora.up.weight"], + weight_b, + ) + ) diff --git a/tests/pipelines/ltx/test_ltx_i2v_long_multi_prompt.py b/tests/pipelines/ltx/test_ltx_i2v_long_multi_prompt.py new file mode 100644 index 0000000..6f04714 --- /dev/null +++ b/tests/pipelines/ltx/test_ltx_i2v_long_multi_prompt.py @@ -0,0 +1,272 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import AutoConfig, AutoTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKLLTXVideo, + LTXEulerAncestralRFScheduler, + LTXI2VLongMultiPromptPipeline, + LTXVideoTransformer3DModel, +) +from diffusers.pipelines.ltx.pipeline_ltx_i2v_long_multi_prompt import ( + adain_normalize_latents, + build_video_coords_for_window, + get_latent_coords, + inject_prev_tail_latents, + linear_overlap_fuse, + parse_prompt_segments, + split_into_temporal_windows, +) + +from ...testing_utils import enable_full_determinism, torch_device + + +enable_full_determinism() + + +class LTXI2VLongMultiPromptHelpersTest(unittest.TestCase): + def test_split_into_temporal_windows_single_and_overlapping(self): + windows = split_into_temporal_windows(latent_len=10, temporal_tile_size=4, temporal_overlap=1, compression=8) + self.assertEqual(windows, [(0, 4), (3, 7), (6, 10)]) + + single = split_into_temporal_windows(latent_len=3, temporal_tile_size=8, temporal_overlap=2, compression=8) + self.assertEqual(single, [(0, 3)]) + + def test_split_into_temporal_windows_rejects_invalid_tile_size(self): + with self.assertRaises(ValueError): + split_into_temporal_windows(latent_len=5, temporal_tile_size=0, temporal_overlap=0, compression=8) + + def test_parse_prompt_segments_bar_split(self): + parts = parse_prompt_segments("walk | eat | sleep", prompt_segments=None) + self.assertEqual(parts, ["walk", "eat", "sleep"]) + + def test_parse_prompt_segments_window_mapping(self): + segments = [ + {"start_window": 0, "end_window": 1, "text": "intro"}, + {"start_window": 2, "end_window": 3, "text": "outro"}, + ] + texts = parse_prompt_segments("", prompt_segments=segments) + self.assertEqual(texts, ["intro", "intro", "outro", "outro"]) + + def test_parse_prompt_segments_fills_gaps_from_previous(self): + segments = [ + {"start_window": 0, "end_window": 0, "text": "first"}, + {"start_window": 2, "end_window": 2, "text": "third"}, + ] + texts = parse_prompt_segments("", prompt_segments=segments) + self.assertEqual(texts, ["first", "first", "third"]) + + def test_linear_overlap_fuse_blends_overlap_region(self): + prev = torch.ones(1, 2, 4, 2, 2) + new = torch.zeros(1, 2, 4, 2, 2) + fused = linear_overlap_fuse(prev, new, overlap=2) + self.assertEqual(fused.shape[2], 6) + self.assertTrue(torch.all(fused[:, :, :2] == 1)) + self.assertTrue(torch.all(fused[:, :, -2:] == 0)) + self.assertTrue(torch.all(fused[:, :, 2:4] > 0)) + self.assertTrue(torch.all(fused[:, :, 2:4] < 1)) + + def test_linear_overlap_fuse_no_overlap_concatenates(self): + prev = torch.ones(1, 2, 2, 2, 2) + new = torch.zeros(1, 2, 3, 2, 2) + fused = linear_overlap_fuse(prev, new, overlap=1) + self.assertEqual(fused.shape[2], 5) + + def test_adain_normalize_latents_noop_without_reference(self): + latents = torch.randn(1, 4, 2, 2, 2) + self.assertIs(adain_normalize_latents(latents, None, factor=0.5), latents) + + def test_adain_normalize_latents_matches_reference_stats(self): + curr = torch.randn(1, 2, 3, 4, 4) + ref = torch.randn(1, 2, 3, 4, 4) * 2 + 1 + normalized = adain_normalize_latents(curr, ref, factor=1.0) + self.assertAlmostEqual(normalized.mean().item(), ref.mean().item(), places=4) + self.assertAlmostEqual(normalized.std().item(), ref.std().item(), places=4) + + def test_inject_prev_tail_latents_appends_tail_and_mask(self): + window = torch.randn(1, 4, 3, 2, 2) + tail = torch.randn(1, 4, 2, 2, 2) + mask = torch.ones(1, 1, 3, 2, 2) + updated, updated_mask, overlap_len = inject_prev_tail_latents( + window, tail, mask, overlap_lat=2, strength=0.75, prev_overlap_len=0 + ) + self.assertEqual(updated.shape[2], 5) + self.assertEqual(updated_mask.shape[2], 5) + self.assertEqual(overlap_len, 2) + self.assertTrue(torch.all(updated[:, :, -2:] == tail)) + self.assertTrue(torch.all(updated_mask[:, :, -2:] == 0.25)) + + def test_get_latent_coords_shape_and_time_shift(self): + coords = get_latent_coords( + latent_num_frames=2, + latent_height=2, + latent_width=2, + batch_size=1, + device=torch.device("cpu"), + rope_interpolation_scale=(8, 4, 4), + latent_idx=1, + ) + self.assertEqual(coords.shape, (1, 3, 8)) + self.assertGreater(coords[0, 0, 0].item(), 0) + + def test_build_video_coords_for_window_applies_frame_rate(self): + latents = torch.zeros(1, 4, 2, 2, 2) + rope_scale = (8, 4, 4) + coords = build_video_coords_for_window( + latents=latents, + overlap_len=0, + guiding_len=0, + negative_len=0, + rope_interpolation_scale=rope_scale, + frame_rate=25, + ) + self.assertEqual(coords.shape, (1, 3, 8)) + self.assertAlmostEqual(coords[0, 0, 0].item(), 0.0, places=6) + + +class LTXI2VLongMultiPromptPipelineFastTests(unittest.TestCase): + def get_dummy_components(self): + torch.manual_seed(0) + transformer = LTXVideoTransformer3DModel( + in_channels=8, + out_channels=8, + patch_size=1, + patch_size_t=1, + num_attention_heads=4, + attention_head_dim=8, + cross_attention_dim=32, + num_layers=1, + caption_channels=32, + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTXVideo( + in_channels=3, + out_channels=3, + latent_channels=8, + block_out_channels=(8, 8, 8, 8), + decoder_block_out_channels=(8, 8, 8, 8), + layers_per_block=(1, 1, 1, 1, 1), + decoder_layers_per_block=(1, 1, 1, 1, 1), + spatio_temporal_scaling=(True, True, False, False), + decoder_spatio_temporal_scaling=(True, True, False, False), + decoder_inject_noise=(False, False, False, False, False), + upsample_residual=(False, False, False, False), + upsample_factor=(1, 1, 1, 1), + timestep_conditioning=False, + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + scheduler = LTXEulerAncestralRFScheduler(num_train_timesteps=1000) + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5") + text_encoder = T5EncoderModel(config) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + return { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + + def test_inference_latent_output_single_window(self): + device = torch_device + pipe = LTXI2VLongMultiPromptPipeline(**self.get_dummy_components()) + pipe.to(device) + + generator = torch.Generator(device=device).manual_seed(0) + output = pipe( + prompt="segment one | segment two", + generator=generator, + num_inference_steps=2, + height=32, + width=32, + num_frames=9, + temporal_tile_size=80, + temporal_overlap=0, + guidance_scale=1.0, + output_type="latent", + max_sequence_length=16, + ) + + frames = output.frames + self.assertIsInstance(frames, torch.Tensor) + self.assertEqual(frames.ndim, 5) + self.assertEqual(frames.shape[0], 1) + self.assertEqual(frames.shape[3], 32 // pipe.vae_spatial_compression_ratio) + self.assertEqual(frames.shape[4], 32 // pipe.vae_spatial_compression_ratio) + + def test_inference_rejects_non_divisible_dimensions(self): + pipe = LTXI2VLongMultiPromptPipeline(**self.get_dummy_components()) + pipe.to(torch_device) + + with self.assertRaises(ValueError): + pipe( + prompt="test", + height=30, + width=32, + num_frames=9, + num_inference_steps=1, + output_type="latent", + ) + + def test_inference_latent_output_multi_window_with_overlap(self): + device = torch_device + pipe = LTXI2VLongMultiPromptPipeline(**self.get_dummy_components()) + pipe.to(device) + + common_kwargs = dict( + prompt="segment one | segment two | segment three", + num_inference_steps=2, + height=32, + width=32, + num_frames=25, + guidance_scale=1.0, + output_type="latent", + max_sequence_length=16, + ) + + single_window = pipe( + **common_kwargs, + generator=torch.Generator(device=device).manual_seed(0), + temporal_tile_size=80, + temporal_overlap=0, + ) + multi_window = pipe( + **common_kwargs, + generator=torch.Generator(device=device).manual_seed(0), + temporal_tile_size=16, + temporal_overlap=8, + ) + + single_frames = single_window.frames + multi_frames = multi_window.frames + self.assertIsInstance(multi_frames, torch.Tensor) + self.assertEqual(multi_frames.ndim, 5) + self.assertEqual(multi_frames.shape[0], 1) + self.assertEqual(multi_frames.shape[3], 32 // pipe.vae_spatial_compression_ratio) + self.assertEqual(multi_frames.shape[4], 32 // pipe.vae_spatial_compression_ratio) + self.assertNotEqual(multi_frames.shape[2], single_frames.shape[2]) + self.assertGreater(multi_frames.shape[2], single_frames.shape[2]) diff --git a/tests/schedulers/test_scheduler_cogvideox.py b/tests/schedulers/test_scheduler_cogvideox.py new file mode 100644 index 0000000..e7b503f --- /dev/null +++ b/tests/schedulers/test_scheduler_cogvideox.py @@ -0,0 +1,199 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.schedulers.scheduling_ddim_cogvideox import DDIMSchedulerOutput + + +class CogVideoXDDIMSchedulerTest(unittest.TestCase): + """ + Contract tests for CogVideoXDDIMScheduler — used by CogView3+ and shares the SNR-shifted + schedule with CogVideoX pipelines. No dedicated scheduler test file existed previously. + """ + + scheduler_class = CogVideoXDDIMScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.00085, + "beta_end": 0.0120, + "beta_schedule": "scaled_linear", + "timestep_spacing": "leading", + } + config.update(**kwargs) + return config + + def test_snr_shift_modifies_alphas_cumprod(self): + shifted = self.scheduler_class(**self.get_default_config(snr_shift_scale=3.0)) + unshifted = self.scheduler_class(**self.get_default_config(snr_shift_scale=1.0)) + self.assertFalse(torch.allclose(shifted.alphas_cumprod, unshifted.alphas_cumprod)) + + def test_set_timesteps_num_inference_steps_exceeds_train_timesteps_raises(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(scheduler.config.num_train_timesteps + 1) + + def test_set_timesteps_produces_expected_count(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 4, 10, 50]: + scheduler.set_timesteps(nfe) + self.assertEqual(scheduler.num_inference_steps, nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe,)) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 4, 8, 8) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0] + + output = scheduler.step(model_output, timestep, sample, eta=0.0) + self.assertIsInstance(output, DDIMSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertEqual(output.prev_sample.dtype, sample.dtype) + + def test_step_deterministic_when_eta_zero(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + generator = torch.Generator().manual_seed(0) + timestep = scheduler.timesteps[0] + + first = scheduler.step(model_output, timestep, sample, eta=0.0, generator=generator).prev_sample + generator = torch.Generator().manual_seed(0) + second = scheduler.step(model_output, timestep, sample, eta=0.0, generator=generator).prev_sample + torch.testing.assert_close(first, second) + + def test_full_denoising_loop(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + generator = torch.Generator().manual_seed(0) + for t in scheduler.timesteps: + sample = scheduler.step( + torch.randn_like(sample), t, sample, eta=0.0, generator=generator + ).prev_sample + self.assertEqual(sample.shape, (1, 4, 8, 8)) + + +class CogVideoXDPMSchedulerTest(unittest.TestCase): + """ + Contract tests for CogVideoXDPMScheduler — CogVideoX pipelines branch on its multi-step API. + """ + + scheduler_class = CogVideoXDPMScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.00085, + "beta_end": 0.0120, + "beta_schedule": "scaled_linear", + "timestep_spacing": "trailing", + } + config.update(**kwargs) + return config + + def test_step_requires_set_timesteps(self): + scheduler = self.scheduler_class(**self.get_default_config()) + sample = torch.randn(1, 4, 8, 8) + with self.assertRaises(ValueError): + scheduler.step( + torch.randn_like(sample), + None, + scheduler.timesteps[0], + None, + sample, + return_dict=False, + ) + + def test_set_timesteps_num_inference_steps_exceeds_train_timesteps_raises(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(scheduler.config.num_train_timesteps + 1) + + def test_first_step_returns_pred_original_sample(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0] + + prev_sample, pred_original_sample = scheduler.step( + model_output, + None, + timestep, + None, + sample, + return_dict=False, + ) + self.assertEqual(prev_sample.shape, sample.shape) + self.assertEqual(pred_original_sample.shape, sample.shape) + + def test_second_step_uses_old_pred_original_sample(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + generator = torch.Generator().manual_seed(0) + + _, old_pred = scheduler.step( + model_output, + None, + scheduler.timesteps[0], + None, + sample, + return_dict=False, + generator=generator, + ) + + prev_sample, _ = scheduler.step( + model_output, + old_pred, + scheduler.timesteps[1], + scheduler.timesteps[0], + sample, + return_dict=False, + generator=generator, + ) + self.assertEqual(prev_sample.shape, sample.shape) + + def test_get_variables_and_get_mult(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + t = scheduler.timesteps[1] + t_prev = t - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + t_back = scheduler.timesteps[0] + + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t_prev] + alpha_prod_t_back = scheduler.alphas_cumprod[t_back] + + h, r, lamb, lamb_next = scheduler.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back) + mult = scheduler.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back) + self.assertEqual(len(mult), 4) + self.assertTrue(torch.isfinite(h)) + self.assertTrue(torch.isfinite(lamb)) diff --git a/tests/schedulers/test_scheduler_ddim.py b/tests/schedulers/test_scheduler_ddim.py index 13b353a..dc6e688 100644 --- a/tests/schedulers/test_scheduler_ddim.py +++ b/tests/schedulers/test_scheduler_ddim.py @@ -73,6 +73,20 @@ def test_timestep_spacing(self): for timestep_spacing in ["trailing", "leading"]: self.check_over_configs(timestep_spacing=timestep_spacing) + def test_set_timesteps_num_inference_steps_exceeds_train_timesteps_raises(self): + # Guard against inverted comparison (num_inference_steps < num_train_timesteps) which would + # reject all valid inference schedules and accept invalid ones. + scheduler_class = self.scheduler_classes[0] + scheduler = scheduler_class(**self.get_scheduler_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(scheduler.config.num_train_timesteps + 1) + + def test_set_timesteps_num_inference_steps_at_limit_succeeds(self): + scheduler_class = self.scheduler_classes[0] + scheduler = scheduler_class(**self.get_scheduler_config()) + scheduler.set_timesteps(scheduler.config.num_train_timesteps) + self.assertEqual(scheduler.num_inference_steps, scheduler.config.num_train_timesteps) + def test_rescale_betas_zero_snr(self): for rescale_betas_zero_snr in [True, False]: self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) diff --git a/tests/schedulers/test_scheduler_ddpm.py b/tests/schedulers/test_scheduler_ddpm.py index 056b5d8..b42b2a5 100644 --- a/tests/schedulers/test_scheduler_ddpm.py +++ b/tests/schedulers/test_scheduler_ddpm.py @@ -72,6 +72,18 @@ def test_rescale_betas_zero_snr(self): for rescale_betas_zero_snr in [True, False]: self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) + def test_set_timesteps_num_inference_steps_exceeds_train_timesteps_raises(self): + scheduler_class = self.scheduler_classes[0] + scheduler = scheduler_class(**self.get_scheduler_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(scheduler.config.num_train_timesteps + 1) + + def test_set_timesteps_num_inference_steps_at_limit_succeeds(self): + scheduler_class = self.scheduler_classes[0] + scheduler = scheduler_class(**self.get_scheduler_config()) + scheduler.set_timesteps(scheduler.config.num_train_timesteps) + self.assertEqual(scheduler.num_inference_steps, scheduler.config.num_train_timesteps) + def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() diff --git a/tests/schedulers/test_scheduler_ddpm_wuerstchen.py b/tests/schedulers/test_scheduler_ddpm_wuerstchen.py new file mode 100644 index 0000000..36ea793 --- /dev/null +++ b/tests/schedulers/test_scheduler_ddpm_wuerstchen.py @@ -0,0 +1,110 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import DDPMWuerstchenScheduler +from diffusers.schedulers.scheduling_ddpm_wuerstchen import DDPMWuerstchenSchedulerOutput + + +class DDPMWuerstchenSchedulerTest(unittest.TestCase): + """ + Contract tests for DDPMWuerstchenScheduler — Stable Cascade prior/decoder use float timesteps + in [1, 0] rather than integer indices. Pipeline tests only exercised it indirectly. + """ + + scheduler_class = DDPMWuerstchenScheduler + + def get_default_config(self, **kwargs): + config = {"scaler": 1.0, "s": 0.008} + config.update(**kwargs) + return config + + def test_set_timesteps_float_schedule(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 4, 10]: + scheduler.set_timesteps(num_inference_steps=nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe + 1,)) + self.assertAlmostEqual(scheduler.timesteps[0].item(), 1.0, places=5) + self.assertAlmostEqual(scheduler.timesteps[-1].item(), 0.0, places=5) + + def test_scaler_modifies_alpha_cumprod(self): + default = self.scheduler_class(**self.get_default_config()) + scaled = self.scheduler_class(**self.get_default_config(scaler=2.0)) + t = torch.tensor([0.5]) + default_alpha = default._alpha_cumprod(t, device="cpu") + scaled_alpha = scaled._alpha_cumprod(t, device="cpu") + self.assertFalse(torch.allclose(default_alpha, scaled_alpha)) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 8, 4, 4) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1].expand(sample.shape[0]) + + output = scheduler.step(model_output, timestep, sample) + self.assertIsInstance(output, DDPMWuerstchenSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertEqual(output.prev_sample.dtype, sample.dtype) + + def test_step_deterministic_with_generator(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1] + generator = torch.Generator().manual_seed(0) + + first = scheduler.step(model_output, timestep, sample, generator=generator).prev_sample + generator = torch.Generator().manual_seed(0) + second = scheduler.step(model_output, timestep, sample, generator=generator).prev_sample + torch.testing.assert_close(first, second) + + def test_previous_timestep_advances_schedule(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + current = scheduler.timesteps[0:1] + prev = scheduler.previous_timestep(current) + self.assertAlmostEqual(prev[0].item(), scheduler.timesteps[1].item(), places=5) + + def test_add_noise_interpolates_sample_and_noise(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.zeros(1, 4, 4, 4) + noise = torch.ones_like(sample) + timesteps = scheduler.timesteps[0:1] + mixed = scheduler.add_noise(sample, noise, timesteps) + self.assertTrue(mixed.min() >= 0.0) + self.assertTrue(mixed.max() <= 1.0) + self.assertFalse(torch.allclose(mixed, sample)) + self.assertFalse(torch.allclose(mixed, noise)) + + def test_full_denoising_loop(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + generator = torch.Generator().manual_seed(0) + + sample = torch.randn(1, 4, 8, 8) + for t in scheduler.timesteps[:-1]: + batch_t = t.expand(sample.shape[0]) + sample = scheduler.step( + torch.randn_like(sample), batch_t, sample, generator=generator + ).prev_sample + self.assertEqual(sample.shape, (1, 4, 8, 8)) diff --git a/tests/schedulers/test_scheduler_flow_match_euler_discrete.py b/tests/schedulers/test_scheduler_flow_match_euler_discrete.py new file mode 100644 index 0000000..56f7410 --- /dev/null +++ b/tests/schedulers/test_scheduler_flow_match_euler_discrete.py @@ -0,0 +1,108 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteSchedulerOutput + + +class FlowMatchEulerDiscreteSchedulerTest(unittest.TestCase): + """ + Contract tests for FlowMatchEulerDiscreteScheduler — shared by SD3, Flux, Wan, and many + flow-matching pipelines. No dedicated scheduler test file existed previously. + """ + + scheduler_class = FlowMatchEulerDiscreteScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "shift": 1.0, + } + config.update(**kwargs) + return config + + def test_set_timesteps_endpoints(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 2, 4, 8, 16]: + scheduler.set_timesteps(num_inference_steps=nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe,)) + self.assertEqual(scheduler.sigmas.shape, (nfe + 1,)) + self.assertAlmostEqual(scheduler.sigmas[-1].item(), 0.0, places=6) + + def test_set_timesteps_dynamic_shifting_requires_mu(self): + scheduler = self.scheduler_class(**self.get_default_config(use_dynamic_shifting=True)) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=4) + + def test_set_timesteps_custom_sigmas_and_timesteps_length_mismatch(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(sigmas=[1.0, 0.5, 0.0], timesteps=[900.0, 500.0]) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 4, 8, 8) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1] + + output = scheduler.step(model_output, timestep, sample) + self.assertIsInstance(output, FlowMatchEulerDiscreteSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertEqual(output.prev_sample.dtype, model_output.dtype) + + def test_step_rejects_integer_timestep(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + sample = torch.randn(1, 4, 4, 4) + with self.assertRaises(ValueError): + scheduler.step(torch.randn_like(sample), 0, sample) + + def test_scale_noise_interpolates_sample_and_noise(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(timesteps=[500.0, 250.0]) + sample = torch.zeros(1, 4, 4, 4) + noise = torch.ones_like(sample) + mid_t = scheduler.timesteps[0:1] + mixed = scheduler.scale_noise(sample, mid_t, noise) + # sigma=0.5 at t=500 with default 1000 training steps + torch.testing.assert_close(mixed, 0.5 * noise, atol=1e-4, rtol=1e-4) + + def test_index_for_timestep_duplicate_handling(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(timesteps=[900.0, 500.0, 500.0, 100.0]) + duplicate = scheduler.timesteps[1] + self.assertEqual(scheduler.index_for_timestep(duplicate), 2) + + def test_set_begin_index_anchors_step_index(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + scheduler.set_begin_index(2) + sample = torch.randn(1, 4, 4, 4) + scheduler.step(torch.randn_like(sample), scheduler.timesteps[0], sample) + self.assertEqual(scheduler.step_index, 3) + + def test_full_denoising_loop(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + for t in scheduler.timesteps: + sample = scheduler.step(torch.randn_like(sample), t, sample).prev_sample + self.assertEqual(sample.shape, (1, 4, 8, 8)) diff --git a/tests/schedulers/test_scheduler_helios.py b/tests/schedulers/test_scheduler_helios.py new file mode 100644 index 0000000..6fac77b --- /dev/null +++ b/tests/schedulers/test_scheduler_helios.py @@ -0,0 +1,107 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch + +import torch + +from diffusers import HeliosScheduler +from diffusers.schedulers.scheduling_helios import HeliosSchedulerOutput + + +class HeliosSchedulerTest(unittest.TestCase): + """ + Unit tests for HeliosScheduler multi-stage scheduling. Pipeline tests only cover stages=1; + these tests lock in per-stage sigma ranges and step behavior used by the default 3-stage config. + """ + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "shift": 1.0, + "stages": 3, + "stage_range": [0, 1 / 3, 2 / 3, 1], + "gamma": 1 / 3, + "prediction_type": "flow_prediction", + "scheduler_type": "euler", + "use_dynamic_shifting": False, + } + config.update(**kwargs) + return config + + def test_multi_stage_init_populates_per_stage_buffers(self): + scheduler = HeliosScheduler(**self.get_default_config()) + self.assertEqual(len(scheduler.timesteps_per_stage), 3) + self.assertEqual(len(scheduler.sigmas_per_stage), 3) + self.assertEqual(len(scheduler.start_sigmas), 3) + self.assertEqual(len(scheduler.end_sigmas), 3) + for stage in range(3): + self.assertGreater(scheduler.start_sigmas[stage], scheduler.end_sigmas[stage]) + + def test_set_timesteps_multi_stage_per_stage_index(self): + scheduler = HeliosScheduler(**self.get_default_config()) + for stage_index in range(3): + scheduler.set_timesteps(num_inference_steps=8, stage_index=stage_index) + self.assertEqual(scheduler.timesteps.shape, (8,)) + self.assertEqual(scheduler.sigmas.shape, (9,)) + self.assertAlmostEqual(scheduler.sigmas[-1].item(), 0.0, places=6) + self.assertGreater(scheduler.timesteps[0].item(), scheduler.timesteps[-1].item()) + + def test_set_timesteps_single_stage(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1])) + scheduler.set_timesteps(num_inference_steps=4) + self.assertEqual(scheduler.timesteps.shape, (4,)) + self.assertEqual(scheduler.sigmas.shape, (5,)) + + def test_step_euler_updates_sample(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1])) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 4, 8, 8) + model_output = torch.randn_like(sample) + output = scheduler.step(model_output, scheduler.timesteps[0], sample) + self.assertIsInstance(output, HeliosSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertFalse(torch.allclose(output.prev_sample, sample)) + + def test_convert_model_output_flow_prediction(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1], scheduler_type="unipc")) + scheduler.set_timesteps(num_inference_steps=4) + scheduler._step_index = 0 + + sample = torch.randn(1, 4, 4, 4) + model_output = torch.randn_like(sample) + x0 = scheduler.convert_model_output(model_output, sample=sample) + expected = sample - scheduler.sigmas[0] * model_output + torch.testing.assert_close(x0, expected) + + def test_dynamic_shifting_rescales_timesteps(self): + scheduler = HeliosScheduler( + **self.get_default_config(stages=1, stage_range=[0, 1], use_dynamic_shifting=True) + ) + scheduler.set_timesteps(num_inference_steps=4, mu=0.5) + self.assertEqual(scheduler.timesteps.shape, (4,)) + + def test_step_unipc_invokes_corrector_on_second_step(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1], scheduler_type="unipc")) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 4, 4) + model_output = torch.randn_like(sample) + scheduler.step(model_output, scheduler.timesteps[0], sample) + + with patch.object(scheduler, "multistep_uni_c_bh_update", wraps=scheduler.multistep_uni_c_bh_update) as corrector: + scheduler.step(model_output, scheduler.timesteps[1], sample) + corrector.assert_called_once() diff --git a/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py b/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py new file mode 100644 index 0000000..f898d54 --- /dev/null +++ b/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py @@ -0,0 +1,124 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import LTXEulerAncestralRFScheduler +from diffusers.schedulers.scheduling_ltx_euler_ancestral_rf import LTXEulerAncestralRFSchedulerOutput + + +class LTXEulerAncestralRFSchedulerTest(unittest.TestCase): + """ + Contract tests for the LTX RF ancestral scheduler used by LTX-Video long-form pipelines. + Mirrors the style of `test_scheduler_flow_map_euler_discrete.py` because this scheduler + has a non-standard step API and cannot reuse `SchedulerCommonTest`. + """ + + scheduler_class = LTXEulerAncestralRFScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "eta": 1.0, + "s_noise": 1.0, + } + config.update(**kwargs) + return config + + def test_set_timesteps_auto_generates_schedule(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 2, 4, 8]: + scheduler.set_timesteps(num_inference_steps=nfe) + self.assertEqual(scheduler.num_inference_steps, nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe,)) + self.assertEqual(scheduler.sigmas.shape, (nfe + 1,)) + self.assertAlmostEqual(scheduler.sigmas[-1].item(), 0.0, places=6) + + def test_set_timesteps_requires_args(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps() + + def test_set_timesteps_explicit_sigmas(self): + scheduler = self.scheduler_class(**self.get_default_config()) + custom_sigmas = [1.0, 0.75, 0.5, 0.25, 0.0] + scheduler.set_timesteps(sigmas=custom_sigmas) + self.assertEqual(scheduler.num_inference_steps, 4) + for i, sigma in enumerate(custom_sigmas): + self.assertAlmostEqual(scheduler.sigmas[i].item(), sigma, places=5) + self.assertAlmostEqual(scheduler.timesteps[0].item(), 1000.0, places=4) + + def test_set_timesteps_rejects_non_1d_sigmas(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(sigmas=[[1.0, 0.5], [0.5, 0.0]]) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 8, 4, 4) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1] + + output = scheduler.step(model_output, timestep, sample) + self.assertIsInstance(output, LTXEulerAncestralRFSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertEqual(output.prev_sample.dtype, sample.dtype) + + def test_step_deterministic_when_eta_zero(self): + scheduler = self.scheduler_class(**self.get_default_config(eta=0.0)) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + generator = torch.Generator().manual_seed(0) + + first = scheduler.step(model_output, scheduler.timesteps[0:1], sample, generator=generator).prev_sample + + scheduler.set_timesteps(num_inference_steps=4) + second = scheduler.step(model_output, scheduler.timesteps[0:1], sample, generator=generator).prev_sample + torch.testing.assert_close(first, second) + + def test_step_rejects_integer_timestep(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + sample = torch.randn(1, 4, 4, 4) + with self.assertRaises(ValueError): + scheduler.step(torch.randn_like(sample), 0, sample) + + def test_index_for_timestep_duplicate_handling(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(sigmas=[1.0, 0.5, 0.5, 0.0]) + duplicate = scheduler.timesteps[1] + self.assertEqual(scheduler.index_for_timestep(duplicate), 2) + + def test_set_begin_index_anchors_step_index(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + scheduler.set_begin_index(2) + sample = torch.randn(1, 4, 4, 4) + scheduler.step(torch.randn_like(sample), scheduler.timesteps[0], sample) + self.assertEqual(scheduler.step_index, 3) + + def test_full_denoising_loop(self): + scheduler = self.scheduler_class(**self.get_default_config(eta=0.0)) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + for t in scheduler.timesteps: + sample = scheduler.step(torch.randn_like(sample), t, sample).prev_sample + self.assertEqual(sample.shape, (1, 4, 8, 8)) diff --git a/utils/check_copies.py b/utils/check_copies.py index 001366c..da30eb7 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -18,6 +18,7 @@ import os import re import subprocess +import sys # All paths are set with the intent you should run this script from the root of the repo with the command @@ -94,7 +95,7 @@ def get_indent(code): def run_ruff(code): - command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"] + command = [sys.executable, "-m", "ruff", "format", "-", "--config", "pyproject.toml", "--silent"] process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE) stdout, _ = process.communicate(input=code.encode()) return stdout.decode() @@ -195,7 +196,23 @@ def is_copy_consistent(filename, overwrite=False): return diffs +def _has_src_diffusers_changes() -> bool: + """Return True when src/diffusers differs from HEAD, index, or main.""" + for cmd in ( + ["git", "diff", "--name-only", "main", "--", "src/diffusers"], + ["git", "diff", "--name-only", "--cached", "--", "src/diffusers"], + ["git", "diff", "--name-only", "--", "src/diffusers"], + ): + result = subprocess.run(cmd, capture_output=True, text=True, check=False) + if result.stdout.strip(): + return True + return False + + def check_copies(overwrite: bool = False): + if not overwrite and not _has_src_diffusers_changes(): + return + all_files = glob.glob(os.path.join(DIFFUSERS_PATH, "**/*.py"), recursive=True) diffs = [] for filename in all_files: