From 05475755257e13384f08f705bcf615921a096f6f Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Tue, 23 Jun 2026 02:05:18 +0000 Subject: [PATCH 1/2] Add scheduler regression tests for LTX, Helios, and FlowMatch Cover high-blast-radius schedulers that previously had no dedicated unit tests: LTXEulerAncestralRFScheduler (LTX long-form video), HeliosScheduler multi-stage scheduling, and FlowMatchEulerDiscreteScheduler (SD3/Flux/Wan). Co-authored-by: Simon Lynch --- ...est_scheduler_flow_match_euler_discrete.py | 108 +++++++++++++++ tests/schedulers/test_scheduler_helios.py | 94 +++++++++++++ .../test_scheduler_ltx_euler_ancestral_rf.py | 124 ++++++++++++++++++ 3 files changed, 326 insertions(+) create mode 100644 tests/schedulers/test_scheduler_flow_match_euler_discrete.py create mode 100644 tests/schedulers/test_scheduler_helios.py create mode 100644 tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py diff --git a/tests/schedulers/test_scheduler_flow_match_euler_discrete.py b/tests/schedulers/test_scheduler_flow_match_euler_discrete.py new file mode 100644 index 0000000..56f7410 --- /dev/null +++ b/tests/schedulers/test_scheduler_flow_match_euler_discrete.py @@ -0,0 +1,108 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteSchedulerOutput + + +class FlowMatchEulerDiscreteSchedulerTest(unittest.TestCase): + """ + Contract tests for FlowMatchEulerDiscreteScheduler — shared by SD3, Flux, Wan, and many + flow-matching pipelines. No dedicated scheduler test file existed previously. + """ + + scheduler_class = FlowMatchEulerDiscreteScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "shift": 1.0, + } + config.update(**kwargs) + return config + + def test_set_timesteps_endpoints(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 2, 4, 8, 16]: + scheduler.set_timesteps(num_inference_steps=nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe,)) + self.assertEqual(scheduler.sigmas.shape, (nfe + 1,)) + self.assertAlmostEqual(scheduler.sigmas[-1].item(), 0.0, places=6) + + def test_set_timesteps_dynamic_shifting_requires_mu(self): + scheduler = self.scheduler_class(**self.get_default_config(use_dynamic_shifting=True)) + with self.assertRaises(ValueError): + scheduler.set_timesteps(num_inference_steps=4) + + def test_set_timesteps_custom_sigmas_and_timesteps_length_mismatch(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(sigmas=[1.0, 0.5, 0.0], timesteps=[900.0, 500.0]) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 4, 8, 8) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1] + + output = scheduler.step(model_output, timestep, sample) + self.assertIsInstance(output, FlowMatchEulerDiscreteSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertEqual(output.prev_sample.dtype, model_output.dtype) + + def test_step_rejects_integer_timestep(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + sample = torch.randn(1, 4, 4, 4) + with self.assertRaises(ValueError): + scheduler.step(torch.randn_like(sample), 0, sample) + + def test_scale_noise_interpolates_sample_and_noise(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(timesteps=[500.0, 250.0]) + sample = torch.zeros(1, 4, 4, 4) + noise = torch.ones_like(sample) + mid_t = scheduler.timesteps[0:1] + mixed = scheduler.scale_noise(sample, mid_t, noise) + # sigma=0.5 at t=500 with default 1000 training steps + torch.testing.assert_close(mixed, 0.5 * noise, atol=1e-4, rtol=1e-4) + + def test_index_for_timestep_duplicate_handling(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(timesteps=[900.0, 500.0, 500.0, 100.0]) + duplicate = scheduler.timesteps[1] + self.assertEqual(scheduler.index_for_timestep(duplicate), 2) + + def test_set_begin_index_anchors_step_index(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + scheduler.set_begin_index(2) + sample = torch.randn(1, 4, 4, 4) + scheduler.step(torch.randn_like(sample), scheduler.timesteps[0], sample) + self.assertEqual(scheduler.step_index, 3) + + def test_full_denoising_loop(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + for t in scheduler.timesteps: + sample = scheduler.step(torch.randn_like(sample), t, sample).prev_sample + self.assertEqual(sample.shape, (1, 4, 8, 8)) diff --git a/tests/schedulers/test_scheduler_helios.py b/tests/schedulers/test_scheduler_helios.py new file mode 100644 index 0000000..6e20cef --- /dev/null +++ b/tests/schedulers/test_scheduler_helios.py @@ -0,0 +1,94 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import HeliosScheduler +from diffusers.schedulers.scheduling_helios import HeliosSchedulerOutput + + +class HeliosSchedulerTest(unittest.TestCase): + """ + Unit tests for HeliosScheduler multi-stage scheduling. Pipeline tests only cover stages=1; + these tests lock in per-stage sigma ranges and step behavior used by the default 3-stage config. + """ + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "shift": 1.0, + "stages": 3, + "stage_range": [0, 1 / 3, 2 / 3, 1], + "gamma": 1 / 3, + "prediction_type": "flow_prediction", + "scheduler_type": "euler", + "use_dynamic_shifting": False, + } + config.update(**kwargs) + return config + + def test_multi_stage_init_populates_per_stage_buffers(self): + scheduler = HeliosScheduler(**self.get_default_config()) + self.assertEqual(len(scheduler.timesteps_per_stage), 3) + self.assertEqual(len(scheduler.sigmas_per_stage), 3) + self.assertEqual(len(scheduler.start_sigmas), 3) + self.assertEqual(len(scheduler.end_sigmas), 3) + for stage in range(3): + self.assertGreater(scheduler.start_sigmas[stage], scheduler.end_sigmas[stage]) + + def test_set_timesteps_multi_stage_per_stage_index(self): + scheduler = HeliosScheduler(**self.get_default_config()) + for stage_index in range(3): + scheduler.set_timesteps(num_inference_steps=8, stage_index=stage_index) + self.assertEqual(scheduler.timesteps.shape, (8,)) + self.assertEqual(scheduler.sigmas.shape, (9,)) + self.assertAlmostEqual(scheduler.sigmas[-1].item(), 0.0, places=6) + self.assertGreater(scheduler.timesteps[0].item(), scheduler.timesteps[-1].item()) + + def test_set_timesteps_single_stage(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1])) + scheduler.set_timesteps(num_inference_steps=4) + self.assertEqual(scheduler.timesteps.shape, (4,)) + self.assertEqual(scheduler.sigmas.shape, (5,)) + + def test_step_euler_updates_sample(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1])) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 4, 8, 8) + model_output = torch.randn_like(sample) + output = scheduler.step(model_output, scheduler.timesteps[0], sample) + self.assertIsInstance(output, HeliosSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertFalse(torch.allclose(output.prev_sample, sample)) + + def test_convert_model_output_flow_prediction(self): + scheduler = HeliosScheduler(**self.get_default_config(stages=1, stage_range=[0, 1], scheduler_type="unipc")) + scheduler.set_timesteps(num_inference_steps=4) + scheduler._step_index = 0 + + sample = torch.randn(1, 4, 4, 4) + model_output = torch.randn_like(sample) + x0 = scheduler.convert_model_output(model_output, sample=sample) + expected = sample - scheduler.sigmas[0] * model_output + torch.testing.assert_close(x0, expected) + + def test_dynamic_shifting_rescales_timesteps(self): + scheduler = HeliosScheduler( + **self.get_default_config(stages=1, stage_range=[0, 1], use_dynamic_shifting=True) + ) + scheduler.set_timesteps(num_inference_steps=4, mu=0.5) + self.assertEqual(scheduler.timesteps.shape, (4,)) diff --git a/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py b/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py new file mode 100644 index 0000000..f898d54 --- /dev/null +++ b/tests/schedulers/test_scheduler_ltx_euler_ancestral_rf.py @@ -0,0 +1,124 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import LTXEulerAncestralRFScheduler +from diffusers.schedulers.scheduling_ltx_euler_ancestral_rf import LTXEulerAncestralRFSchedulerOutput + + +class LTXEulerAncestralRFSchedulerTest(unittest.TestCase): + """ + Contract tests for the LTX RF ancestral scheduler used by LTX-Video long-form pipelines. + Mirrors the style of `test_scheduler_flow_map_euler_discrete.py` because this scheduler + has a non-standard step API and cannot reuse `SchedulerCommonTest`. + """ + + scheduler_class = LTXEulerAncestralRFScheduler + + def get_default_config(self, **kwargs): + config = { + "num_train_timesteps": 1000, + "eta": 1.0, + "s_noise": 1.0, + } + config.update(**kwargs) + return config + + def test_set_timesteps_auto_generates_schedule(self): + scheduler = self.scheduler_class(**self.get_default_config()) + for nfe in [1, 2, 4, 8]: + scheduler.set_timesteps(num_inference_steps=nfe) + self.assertEqual(scheduler.num_inference_steps, nfe) + self.assertEqual(scheduler.timesteps.shape, (nfe,)) + self.assertEqual(scheduler.sigmas.shape, (nfe + 1,)) + self.assertAlmostEqual(scheduler.sigmas[-1].item(), 0.0, places=6) + + def test_set_timesteps_requires_args(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps() + + def test_set_timesteps_explicit_sigmas(self): + scheduler = self.scheduler_class(**self.get_default_config()) + custom_sigmas = [1.0, 0.75, 0.5, 0.25, 0.0] + scheduler.set_timesteps(sigmas=custom_sigmas) + self.assertEqual(scheduler.num_inference_steps, 4) + for i, sigma in enumerate(custom_sigmas): + self.assertAlmostEqual(scheduler.sigmas[i].item(), sigma, places=5) + self.assertAlmostEqual(scheduler.timesteps[0].item(), 1000.0, places=4) + + def test_set_timesteps_rejects_non_1d_sigmas(self): + scheduler = self.scheduler_class(**self.get_default_config()) + with self.assertRaises(ValueError): + scheduler.set_timesteps(sigmas=[[1.0, 0.5], [0.5, 0.0]]) + + def test_step_shape_preserved(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(2, 8, 4, 4) + model_output = torch.randn_like(sample) + timestep = scheduler.timesteps[0:1] + + output = scheduler.step(model_output, timestep, sample) + self.assertIsInstance(output, LTXEulerAncestralRFSchedulerOutput) + self.assertEqual(output.prev_sample.shape, sample.shape) + self.assertEqual(output.prev_sample.dtype, sample.dtype) + + def test_step_deterministic_when_eta_zero(self): + scheduler = self.scheduler_class(**self.get_default_config(eta=0.0)) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + model_output = torch.randn_like(sample) + generator = torch.Generator().manual_seed(0) + + first = scheduler.step(model_output, scheduler.timesteps[0:1], sample, generator=generator).prev_sample + + scheduler.set_timesteps(num_inference_steps=4) + second = scheduler.step(model_output, scheduler.timesteps[0:1], sample, generator=generator).prev_sample + torch.testing.assert_close(first, second) + + def test_step_rejects_integer_timestep(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + sample = torch.randn(1, 4, 4, 4) + with self.assertRaises(ValueError): + scheduler.step(torch.randn_like(sample), 0, sample) + + def test_index_for_timestep_duplicate_handling(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(sigmas=[1.0, 0.5, 0.5, 0.0]) + duplicate = scheduler.timesteps[1] + self.assertEqual(scheduler.index_for_timestep(duplicate), 2) + + def test_set_begin_index_anchors_step_index(self): + scheduler = self.scheduler_class(**self.get_default_config()) + scheduler.set_timesteps(num_inference_steps=4) + scheduler.set_begin_index(2) + sample = torch.randn(1, 4, 4, 4) + scheduler.step(torch.randn_like(sample), scheduler.timesteps[0], sample) + self.assertEqual(scheduler.step_index, 3) + + def test_full_denoising_loop(self): + scheduler = self.scheduler_class(**self.get_default_config(eta=0.0)) + scheduler.set_timesteps(num_inference_steps=4) + + sample = torch.randn(1, 4, 8, 8) + for t in scheduler.timesteps: + sample = scheduler.step(torch.randn_like(sample), t, sample).prev_sample + self.assertEqual(sample.shape, (1, 4, 8, 8)) From 0fc96ff5520e19eb568aa5fa80d9bb3eafe80482 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Tue, 23 Jun 2026 02:05:22 +0000 Subject: [PATCH 2/2] Add unit tests for state_dict_utils and remote_utils validation Cover LoRA key conversion paths used on every adapter load and input validation / postprocessing branches in remote VAE encode-decode helpers. Co-authored-by: Simon Lynch --- tests/others/test_remote_utils.py | 104 ++++++++++++++++++++++++ tests/others/test_state_dict_utils.py | 110 ++++++++++++++++++++++++++ 2 files changed, 214 insertions(+) create mode 100644 tests/others/test_remote_utils.py create mode 100644 tests/others/test_state_dict_utils.py diff --git a/tests/others/test_remote_utils.py b/tests/others/test_remote_utils.py new file mode 100644 index 0000000..501b94b --- /dev/null +++ b/tests/others/test_remote_utils.py @@ -0,0 +1,104 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import json +import unittest +from unittest.mock import Mock + +import torch +from PIL import Image + +from diffusers.image_processor import VaeImageProcessor +from diffusers.utils.remote_utils import ( + check_inputs_decode, + detect_image_type, + postprocess_decode, + prepare_decode, + prepare_encode, +) + + +class RemoteUtilsTest(unittest.TestCase): + def test_detect_image_type(self): + self.assertEqual(detect_image_type(b"\xff\xd8\xff"), "jpeg") + self.assertEqual(detect_image_type(b"\x89PNG\r\n\x1a\n"), "png") + self.assertEqual(detect_image_type(b"GIF89a"), "gif") + self.assertEqual(detect_image_type(b"BM"), "bmp") + self.assertEqual(detect_image_type(b"unknown"), "unknown") + + def test_check_inputs_decode_packed_latents_requires_hw(self): + tensor = torch.randn(4, 8, 8) + with self.assertRaises(ValueError): + check_inputs_decode("http://example.com", tensor) + + def test_check_inputs_decode_processor_required(self): + tensor = torch.randn(1, 4, 8, 8) + with self.assertRaises(ValueError): + check_inputs_decode( + "http://example.com", + tensor, + processor=None, + output_type="pt", + return_type="pil", + partial_postprocess=False, + ) + + def test_prepare_decode_sets_accept_header_for_jpeg(self): + tensor = torch.randn(1, 4, 8, 8, dtype=torch.float16) + payload = prepare_decode(tensor, output_type="pil", image_format="jpg") + self.assertEqual(payload["headers"]["Accept"], "image/jpeg") + self.assertEqual(payload["params"]["output_type"], "pil") + self.assertEqual(payload["params"]["shape"], list(tensor.shape)) + + def test_prepare_encode_tensor_includes_shape_and_dtype(self): + tensor = torch.randn(1, 3, 8, 8, dtype=torch.float16) + payload = prepare_encode(tensor, scaling_factor=0.18215) + self.assertEqual(payload["params"]["shape"], list(tensor.shape)) + self.assertEqual(payload["params"]["dtype"], "float16") + self.assertEqual(payload["params"]["scaling_factor"], 0.18215) + + def test_prepare_encode_pil_image(self): + image = Image.new("RGB", (8, 8), color="red") + payload = prepare_encode(image) + self.assertIn(b"PNG", payload["data"][:8]) + + def test_postprocess_decode_pil_without_processor(self): + buffer = io.BytesIO() + Image.new("RGB", (4, 4), color="blue").save(buffer, format="PNG") + response = Mock() + response.content = buffer.getvalue() + + output = postprocess_decode(response, processor=None, output_type="pil", return_type="pil") + self.assertIsInstance(output, Image.Image) + self.assertEqual(output.size, (4, 4)) + self.assertEqual(output.format, "png") + + def test_postprocess_decode_pt_tensor(self): + tensor = torch.arange(16, dtype=torch.float32).reshape(1, 4, 2, 2) + response = Mock() + response.content = tensor.numpy().tobytes() + response.headers = { + "shape": json.dumps(list(tensor.shape)), + "dtype": "float32", + } + + output = postprocess_decode( + response, + processor=None, + output_type="pt", + return_type="pt", + partial_postprocess=False, + ) + torch.testing.assert_close(output, tensor) diff --git a/tests/others/test_state_dict_utils.py b/tests/others/test_state_dict_utils.py new file mode 100644 index 0000000..0389546 --- /dev/null +++ b/tests/others/test_state_dict_utils.py @@ -0,0 +1,110 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers.utils.state_dict_utils import ( + StateDictType, + convert_all_state_dict_to_peft, + convert_state_dict, + convert_state_dict_to_diffusers, + convert_state_dict_to_peft, + convert_unet_state_dict_to_peft, + state_dict_all_zero, +) + + +class StateDictUtilsTest(unittest.TestCase): + def test_convert_state_dict_applies_first_matching_pattern(self): + state_dict = {"layer.processor.weight": torch.ones(1)} + converted = convert_state_dict(state_dict, {".processor.": "."}) + self.assertIn("layer.weight", converted) + self.assertNotIn("layer.processor.weight", converted) + + def test_convert_state_dict_to_peft_auto_infers_diffusers_old(self): + state_dict = { + "unet.down_blocks.0.attentions.0.to_out_lora.down.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_out_lora.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_peft(state_dict) + self.assertIn("unet.down_blocks.0.attentions.0.out_proj.lora_A.weight", converted) + + state_dict = { + "unet.down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_peft(state_dict, original_type=StateDictType.DIFFUSERS_OLD) + self.assertIn("unet.down_blocks.0.attentions.0.q_proj.lora_A.weight", converted) + self.assertIn("unet.down_blocks.0.attentions.0.q_proj.lora_B.weight", converted) + + def test_convert_state_dict_to_peft_diffusers(self): + state_dict = { + "text_encoder.encoder.layers.0.self_attn.q_proj.lora_linear_layer.down.weight": torch.ones(2, 2), + "text_encoder.encoder.layers.0.self_attn.q_proj.lora_linear_layer.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_peft(state_dict, original_type=StateDictType.DIFFUSERS) + self.assertIn("text_encoder.encoder.layers.0.self_attn.q_proj.lora_A.weight", converted) + self.assertIn("text_encoder.encoder.layers.0.self_attn.q_proj.lora_B.weight", converted) + + def test_convert_state_dict_to_diffusers_from_peft(self): + state_dict = { + "unet.down_blocks.0.attentions.0.to_q.lora_A.weight": torch.ones(2, 2), + "unet.down_blocks.0.attentions.0.to_q.lora_B.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_diffusers(state_dict, original_type=StateDictType.PEFT) + self.assertIn("unet.down_blocks.0.attentions.0.to_q.lora.down.weight", converted) + self.assertIn("unet.down_blocks.0.attentions.0.to_q.lora.up.weight", converted) + + def test_convert_state_dict_to_diffusers_already_diffusers(self): + state_dict = { + "layer.lora_linear_layer.down.weight": torch.ones(2, 2), + "layer.lora_linear_layer.up.weight": torch.ones(2, 2), + } + converted = convert_state_dict_to_diffusers(state_dict) + self.assertIs(converted, state_dict) + + def test_convert_unet_state_dict_to_peft(self): + state_dict = { + "down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2), + "down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2), + } + converted = convert_unet_state_dict_to_peft(state_dict) + self.assertIn("down_blocks.0.attentions.0.to_q.lora_A.weight", converted) + self.assertIn("down_blocks.0.attentions.0.to_q.lora_B.weight", converted) + + def test_convert_all_state_dict_to_peft_falls_back_to_unet(self): + state_dict = { + "down_blocks.0.attentions.0.to_q_lora.down.weight": torch.ones(2, 2), + "down_blocks.0.attentions.0.to_q_lora.up.weight": torch.ones(2, 2), + } + converted = convert_all_state_dict_to_peft(state_dict) + self.assertTrue(any("lora_A" in key or "lora_B" in key for key in converted)) + + def test_state_dict_all_zero(self): + state_dict = { + "a": torch.zeros(2, 2), + "b": torch.zeros(3), + } + self.assertTrue(state_dict_all_zero(state_dict)) + state_dict["b"] = torch.ones(3) + self.assertFalse(state_dict_all_zero(state_dict)) + + def test_state_dict_all_zero_with_filter(self): + state_dict = { + "lora.down": torch.zeros(2, 2), + "bias": torch.ones(3), + } + self.assertTrue(state_dict_all_zero(state_dict, filter_str="lora"))