diff --git a/tests/models/unets/test_models_unet_3d_condition.py b/tests/models/unets/test_models_unet_3d_condition.py index f73e3461c38e..9ebdd659ee4b 100644 --- a/tests/models/unets/test_models_unet_3d_condition.py +++ b/tests/models/unets/test_models_unet_3d_condition.py @@ -13,52 +13,43 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - -import numpy as np import torch -from diffusers.models import ModelMixin, UNet3DConditionModel -from diffusers.utils import logging -from diffusers.utils.import_utils import is_xformers_available +from diffusers import UNet3DConditionModel +from diffusers.utils.torch_utils import randn_tensor -from ...testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + MemoryTesterMixin, + ModelTesterMixin, + TrainingTesterMixin, +) enable_full_determinism() -logger = logging.get_logger(__name__) - - -@skip_mps -class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNet3DConditionModel - main_input_name = "sample" +class UNet3DConditionModelTesterConfig(BaseModelTesterConfig): @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - num_frames = 4 - sizes = (16, 16) + def model_class(self): + return UNet3DConditionModel - noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) - - return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + @property + def main_input_name(self) -> str: + return "sample" @property - def input_shape(self): + def output_shape(self) -> tuple: return (4, 4, 16, 16) @property - def output_shape(self): - return (4, 4, 16, 16) + def generator(self): + return torch.Generator("cpu").manual_seed(0) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self) -> dict: + return { "block_out_channels": (4, 8), "norm_num_groups": 4, "down_block_types": ( @@ -73,111 +64,57 @@ def prepare_init_args_and_inputs_for_common(self): "layers_per_block": 1, "sample_size": 16, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_xformers_enable_works(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.enable_xformers_memory_efficient_attention() + def get_dummy_inputs(self) -> dict: + batch_size = 4 + num_channels = 4 + num_frames = 4 + sizes = (16, 16) + noise = randn_tensor( + (batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device + ) + timestep = torch.tensor([10], device=torch_device) + encoder_hidden_states = randn_tensor((batch_size, 4, 8), generator=self.generator, device=torch_device) + return {"sample": noise, "timestep": timestep, "encoder_hidden_states": encoder_hidden_states} - assert ( - model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ - == "XFormersAttnProcessor" - ), "xformers is not enabled" - # Overriding to set `norm_num_groups` needs to be different for this model. +class TestUNet3DConditionModel(UNet3DConditionModelTesterConfig, ModelTesterMixin): + # Overridden because UNet3DConditionModel needs a different `norm_num_groups`. def test_forward_with_norm_groups(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() init_dict["block_out_channels"] = (32, 64) init_dict["norm_num_groups"] = 32 - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() + model = self.model_class(**init_dict).to(torch_device).eval() with torch.no_grad(): - output = model(**inputs_dict) + output = model(**self.get_dummy_inputs()).sample - if isinstance(output, dict): - output = output.sample + assert output.shape == self.get_dummy_inputs()["sample"].shape, "Input and output shapes do not match" - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - # Overriding since the UNet3D outputs a different structure. - def test_determinism(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() + def test_feed_forward_chunking(self): + init_dict = self.get_init_dict() + init_dict["block_out_channels"] = (32, 64) + init_dict["norm_num_groups"] = 32 + model = self.model_class(**init_dict).to(torch_device).eval() with torch.no_grad(): - # Warmup pass when using mps (see #372) - if torch_device == "mps" and isinstance(model, ModelMixin): - model(**self.dummy_input) - - first = model(**inputs_dict) - if isinstance(first, dict): - first = first.sample - - second = model(**inputs_dict) - if isinstance(second, dict): - second = second.sample - - out_1 = first.cpu().numpy() - out_2 = second.cpu().numpy() - out_1 = out_1[~np.isnan(out_1)] - out_2 = out_2[~np.isnan(out_2)] - max_diff = np.amax(np.abs(out_1 - out_2)) - self.assertLessEqual(max_diff, 1e-5) + output = model(**self.get_dummy_inputs())[0] - def test_model_attention_slicing(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["block_out_channels"] = (16, 32) - init_dict["attention_head_dim"] = 8 - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - model.set_attention_slice("auto") + model.enable_forward_chunking() with torch.no_grad(): - output = model(**inputs_dict) - assert output is not None + output_2 = model(**self.get_dummy_inputs())[0] - model.set_attention_slice("max") - with torch.no_grad(): - output = model(**inputs_dict) - assert output is not None + assert output.shape == output_2.shape, "Shape doesn't match" + assert (output - output_2).abs().max() < 1e-2 - model.set_attention_slice(2) - with torch.no_grad(): - output = model(**inputs_dict) - assert output is not None - def test_feed_forward_chunking(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - init_dict["block_out_channels"] = (32, 64) - init_dict["norm_num_groups"] = 32 +class TestUNet3DConditionModelTraining(UNet3DConditionModelTesterConfig, TrainingTesterMixin): + """Training tests for UNet3DConditionModel.""" - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - with torch.no_grad(): - output = model(**inputs_dict)[0] +class TestUNet3DConditionModelMemory(UNet3DConditionModelTesterConfig, MemoryTesterMixin): + """Memory optimization tests for UNet3DConditionModel.""" - model.enable_forward_chunking() - with torch.no_grad(): - output_2 = model(**inputs_dict)[0] - self.assertEqual(output.shape, output_2.shape, "Shape doesn't match") - assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2 +class TestUNet3DConditionModelAttention(UNet3DConditionModelTesterConfig, AttentionTesterMixin): + """Attention processor tests for UNet3DConditionModel.""" diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index 40773536df70..a12eb57228da 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -13,59 +13,44 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - -import numpy as np import torch from torch import nn from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel -from diffusers.utils import logging - -from ...testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from diffusers.utils.torch_utils import randn_tensor +from ...testing_utils import enable_full_determinism, is_flaky, torch_device +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + MemoryTesterMixin, + ModelTesterMixin, + TrainingTesterMixin, +) -logger = logging.get_logger(__name__) enable_full_determinism() -class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNetControlNetXSModel - main_input_name = "sample" - +class UNetControlNetXSModelTesterConfig(BaseModelTesterConfig): @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - sizes = (16, 16) - conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning - - noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device) - controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device) - conditioning_scale = 1 + def model_class(self): + return UNetControlNetXSModel - return { - "sample": noise, - "timestep": time_step, - "encoder_hidden_states": encoder_hidden_states, - "controlnet_cond": controlnet_cond, - "conditioning_scale": conditioning_scale, - } + @property + def main_input_name(self) -> str: + return "sample" @property - def input_shape(self): + def output_shape(self) -> tuple: return (4, 16, 16) @property - def output_shape(self): - return (4, 16, 16) + def generator(self): + return torch.Generator("cpu").manual_seed(0) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self) -> dict: + return { "sample_size": 16, "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"), "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"), @@ -80,11 +65,27 @@ def prepare_init_args_and_inputs_for_common(self): "ctrl_max_norm_num_groups": 2, "ctrl_conditioning_embedding_out_channels": (2, 2), } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self) -> dict: + batch_size = 4 + num_channels = 4 + sizes = (16, 16) + noise = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device) + timestep = torch.tensor([10], device=torch_device) + encoder_hidden_states = randn_tensor((batch_size, 4, 8), generator=self.generator, device=torch_device) + controlnet_cond = randn_tensor((batch_size, 3, 32, 32), generator=self.generator, device=torch_device) + return { + "sample": noise, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "controlnet_cond": controlnet_cond, + "conditioning_scale": 1, + } + + +class TestUNetControlNetXSModel(UNetControlNetXSModelTesterConfig, ModelTesterMixin): def get_dummy_unet(self): - """For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter""" + """The underlying UNet, used to build the UNetControlNetXSModel from a UNet and a ControlNetXS-Adapter.""" return UNet2DConditionModel( block_out_channels=(4, 8), layers_per_block=2, @@ -99,8 +100,7 @@ def get_dummy_unet(self): ) def get_dummy_controlnet_from_unet(self, unet, **kwargs): - """For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter""" - # size_ratio and conditioning_embedding_out_channels chosen to keep model small + """The underlying ControlNetXS-Adapter. size_ratio and conditioning_embedding_out_channels keep the model small.""" return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs) def test_from_unet(self): @@ -114,21 +114,11 @@ def assert_equal_weights(module, weight_dict_prefix): for param_name, param_value in module.named_parameters(): assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value) - # # check unet - # everything expect down,mid,up blocks - modules_from_unet = [ - "time_embedding", - "conv_in", - "conv_norm_out", - "conv_out", - ] + # check unet: everything except down, mid, up blocks + modules_from_unet = ["time_embedding", "conv_in", "conv_norm_out", "conv_out"] for p in modules_from_unet: assert_equal_weights(getattr(unet, p), "base_" + p) - optional_modules_from_unet = [ - "class_embedding", - "add_time_proj", - "add_embedding", - ] + optional_modules_from_unet = ["class_embedding", "add_time_proj", "add_embedding"] for p in optional_modules_from_unet: if hasattr(unet, p) and getattr(unet, p) is not None: assert_equal_weights(getattr(unet, p), "base_" + p) @@ -151,8 +141,7 @@ def assert_equal_weights(module, weight_dict_prefix): if hasattr(u, "upsamplers") and getattr(u, "upsamplers") is not None: assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers") - # # check controlnet - # everything expect down,mid,up blocks + # check controlnet: everything except down, mid, up blocks modules_from_controlnet = { "controlnet_cond_embedding": "controlnet_cond_embedding", "conv_in": "ctrl_conv_in", @@ -161,7 +150,6 @@ def assert_equal_weights(module, weight_dict_prefix): optional_modules_from_controlnet = {"time_embedding": "ctrl_time_embedding"} for name_in_controlnet, name_in_unetcnxs in modules_from_controlnet.items(): assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs) - for name_in_controlnet, name_in_unetcnxs in optional_modules_from_controlnet.items(): if hasattr(controlnet, name_in_controlnet) and getattr(controlnet, name_in_controlnet) is not None: assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs) @@ -193,12 +181,10 @@ def assert_unfrozen(module): for p in module.parameters(): assert p.requires_grad - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - model = UNetControlNetXSModel(**init_dict) + model = UNetControlNetXSModel(**self.get_init_dict()) model.freeze_unet_params() - # # check unet - # everything expect down,mid,up blocks + # check unet: everything except down, mid, up blocks modules_from_unet = [ model.base_time_embedding, model.base_conv_in, @@ -207,49 +193,39 @@ def assert_unfrozen(module): ] for m in modules_from_unet: assert_frozen(m) - - optional_modules_from_unet = [ - model.base_add_time_proj, - model.base_add_embedding, - ] + optional_modules_from_unet = [model.base_add_time_proj, model.base_add_embedding] for m in optional_modules_from_unet: if m is not None: assert_frozen(m) - # down blocks - for i, d in enumerate(model.down_blocks): + for d in model.down_blocks: assert_frozen(d.base_resnets) if isinstance(d.base_attentions, nn.ModuleList): # attentions can be list of Nones assert_frozen(d.base_attentions) if d.base_downsamplers is not None: assert_frozen(d.base_downsamplers) - # mid block assert_frozen(model.mid_block.base_midblock) - # up blocks - for i, u in enumerate(model.up_blocks): + for u in model.up_blocks: assert_frozen(u.resnets) if isinstance(u.attentions, nn.ModuleList): # attentions can be list of Nones assert_frozen(u.attentions) if u.upsamplers is not None: assert_frozen(u.upsamplers) - # # check controlnet - # everything expect down,mid,up blocks + # check controlnet: everything except down, mid, up blocks modules_from_controlnet = [ model.controlnet_cond_embedding, model.ctrl_conv_in, model.control_to_base_for_conv_in, ] optional_modules_from_controlnet = [model.ctrl_time_embedding] - for m in modules_from_controlnet: assert_unfrozen(m) for m in optional_modules_from_controlnet: if m is not None: assert_unfrozen(m) - # down blocks for d in model.down_blocks: assert_unfrozen(d.ctrl_resnets) @@ -267,36 +243,24 @@ def assert_unfrozen(module): for u in model.up_blocks: assert_unfrozen(u.ctrl_to_base) - def test_gradient_checkpointing_is_applied(self): - expected_set = { - "Transformer2DModel", - "UNetMidBlock2DCrossAttn", - "ControlNetXSCrossAttnDownBlock2D", - "ControlNetXSCrossAttnMidBlock2D", - "ControlNetXSCrossAttnUpBlock2D", - } - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - @is_flaky def test_forward_no_control(self): unet = self.get_dummy_unet() controlnet = self.get_dummy_controlnet_from_unet(unet) model = UNetControlNetXSModel.from_unet(unet, controlnet) - unet = unet.to(torch_device) model = model.to(torch_device) - input_ = self.dummy_input - + inputs = self.get_dummy_inputs() control_specific_input = ["controlnet_cond", "conditioning_scale"] - input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input} + input_for_unet = {k: v for k, v in inputs.items() if k not in control_specific_input} with torch.no_grad(): unet_output = unet(**input_for_unet).sample.cpu() - unet_controlnet_output = model(**input_, apply_control=False).sample.cpu() + unet_controlnet_output = model(**inputs, apply_control=False).sample.cpu() - assert np.abs(unet_output.flatten() - unet_controlnet_output.flatten()).max() < 3e-4 + assert (unet_output.flatten() - unet_controlnet_output.flatten()).abs().max() < 3e-4 def test_time_embedding_mixing(self): unet = self.get_dummy_unet() @@ -305,22 +269,34 @@ def test_time_embedding_mixing(self): unet, time_embedding_mix=0.5, learn_time_embedding=True ) - model = UNetControlNetXSModel.from_unet(unet, controlnet) - model_mix_time = UNetControlNetXSModel.from_unet(unet, controlnet_mix_time) - - unet = unet.to(torch_device) - model = model.to(torch_device) - model_mix_time = model_mix_time.to(torch_device) - - input_ = self.dummy_input + model = UNetControlNetXSModel.from_unet(unet, controlnet).to(torch_device) + model_mix_time = UNetControlNetXSModel.from_unet(unet, controlnet_mix_time).to(torch_device) + inputs = self.get_dummy_inputs() with torch.no_grad(): - output = model(**input_).sample - output_mix_time = model_mix_time(**input_).sample + output = model(**inputs).sample + output_mix_time = model_mix_time(**inputs).sample assert output.shape == output_mix_time.shape - @unittest.skip("Test not supported.") - def test_forward_with_norm_groups(self): - # UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups. - pass + +class TestUNetControlNetXSModelTraining(UNetControlNetXSModelTesterConfig, TrainingTesterMixin): + """Training tests for UNetControlNetXSModel.""" + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "Transformer2DModel", + "UNetMidBlock2DCrossAttn", + "ControlNetXSCrossAttnDownBlock2D", + "ControlNetXSCrossAttnMidBlock2D", + "ControlNetXSCrossAttnUpBlock2D", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestUNetControlNetXSModelMemory(UNetControlNetXSModelTesterConfig, MemoryTesterMixin): + """Memory optimization tests for UNetControlNetXSModel.""" + + +class TestUNetControlNetXSModelAttention(UNetControlNetXSModelTesterConfig, AttentionTesterMixin): + """Attention processor tests for UNetControlNetXSModel.""" diff --git a/tests/models/unets/test_models_unet_motion.py b/tests/models/unets/test_models_unet_motion.py index d931b345fd09..21bc9bd62de9 100644 --- a/tests/models/unets/test_models_unet_motion.py +++ b/tests/models/unets/test_models_unet_motion.py @@ -15,56 +15,44 @@ import copy import os -import tempfile -import unittest -import numpy as np import torch from diffusers import MotionAdapter, UNet2DConditionModel, UNetMotionModel -from diffusers.utils import logging -from diffusers.utils.import_utils import is_xformers_available - -from ...testing_utils import ( - enable_full_determinism, - floats_tensor, - torch_device, +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + MemoryTesterMixin, + ModelTesterMixin, + TrainingTesterMixin, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin - -logger = logging.get_logger(__name__) enable_full_determinism() -class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNetMotionModel - main_input_name = "sample" - +class UNetMotionModelTesterConfig(BaseModelTesterConfig): @property - def dummy_input(self): - batch_size = 4 - num_channels = 4 - num_frames = 4 - sizes = (16, 16) - - noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size * num_frames, 4, 16)).to(torch_device) + def model_class(self): + return UNetMotionModel - return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + @property + def main_input_name(self) -> str: + return "sample" @property - def input_shape(self): + def output_shape(self) -> tuple: return (4, 4, 16, 16) @property - def output_shape(self): - return (4, 4, 16, 16) + def generator(self): + return torch.Generator("cpu").manual_seed(0) - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self) -> dict: + return { "block_out_channels": (16, 32), "norm_num_groups": 16, "down_block_types": ("CrossAttnDownBlockMotion", "DownBlockMotion"), @@ -76,9 +64,23 @@ def prepare_init_args_and_inputs_for_common(self): "layers_per_block": 1, "sample_size": 16, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict + def get_dummy_inputs(self) -> dict: + batch_size = 4 + num_channels = 4 + num_frames = 4 + sizes = (16, 16) + noise = randn_tensor( + (batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device + ) + timestep = torch.tensor([10], device=torch_device) + encoder_hidden_states = randn_tensor( + (batch_size * num_frames, 4, 16), generator=self.generator, device=torch_device + ) + return {"sample": noise, "timestep": timestep, "encoder_hidden_states": encoder_hidden_states} + + +class TestUNetMotionModel(UNetMotionModelTesterConfig, ModelTesterMixin): def test_from_unet2d(self): torch.manual_seed(0) unet2d = UNet2DConditionModel() @@ -88,19 +90,17 @@ def test_from_unet2d(self): model_state_dict = model.state_dict() for param_name, param_value in unet2d.named_parameters(): - self.assertTrue(torch.equal(model_state_dict[param_name], param_value)) + assert torch.equal(model_state_dict[param_name], param_value) def test_freeze_unet2d(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) + model = self.model_class(**self.get_init_dict()) model.freeze_unet2d_params() for param_name, param_value in model.named_parameters(): if "motion_modules" not in param_name: - self.assertFalse(param_value.requires_grad) - + assert not param_value.requires_grad else: - self.assertTrue(param_value.requires_grad) + assert param_value.requires_grad def test_loading_motion_adapter(self): model = self.model_class() @@ -110,210 +110,113 @@ def test_loading_motion_adapter(self): for idx, down_block in enumerate(model.down_blocks): adapter_state_dict = adapter.down_blocks[idx].motion_modules.state_dict() for param_name, param_value in down_block.motion_modules.named_parameters(): - self.assertTrue(torch.equal(adapter_state_dict[param_name], param_value)) + assert torch.equal(adapter_state_dict[param_name], param_value) for idx, up_block in enumerate(model.up_blocks): adapter_state_dict = adapter.up_blocks[idx].motion_modules.state_dict() for param_name, param_value in up_block.motion_modules.named_parameters(): - self.assertTrue(torch.equal(adapter_state_dict[param_name], param_value)) + assert torch.equal(adapter_state_dict[param_name], param_value) mid_block_adapter_state_dict = adapter.mid_block.motion_modules.state_dict() for param_name, param_value in model.mid_block.motion_modules.named_parameters(): - self.assertTrue(torch.equal(mid_block_adapter_state_dict[param_name], param_value)) + assert torch.equal(mid_block_adapter_state_dict[param_name], param_value) - def test_saving_motion_modules(self): + def test_saving_motion_modules(self, tmp_path): torch.manual_seed(0) - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_motion_modules(tmpdirname) - self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors"))) + init_dict = self.get_init_dict() + model = self.model_class(**init_dict).to(torch_device) - adapter_loaded = MotionAdapter.from_pretrained(tmpdirname) + model.save_motion_modules(tmp_path) + assert os.path.isfile(os.path.join(tmp_path, "diffusion_pytorch_model.safetensors")) - torch.manual_seed(0) - model_loaded = self.model_class(**init_dict) - model_loaded.load_motion_modules(adapter_loaded) - model_loaded.to(torch_device) + adapter_loaded = MotionAdapter.from_pretrained(tmp_path) + torch.manual_seed(0) + model_loaded = self.model_class(**init_dict) + model_loaded.load_motion_modules(adapter_loaded) + model_loaded.to(torch_device) with torch.no_grad(): - output = model(**inputs_dict)[0] - output_loaded = model_loaded(**inputs_dict)[0] - - max_diff = (output - output_loaded).abs().max().item() - self.assertLessEqual(max_diff, 1e-4, "Models give different forward passes") + output = model(**self.get_dummy_inputs())[0] + output_loaded = model_loaded(**self.get_dummy_inputs())[0] - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_xformers_enable_works(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - - model.enable_xformers_memory_efficient_attention() - - assert ( - model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ - == "XFormersAttnProcessor" - ), "xformers is not enabled" - - def test_gradient_checkpointing_is_applied(self): - expected_set = { - "CrossAttnUpBlockMotion", - "CrossAttnDownBlockMotion", - "UNetMidBlockCrossAttnMotion", - "UpBlockMotion", - "Transformer2DModel", - "DownBlockMotion", - } - super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + assert (output - output_loaded).abs().max().item() <= 1e-4, "Models give different forward passes" def test_feed_forward_chunking(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() init_dict["block_out_channels"] = (32, 64) init_dict["norm_num_groups"] = 32 - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() + model = self.model_class(**init_dict).to(torch_device).eval() with torch.no_grad(): - output = model(**inputs_dict)[0] + output = model(**self.get_dummy_inputs())[0] model.enable_forward_chunking() with torch.no_grad(): - output_2 = model(**inputs_dict)[0] + output_2 = model(**self.get_dummy_inputs())[0] - self.assertEqual(output.shape, output_2.shape, "Shape doesn't match") - assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2 + assert output.shape == output_2.shape, "Shape doesn't match" + assert (output - output_2).abs().max() < 1e-2 def test_pickle(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) + model = self.model_class(**self.get_init_dict()).to(torch_device) with torch.no_grad(): - sample = model(**inputs_dict).sample + sample = model(**self.get_dummy_inputs()).sample sample_copy = copy.copy(sample) - assert (sample - sample_copy).abs().max() < 1e-4 - def test_from_save_pretrained(self, expected_max_diff=5e-5): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname, safe_serialization=False) - torch.manual_seed(0) - new_model = self.model_class.from_pretrained(tmpdirname) - new_model.to(torch_device) - - with torch.no_grad(): - image = model(**inputs_dict) - if isinstance(image, dict): - image = image.to_tuple()[0] - - new_image = new_model(**inputs_dict) - - if isinstance(new_image, dict): - new_image = new_image.to_tuple()[0] - - max_diff = (image - new_image).abs().max().item() - self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes") - - def test_from_save_pretrained_variant(self, expected_max_diff=5e-5): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - torch.manual_seed(0) - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname, variant="fp16", safe_serialization=False) - - torch.manual_seed(0) - new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16") - # non-variant cannot be loaded - with self.assertRaises(OSError) as error_context: - self.model_class.from_pretrained(tmpdirname) - - # make sure that error message states what keys are missing - assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(error_context.exception) - - new_model.to(torch_device) - - with torch.no_grad(): - image = model(**inputs_dict) - if isinstance(image, dict): - image = image.to_tuple()[0] - - new_image = new_model(**inputs_dict) - - if isinstance(new_image, dict): - new_image = new_image.to_tuple()[0] - - max_diff = (image - new_image).abs().max().item() - self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes") - def test_forward_with_norm_groups(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() init_dict["norm_num_groups"] = 16 init_dict["block_out_channels"] = (16, 32) - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() + model = self.model_class(**init_dict).to(torch_device).eval() with torch.no_grad(): - output = model(**inputs_dict) + output = model(**self.get_dummy_inputs()).sample - if isinstance(output, dict): - output = output.to_tuple()[0] - - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == self.get_dummy_inputs()["sample"].shape, "Input and output shapes do not match" def test_asymmetric_motion_model(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() init_dict["layers_per_block"] = (2, 3) init_dict["transformer_layers_per_block"] = ((1, 2), (3, 4, 5)) init_dict["reverse_transformer_layers_per_block"] = ((7, 6, 7, 4), (4, 2, 2)) - init_dict["temporal_transformer_layers_per_block"] = ((2, 5), (2, 3, 5)) init_dict["reverse_temporal_transformer_layers_per_block"] = ((5, 4, 3, 4), (3, 2, 2)) - init_dict["num_attention_heads"] = (2, 4) init_dict["motion_num_attention_heads"] = (4, 4) init_dict["reverse_motion_num_attention_heads"] = (2, 2) - init_dict["use_motion_mid_block"] = True init_dict["mid_block_layers"] = 2 init_dict["transformer_layers_per_mid_block"] = (1, 5) init_dict["temporal_transformer_layers_per_mid_block"] = (2, 4) - - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() + model = self.model_class(**init_dict).to(torch_device).eval() with torch.no_grad(): - output = model(**inputs_dict) + output = model(**self.get_dummy_inputs()).sample + + assert output.shape == self.get_dummy_inputs()["sample"].shape, "Input and output shapes do not match" + + +class TestUNetMotionModelTraining(UNetMotionModelTesterConfig, TrainingTesterMixin): + """Training tests for UNetMotionModel.""" + + def test_gradient_checkpointing_is_applied(self): + expected_set = { + "CrossAttnUpBlockMotion", + "CrossAttnDownBlockMotion", + "UNetMidBlockCrossAttnMotion", + "UpBlockMotion", + "Transformer2DModel", + "DownBlockMotion", + } + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + + +class TestUNetMotionModelMemory(UNetMotionModelTesterConfig, MemoryTesterMixin): + """Memory optimization tests for UNetMotionModel.""" - if isinstance(output, dict): - output = output.to_tuple()[0] - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") +class TestUNetMotionModelAttention(UNetMotionModelTesterConfig, AttentionTesterMixin): + """Attention processor tests for UNetMotionModel."""