From c5c006a8f3ef9e5d40881dde4bb272dc0b75f539 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Mon, 8 Jun 2026 22:48:09 -0700 Subject: [PATCH] refactor unet_spatiotemporal tests --- .../unets/test_models_unet_spatiotemporal.py | 210 ++++++------------ 1 file changed, 67 insertions(+), 143 deletions(-) diff --git a/tests/models/unets/test_models_unet_spatiotemporal.py b/tests/models/unets/test_models_unet_spatiotemporal.py index 7df868c9e95b..fc2ad7dec8dc 100644 --- a/tests/models/unets/test_models_unet_spatiotemporal.py +++ b/tests/models/unets/test_models_unet_spatiotemporal.py @@ -14,77 +14,46 @@ # limitations under the License. import copy -import unittest import torch from diffusers import UNetSpatioTemporalConditionModel -from diffusers.utils import logging -from diffusers.utils.import_utils import is_xformers_available - -from ...testing_utils import ( - enable_full_determinism, - floats_tensor, - skip_mps, - torch_device, +from diffusers.utils.torch_utils import randn_tensor + +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import ( + AttentionTesterMixin, + BaseModelTesterConfig, + MemoryTesterMixin, + ModelTesterMixin, + TrainingTesterMixin, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin -logger = logging.get_logger(__name__) - enable_full_determinism() -@skip_mps -class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): - model_class = UNetSpatioTemporalConditionModel - main_input_name = "sample" +class UNetSpatioTemporalConditionModelTesterConfig(BaseModelTesterConfig): + addition_time_embed_dim = 32 @property - def dummy_input(self): - batch_size = 2 - num_frames = 2 - num_channels = 4 - sizes = (32, 32) - - noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device) - time_step = torch.tensor([10]).to(torch_device) - encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device) - - return { - "sample": noise, - "timestep": time_step, - "encoder_hidden_states": encoder_hidden_states, - "added_time_ids": self._get_add_time_ids(), - } + def model_class(self): + return UNetSpatioTemporalConditionModel @property - def input_shape(self): - return (2, 2, 4, 32, 32) + def main_input_name(self) -> str: + return "sample" @property - def output_shape(self): + def output_shape(self) -> tuple: return (4, 32, 32) @property - def fps(self): - return 6 - - @property - def motion_bucket_id(self): - return 127 - - @property - def noise_aug_strength(self): - return 0.02 + def generator(self): + return torch.Generator("cpu").manual_seed(0) - @property - def addition_time_embed_dim(self): - return 32 - - def prepare_init_args_and_inputs_for_common(self): - init_dict = { + def get_init_dict(self) -> dict: + return { "block_out_channels": (32, 64), "down_block_types": ( "CrossAttnDownBlockSpatioTemporal", @@ -103,98 +72,62 @@ def prepare_init_args_and_inputs_for_common(self): "projection_class_embeddings_input_dim": self.addition_time_embed_dim * 3, "addition_time_embed_dim": self.addition_time_embed_dim, } - inputs_dict = self.dummy_input - return init_dict, inputs_dict - - def _get_add_time_ids(self, do_classifier_free_guidance=True): - add_time_ids = [self.fps, self.motion_bucket_id, self.noise_aug_strength] - - passed_add_embed_dim = self.addition_time_embed_dim * len(add_time_ids) - expected_add_embed_dim = self.addition_time_embed_dim * 3 - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], device=torch_device) - add_time_ids = add_time_ids.repeat(1, 1) - if do_classifier_free_guidance: - add_time_ids = torch.cat([add_time_ids, add_time_ids]) - - return add_time_ids - - @unittest.skip("Number of Norm Groups is not configurable") - def test_forward_with_norm_groups(self): - pass - - @unittest.skip("Deprecated functionality") - def test_model_attention_slicing(self): - pass - - @unittest.skip("Not supported") - def test_model_with_use_linear_projection(self): - pass - - @unittest.skip("Not supported") - def test_model_with_simple_projection(self): - pass - - @unittest.skip("Not supported") - def test_model_with_class_embeddings_concat(self): - pass - - @unittest.skipIf( - torch_device != "cuda" or not is_xformers_available(), - reason="XFormers attention is only available with CUDA and `xformers` installed", - ) - def test_xformers_enable_works(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.enable_xformers_memory_efficient_attention() + def get_dummy_inputs(self) -> dict: + batch_size = 2 + num_frames = 2 + num_channels = 4 + sizes = (32, 32) + noise = randn_tensor( + (batch_size, num_frames, num_channels, *sizes), generator=self.generator, device=torch_device + ) + timestep = torch.tensor([10], device=torch_device) + encoder_hidden_states = randn_tensor((batch_size, 1, 32), generator=self.generator, device=torch_device) + add_time_ids = torch.tensor([[6, 127, 0.02]], device=torch_device) + add_time_ids = torch.cat([add_time_ids, add_time_ids]) + return { + "sample": noise, + "timestep": timestep, + "encoder_hidden_states": encoder_hidden_states, + "added_time_ids": add_time_ids, + } - assert ( - model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__ - == "XFormersAttnProcessor" - ), "xformers is not enabled" +class TestUNetSpatioTemporalConditionModel(UNetSpatioTemporalConditionModelTesterConfig, ModelTesterMixin): def test_model_with_num_attention_heads_tuple(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() init_dict["num_attention_heads"] = (8, 16) - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() + model = self.model_class(**init_dict).to(torch_device).eval() with torch.no_grad(): - output = model(**inputs_dict) - - if isinstance(output, dict): - output = output.sample + output = model(**self.get_dummy_inputs()).sample - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + assert output.shape == self.get_dummy_inputs()["sample"].shape, "Input and output shapes do not match" def test_model_with_cross_attention_dim_tuple(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - + init_dict = self.get_init_dict() init_dict["cross_attention_dim"] = (32, 32) + model = self.model_class(**init_dict).to(torch_device).eval() + + with torch.no_grad(): + output = model(**self.get_dummy_inputs()).sample - model = self.model_class(**init_dict) - model.to(torch_device) - model.eval() + assert output.shape == self.get_dummy_inputs()["sample"].shape, "Input and output shapes do not match" + + def test_pickle(self): + init_dict = self.get_init_dict() + init_dict["num_attention_heads"] = (8, 16) + model = self.model_class(**init_dict).to(torch_device) with torch.no_grad(): - output = model(**inputs_dict) + sample = model(**self.get_dummy_inputs()).sample + + sample_copy = copy.copy(sample) + assert (sample - sample_copy).abs().max() < 1e-4 - if isinstance(output, dict): - output = output.sample - self.assertIsNotNone(output) - expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") +class TestUNetSpatioTemporalConditionModelTraining(UNetSpatioTemporalConditionModelTesterConfig, TrainingTesterMixin): + """Training tests for UNetSpatioTemporalConditionModel.""" def test_gradient_checkpointing_is_applied(self): expected_set = { @@ -205,23 +138,14 @@ def test_gradient_checkpointing_is_applied(self): "CrossAttnUpBlockSpatioTemporal", "UNetMidBlockSpatioTemporal", } - num_attention_heads = (8, 16) - super().test_gradient_checkpointing_is_applied( - expected_set=expected_set, num_attention_heads=num_attention_heads - ) + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - def test_pickle(self): - # enable deterministic behavior for gradient checkpointing - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - init_dict["num_attention_heads"] = (8, 16) - model = self.model_class(**init_dict) - model.to(torch_device) +class TestUNetSpatioTemporalConditionModelMemory(UNetSpatioTemporalConditionModelTesterConfig, MemoryTesterMixin): + """Memory optimization tests for UNetSpatioTemporalConditionModel.""" - with torch.no_grad(): - sample = model(**inputs_dict).sample - sample_copy = copy.copy(sample) - - assert (sample - sample_copy).abs().max() < 1e-4 +class TestUNetSpatioTemporalConditionModelAttention( + UNetSpatioTemporalConditionModelTesterConfig, AttentionTesterMixin +): + """Attention processor tests for UNetSpatioTemporalConditionModel."""