Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 54 additions & 117 deletions tests/models/unets/test_models_unet_3d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,52 +13,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch

from diffusers.models import ModelMixin, UNet3DConditionModel
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers import UNet3DConditionModel
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TrainingTesterMixin,
)


enable_full_determinism()

logger = logging.get_logger(__name__)


@skip_mps
class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet3DConditionModel
main_input_name = "sample"

class UNet3DConditionModelTesterConfig(BaseModelTesterConfig):
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
num_frames = 4
sizes = (16, 16)
def model_class(self):
return UNet3DConditionModel

noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)

return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def main_input_name(self) -> str:
return "sample"

@property
def input_shape(self):
def output_shape(self) -> tuple:
return (4, 4, 16, 16)

@property
def output_shape(self):
return (4, 4, 16, 16)
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
def get_init_dict(self) -> dict:
return {
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": (
Expand All @@ -73,111 +64,57 @@ def prepare_init_args_and_inputs_for_common(self):
"layers_per_block": 1,
"sample_size": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)

model.enable_xformers_memory_efficient_attention()
def get_dummy_inputs(self) -> dict:
batch_size = 4
num_channels = 4
num_frames = 4
sizes = (16, 16)
noise = randn_tensor(
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
)
timestep = torch.tensor([10], device=torch_device)
encoder_hidden_states = randn_tensor((batch_size, 4, 8), generator=self.generator, device=torch_device)
return {"sample": noise, "timestep": timestep, "encoder_hidden_states": encoder_hidden_states}

assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"

# Overriding to set `norm_num_groups` needs to be different for this model.
class TestUNet3DConditionModel(UNet3DConditionModelTesterConfig, ModelTesterMixin):
# Overridden because UNet3DConditionModel needs a different `norm_num_groups`.
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32

model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
model = self.model_class(**init_dict).to(torch_device).eval()

with torch.no_grad():
output = model(**inputs_dict)
output = model(**self.get_dummy_inputs()).sample

if isinstance(output, dict):
output = output.sample
assert output.shape == self.get_dummy_inputs()["sample"].shape, "Input and output shapes do not match"

self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

# Overriding since the UNet3D outputs a different structure.
def test_determinism(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
def test_feed_forward_chunking(self):
init_dict = self.get_init_dict()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict).to(torch_device).eval()

with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
model(**self.dummy_input)

first = model(**inputs_dict)
if isinstance(first, dict):
first = first.sample

second = model(**inputs_dict)
if isinstance(second, dict):
second = second.sample

out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
output = model(**self.get_dummy_inputs())[0]

def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = 8

model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()

model.set_attention_slice("auto")
model.enable_forward_chunking()
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
output_2 = model(**self.get_dummy_inputs())[0]

model.set_attention_slice("max")
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
assert output.shape == output_2.shape, "Shape doesn't match"
assert (output - output_2).abs().max() < 1e-2

model.set_attention_slice(2)
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None

def test_feed_forward_chunking(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
class TestUNet3DConditionModelTraining(UNet3DConditionModelTesterConfig, TrainingTesterMixin):
"""Training tests for UNet3DConditionModel."""

model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()

with torch.no_grad():
output = model(**inputs_dict)[0]
class TestUNet3DConditionModelMemory(UNet3DConditionModelTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for UNet3DConditionModel."""

model.enable_forward_chunking()
with torch.no_grad():
output_2 = model(**inputs_dict)[0]

self.assertEqual(output.shape, output_2.shape, "Shape doesn't match")
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
class TestUNet3DConditionModelAttention(UNet3DConditionModelTesterConfig, AttentionTesterMixin):
"""Attention processor tests for UNet3DConditionModel."""
Loading
Loading