From 266f1ef777075e0b77d023d81451b04e56c536ce Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Sat, 27 Jun 2026 02:02:56 +0000 Subject: [PATCH 1/2] Add regression tests for LoRA utils, remote VAE, schedulers, and FirstBlockCache Cover high-blast-radius paths that lacked dedicated unit tests: - LoRA state dict key conversion (every adapter load) - Remote VAE encode/decode validation and postprocessing - FlowMatch, Helios, and LTX scheduler contracts - FirstBlockCache skip/compute threshold logic Co-authored-by: Simon Lynch --- tests/hooks/test_first_block_cache.py | 138 ++++++++++++++++++ tests/others/test_remote_utils.py | 103 +++++++++++++ tests/others/test_state_dict_utils.py | 110 ++++++++++++++ ...est_scheduler_flow_match_euler_discrete.py | 108 ++++++++++++++ tests/schedulers/test_scheduler_helios.py | 94 ++++++++++++ .../test_scheduler_ltx_euler_ancestral_rf.py | 124 ++++++++++++++++ 6 files changed, 677 insertions(+) create mode 100644 tests/hooks/test_first_block_cache.py create mode 100644 tests/others/test_remote_utils.py create mode 100644 tests/others/test_state_dict_utils.py create mode 100644 tests/schedulers/test_scheduler_flow_match_euler_discrete.py create mode 100644 tests/schedulers/test_scheduler_helios.py create mode 100644 tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py diff --git a/tests/hooks/test_first_block_cache.py b/tests/hooks/test_first_block_cache.py new file mode 100644 index 0000000..0a2eabc --- /dev/null +++ b/tests/hooks/test_first_block_cache.py @@ -0,0 +1,138 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry +from diffusers.hooks.first_block_cache import FirstBlockCacheConfig, apply_first_block_cache +from diffusers.models import ModelMixin + + +class DummyBlock(torch.nn.Module): + def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): + return hidden_states * 2.0 + + +class DummyTransformer(ModelMixin): + def __init__(self): + super().__init__() + self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()]) + + def forward(self, hidden_states, encoder_hidden_states=None): + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states) + return hidden_states + + +class TupleOutputBlock(torch.nn.Module): + def forward(self, hidden_states, encoder_hidden_states=None, **kwargs): + return hidden_states * 2.0, encoder_hidden_states + + +class TupleTransformer(ModelMixin): + def __init__(self): + super().__init__() + self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock(), TupleOutputBlock()]) + + def forward(self, hidden_states, encoder_hidden_states=None): + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states) + return hidden_states, encoder_hidden_states + + +@pytest.fixture(autouse=True) +def register_dummy_blocks(): + TransformerBlockRegistry.register( + DummyBlock, + TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None), + ) + TransformerBlockRegistry.register( + TupleOutputBlock, + TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1), + ) + + +def test_first_block_cache_skipping_logic(): + """ + FirstBlockCache skips tail blocks when the first-block residual change is below threshold. + """ + model = DummyTransformer() + apply_first_block_cache(model, FirstBlockCacheConfig(threshold=1.0)) + + # Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each). Head residual = 10, tail residual = 20. + input_t0 = torch.tensor([[[10.0]]]) + output_t0 = model(input_t0) + assert torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed" + + # Step 1: Input 11.0. + # If skipped: head output 22 + tail residual 20 = 42.0 + # If computed: 11 * 4 = 44.0 + input_t1 = torch.tensor([[[11.0]]]) + output_t1 = model(input_t1) + assert torch.allclose(output_t1, torch.tensor([[[42.0]]])), f"Expected skip (42.0), got {output_t1.item()}" + + +def test_first_block_cache_compute_when_residual_changes(): + """A low threshold forces full recomputation when the first-block residual shifts.""" + model = DummyTransformer() + apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.01)) + + model(torch.tensor([[[10.0]]])) + + input_t1 = torch.tensor([[[11.0]]]) + output_t1 = model(input_t1) + assert torch.allclose(output_t1, torch.tensor([[[44.0]]])), ( + f"Expected compute (44.0) due to low threshold, got {output_t1.item()}" + ) + + +def test_first_block_cache_tuple_outputs(): + """Test compatibility with models returning (hidden, encoder_hidden) like Flux.""" + model = TupleTransformer() + apply_first_block_cache(model, FirstBlockCacheConfig(threshold=1.0)) + + input_t0 = torch.tensor([[[10.0]]]) + enc_t0 = torch.tensor([[[1.0]]]) + out_0, _ = model(input_t0, encoder_hidden_states=enc_t0) + assert torch.allclose(out_0, torch.tensor([[[40.0]]])) + + # Step 1: Skip. Input 11.0 -> head output 22 + tail residual 20 = 42.0 + input_t1 = torch.tensor([[[11.0]]]) + out_1, _ = model(input_t1, encoder_hidden_states=enc_t0) + assert torch.allclose(out_1, torch.tensor([[[42.0]]])), f"Tuple skip failed. Expected 42.0, got {out_1.item()}" + + +def test_first_block_cache_first_pass_always_computes(): + """The first forward pass must always run all blocks to populate tail residuals.""" + model = DummyTransformer() + apply_first_block_cache(model, FirstBlockCacheConfig(threshold=1.0)) + + input_t0 = torch.tensor([[[5.0]]]) + output_t0 = model(input_t0) + assert torch.allclose(output_t0, torch.tensor([[[20.0]]])) + + +def test_first_block_cache_large_input_change_recomputes(): + """Large input changes exceed the threshold and trigger full recomputation.""" + model = DummyTransformer() + apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.05)) + + model(torch.tensor([[[10.0]]])) + + input_t1 = torch.tensor([[[20.0]]]) + output_t1 = model(input_t1) + assert torch.allclose(output_t1, torch.tensor([[[80.0]]])), ( + f"Expected full compute (80.0) for large input change, got {output_t1.item()}" + ) diff --git a/tests/others/test_remote_utils.py b/tests/others/test_remote_utils.py new file mode 100644 index 0000000..b79ce78 --- /dev/null +++ b/tests/others/test_remote_utils.py @@ -0,0 +1,103 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import json +import unittest +from unittest.mock import Mock + +import torch +from PIL import Image + +from diffusers.utils.remote_utils import ( + check_inputs_decode, + detect_image_type, + postprocess_decode, + prepare_decode, + prepare_encode, +) + + +class RemoteUtilsTest(unittest.TestCase): + def test_detect_image_type(self): + self.assertEqual(detect_image_type(b"\xff\xd8\xff"), "jpeg") + self.assertEqual(detect_image_type(b"\x89PNG\r\n\x1a\n"), "png") + self.assertEqual(detect_image_type(b"GIF89a"), "gif") + self.assertEqual(detect_image_type(b"BM"), "bmp") + self.assertEqual(detect_image_type(b"unknown"), "unknown") + + def test_check_inputs_decode_packed_latents_requires_hw(self): + tensor = torch.randn(4, 8, 8) + with self.assertRaises(ValueError): + check_inputs_decode("http://example.com", tensor) + + def test_check_inputs_decode_processor_required(self): + tensor = torch.randn(1, 4, 8, 8) + with self.assertRaises(ValueError): + check_inputs_decode( + "http://example.com", + tensor, + processor=None, + output_type="pt", + return_type="pil", + partial_postprocess=False, + ) + + def test_prepare_decode_sets_accept_header_for_jpeg(self): + tensor = torch.randn(1, 4, 8, 8, dtype=torch.float16) + payload = prepare_decode(tensor, output_type="pil", image_format="jpg") + self.assertEqual(payload["headers"]["Accept"], "image/jpeg") + self.assertEqual(payload["params"]["output_type"], "pil") + self.assertEqual(payload["params"]["shape"], list(tensor.shape)) + + def test_prepare_encode_tensor_includes_shape_and_dtype(self): + tensor = torch.randn(1, 3, 8, 8, dtype=torch.float16) + payload = prepare_encode(tensor, scaling_factor=0.18215) + self.assertEqual(payload["params"]["shape"], list(tensor.shape)) + self.assertEqual(payload["params"]["dtype"], "float16") + self.assertEqual(payload["params"]["scaling_factor"], 0.18215) + + def test_prepare_encode_pil_image(self): + image = Image.new("RGB", (8, 8), color="red") + payload = prepare_encode(image) + self.assertIn(b"PNG", payload["data"][:8]) + + def test_postprocess_decode_pil_without_processor(self): + buffer = io.BytesIO() + Image.new("RGB", (4, 4), color="blue").save(buffer, format="PNG") + response = Mock() + response.content = buffer.getvalue() + + output = postprocess_decode(response, processor=None, output_type="pil", return_type="pil") + self.assertIsInstance(output, Image.Image) + self.assertEqual(output.size, (4, 4)) + self.assertEqual(output.format, "png") + + def test_postprocess_decode_pt_tensor(self): + tensor = torch.arange(16, dtype=torch.float32).reshape(1, 4, 2, 2) + response = Mock() + response.content = tensor.numpy().tobytes() + response.headers = { + "shape": json.dumps(list(tensor.shape)), + "dtype": "float32", + } + + output = postprocess_decode( + response, + processor=None, + output_type="pt", + return_type="pt", + partial_postprocess=False, + ) + torch.testing.assert_close(output, tensor) diff --git a/tests/others/test_state_dict_utils.py b/tests/others/test_state_dict_utils.py new file mode 100644 index 0000000..0389546 --- /dev/null +++ b/tests/others/test_state_dict_utils.py @@ -0,0 +1,110 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers.utils.state_dict_utils import ( + StateDictType, + convert_all_state_dict_to_peft, + convert_state_dict, + convert_state_dict_to_diffusers, + convert_state_dict_to_peft, + convert_unet_state_dict_to_peft, + state_dict_all_zero, +) + + +class StateDictUtilsTest(unittest.TestCase): + def test_convert_state_dict_applies_first_matching_pattern(self): + state_dict = {"layer.processor.weight": torch.ones(1)} + converted = convert_state_dict(state_dict, {".processor.": "."}) + self.assertIn("layer.weight", converted) + self.assertNotIn("layer.processor.weight", converted) + + def test_convert_state_dict_to_peft_auto_infers_diffusers_old(self): + state_dict = { + "unet.down_blocks.0.attentions.0.to_out_lora.down.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_out_lora.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_peft(state_dict) + self.assertIn("unet.down_blocks.0.attentions.0.out_proj.lora_A.weight", converted) + + state_dict = { + "unet.down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_peft(state_dict, original_type=StateDictType.DIFFUSERS_OLD) + self.assertIn("unet.down_blocks.0.attentions.0.q_proj.lora_A.weight", converted) + self.assertIn("unet.down_blocks.0.attentions.0.q_proj.lora_B.weight", converted) + + def test_convert_state_dict_to_peft_diffusers(self): + state_dict = { + "text_encoder.encoder.layers.0.self_attn.q_proj.lora_linear_layer.down.weight": torch.ones(2, 2), + "text_encoder.encoder.layers.0.self_attn.q_proj.lora_linear_layer.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_peft(state_dict, original_type=StateDictType.DIFFUSERS) + self.assertIn("text_encoder.encoder.layers.0.self_attn.q_proj.lora_A.weight", converted) + self.assertIn("text_encoder.encoder.layers.0.self_attn.q_proj.lora_B.weight", converted) + + def test_convert_state_dict_to_diffusers_from_peft(self): + state_dict = { + "unet.down_blocks.0.attentions.0.to_q.lora_A.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_q.lora_B.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_diffusers(state_dict, original_type=StateDictType.PEFT) + self.assertIn("unet.down_blocks.0.attentions.0.to_q.lora.down.weight", converted) + self.assertIn("unet.down_blocks.0.attentions.0.to_q.lora.up.weight", converted) + + def test_convert_state_dict_to_diffusers_already_diffusers(self): + state_dict = { + "layer.lora_linear_layer.down.weight": torch.ones(2, 2), + "layer.lora_linear_layer.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_diffusers(state_dict) + self.assertIs(converted, state_dict) + + def test_convert_unet_state_dict_to_peft(self): + state_dict = { + "down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2), + "down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2), + } + converted = convert_unet_state_dict_to_peft(state_dict) + self.assertIn("down_blocks.0.attentions.0.to_q.lora_A.weight", converted) + self.assertIn("down_blocks.0.attentions.0.to_q.lora_B.weight", converted) + + def test_convert_all_state_dict_to_peft_falls_back_to_unet(self): + state_dict = { + "down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2), + "down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2), + } + converted = convert_all_state_dict_to_peft(state_dict) + self.assertTrue(any("lora_A" in key or "lora_B" in key for key in converted)) + + def test_state_dict_all_zero(self): + state_dict = { + "a": torch.zeros(2, 2), + "b": torch.zeros(3), + } + self.assertTrue(state_dict_all_zero(state_dict)) + state_dict["b"] = torch.ones(3) + self.assertFalse(state_dict_all_zero(state_dict)) + + def test_state_dict_all_zero_with_filter(self): + state_dict = { + "lora.down": torch.zeros(2, 2), + "bias": torch.ones(3), + } + self.assertTrue(state_dict_all_zero(state_dict, filter_str="lora")) diff --git a/tests/schedulers/test_scheduler_flow_match_euler_discrete.py b/tests/schedulers/test_scheduler_flow_match_euler_discrete.py new file mode 100644 index 0000000..56f7410 --- /dev/null +++ b/tests/schedulers/test_scheduler_flow_match_euler_discrete.py @@ -0,0 +1,108 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteSchedulerOutput + + +class FlowMatchEulerDiscreteSchedulerTest(unittest.TestCase): + """ + Contract tests for FlowMatchEulerDiscreteScheduler — shared by SD3, Flux, Wan, and many + flow-matching pipelines. No dedicated scheduler test file existed previously. + """ + + scheduler_class = FlowMatchEulerDiscreteScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "shift": 1.0, + } + config.update(**kwargs) + return config + + def test_set_timesteps_endpoints(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 2, 4, 8, 16]: + scheduler.set_timesteps(num_inference_steps=nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe,)) + self.assertEqual(scheduler.sigmas.shape, (nfe + 1,)) + self.assertAlmostEqual(scheduler.sigmas[-1].item(), 0.0, places=6) + + def test_set_timesteps_dynamic_shifting_requires_mu(self): + scheduler = self.scheduler_class(**self.get_default_config(use_dynamic_shifting=True)) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=4) + + def test_set_timesteps_custom_sigmas_and_timesteps_length_mismatch(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(sigmas=[1.0, 0.5, 0.0], timesteps=[900.0, 500.0]) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 4, 8, 8) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1] + + output = scheduler.step(model_output, timestep, sample) + self.assertIsInstance(output, FlowMatchEulerDiscreteSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertEqual(output.prev_sample.dtype, model_output.dtype) + + def test_step_rejects_integer_timestep(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + sample = torch.randn(1, 4, 4, 4) + with self.assertRaises(ValueError): + scheduler.step(torch.randn_like(sample), 0, sample) + + def test_scale_noise_interpolates_sample_and_noise(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(timesteps=[500.0, 250.0]) + sample = torch.zeros(1, 4, 4, 4) + noise = torch.ones_like(sample) + mid_t = scheduler.timesteps[0:1] + mixed = scheduler.scale_noise(sample, mid_t, noise) + # sigma=0.5 at t=500 with default 1000 training steps + torch.testing.assert_close(mixed, 0.5 * noise, atol=1e-4, rtol=1e-4) + + def test_index_for_timestep_duplicate_handling(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(timesteps=[900.0, 500.0, 500.0, 100.0]) + duplicate = scheduler.timesteps[1] + self.assertEqual(scheduler.index_for_timestep(duplicate), 2) + + def test_set_begin_index_anchors_step_index(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + scheduler.set_begin_index(2) + sample = torch.randn(1, 4, 4, 4) + scheduler.step(torch.randn_like(sample), scheduler.timesteps[0], sample) + self.assertEqual(scheduler.step_index, 3) + + def test_full_denoising_loop(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + for t in scheduler.timesteps: + sample = scheduler.step(torch.randn_like(sample), t, sample).prev_sample + self.assertEqual(sample.shape, (1, 4, 8, 8)) diff --git a/tests/schedulers/test_scheduler_helios.py b/tests/schedulers/test_scheduler_helios.py new file mode 100644 index 0000000..6e20cef --- /dev/null +++ b/tests/schedulers/test_scheduler_helios.py @@ -0,0 +1,94 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import HeliosScheduler +from diffusers.schedulers.scheduling_helios import HeliosSchedulerOutput + + +class HeliosSchedulerTest(unittest.TestCase): + """ + Unit tests for HeliosScheduler multi-stage scheduling. Pipeline tests only cover stages=1; + these tests lock in per-stage sigma ranges and step behavior used by the default 3-stage config. + """ + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "shift": 1.0, + "stages": 3, + "stage_range": [0, 1 / 3, 2 / 3, 1], + "gamma": 1 / 3, + "prediction_type": "flow_prediction", + "scheduler_type": "euler", + "use_dynamic_shifting": False, + } + config.update(**kwargs) + return config + + def test_multi_stage_init_populates_per_stage_buffers(self): + scheduler = HeliosScheduler(**self.get_default_config()) + self.assertEqual(len(scheduler.timesteps_per_stage), 3) + self.assertEqual(len(scheduler.sigmas_per_stage), 3) + self.assertEqual(len(scheduler.start_sigmas), 3) + self.assertEqual(len(scheduler.end_sigmas), 3) + for stage in range(3): + self.assertGreater(scheduler.start_sigmas[stage], scheduler.end_sigmas[stage]) + + def test_set_timesteps_multi_stage_per_stage_index(self): + scheduler = HeliosScheduler(**self.get_default_config()) + for stage_index in range(3): + scheduler.set_timesteps(num_inference_steps=8, stage_index=stage_index) + self.assertEqual(scheduler.timesteps.shape, (8,)) + self.assertEqual(scheduler.sigmas.shape, (9,)) + self.assertAlmostEqual(scheduler.sigmas[-1].item(), 0.0, places=6) + self.assertGreater(scheduler.timesteps[0].item(), scheduler.timesteps[-1].item()) + + def test_set_timesteps_single_stage(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1])) + scheduler.set_timesteps(num_inference_steps=4) + self.assertEqual(scheduler.timesteps.shape, (4,)) + self.assertEqual(scheduler.sigmas.shape, (5,)) + + def test_step_euler_updates_sample(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1])) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 4, 8, 8) + model_output = torch.randn_like(sample) + output = scheduler.step(model_output, scheduler.timesteps[0], sample) + self.assertIsInstance(output, HeliosSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertFalse(torch.allclose(output.prev_sample, sample)) + + def test_convert_model_output_flow_prediction(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1], scheduler_type="unipc")) + scheduler.set_timesteps(num_inference_steps=4) + scheduler._step_index = 0 + + sample = torch.randn(1, 4, 4, 4) + model_output = torch.randn_like(sample) + x0 = scheduler.convert_model_output(model_output, sample=sample) + expected = sample - scheduler.sigmas[0] * model_output + torch.testing.assert_close(x0, expected) + + def test_dynamic_shifting_rescales_timesteps(self): + scheduler = HeliosScheduler( + **self.get_default_config(stages=1, stage_range=[0, 1], use_dynamic_shifting=True) + ) + scheduler.set_timesteps(num_inference_steps=4, mu=0.5) + self.assertEqual(scheduler.timesteps.shape, (4,)) diff --git a/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py b/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py new file mode 100644 index 0000000..f898d54 --- /dev/null +++ b/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py @@ -0,0 +1,124 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import LTXEulerAncestralRFScheduler +from diffusers.schedulers.scheduling_ltx_euler_ancestral_rf import LTXEulerAncestralRFSchedulerOutput + + +class LTXEulerAncestralRFSchedulerTest(unittest.TestCase): + """ + Contract tests for the LTX RF ancestral scheduler used by LTX-Video long-form pipelines. + Mirrors the style of `test_scheduler_flow_map_euler_discrete.py` because this scheduler + has a non-standard step API and cannot reuse `SchedulerCommonTest`. + """ + + scheduler_class = LTXEulerAncestralRFScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "eta": 1.0, + "s_noise": 1.0, + } + config.update(**kwargs) + return config + + def test_set_timesteps_auto_generates_schedule(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 2, 4, 8]: + scheduler.set_timesteps(num_inference_steps=nfe) + self.assertEqual(scheduler.num_inference_steps, nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe,)) + self.assertEqual(scheduler.sigmas.shape, (nfe + 1,)) + self.assertAlmostEqual(scheduler.sigmas[-1].item(), 0.0, places=6) + + def test_set_timesteps_requires_args(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps() + + def test_set_timesteps_explicit_sigmas(self): + scheduler = self.scheduler_class(**self.get_default_config()) + custom_sigmas = [1.0, 0.75, 0.5, 0.25, 0.0] + scheduler.set_timesteps(sigmas=custom_sigmas) + self.assertEqual(scheduler.num_inference_steps, 4) + for i, sigma in enumerate(custom_sigmas): + self.assertAlmostEqual(scheduler.sigmas[i].item(), sigma, places=5) + self.assertAlmostEqual(scheduler.timesteps[0].item(), 1000.0, places=4) + + def test_set_timesteps_rejects_non_1d_sigmas(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(sigmas=[[1.0, 0.5], [0.5, 0.0]]) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 8, 4, 4) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1] + + output = scheduler.step(model_output, timestep, sample) + self.assertIsInstance(output, LTXEulerAncestralRFSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertEqual(output.prev_sample.dtype, sample.dtype) + + def test_step_deterministic_when_eta_zero(self): + scheduler = self.scheduler_class(**self.get_default_config(eta=0.0)) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + generator = torch.Generator().manual_seed(0) + + first = scheduler.step(model_output, scheduler.timesteps[0:1], sample, generator=generator).prev_sample + + scheduler.set_timesteps(num_inference_steps=4) + second = scheduler.step(model_output, scheduler.timesteps[0:1], sample, generator=generator).prev_sample + torch.testing.assert_close(first, second) + + def test_step_rejects_integer_timestep(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + sample = torch.randn(1, 4, 4, 4) + with self.assertRaises(ValueError): + scheduler.step(torch.randn_like(sample), 0, sample) + + def test_index_for_timestep_duplicate_handling(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(sigmas=[1.0, 0.5, 0.5, 0.0]) + duplicate = scheduler.timesteps[1] + self.assertEqual(scheduler.index_for_timestep(duplicate), 2) + + def test_set_begin_index_anchors_step_index(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + scheduler.set_begin_index(2) + sample = torch.randn(1, 4, 4, 4) + scheduler.step(torch.randn_like(sample), scheduler.timesteps[0], sample) + self.assertEqual(scheduler.step_index, 3) + + def test_full_denoising_loop(self): + scheduler = self.scheduler_class(**self.get_default_config(eta=0.0)) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + for t in scheduler.timesteps: + sample = scheduler.step(torch.randn_like(sample), t, sample).prev_sample + self.assertEqual(sample.shape, (1, 4, 8, 8)) From 21479fce4847077b38bb3e73910a7925e0af8168 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Sat, 27 Jun 2026 02:05:44 +0000 Subject: [PATCH 2/2] Fix FirstBlockCache tests: set cache context before forward passes Co-authored-by: Simon Lynch --- tests/hooks/test_first_block_cache.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/hooks/test_first_block_cache.py b/tests/hooks/test_first_block_cache.py index 0a2eabc..0bbb7c2 100644 --- a/tests/hooks/test_first_block_cache.py +++ b/tests/hooks/test_first_block_cache.py @@ -64,12 +64,19 @@ def register_dummy_blocks(): ) +def _set_context(model, context_name): + for module in model.modules(): + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook._set_context(context_name) + + def test_first_block_cache_skipping_logic(): """ FirstBlockCache skips tail blocks when the first-block residual change is below threshold. """ model = DummyTransformer() apply_first_block_cache(model, FirstBlockCacheConfig(threshold=1.0)) + _set_context(model, "test_context") # Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each). Head residual = 10, tail residual = 20. input_t0 = torch.tensor([[[10.0]]]) @@ -88,6 +95,7 @@ def test_first_block_cache_compute_when_residual_changes(): """A low threshold forces full recomputation when the first-block residual shifts.""" model = DummyTransformer() apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.01)) + _set_context(model, "test_context") model(torch.tensor([[[10.0]]])) @@ -102,6 +110,7 @@ def test_first_block_cache_tuple_outputs(): """Test compatibility with models returning (hidden, encoder_hidden) like Flux.""" model = TupleTransformer() apply_first_block_cache(model, FirstBlockCacheConfig(threshold=1.0)) + _set_context(model, "test_context") input_t0 = torch.tensor([[[10.0]]]) enc_t0 = torch.tensor([[[1.0]]]) @@ -118,6 +127,7 @@ def test_first_block_cache_first_pass_always_computes(): """The first forward pass must always run all blocks to populate tail residuals.""" model = DummyTransformer() apply_first_block_cache(model, FirstBlockCacheConfig(threshold=1.0)) + _set_context(model, "test_context") input_t0 = torch.tensor([[[5.0]]]) output_t0 = model(input_t0) @@ -128,6 +138,7 @@ def test_first_block_cache_large_input_change_recomputes(): """Large input changes exceed the threshold and trigger full recomputation.""" model = DummyTransformer() apply_first_block_cache(model, FirstBlockCacheConfig(threshold=0.05)) + _set_context(model, "test_context") model(torch.tensor([[[10.0]]]))