From 5ad14b67cbfb9fe4948d68361e43bcafe62fa6cd Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Sun, 28 Jun 2026 02:02:51 +0000 Subject: [PATCH 1/4] Add regression tests for high-blast-radius schedulers and utilities Cover FlowMatchEulerDiscreteScheduler, HeliosScheduler, and LTXEulerAncestralRFScheduler contract behavior, DDIM/DDPM set_timesteps validation guards, state_dict LoRA conversion paths, and remote_utils encode/decode helpers. Co-authored-by: Simon Lynch --- tests/others/test_remote_utils.py | 104 +++++++++++++ tests/others/test_state_dict_utils.py | 137 ++++++++++++++++++ tests/schedulers/test_scheduler_ddim.py | 14 ++ tests/schedulers/test_scheduler_ddpm.py | 12 ++ ...est_scheduler_flow_match_euler_discrete.py | 108 ++++++++++++++ tests/schedulers/test_scheduler_helios.py | 107 ++++++++++++++ .../test_scheduler_ltx_euler_ancestral_rf.py | 124 ++++++++++++++++ 7 files changed, 606 insertions(+) create mode 100644 tests/others/test_remote_utils.py create mode 100644 tests/others/test_state_dict_utils.py create mode 100644 tests/schedulers/test_scheduler_flow_match_euler_discrete.py create mode 100644 tests/schedulers/test_scheduler_helios.py create mode 100644 tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py diff --git a/tests/others/test_remote_utils.py b/tests/others/test_remote_utils.py new file mode 100644 index 0000000..501b94b --- /dev/null +++ b/tests/others/test_remote_utils.py @@ -0,0 +1,104 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import json +import unittest +from unittest.mock import Mock + +import torch +from PIL import Image + +from diffusers.image_processor import VaeImageProcessor +from diffusers.utils.remote_utils import ( + check_inputs_decode, + detect_image_type, + postprocess_decode, + prepare_decode, + prepare_encode, +) + + +class RemoteUtilsTest(unittest.TestCase): + def test_detect_image_type(self): + self.assertEqual(detect_image_type(b"\xff\xd8\xff"), "jpeg") + self.assertEqual(detect_image_type(b"\x89PNG\r\n\x1a\n"), "png") + self.assertEqual(detect_image_type(b"GIF89a"), "gif") + self.assertEqual(detect_image_type(b"BM"), "bmp") + self.assertEqual(detect_image_type(b"unknown"), "unknown") + + def test_check_inputs_decode_packed_latents_requires_hw(self): + tensor = torch.randn(4, 8, 8) + with self.assertRaises(ValueError): + check_inputs_decode("http://example.com", tensor) + + def test_check_inputs_decode_processor_required(self): + tensor = torch.randn(1, 4, 8, 8) + with self.assertRaises(ValueError): + check_inputs_decode( + "http://example.com", + tensor, + processor=None, + output_type="pt", + return_type="pil", + partial_postprocess=False, + ) + + def test_prepare_decode_sets_accept_header_for_jpeg(self): + tensor = torch.randn(1, 4, 8, 8, dtype=torch.float16) + payload = prepare_decode(tensor, output_type="pil", image_format="jpg") + self.assertEqual(payload["headers"]["Accept"], "image/jpeg") + self.assertEqual(payload["params"]["output_type"], "pil") + self.assertEqual(payload["params"]["shape"], list(tensor.shape)) + + def test_prepare_encode_tensor_includes_shape_and_dtype(self): + tensor = torch.randn(1, 3, 8, 8, dtype=torch.float16) + payload = prepare_encode(tensor, scaling_factor=0.18215) + self.assertEqual(payload["params"]["shape"], list(tensor.shape)) + self.assertEqual(payload["params"]["dtype"], "float16") + self.assertEqual(payload["params"]["scaling_factor"], 0.18215) + + def test_prepare_encode_pil_image(self): + image = Image.new("RGB", (8, 8), color="red") + payload = prepare_encode(image) + self.assertIn(b"PNG", payload["data"][:8]) + + def test_postprocess_decode_pil_without_processor(self): + buffer = io.BytesIO() + Image.new("RGB", (4, 4), color="blue").save(buffer, format="PNG") + response = Mock() + response.content = buffer.getvalue() + + output = postprocess_decode(response, processor=None, output_type="pil", return_type="pil") + self.assertIsInstance(output, Image.Image) + self.assertEqual(output.size, (4, 4)) + self.assertEqual(output.format, "png") + + def test_postprocess_decode_pt_tensor(self): + tensor = torch.arange(16, dtype=torch.float32).reshape(1, 4, 2, 2) + response = Mock() + response.content = tensor.numpy().tobytes() + response.headers = { + "shape": json.dumps(list(tensor.shape)), + "dtype": "float32", + } + + output = postprocess_decode( + response, + processor=None, + output_type="pt", + return_type="pt", + partial_postprocess=False, + ) + torch.testing.assert_close(output, tensor) diff --git a/tests/others/test_state_dict_utils.py b/tests/others/test_state_dict_utils.py new file mode 100644 index 0000000..ac399d9 --- /dev/null +++ b/tests/others/test_state_dict_utils.py @@ -0,0 +1,137 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers.utils.state_dict_utils import ( + StateDictType, + convert_all_state_dict_to_peft, + convert_state_dict, + convert_state_dict_to_diffusers, + convert_state_dict_to_kohya, + convert_state_dict_to_peft, + convert_unet_state_dict_to_peft, + state_dict_all_zero, +) + + +class StateDictUtilsTest(unittest.TestCase): + def test_convert_state_dict_applies_first_matching_pattern(self): + state_dict = {"layer.processor.weight": torch.ones(1)} + converted = convert_state_dict(state_dict, {".processor.": "."}) + self.assertIn("layer.weight", converted) + self.assertNotIn("layer.processor.weight", converted) + + def test_convert_state_dict_to_peft_auto_infers_diffusers_old(self): + state_dict = { + "unet.down_blocks.0.attentions.0.to_out_lora.down.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_out_lora.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_peft(state_dict) + self.assertIn("unet.down_blocks.0.attentions.0.out_proj.lora_A.weight", converted) + + state_dict = { + "unet.down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_peft(state_dict, original_type=StateDictType.DIFFUSERS_OLD) + self.assertIn("unet.down_blocks.0.attentions.0.q_proj.lora_A.weight", converted) + self.assertIn("unet.down_blocks.0.attentions.0.q_proj.lora_B.weight", converted) + + def test_convert_state_dict_to_peft_diffusers(self): + state_dict = { + "text_encoder.encoder.layers.0.self_attn.q_proj.lora_linear_layer.down.weight": torch.ones(2, 2), + "text_encoder.encoder.layers.0.self_attn.q_proj.lora_linear_layer.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_peft(state_dict, original_type=StateDictType.DIFFUSERS) + self.assertIn("text_encoder.encoder.layers.0.self_attn.q_proj.lora_A.weight", converted) + self.assertIn("text_encoder.encoder.layers.0.self_attn.q_proj.lora_B.weight", converted) + + def test_convert_state_dict_to_diffusers_from_peft(self): + state_dict = { + "unet.down_blocks.0.attentions.0.to_q.lora_A.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_q.lora_B.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_diffusers(state_dict, original_type=StateDictType.PEFT) + self.assertIn("unet.down_blocks.0.attentions.0.to_q.lora.down.weight", converted) + self.assertIn("unet.down_blocks.0.attentions.0.to_q.lora.up.weight", converted) + + def test_convert_state_dict_to_diffusers_already_diffusers(self): + state_dict = { + "layer.lora_linear_layer.down.weight": torch.ones(2, 2), + "layer.lora_linear_layer.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_diffusers(state_dict) + self.assertIs(converted, state_dict) + + def test_convert_unet_state_dict_to_peft(self): + state_dict = { + "down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2), + "down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2), + } + converted = convert_unet_state_dict_to_peft(state_dict) + self.assertIn("down_blocks.0.attentions.0.to_q.lora_A.weight", converted) + self.assertIn("down_blocks.0.attentions.0.to_q.lora_B.weight", converted) + + def test_convert_all_state_dict_to_peft_falls_back_to_unet(self): + state_dict = { + "down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2), + "down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2), + } + converted = convert_all_state_dict_to_peft(state_dict) + self.assertTrue(any("lora_A" in key or "lora_B" in key for key in converted)) + + def test_state_dict_all_zero(self): + state_dict = { + "a": torch.zeros(2, 2), + "b": torch.zeros(3), + } + self.assertTrue(state_dict_all_zero(state_dict)) + state_dict["b"] = torch.ones(3) + self.assertFalse(state_dict_all_zero(state_dict)) + + def test_state_dict_all_zero_with_filter(self): + state_dict = { + "lora.down": torch.zeros(2, 2), + "bias": torch.ones(3), + } + self.assertTrue(state_dict_all_zero(state_dict, filter_str="lora")) + + def test_convert_state_dict_to_kohya_remaps_component_prefixes(self): + state_dict = { + "text_encoder.encoder.layers.0.self_attn.q_proj.lora_A.weight": torch.ones(2, 2), + "text_encoder.encoder.layers.0.self_attn.q_proj.lora_B.weight": torch.ones(2, 2), + "text_encoder_2.encoder.layers.0.self_attn.q_proj.lora_A.weight": torch.ones(2, 2), + "text_encoder_2.encoder.layers.0.self_attn.q_proj.lora_B.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_q.lora_A.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_q.lora_B.weight": torch.ones(2, 2), + } + kohya = convert_state_dict_to_kohya(state_dict) + self.assertTrue(any(key.startswith("lora_te1_") for key in kohya)) + self.assertTrue(any(key.startswith("lora_te2_") for key in kohya)) + self.assertTrue(any("lora_unet" in key for key in kohya)) + self.assertTrue(any("alpha" in key for key in kohya)) + + def test_convert_state_dict_to_kohya_adds_alpha_for_down_weights(self): + down_weight = torch.ones(4, 2) + state_dict = { + "unet.down_blocks.0.attentions.0.to_q.lora_A.weight": down_weight, + "unet.down_blocks.0.attentions.0.to_q.lora_B.weight": torch.ones(2, 4), + } + kohya = convert_state_dict_to_kohya(state_dict) + alpha_keys = [key for key in kohya if key.endswith(".alpha")] + self.assertEqual(len(alpha_keys), 1) + self.assertEqual(kohya[alpha_keys[0]].item(), len(down_weight)) diff --git a/tests/schedulers/test_scheduler_ddim.py b/tests/schedulers/test_scheduler_ddim.py index 13b353a..dc6e688 100644 --- a/tests/schedulers/test_scheduler_ddim.py +++ b/tests/schedulers/test_scheduler_ddim.py @@ -73,6 +73,20 @@ def test_timestep_spacing(self): for timestep_spacing in ["trailing", "leading"]: self.check_over_configs(timestep_spacing=timestep_spacing) + def test_set_timesteps_num_inference_steps_exceeds_train_timesteps_raises(self): + # Guard against inverted comparison (num_inference_steps < num_train_timesteps) which would + # reject all valid inference schedules and accept invalid ones. + scheduler_class = self.scheduler_classes[0] + scheduler = scheduler_class(**self.get_scheduler_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(scheduler.config.num_train_timesteps + 1) + + def test_set_timesteps_num_inference_steps_at_limit_succeeds(self): + scheduler_class = self.scheduler_classes[0] + scheduler = scheduler_class(**self.get_scheduler_config()) + scheduler.set_timesteps(scheduler.config.num_train_timesteps) + self.assertEqual(scheduler.num_inference_steps, scheduler.config.num_train_timesteps) + def test_rescale_betas_zero_snr(self): for rescale_betas_zero_snr in [True, False]: self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) diff --git a/tests/schedulers/test_scheduler_ddpm.py b/tests/schedulers/test_scheduler_ddpm.py index 056b5d8..b42b2a5 100644 --- a/tests/schedulers/test_scheduler_ddpm.py +++ b/tests/schedulers/test_scheduler_ddpm.py @@ -72,6 +72,18 @@ def test_rescale_betas_zero_snr(self): for rescale_betas_zero_snr in [True, False]: self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) + def test_set_timesteps_num_inference_steps_exceeds_train_timesteps_raises(self): + scheduler_class = self.scheduler_classes[0] + scheduler = scheduler_class(**self.get_scheduler_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(scheduler.config.num_train_timesteps + 1) + + def test_set_timesteps_num_inference_steps_at_limit_succeeds(self): + scheduler_class = self.scheduler_classes[0] + scheduler = scheduler_class(**self.get_scheduler_config()) + scheduler.set_timesteps(scheduler.config.num_train_timesteps) + self.assertEqual(scheduler.num_inference_steps, scheduler.config.num_train_timesteps) + def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() diff --git a/tests/schedulers/test_scheduler_flow_match_euler_discrete.py b/tests/schedulers/test_scheduler_flow_match_euler_discrete.py new file mode 100644 index 0000000..56f7410 --- /dev/null +++ b/tests/schedulers/test_scheduler_flow_match_euler_discrete.py @@ -0,0 +1,108 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteSchedulerOutput + + +class FlowMatchEulerDiscreteSchedulerTest(unittest.TestCase): + """ + Contract tests for FlowMatchEulerDiscreteScheduler — shared by SD3, Flux, Wan, and many + flow-matching pipelines. No dedicated scheduler test file existed previously. + """ + + scheduler_class = FlowMatchEulerDiscreteScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "shift": 1.0, + } + config.update(**kwargs) + return config + + def test_set_timesteps_endpoints(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 2, 4, 8, 16]: + scheduler.set_timesteps(num_inference_steps=nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe,)) + self.assertEqual(scheduler.sigmas.shape, (nfe + 1,)) + self.assertAlmostEqual(scheduler.sigmas[-1].item(), 0.0, places=6) + + def test_set_timesteps_dynamic_shifting_requires_mu(self): + scheduler = self.scheduler_class(**self.get_default_config(use_dynamic_shifting=True)) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=4) + + def test_set_timesteps_custom_sigmas_and_timesteps_length_mismatch(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(sigmas=[1.0, 0.5, 0.0], timesteps=[900.0, 500.0]) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 4, 8, 8) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1] + + output = scheduler.step(model_output, timestep, sample) + self.assertIsInstance(output, FlowMatchEulerDiscreteSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertEqual(output.prev_sample.dtype, model_output.dtype) + + def test_step_rejects_integer_timestep(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + sample = torch.randn(1, 4, 4, 4) + with self.assertRaises(ValueError): + scheduler.step(torch.randn_like(sample), 0, sample) + + def test_scale_noise_interpolates_sample_and_noise(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(timesteps=[500.0, 250.0]) + sample = torch.zeros(1, 4, 4, 4) + noise = torch.ones_like(sample) + mid_t = scheduler.timesteps[0:1] + mixed = scheduler.scale_noise(sample, mid_t, noise) + # sigma=0.5 at t=500 with default 1000 training steps + torch.testing.assert_close(mixed, 0.5 * noise, atol=1e-4, rtol=1e-4) + + def test_index_for_timestep_duplicate_handling(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(timesteps=[900.0, 500.0, 500.0, 100.0]) + duplicate = scheduler.timesteps[1] + self.assertEqual(scheduler.index_for_timestep(duplicate), 2) + + def test_set_begin_index_anchors_step_index(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + scheduler.set_begin_index(2) + sample = torch.randn(1, 4, 4, 4) + scheduler.step(torch.randn_like(sample), scheduler.timesteps[0], sample) + self.assertEqual(scheduler.step_index, 3) + + def test_full_denoising_loop(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + for t in scheduler.timesteps: + sample = scheduler.step(torch.randn_like(sample), t, sample).prev_sample + self.assertEqual(sample.shape, (1, 4, 8, 8)) diff --git a/tests/schedulers/test_scheduler_helios.py b/tests/schedulers/test_scheduler_helios.py new file mode 100644 index 0000000..6fac77b --- /dev/null +++ b/tests/schedulers/test_scheduler_helios.py @@ -0,0 +1,107 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch + +import torch + +from diffusers import HeliosScheduler +from diffusers.schedulers.scheduling_helios import HeliosSchedulerOutput + + +class HeliosSchedulerTest(unittest.TestCase): + """ + Unit tests for HeliosScheduler multi-stage scheduling. Pipeline tests only cover stages=1; + these tests lock in per-stage sigma ranges and step behavior used by the default 3-stage config. + """ + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "shift": 1.0, + "stages": 3, + "stage_range": [0, 1 / 3, 2 / 3, 1], + "gamma": 1 / 3, + "prediction_type": "flow_prediction", + "scheduler_type": "euler", + "use_dynamic_shifting": False, + } + config.update(**kwargs) + return config + + def test_multi_stage_init_populates_per_stage_buffers(self): + scheduler = HeliosScheduler(**self.get_default_config()) + self.assertEqual(len(scheduler.timesteps_per_stage), 3) + self.assertEqual(len(scheduler.sigmas_per_stage), 3) + self.assertEqual(len(scheduler.start_sigmas), 3) + self.assertEqual(len(scheduler.end_sigmas), 3) + for stage in range(3): + self.assertGreater(scheduler.start_sigmas[stage], scheduler.end_sigmas[stage]) + + def test_set_timesteps_multi_stage_per_stage_index(self): + scheduler = HeliosScheduler(**self.get_default_config()) + for stage_index in range(3): + scheduler.set_timesteps(num_inference_steps=8, stage_index=stage_index) + self.assertEqual(scheduler.timesteps.shape, (8,)) + self.assertEqual(scheduler.sigmas.shape, (9,)) + self.assertAlmostEqual(scheduler.sigmas[-1].item(), 0.0, places=6) + self.assertGreater(scheduler.timesteps[0].item(), scheduler.timesteps[-1].item()) + + def test_set_timesteps_single_stage(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1])) + scheduler.set_timesteps(num_inference_steps=4) + self.assertEqual(scheduler.timesteps.shape, (4,)) + self.assertEqual(scheduler.sigmas.shape, (5,)) + + def test_step_euler_updates_sample(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1])) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 4, 8, 8) + model_output = torch.randn_like(sample) + output = scheduler.step(model_output, scheduler.timesteps[0], sample) + self.assertIsInstance(output, HeliosSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertFalse(torch.allclose(output.prev_sample, sample)) + + def test_convert_model_output_flow_prediction(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1], scheduler_type="unipc")) + scheduler.set_timesteps(num_inference_steps=4) + scheduler._step_index = 0 + + sample = torch.randn(1, 4, 4, 4) + model_output = torch.randn_like(sample) + x0 = scheduler.convert_model_output(model_output, sample=sample) + expected = sample - scheduler.sigmas[0] * model_output + torch.testing.assert_close(x0, expected) + + def test_dynamic_shifting_rescales_timesteps(self): + scheduler = HeliosScheduler( + **self.get_default_config(stages=1, stage_range=[0, 1], use_dynamic_shifting=True) + ) + scheduler.set_timesteps(num_inference_steps=4, mu=0.5) + self.assertEqual(scheduler.timesteps.shape, (4,)) + + def test_step_unipc_invokes_corrector_on_second_step(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1], scheduler_type="unipc")) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 4, 4) + model_output = torch.randn_like(sample) + scheduler.step(model_output, scheduler.timesteps[0], sample) + + with patch.object(scheduler, "multistep_uni_c_bh_update", wraps=scheduler.multistep_uni_c_bh_update) as corrector: + scheduler.step(model_output, scheduler.timesteps[1], sample) + corrector.assert_called_once() diff --git a/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py b/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py new file mode 100644 index 0000000..f898d54 --- /dev/null +++ b/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py @@ -0,0 +1,124 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import LTXEulerAncestralRFScheduler +from diffusers.schedulers.scheduling_ltx_euler_ancestral_rf import LTXEulerAncestralRFSchedulerOutput + + +class LTXEulerAncestralRFSchedulerTest(unittest.TestCase): + """ + Contract tests for the LTX RF ancestral scheduler used by LTX-Video long-form pipelines. + Mirrors the style of `test_scheduler_flow_map_euler_discrete.py` because this scheduler + has a non-standard step API and cannot reuse `SchedulerCommonTest`. + """ + + scheduler_class = LTXEulerAncestralRFScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "eta": 1.0, + "s_noise": 1.0, + } + config.update(**kwargs) + return config + + def test_set_timesteps_auto_generates_schedule(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 2, 4, 8]: + scheduler.set_timesteps(num_inference_steps=nfe) + self.assertEqual(scheduler.num_inference_steps, nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe,)) + self.assertEqual(scheduler.sigmas.shape, (nfe + 1,)) + self.assertAlmostEqual(scheduler.sigmas[-1].item(), 0.0, places=6) + + def test_set_timesteps_requires_args(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps() + + def test_set_timesteps_explicit_sigmas(self): + scheduler = self.scheduler_class(**self.get_default_config()) + custom_sigmas = [1.0, 0.75, 0.5, 0.25, 0.0] + scheduler.set_timesteps(sigmas=custom_sigmas) + self.assertEqual(scheduler.num_inference_steps, 4) + for i, sigma in enumerate(custom_sigmas): + self.assertAlmostEqual(scheduler.sigmas[i].item(), sigma, places=5) + self.assertAlmostEqual(scheduler.timesteps[0].item(), 1000.0, places=4) + + def test_set_timesteps_rejects_non_1d_sigmas(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(sigmas=[[1.0, 0.5], [0.5, 0.0]]) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 8, 4, 4) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1] + + output = scheduler.step(model_output, timestep, sample) + self.assertIsInstance(output, LTXEulerAncestralRFSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertEqual(output.prev_sample.dtype, sample.dtype) + + def test_step_deterministic_when_eta_zero(self): + scheduler = self.scheduler_class(**self.get_default_config(eta=0.0)) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + generator = torch.Generator().manual_seed(0) + + first = scheduler.step(model_output, scheduler.timesteps[0:1], sample, generator=generator).prev_sample + + scheduler.set_timesteps(num_inference_steps=4) + second = scheduler.step(model_output, scheduler.timesteps[0:1], sample, generator=generator).prev_sample + torch.testing.assert_close(first, second) + + def test_step_rejects_integer_timestep(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + sample = torch.randn(1, 4, 4, 4) + with self.assertRaises(ValueError): + scheduler.step(torch.randn_like(sample), 0, sample) + + def test_index_for_timestep_duplicate_handling(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(sigmas=[1.0, 0.5, 0.5, 0.0]) + duplicate = scheduler.timesteps[1] + self.assertEqual(scheduler.index_for_timestep(duplicate), 2) + + def test_set_begin_index_anchors_step_index(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + scheduler.set_begin_index(2) + sample = torch.randn(1, 4, 4, 4) + scheduler.step(torch.randn_like(sample), scheduler.timesteps[0], sample) + self.assertEqual(scheduler.step_index, 3) + + def test_full_denoising_loop(self): + scheduler = self.scheduler_class(**self.get_default_config(eta=0.0)) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + for t in scheduler.timesteps: + sample = scheduler.step(torch.randn_like(sample), t, sample).prev_sample + self.assertEqual(sample.shape, (1, 4, 8, 8)) From 6d969814ae81a5305f93a5492263b5adbc053fd2 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Mon, 29 Jun 2026 02:08:55 +0000 Subject: [PATCH 2/4] Add regression tests for CogVideoX schedulers, Wuerstchen, and loading_utils Cover high-blast-radius code paths that previously only had indirect pipeline coverage: - CogVideoXDDIMScheduler SNR shift and set_timesteps guard - CogVideoXDPMScheduler multi-step API (old_pred_original_sample) - DDPMWuerstchenScheduler float timestep schedule and add_noise - loading_utils path/URL validation and submodule resolution Use python -m ruff in check_copies so the commit hook works when ruff is installed as a Python package but not on PATH. Co-authored-by: Simon Lynch --- tests/others/test_loading_utils.py | 119 +++++++++++ tests/schedulers/test_scheduler_cogvideox.py | 199 ++++++++++++++++++ .../test_scheduler_ddpm_wuerstchen.py | 110 ++++++++++ utils/check_copies.py | 3 +- 4 files changed, 430 insertions(+), 1 deletion(-) create mode 100644 tests/others/test_loading_utils.py create mode 100644 tests/schedulers/test_scheduler_cogvideox.py create mode 100644 tests/schedulers/test_scheduler_ddpm_wuerstchen.py diff --git a/tests/others/test_loading_utils.py b/tests/others/test_loading_utils.py new file mode 100644 index 0000000..1f8b815 --- /dev/null +++ b/tests/others/test_loading_utils.py @@ -0,0 +1,119 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os +import tempfile +import unittest +from unittest.mock import Mock, patch + +import PIL.Image +import torch +from torch import nn + +from diffusers.utils.loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video + + +class LoadingUtilsTest(unittest.TestCase): + def test_load_image_pil_passthrough_converts_rgb(self): + image = PIL.Image.new("RGBA", (4, 4), color=(255, 0, 0, 128)) + loaded = load_image(image) + self.assertEqual(loaded.mode, "RGB") + self.assertEqual(loaded.size, (4, 4)) + + def test_load_image_local_path(self): + image = PIL.Image.new("RGB", (8, 8), color="green") + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + image.save(tmp.name) + path = tmp.name + try: + loaded = load_image(path) + self.assertEqual(loaded.size, (8, 8)) + self.assertEqual(loaded.mode, "RGB") + finally: + os.remove(path) + + def test_load_image_invalid_path_raises(self): + with self.assertRaises(ValueError): + load_image("/path/that/does/not/exist.png") + + def test_load_image_invalid_scheme_raises(self): + with self.assertRaises(ValueError): + load_image("ftp://example.com/image.png") + + def test_load_image_invalid_type_raises(self): + with self.assertRaises(ValueError): + load_image(123) + + def test_load_image_custom_convert_method(self): + image = PIL.Image.new("RGB", (4, 4), color="blue") + + def to_grayscale(img): + return img.convert("L") + + loaded = load_image(image, convert_method=to_grayscale) + self.assertEqual(loaded.mode, "L") + + @patch("diffusers.utils.loading_utils.requests.get") + def test_load_image_from_url(self, mock_get): + buffer = io.BytesIO() + PIL.Image.new("RGB", (6, 6), color="red").save(buffer, format="PNG") + buffer.seek(0) + mock_response = Mock() + mock_response.raw = buffer + mock_get.return_value = mock_response + + loaded = load_image("https://example.com/image.png") + self.assertEqual(loaded.size, (6, 6)) + self.assertEqual(loaded.mode, "RGB") + + def test_load_video_invalid_path_raises(self): + with self.assertRaises(ValueError): + load_video("/path/that/does/not/exist.mp4") + + def test_load_video_gif_frames(self): + frames = [PIL.Image.new("RGB", (4, 4), color=(i * 40, 0, 0)) for i in range(3)] + with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as tmp: + path = tmp.name + try: + frames[0].save(path, save_all=True, append_images=frames[1:], duration=100, loop=0) + loaded = load_video(path) + self.assertEqual(len(loaded), 3) + self.assertEqual(loaded[0].size, (4, 4)) + finally: + os.remove(path) + + def test_get_module_from_name_nested(self): + module = nn.Sequential(nn.Linear(4, 4), nn.ReLU()) + found, name = get_module_from_name(module, "0.weight") + self.assertIsInstance(found, nn.Linear) + self.assertEqual(name, "weight") + + def test_get_module_from_name_missing_attribute_raises(self): + module = nn.Linear(4, 4) + with self.assertRaises(AttributeError): + get_module_from_name(module, "missing.weight") + + def test_get_submodule_by_name_modulelist_index(self): + module = nn.ModuleList([nn.Linear(2, 2), nn.Linear(3, 3)]) + found = get_submodule_by_name(module, "1") + self.assertIsInstance(found, nn.Linear) + self.assertEqual(found.in_features, 3) + + def test_get_submodule_by_name_dotted_path(self): + module = nn.Sequential( + nn.ModuleDict({"block": nn.Linear(4, 4)}), + ) + found = get_submodule_by_name(module, "0.block") + self.assertIsInstance(found, nn.Linear) diff --git a/tests/schedulers/test_scheduler_cogvideox.py b/tests/schedulers/test_scheduler_cogvideox.py new file mode 100644 index 0000000..e7b503f --- /dev/null +++ b/tests/schedulers/test_scheduler_cogvideox.py @@ -0,0 +1,199 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.schedulers.scheduling_ddim_cogvideox import DDIMSchedulerOutput + + +class CogVideoXDDIMSchedulerTest(unittest.TestCase): + """ + Contract tests for CogVideoXDDIMScheduler — used by CogView3+ and shares the SNR-shifted + schedule with CogVideoX pipelines. No dedicated scheduler test file existed previously. + """ + + scheduler_class = CogVideoXDDIMScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.00085, + "beta_end": 0.0120, + "beta_schedule": "scaled_linear", + "timestep_spacing": "leading", + } + config.update(**kwargs) + return config + + def test_snr_shift_modifies_alphas_cumprod(self): + shifted = self.scheduler_class(**self.get_default_config(snr_shift_scale=3.0)) + unshifted = self.scheduler_class(**self.get_default_config(snr_shift_scale=1.0)) + self.assertFalse(torch.allclose(shifted.alphas_cumprod, unshifted.alphas_cumprod)) + + def test_set_timesteps_num_inference_steps_exceeds_train_timesteps_raises(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(scheduler.config.num_train_timesteps + 1) + + def test_set_timesteps_produces_expected_count(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 4, 10, 50]: + scheduler.set_timesteps(nfe) + self.assertEqual(scheduler.num_inference_steps, nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe,)) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 4, 8, 8) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0] + + output = scheduler.step(model_output, timestep, sample, eta=0.0) + self.assertIsInstance(output, DDIMSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertEqual(output.prev_sample.dtype, sample.dtype) + + def test_step_deterministic_when_eta_zero(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + generator = torch.Generator().manual_seed(0) + timestep = scheduler.timesteps[0] + + first = scheduler.step(model_output, timestep, sample, eta=0.0, generator=generator).prev_sample + generator = torch.Generator().manual_seed(0) + second = scheduler.step(model_output, timestep, sample, eta=0.0, generator=generator).prev_sample + torch.testing.assert_close(first, second) + + def test_full_denoising_loop(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + generator = torch.Generator().manual_seed(0) + for t in scheduler.timesteps: + sample = scheduler.step( + torch.randn_like(sample), t, sample, eta=0.0, generator=generator + ).prev_sample + self.assertEqual(sample.shape, (1, 4, 8, 8)) + + +class CogVideoXDPMSchedulerTest(unittest.TestCase): + """ + Contract tests for CogVideoXDPMScheduler — CogVideoX pipelines branch on its multi-step API. + """ + + scheduler_class = CogVideoXDPMScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "beta_start": 0.00085, + "beta_end": 0.0120, + "beta_schedule": "scaled_linear", + "timestep_spacing": "trailing", + } + config.update(**kwargs) + return config + + def test_step_requires_set_timesteps(self): + scheduler = self.scheduler_class(**self.get_default_config()) + sample = torch.randn(1, 4, 8, 8) + with self.assertRaises(ValueError): + scheduler.step( + torch.randn_like(sample), + None, + scheduler.timesteps[0], + None, + sample, + return_dict=False, + ) + + def test_set_timesteps_num_inference_steps_exceeds_train_timesteps_raises(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(scheduler.config.num_train_timesteps + 1) + + def test_first_step_returns_pred_original_sample(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0] + + prev_sample, pred_original_sample = scheduler.step( + model_output, + None, + timestep, + None, + sample, + return_dict=False, + ) + self.assertEqual(prev_sample.shape, sample.shape) + self.assertEqual(pred_original_sample.shape, sample.shape) + + def test_second_step_uses_old_pred_original_sample(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + generator = torch.Generator().manual_seed(0) + + _, old_pred = scheduler.step( + model_output, + None, + scheduler.timesteps[0], + None, + sample, + return_dict=False, + generator=generator, + ) + + prev_sample, _ = scheduler.step( + model_output, + old_pred, + scheduler.timesteps[1], + scheduler.timesteps[0], + sample, + return_dict=False, + generator=generator, + ) + self.assertEqual(prev_sample.shape, sample.shape) + + def test_get_variables_and_get_mult(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + t = scheduler.timesteps[1] + t_prev = t - scheduler.config.num_train_timesteps // scheduler.num_inference_steps + t_back = scheduler.timesteps[0] + + alpha_prod_t = scheduler.alphas_cumprod[t] + alpha_prod_t_prev = scheduler.alphas_cumprod[t_prev] + alpha_prod_t_back = scheduler.alphas_cumprod[t_back] + + h, r, lamb, lamb_next = scheduler.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back) + mult = scheduler.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back) + self.assertEqual(len(mult), 4) + self.assertTrue(torch.isfinite(h)) + self.assertTrue(torch.isfinite(lamb)) diff --git a/tests/schedulers/test_scheduler_ddpm_wuerstchen.py b/tests/schedulers/test_scheduler_ddpm_wuerstchen.py new file mode 100644 index 0000000..36ea793 --- /dev/null +++ b/tests/schedulers/test_scheduler_ddpm_wuerstchen.py @@ -0,0 +1,110 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import DDPMWuerstchenScheduler +from diffusers.schedulers.scheduling_ddpm_wuerstchen import DDPMWuerstchenSchedulerOutput + + +class DDPMWuerstchenSchedulerTest(unittest.TestCase): + """ + Contract tests for DDPMWuerstchenScheduler — Stable Cascade prior/decoder use float timesteps + in [1, 0] rather than integer indices. Pipeline tests only exercised it indirectly. + """ + + scheduler_class = DDPMWuerstchenScheduler + + def get_default_config(self, **kwargs): + config = {"scaler": 1.0, "s": 0.008} + config.update(**kwargs) + return config + + def test_set_timesteps_float_schedule(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 4, 10]: + scheduler.set_timesteps(num_inference_steps=nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe + 1,)) + self.assertAlmostEqual(scheduler.timesteps[0].item(), 1.0, places=5) + self.assertAlmostEqual(scheduler.timesteps[-1].item(), 0.0, places=5) + + def test_scaler_modifies_alpha_cumprod(self): + default = self.scheduler_class(**self.get_default_config()) + scaled = self.scheduler_class(**self.get_default_config(scaler=2.0)) + t = torch.tensor([0.5]) + default_alpha = default._alpha_cumprod(t, device="cpu") + scaled_alpha = scaled._alpha_cumprod(t, device="cpu") + self.assertFalse(torch.allclose(default_alpha, scaled_alpha)) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 8, 4, 4) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1].expand(sample.shape[0]) + + output = scheduler.step(model_output, timestep, sample) + self.assertIsInstance(output, DDPMWuerstchenSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertEqual(output.prev_sample.dtype, sample.dtype) + + def test_step_deterministic_with_generator(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1] + generator = torch.Generator().manual_seed(0) + + first = scheduler.step(model_output, timestep, sample, generator=generator).prev_sample + generator = torch.Generator().manual_seed(0) + second = scheduler.step(model_output, timestep, sample, generator=generator).prev_sample + torch.testing.assert_close(first, second) + + def test_previous_timestep_advances_schedule(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + current = scheduler.timesteps[0:1] + prev = scheduler.previous_timestep(current) + self.assertAlmostEqual(prev[0].item(), scheduler.timesteps[1].item(), places=5) + + def test_add_noise_interpolates_sample_and_noise(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.zeros(1, 4, 4, 4) + noise = torch.ones_like(sample) + timesteps = scheduler.timesteps[0:1] + mixed = scheduler.add_noise(sample, noise, timesteps) + self.assertTrue(mixed.min() >= 0.0) + self.assertTrue(mixed.max() <= 1.0) + self.assertFalse(torch.allclose(mixed, sample)) + self.assertFalse(torch.allclose(mixed, noise)) + + def test_full_denoising_loop(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + generator = torch.Generator().manual_seed(0) + + sample = torch.randn(1, 4, 8, 8) + for t in scheduler.timesteps[:-1]: + batch_t = t.expand(sample.shape[0]) + sample = scheduler.step( + torch.randn_like(sample), batch_t, sample, generator=generator + ).prev_sample + self.assertEqual(sample.shape, (1, 4, 8, 8)) diff --git a/utils/check_copies.py b/utils/check_copies.py index 001366c..338a25d 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -18,6 +18,7 @@ import os import re import subprocess +import sys # All paths are set with the intent you should run this script from the root of the repo with the command @@ -94,7 +95,7 @@ def get_indent(code): def run_ruff(code): - command = ["ruff", "format", "-", "--config", "pyproject.toml", "--silent"] + command = [sys.executable, "-m", "ruff", "format", "-", "--config", "pyproject.toml", "--silent"] process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE) stdout, _ = process.communicate(input=code.encode()) return stdout.decode() From 4947ac878feca1b6df5678754b41ee7cd9a13e91 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 1 Jul 2026 02:05:58 +0000 Subject: [PATCH 3/4] Add regression tests for FirstBlockCache hooks and AmusedScheduler - Unit-test FBC skip/recompute logic, tuple outputs, and threshold behavior - Cover AmusedScheduler Gumbel masking, 2D reshaping, and schedule validation Co-authored-by: Simon Lynch --- tests/hooks/test_first_block_cache.py | 125 ++++++++++++++++++++ tests/schedulers/test_scheduler_amused.py | 136 ++++++++++++++++++++++ 2 files changed, 261 insertions(+) create mode 100644 tests/hooks/test_first_block_cache.py create mode 100644 tests/schedulers/test_scheduler_amused.py diff --git a/tests/hooks/test_first_block_cache.py b/tests/hooks/test_first_block_cache.py new file mode 100644 index 0000000..aa21afe --- /dev/null +++ b/tests/hooks/test_first_block_cache.py @@ -0,0 +1,125 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry +from diffusers.hooks.first_block_cache import FirstBlockCacheConfig, apply_first_block_cache +from diffusers.models import ModelMixin + + +class DummyBlock(torch.nn.Module): + def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): + return hidden_states * 2.0 + + +class DummyTransformer(ModelMixin): + def __init__(self): + super().__init__() + self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()]) + + def forward(self, hidden_states, encoder_hidden_states=None): + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states) + return hidden_states + + +class TupleOutputBlock(torch.nn.Module): + def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): + return hidden_states * 2.0, encoder_hidden_states + + +class TupleTransformer(ModelMixin): + def __init__(self): + super().__init__() + self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock(), TupleOutputBlock()]) + + def forward(self, hidden_states, encoder_hidden_states=None): + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states) + return hidden_states, encoder_hidden_states + + +def _set_context(model, context_name): + for module in model.modules(): + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook._set_context(context_name) + + +@pytest.fixture(autouse=True) +def register_dummy_blocks(): + TransformerBlockRegistry.register( + DummyBlock, + TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None), + ) + TransformerBlockRegistry.register( + TupleOutputBlock, + TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1), + ) + + +def test_first_block_cache_skips_when_residual_is_stable(): + """When head-block residuals are similar, tail blocks should be skipped.""" + model = DummyTransformer() + apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.05)) + _set_context(model, "test_context") + + input_t0 = torch.tensor([[[10.0]]]) + output_t0 = model(input_t0) + assert torch.allclose(output_t0, torch.tensor([[[40.0]]])) + + # Identical input -> residual diff is 0 -> skip tail block (40.0, not 44.0). + output_t1 = model(input_t0) + assert torch.allclose(output_t1, torch.tensor([[[40.0]]])) + + +def test_first_block_cache_recomputes_when_residual_changes(): + """When residuals exceed the threshold, the full block stack must run.""" + model = DummyTransformer() + apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.05)) + _set_context(model, "test_context") + + model(torch.tensor([[[10.0]]])) + + output_t1 = model(torch.tensor([[[11.0]]])) + assert torch.allclose(output_t1, torch.tensor([[[44.0]]])) + + +def test_first_block_cache_tuple_outputs(): + """First Block Cache must support tuple block outputs (Flux-style).""" + model = TupleTransformer() + apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.05)) + _set_context(model, "test_context") + + input_t0 = torch.tensor([[[10.0]]]) + enc_t0 = torch.tensor([[[1.0]]]) + out_0, _ = model(input_t0, encoder_hidden_states=enc_t0) + assert torch.allclose(out_0, torch.tensor([[[40.0]]])) + + out_1, _ = model(input_t0, encoder_hidden_states=enc_t0) + assert torch.allclose(out_1, torch.tensor([[[40.0]]])) + + +def test_first_block_cache_recomputes_after_skip_when_input_changes(): + """A large input change after a cached step must trigger full recomputation.""" + model = DummyTransformer() + apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.05)) + _set_context(model, "test_context") + + model(torch.tensor([[[10.0]]])) + model(torch.tensor([[[10.0]]])) + + output = model(torch.tensor([[[12.0]]])) + assert torch.allclose(output, torch.tensor([[[48.0]]])) diff --git a/tests/schedulers/test_scheduler_amused.py b/tests/schedulers/test_scheduler_amused.py new file mode 100644 index 0000000..1e402a3 --- /dev/null +++ b/tests/schedulers/test_scheduler_amused.py @@ -0,0 +1,136 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AmusedScheduler +from diffusers.schedulers.scheduling_amused import ( + AmusedSchedulerOutput, + gumbel_noise, + mask_by_random_topk, +) + + +class AmusedSchedulerTest(unittest.TestCase): + """ + Unit tests for AmusedScheduler masked-token denoising. This scheduler has a bespoke + step API (discrete token IDs, Gumbel masking) and is not covered by SchedulerCommonTest. + """ + + mask_token_id = 0 + + def get_default_config(self, **kwargs): + config = { + "mask_token_id": self.mask_token_id, + "masking_schedule": "cosine", + } + config.update(**kwargs) + return config + + def test_gumbel_noise_matches_input_shape(self): + probs = torch.ones(2, 4) + noise = gumbel_noise(probs) + self.assertEqual(noise.shape, probs.shape) + + def test_gumbel_noise_is_reproducible_with_generator(self): + probs = torch.ones(2, 4) + generator = torch.Generator().manual_seed(0) + noise_a = gumbel_noise(probs, generator=generator) + generator = torch.Generator().manual_seed(0) + noise_b = gumbel_noise(probs, generator=generator) + self.assertTrue(torch.allclose(noise_a, noise_b)) + + def test_mask_by_random_topk_masks_requested_count(self): + probs = torch.tensor([[0.9, 0.8, 0.1, 0.05]]) + mask_len = torch.tensor([[2]]) + masking = mask_by_random_topk(mask_len, probs, temperature=0.0) + self.assertEqual(masking.sum().item(), 2) + + def test_set_timesteps_builds_temperature_schedule(self): + scheduler = AmusedScheduler(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4, temperature=(2, 0)) + self.assertEqual(scheduler.timesteps.shape, (4,)) + self.assertEqual(scheduler.temperatures.shape, (4,)) + self.assertAlmostEqual(scheduler.temperatures[0].item(), 2.0, places=5) + self.assertAlmostEqual(scheduler.temperatures[-1].item(), 0.0, places=5) + + def test_step_final_timestep_unmasks_all_tokens(self): + scheduler = AmusedScheduler(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=2) + + sample = torch.tensor([[self.mask_token_id, self.mask_token_id, 5, 6]]) + model_output = torch.zeros(1, 4, 8) + model_output[0, :, 3] = 10.0 # token id 3 wins for every position + + output = scheduler.step(model_output, timestep=0, sample=sample, generator=torch.Generator().manual_seed(0)) + self.assertIsInstance(output, AmusedSchedulerOutput) + self.assertFalse((output.prev_sample == self.mask_token_id).any()) + + def test_step_intermediate_timestep_keeps_some_masks(self): + scheduler = AmusedScheduler(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4, temperature=0.0) + + sample = torch.full((1, 6), self.mask_token_id) + model_output = torch.randn(1, 6, 4) + + timestep = scheduler.timesteps[1].item() + output = scheduler.step( + model_output, + timestep=timestep, + sample=sample, + generator=torch.Generator().manual_seed(0), + ) + self.assertTrue((output.prev_sample == self.mask_token_id).any()) + self.assertFalse((output.prev_sample == self.mask_token_id).all()) + + def test_step_2d_input_is_reshaped_and_restored(self): + scheduler = AmusedScheduler(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=2) + + sample = torch.full((1, 2, 2), self.mask_token_id) + model_output = torch.zeros(1, 4, 2, 2) + model_output[:, :, 0, 0] = 10.0 + + output = scheduler.step(model_output, timestep=0, sample=sample, generator=torch.Generator().manual_seed(0)) + self.assertEqual(output.prev_sample.shape, (1, 2, 2)) + + def test_add_noise_respects_mask_ratio(self): + scheduler = AmusedScheduler(**self.get_default_config(masking_schedule="linear")) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.arange(1, 17).reshape(1, 4, 4) + # Early timesteps mask most tokens; use the second step for a partial mask. + timestep = scheduler.timesteps[1].item() + masked = scheduler.add_noise(sample, timesteps=timestep, generator=torch.Generator().manual_seed(0)) + self.assertTrue((masked == self.mask_token_id).any()) + self.assertFalse((masked == self.mask_token_id).all()) + + def test_unknown_masking_schedule_raises_in_step(self): + scheduler = AmusedScheduler(**self.get_default_config(masking_schedule="invalid")) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.tensor([[self.mask_token_id, 1, 2, 3]]) + model_output = torch.randn(1, 4, 4) + with self.assertRaises(ValueError): + scheduler.step(model_output, timestep=scheduler.timesteps[1].item(), sample=sample) + + def test_unknown_masking_schedule_raises_in_add_noise(self): + scheduler = AmusedScheduler(**self.get_default_config(masking_schedule="invalid")) + scheduler.set_timesteps(num_inference_steps=2) + + sample = torch.tensor([[1, 2, 3, 4]]) + with self.assertRaises(ValueError): + scheduler.add_noise(sample, timesteps=scheduler.timesteps[0].item()) From 97cdb2f0b760bbaeae9a304d5d1f608954f733f9 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 1 Jul 2026 02:10:36 +0000 Subject: [PATCH 4/4] Fix check_copies comparison to stylify both sides consistently Compare observed and theoretical copied blocks after applying the same ruff formatting to each, avoiding false drift when only one side was previously stylified. Also fix multi-line overwrite and use python -m ruff. Co-authored-by: Simon Lynch --- utils/check_copies.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/utils/check_copies.py b/utils/check_copies.py index 338a25d..711cd8d 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -101,6 +101,18 @@ def run_ruff(code): return stdout.decode() +def _get_definition_header(lines, start_index): + idx = start_index + while idx < len(lines): + stripped = lines[idx].strip() + if stripped.startswith(("def ", "class ", "@dataclass")): + return lines[idx] + if stripped and not stripped.startswith("#") and not stripped.startswith("@"): + break + idx += 1 + return lines[start_index - 1] if start_index > 0 else "" + + def stylify(code: str) -> str: """ Applies the ruff part of our `make style` command to some code. This formats the code using `ruff format`. @@ -177,16 +189,24 @@ def is_copy_consistent(filename, overwrite=False): theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code) # stylify after replacement. To be able to do that, we need the header (class or function definition) - # from the previous line - theoretical_code = stylify(lines[start_index - 1] + theoretical_code) - theoretical_code = theoretical_code[len(lines[start_index - 1]) :] + # from the copied block, not the `# Copied from` comment line. + header_line = _get_definition_header(lines, start_index) + theoretical_code = stylify(header_line + theoretical_code) + theoretical_code = theoretical_code[len(header_line) :] + observed_code = stylify(header_line + observed_code) + observed_code = observed_code[len(header_line) :] # Test for a diff and act accordingly. if observed_code != theoretical_code: diffs.append([object_name, start_index]) if overwrite: - lines = lines[:start_index] + [theoretical_code] + lines[line_index:] - line_index = start_index + 1 + replacement_lines = theoretical_code.splitlines(keepends=True) + if replacement_lines and not replacement_lines[-1].endswith("\n"): + replacement_lines[-1] += "\n" + if lines[start_index].strip().startswith(("def ", "class ")): + replacement_lines = [lines[start_index]] + replacement_lines + lines = lines[:start_index] + replacement_lines + lines[line_index:] + line_index = start_index + len(replacement_lines) if overwrite and len(diffs) > 0: # Warn the user a file has been modified.