Skip to content

Commit 4983dcd

Browse files
committed
Fix multi-adapter ControlNet-XS bugs in model and pipelines
Guard optional time embeddings before device moves, remove hardcoded CUDA casts, pass per-adapter timestep conditioning, register control modules with ModuleList/add_module, fix downsampler freeze attribute, correct multi-image batch repeat logic, and preserve all adapter types in pipeline config when registering module lists.
1 parent 6c2a47f commit 4983dcd

4 files changed

Lines changed: 54 additions & 44 deletions

File tree

src/diffusers/models/controlnets/controlnet_xs.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,10 +1921,14 @@ def __init__(
19211921

19221922
self.ctrl_time_embeddings = []
19231923
self.ctrl_time_embedding_mix = []
1924-
for controlnet in controlnets:
1924+
for index, controlnet in enumerate(controlnets):
19251925
self.ctrl_time_embedding_mix.append(controlnet.get("time_embedding_mix"))
19261926
if controlnet.get("ctrl_learn_time_embedding"):
1927-
self.ctrl_time_embeddings.append(TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim))
1927+
ctrl_time_embedding = TimestepEmbedding(
1928+
in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim
1929+
)
1930+
self.ctrl_time_embeddings.append(ctrl_time_embedding)
1931+
self.add_module(f"ctrl_time_embedding_{index}", ctrl_time_embedding)
19281932
else:
19291933
self.ctrl_time_embeddings.append(None)
19301934

@@ -2194,19 +2198,19 @@ def forward(
21942198
# there might be better ways to encapsulate this.
21952199
t_emb = t_emb.to(dtype=sample.dtype)
21962200

2197-
time_embedding_list = []
2201+
base_temb = self.base_time_embedding(t_emb, timestep_cond)
2202+
base_temb = base_temb + aug_emb if aug_emb is not None else base_temb
2203+
2204+
ctrl_temb_list = []
21982205
for index in range(len(self.ctrl_time_embeddings)):
2199-
base_temb = self.base_time_embedding(t_emb, timestep_cond)
22002206
temb_module = self.ctrl_time_embeddings[index]
2201-
temb_module.to(t_emb.device, t_emb.dtype)
22022207
if temb_module is not None and apply_control:
22032208
ctrl_temb = temb_module(t_emb, timestep_cond)
2204-
interpolation_param = self.ctrl_time_embedding_mix[index]**0.3
2205-
temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param)
2209+
interpolation_param = self.ctrl_time_embedding_mix[index] ** 0.3
2210+
ctrl_temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param)
22062211
else:
2207-
temb = base_temb
2208-
temb = temb + aug_emb if aug_emb is not None else temb
2209-
time_embedding_list.append(temb)
2212+
ctrl_temb = base_temb
2213+
ctrl_temb_list.append(ctrl_temb)
22102214

22112215

22122216
# text embeddings
@@ -2220,7 +2224,8 @@ def forward(
22202224
h_base, h_ctrls, residual_hb, residual_hc = down(
22212225
hidden_states_base=h_base,
22222226
hidden_states_ctrl=h_ctrls,
2223-
temb=temb,
2227+
temb=base_temb,
2228+
ctrl_temb_list=ctrl_temb_list,
22242229
encoder_hidden_states=cemb,
22252230
conditioning_scale=conditioning_scale,
22262231
cross_attention_kwargs=cross_attention_kwargs,
@@ -2236,7 +2241,8 @@ def forward(
22362241
h_base, h_ctrls = self.mid_block(
22372242
hidden_states_base=h_base,
22382243
hidden_states_ctrl=h_ctrls,
2239-
temb=temb,
2244+
temb=base_temb,
2245+
ctrl_temb_list=ctrl_temb_list,
22402246
encoder_hidden_states=cemb,
22412247
conditioning_scale=conditioning_scale,
22422248
cross_attention_kwargs=cross_attention_kwargs,
@@ -2255,7 +2261,7 @@ def forward(
22552261
hidden_states=h_base,
22562262
res_hidden_states_tuple_base=skips_hb,
22572263
res_hidden_states_tuple_ctrl=skips_hc,
2258-
temb=temb,
2264+
temb=base_temb,
22592265
encoder_hidden_states=cemb,
22602266
conditioning_scale=conditioning_scale,
22612267
cross_attention_kwargs=cross_attention_kwargs,
@@ -2527,8 +2533,8 @@ def freeze_base_params(self) -> None:
25272533
base_parts = [self.base_resnets]
25282534
if isinstance(self.base_attentions, nn.ModuleList): # attentions can be a list of Nones
25292535
base_parts.append(self.base_attentions)
2530-
if self.base_downsamplers is not None:
2531-
base_parts.append(self.base_downsamplers)
2536+
if self.base_downsampler is not None:
2537+
base_parts.append(self.base_downsampler)
25322538
for part in base_parts:
25332539
for param in part.parameters():
25342540
param.requires_grad = False
@@ -2539,6 +2545,7 @@ def forward(
25392545
temb: Tensor,
25402546
encoder_hidden_states: Optional[Tensor] = None, # text embedding
25412547
hidden_states_ctrl: Optional[List[Tensor]] = None, # ctrl embedding
2548+
ctrl_temb_list: Optional[List[Tensor]] = None,
25422549
conditioning_scale: Optional[List[float]] = None,
25432550
attention_mask: Optional[Tensor] = None,
25442551
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -2596,10 +2603,11 @@ def forward(
25962603
if apply_control:
25972604
for index in range(len(h_ctrl_list)):
25982605
c_res, c_attn = ctrl_blocks_list[index][layer_index]
2606+
ctrl_temb = ctrl_temb_list[index] if ctrl_temb_list is not None else temb
25992607
if torch.is_grad_enabled() and self.gradient_checkpointing:
2600-
h_ctrl_list[index] = self._gradient_checkpointing_func(c_res, h_ctrl_list[index], temb)
2608+
h_ctrl_list[index] = self._gradient_checkpointing_func(c_res, h_ctrl_list[index], ctrl_temb)
26012609
else:
2602-
h_ctrl_list[index] = c_res(h_ctrl_list[index], temb)
2610+
h_ctrl_list[index] = c_res(h_ctrl_list[index], ctrl_temb)
26032611
if c_attn is not None:
26042612
h_ctrl_list[index] = c_attn(
26052613
h_ctrl_list[index],
@@ -2665,9 +2673,9 @@ def __init__(
26652673

26662674
# Before the midblock application, information is concatted from base to control.
26672675
# Concat doesn't require change in number of channels
2668-
self.base_to_ctrl = []
2669-
for i in range(len(ctrl_channels_list)):
2670-
self.base_to_ctrl.append(make_zero_conv(base_channels, base_channels))
2676+
self.base_to_ctrl = nn.ModuleList(
2677+
[make_zero_conv(base_channels, base_channels) for _ in range(len(ctrl_channels_list))]
2678+
)
26712679

26722680
self.base_midblock = UNetMidBlock2DCrossAttn(
26732681
transformer_layers_per_block=transformer_layers_per_block,
@@ -2680,7 +2688,7 @@ def __init__(
26802688
upcast_attention=upcast_attention,
26812689
)
26822690

2683-
self.ctrl_midblocks = []
2691+
ctrl_midblocks = []
26842692
for i in range(len(ctrl_channels_list)):
26852693
ctrl_channels = ctrl_channels_list[i]
26862694
ctrl_max_norm_num_groups = ctrl_max_norm_num_groups_list[i]
@@ -2701,14 +2709,15 @@ def __init__(
27012709
upcast_attention=upcast_attention,
27022710
)
27032711

2704-
self.ctrl_midblocks.append(ctrl_midblock)
2712+
ctrl_midblocks.append(ctrl_midblock)
2713+
2714+
self.ctrl_midblocks = nn.ModuleList(ctrl_midblocks)
27052715

27062716
# After the midblock application, information is added from control to base
27072717
# Addition requires change in number of channels
2708-
self.ctrl_to_base = []
2709-
for i in range(len(ctrl_channels_list)):
2710-
ctrl_channels = ctrl_channels_list[i]
2711-
self.ctrl_to_base.append(make_zero_conv(ctrl_channels, base_channels))
2718+
self.ctrl_to_base = nn.ModuleList(
2719+
[make_zero_conv(ctrl_channels_list[i], base_channels) for i in range(len(ctrl_channels_list))]
2720+
)
27122721

27132722
self.gradient_checkpointing = False
27142723

@@ -2787,6 +2796,7 @@ def forward(
27872796
encoder_hidden_states: Tensor,
27882797
hidden_states_ctrl: List[Tensor],
27892798
conditioning_scale: List[float],
2799+
ctrl_temb_list: Optional[List[Tensor]] = None,
27902800
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
27912801
attention_mask: Optional[Tensor] = None,
27922802
encoder_attention_mask: Optional[Tensor] = None,
@@ -2810,15 +2820,15 @@ def forward(
28102820

28112821
if apply_control:
28122822
for i in range(len(h_ctrl_list)):
2813-
b2c = self.base_to_ctrl[i].to(device="cuda", dtype=h_base.dtype)
2814-
h_ctrl_list[i] = torch.cat([h_ctrl_list[i], b2c(h_base)], dim=1)
2823+
h_ctrl_list[i] = torch.cat([h_ctrl_list[i], self.base_to_ctrl[i](h_base)], dim=1)
28152824
h_base = self.base_midblock(h_base, **joint_args)
28162825
if apply_control:
28172826
for i in range(len(h_ctrl_list)):
2818-
mid = self.ctrl_midblocks[i].to(device="cuda", dtype=h_ctrl_list[i].dtype)
2819-
h_ctrl_list[i] = mid(h_ctrl_list[i], **joint_args) # apply ctrl mid block
2820-
c2b = self.ctrl_to_base[i].to(device="cuda", dtype=h_ctrl_list[i].dtype)
2821-
h_base = h_base + c2b(h_ctrl_list[i]) * conditioning_scale[i] # add ctrl -> base
2827+
ctrl_joint_args = dict(joint_args)
2828+
if ctrl_temb_list is not None:
2829+
ctrl_joint_args["temb"] = ctrl_temb_list[i]
2830+
h_ctrl_list[i] = self.ctrl_midblocks[i](h_ctrl_list[i], **ctrl_joint_args) # apply ctrl mid block
2831+
h_base = h_base + self.ctrl_to_base[i](h_ctrl_list[i]) * conditioning_scale[i] # add ctrl -> base
28222832

28232833
return h_base, h_ctrl_list
28242834

src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,8 +633,12 @@ def prepare_image_list(
633633
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
634634
image_list.append(image)
635635

636-
repeat_by = batch_size if len(image_list) == 1 else num_images_per_prompt
637636
for index in range(len(image_list)):
637+
image_batch_size = image_list[index].shape[0]
638+
if image_batch_size == 1:
639+
repeat_by = batch_size
640+
else:
641+
repeat_by = num_images_per_prompt
638642
image_list[index] = image_list[index].repeat_interleave(repeat_by, dim=0)
639643
image_list[index] = image_list[index].to(device=device, dtype=dtype)
640644
if do_classifier_free_guidance:

src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs_sd_xl.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -680,15 +680,13 @@ def prepare_image_list(
680680
for image in images:
681681
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
682682
image_list.append(image)
683-
image_batch_size = len(image_list)
684-
685-
if image_batch_size == 1:
686-
repeat_by = batch_size
687-
else:
688-
# image batch size is the same as prompt batch size
689-
repeat_by = num_images_per_prompt
690683

691684
for index in range(len(image_list)):
685+
image_batch_size = image_list[index].shape[0]
686+
if image_batch_size == 1:
687+
repeat_by = batch_size
688+
else:
689+
repeat_by = num_images_per_prompt
692690
image_list[index] = image_list[index].repeat_interleave(repeat_by, dim=0)
693691
image_list[index] = image_list[index].to(device=device, dtype=dtype)
694692
if do_classifier_free_guidance:
@@ -1104,7 +1102,7 @@ def __call__(
11041102
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
11051103
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
11061104

1107-
latent_model_input = latent_model_input.to(device="cuda", dtype=self.unet.dtype)
1105+
latent_model_input = latent_model_input.to(dtype=self.unet.dtype)
11081106

11091107
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
11101108

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,7 @@ def register_modules(self, **kwargs):
213213
if module is None or isinstance(module, (tuple, list)) and module[0] is None:
214214
register_dict = {name: (None, None)}
215215
elif isinstance(module, list) and module[0] is not None:
216-
for mod in module:
217-
library, class_name = _fetch_class_library_tuple(mod)
218-
register_dict = {name: (library, class_name)}
216+
register_dict = {name: [_fetch_class_library_tuple(mod) for mod in module]}
219217
else:
220218
library, class_name = _fetch_class_library_tuple(module)
221219
register_dict = {name: (library, class_name)}

0 commit comments

Comments
 (0)