Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,8 @@ def enable_layerwise_casting(
apply_layerwise_casting(
self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
)
# Casting hooks change the reported dtype without flowing through `_apply`, so invalidate the cache here.
self.__dict__.pop("_cached_dtype", None)

def enable_group_offload(
self,
Expand Down Expand Up @@ -1903,14 +1905,31 @@ def device(self) -> torch.device:
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
"""
# Not cached: with group offloading, the effective device changes per-forward as groups onload/offload.
return get_parameter_device(self)

@property
def dtype(self) -> torch.dtype:
"""
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
return get_parameter_dtype(self)
cached = self.__dict__.get("_cached_dtype")
Comment thread
akshan-main marked this conversation as resolved.
if cached is not None:
return cached
cached = get_parameter_dtype(self)
self.__dict__["_cached_dtype"] = cached
return cached

def _apply(self, fn, *args, **kwargs):
# Invalidate cached dtype since `.to()`, `.cpu()`, `.cuda()`, `.half()`, etc. all flow through `_apply`.
self.__dict__.pop("_cached_dtype", None)
return super()._apply(fn, *args, **kwargs)

def register_parameter(self, name, param):
# Some modules change dtype by reassigning parameters (e.g. a custom `.to()`) instead of going through
# `_apply`, so invalidate here too.
self.__dict__.pop("_cached_dtype", None)
return super().register_parameter(name, param)

def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
"""
Expand Down
30 changes: 30 additions & 0 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,36 @@ class ModelUtilsTest(unittest.TestCase):
def tearDown(self):
super().tearDown()

def _get_dummy_unet(self):
return UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=1,
sample_size=16,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=8,
norm_num_groups=4,
)

def test_dtype_cache_reflects_to_and_half(self):
# `dtype` is cached and invalidated through `_apply`, which `.to()` / `.half()` flow through.
model = self._get_dummy_unet()
assert model.dtype == torch.float32
model.half()
assert model.dtype == torch.float16
model.to(torch.float32)
assert model.dtype == torch.float32

def test_dtype_cache_invalidated_by_layerwise_casting(self):
# Layerwise casting reports `compute_dtype` via a hook without flowing through `_apply`,
# so the cached dtype must be invalidated when it is enabled.
model = self._get_dummy_unet()
assert model.dtype == torch.float32 # populate the cache
model.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float16)
assert model.dtype == torch.float16

def test_missing_key_loading_warning_message(self):
with self.assertLogs("diffusers.models.modeling_utils", level="WARNING") as logs:
UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")
Expand Down
Loading