diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 0423b7287193..096865f4f198 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -516,6 +516,8 @@ def enable_layerwise_casting( apply_layerwise_casting( self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking ) + # Casting hooks change the reported dtype without flowing through `_apply`, so invalidate the cache here. + self.__dict__.pop("_cached_dtype", None) def enable_group_offload( self, @@ -1903,6 +1905,7 @@ def device(self) -> torch.device: `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). """ + # Not cached: with group offloading, the effective device changes per-forward as groups onload/offload. return get_parameter_device(self) @property @@ -1910,7 +1913,23 @@ def dtype(self) -> torch.dtype: """ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). """ - return get_parameter_dtype(self) + cached = self.__dict__.get("_cached_dtype") + if cached is not None: + return cached + cached = get_parameter_dtype(self) + self.__dict__["_cached_dtype"] = cached + return cached + + def _apply(self, fn, *args, **kwargs): + # Invalidate cached dtype since `.to()`, `.cpu()`, `.cuda()`, `.half()`, etc. all flow through `_apply`. + self.__dict__.pop("_cached_dtype", None) + return super()._apply(fn, *args, **kwargs) + + def register_parameter(self, name, param): + # Some modules change dtype by reassigning parameters (e.g. a custom `.to()`) instead of going through + # `_apply`, so invalidate here too. + self.__dict__.pop("_cached_dtype", None) + return super().register_parameter(name, param) def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: """ diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index dc961c70c0fe..26357d155d3b 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -218,6 +218,36 @@ class ModelUtilsTest(unittest.TestCase): def tearDown(self): super().tearDown() + def _get_dummy_unet(self): + return UNet2DConditionModel( + block_out_channels=(4, 8), + layers_per_block=1, + sample_size=16, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=8, + norm_num_groups=4, + ) + + def test_dtype_cache_reflects_to_and_half(self): + # `dtype` is cached and invalidated through `_apply`, which `.to()` / `.half()` flow through. + model = self._get_dummy_unet() + assert model.dtype == torch.float32 + model.half() + assert model.dtype == torch.float16 + model.to(torch.float32) + assert model.dtype == torch.float32 + + def test_dtype_cache_invalidated_by_layerwise_casting(self): + # Layerwise casting reports `compute_dtype` via a hook without flowing through `_apply`, + # so the cached dtype must be invalidated when it is enabled. + model = self._get_dummy_unet() + assert model.dtype == torch.float32 # populate the cache + model.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float16) + assert model.dtype == torch.float16 + def test_missing_key_loading_warning_message(self): with self.assertLogs("diffusers.models.modeling_utils", level="WARNING") as logs: UNet2DConditionModel.from_pretrained("hf-internal-testing/stable-diffusion-broken", subfolder="unet")