diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index d695f5e7284d..f8790c313a72 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -952,6 +952,12 @@ def _get_load_method(class_obj: object, load_method_name: str, is_dduf: bool) -> def _fetch_class_library_tuple(module): + if isinstance(module, (bool, int, float, str)): + raise TypeError( + f"Expected a model module, but got scalar value of type {type(module).__name__}. " + "This can happen when a subclass calls super().__init__() with a mismatched signature." + ) + # import it here to avoid circular import diffusers_module = importlib.import_module(__name__.split(".")[0]) pipelines = getattr(diffusers_module, "pipelines") diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 1fa4db90d995..bccbe257f89a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -212,12 +212,17 @@ def register_modules(self, **kwargs): # retrieve library if module is None or isinstance(module, (tuple, list)) and module[0] is None: register_dict = {name: (None, None)} + elif isinstance(module, (bool, int, float, str)): + # Scalar values (e.g. bool mis-passed as image_encoder from legacy super().__init__) are not + # model components — skip library tuple registration. See https://github.com/huggingface/diffusers/issues/6969 + register_dict = {} else: library, class_name = _fetch_class_library_tuple(module) register_dict = {name: (library, class_name)} # save model index config - self.register_to_config(**register_dict) + if register_dict: + self.register_to_config(**register_dict) # set models setattr(self, name, module) @@ -227,7 +232,10 @@ def __setattr__(self, name: str, value: Any): # We need to overwrite the config if name exists in config if isinstance(getattr(self.config, name), (tuple, list)): if value is not None and self.config[name][0] is not None: - class_library_tuple = _fetch_class_library_tuple(value) + if isinstance(value, (bool, int, float, str)): + class_library_tuple = (None, None) + else: + class_library_tuple = _fetch_class_library_tuple(value) else: class_library_tuple = (None, None) diff --git a/tests/pipelines/test_register_modules_legacy_init.py b/tests/pipelines/test_register_modules_legacy_init.py new file mode 100644 index 000000000000..1e64a3357fbb --- /dev/null +++ b/tests/pipelines/test_register_modules_legacy_init.py @@ -0,0 +1,122 @@ +import unittest + +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.pipelines.pipeline_loading_utils import _fetch_class_library_tuple + +from ..testing_utils import require_torch + + +def _get_dummy_sd_components(): + cross_attention_dim = 8 + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(4, 8), + layers_per_block=1, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=cross_attention_dim, + norm_num_groups=2, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[4, 8], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + norm_num_groups=2, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=cross_attention_dim, + intermediate_size=16, + layer_norm_eps=1e-05, + num_attention_heads=2, + num_hidden_layers=2, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + return { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + + +class LegacyStableDiffusionPipeline(StableDiffusionPipeline): + """Simulates community pipelines with old positional super().__init__ signature.""" + + def __init__( + self, + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + requires_safety_checker=True, + ): + super().__init__( + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + requires_safety_checker, + ) + + +@require_torch +class RegisterModulesLegacyInitTests(unittest.TestCase): + def test_register_modules_scalar_bool_no_crash(self): + class DummyPipeline(DiffusionPipeline): + def __init__(self): + super().__init__() + + pipe = DummyPipeline() + pipe.register_modules(image_encoder=True) + self.assertIs(pipe.image_encoder, True) + + def test_fetch_class_library_tuple_scalar_raises_type_error(self): + with self.assertRaises(TypeError): + _fetch_class_library_tuple(True) + + def test_legacy_sd_pipeline_positional_init(self): + components = _get_dummy_sd_components() + # Legacy positional super().__init__ must not crash (#6969) + pipe = LegacyStableDiffusionPipeline(**components, requires_safety_checker=False) + self.assertIsNotNone(pipe.unet) + # bool lands on image_encoder attribute due to signature mismatch, but init completes + self.assertIs(pipe.image_encoder, False) + + def test_new_signature_image_encoder_none(self): + components = _get_dummy_sd_components() + components["image_encoder"] = None + pipe = StableDiffusionPipeline(**components, requires_safety_checker=False) + self.assertIsNone(pipe.image_encoder) + self.assertFalse(pipe.config.requires_safety_checker)