Add multi ControlNetXS adapter support (eval r2 #8434)#7
Conversation
Resolves huggingface#8434 by introducing UNetMultiControlNetXSModel and allowing controlnet=[adapter, ...] in ControlNet-XS pipelines with per-adapter conditioning scales and combined bidirectional fusion. Co-authored-by: Cursor <cursoragent@cursor.com>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 8 potential issues.
Bugbot Autofix prepared fixes for all 8 issues found in the latest run.
- ✅ Fixed: None time embedding
.tocrash- Removed the unconditional
.to()call on optional time-embedding modules and only invoke learned embeddings when the module is not None.
- Removed the unconditional
- ✅ Fixed: Hardcoded CUDA in mid block
- Removed hardcoded
device="cuda"casts in the multi mid block forward so fusion and control modules run on the activation device.
- Removed hardcoded
- ✅ Fixed: Latents forced to CUDA
- Dropped the forced CUDA cast in the SDXL denoising loop and only cast latents to the UNet dtype, leaving them on the pipeline device.
- ✅ Fixed: Wrong multi image batch repeat
- Updated
prepare_image_listin both SD and SDXL pipelines to deriverepeat_byfrom each image tensor's batch dimension, matchingprepare_image.
- Updated
- ✅ Fixed: Wrong downsampler attribute name
- Changed
freeze_base_paramsto referenceself.base_downsampler, which is the attribute actually defined on multi down blocks.
- Changed
- ✅ Fixed: Control modules not registered
- Registered learned time embeddings via
add_moduleand converted multi mid-blockbase_to_ctrl,ctrl_midblocks, andctrl_to_basetonn.ModuleList.
- Registered learned time embeddings via
- ✅ Fixed: Pipeline list config overwrite
- Updated
register_modulesto accumulate a list of(library, class_name)tuples for list-valued modules instead of overwriting with the last entry.
- Updated
- ✅ Fixed: Unused per-adapter time embeddings
- Computed per-adapter
ctrl_temb_listand passed it through down and mid blocks so each control branch receives its own timestep conditioning.
- Computed per-adapter
Or push these changes by commenting:
@cursor push 4983dcda0a
Preview (4983dcda0a)
diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py
--- a/src/diffusers/models/controlnets/controlnet_xs.py
+++ b/src/diffusers/models/controlnets/controlnet_xs.py
@@ -1921,10 +1921,14 @@
self.ctrl_time_embeddings = []
self.ctrl_time_embedding_mix = []
- for controlnet in controlnets:
+ for index, controlnet in enumerate(controlnets):
self.ctrl_time_embedding_mix.append(controlnet.get("time_embedding_mix"))
if controlnet.get("ctrl_learn_time_embedding"):
- self.ctrl_time_embeddings.append(TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim))
+ ctrl_time_embedding = TimestepEmbedding(
+ in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim
+ )
+ self.ctrl_time_embeddings.append(ctrl_time_embedding)
+ self.add_module(f"ctrl_time_embedding_{index}", ctrl_time_embedding)
else:
self.ctrl_time_embeddings.append(None)
@@ -2194,19 +2198,19 @@
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
- time_embedding_list = []
+ base_temb = self.base_time_embedding(t_emb, timestep_cond)
+ base_temb = base_temb + aug_emb if aug_emb is not None else base_temb
+
+ ctrl_temb_list = []
for index in range(len(self.ctrl_time_embeddings)):
- base_temb = self.base_time_embedding(t_emb, timestep_cond)
temb_module = self.ctrl_time_embeddings[index]
- temb_module.to(t_emb.device, t_emb.dtype)
if temb_module is not None and apply_control:
ctrl_temb = temb_module(t_emb, timestep_cond)
- interpolation_param = self.ctrl_time_embedding_mix[index]**0.3
- temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param)
+ interpolation_param = self.ctrl_time_embedding_mix[index] ** 0.3
+ ctrl_temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param)
else:
- temb = base_temb
- temb = temb + aug_emb if aug_emb is not None else temb
- time_embedding_list.append(temb)
+ ctrl_temb = base_temb
+ ctrl_temb_list.append(ctrl_temb)
# text embeddings
@@ -2220,7 +2224,8 @@
h_base, h_ctrls, residual_hb, residual_hc = down(
hidden_states_base=h_base,
hidden_states_ctrl=h_ctrls,
- temb=temb,
+ temb=base_temb,
+ ctrl_temb_list=ctrl_temb_list,
encoder_hidden_states=cemb,
conditioning_scale=conditioning_scale,
cross_attention_kwargs=cross_attention_kwargs,
@@ -2236,7 +2241,8 @@
h_base, h_ctrls = self.mid_block(
hidden_states_base=h_base,
hidden_states_ctrl=h_ctrls,
- temb=temb,
+ temb=base_temb,
+ ctrl_temb_list=ctrl_temb_list,
encoder_hidden_states=cemb,
conditioning_scale=conditioning_scale,
cross_attention_kwargs=cross_attention_kwargs,
@@ -2255,7 +2261,7 @@
hidden_states=h_base,
res_hidden_states_tuple_base=skips_hb,
res_hidden_states_tuple_ctrl=skips_hc,
- temb=temb,
+ temb=base_temb,
encoder_hidden_states=cemb,
conditioning_scale=conditioning_scale,
cross_attention_kwargs=cross_attention_kwargs,
@@ -2527,8 +2533,8 @@
base_parts = [self.base_resnets]
if isinstance(self.base_attentions, nn.ModuleList): # attentions can be a list of Nones
base_parts.append(self.base_attentions)
- if self.base_downsamplers is not None:
- base_parts.append(self.base_downsamplers)
+ if self.base_downsampler is not None:
+ base_parts.append(self.base_downsampler)
for part in base_parts:
for param in part.parameters():
param.requires_grad = False
@@ -2539,6 +2545,7 @@
temb: Tensor,
encoder_hidden_states: Optional[Tensor] = None, # text embedding
hidden_states_ctrl: Optional[List[Tensor]] = None, # ctrl embedding
+ ctrl_temb_list: Optional[List[Tensor]] = None,
conditioning_scale: Optional[List[float]] = None,
attention_mask: Optional[Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -2596,10 +2603,11 @@
if apply_control:
for index in range(len(h_ctrl_list)):
c_res, c_attn = ctrl_blocks_list[index][layer_index]
+ ctrl_temb = ctrl_temb_list[index] if ctrl_temb_list is not None else temb
if torch.is_grad_enabled() and self.gradient_checkpointing:
- h_ctrl_list[index] = self._gradient_checkpointing_func(c_res, h_ctrl_list[index], temb)
+ h_ctrl_list[index] = self._gradient_checkpointing_func(c_res, h_ctrl_list[index], ctrl_temb)
else:
- h_ctrl_list[index] = c_res(h_ctrl_list[index], temb)
+ h_ctrl_list[index] = c_res(h_ctrl_list[index], ctrl_temb)
if c_attn is not None:
h_ctrl_list[index] = c_attn(
h_ctrl_list[index],
@@ -2665,9 +2673,9 @@
# Before the midblock application, information is concatted from base to control.
# Concat doesn't require change in number of channels
- self.base_to_ctrl = []
- for i in range(len(ctrl_channels_list)):
- self.base_to_ctrl.append(make_zero_conv(base_channels, base_channels))
+ self.base_to_ctrl = nn.ModuleList(
+ [make_zero_conv(base_channels, base_channels) for _ in range(len(ctrl_channels_list))]
+ )
self.base_midblock = UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block,
@@ -2680,7 +2688,7 @@
upcast_attention=upcast_attention,
)
- self.ctrl_midblocks = []
+ ctrl_midblocks = []
for i in range(len(ctrl_channels_list)):
ctrl_channels = ctrl_channels_list[i]
ctrl_max_norm_num_groups = ctrl_max_norm_num_groups_list[i]
@@ -2701,14 +2709,15 @@
upcast_attention=upcast_attention,
)
- self.ctrl_midblocks.append(ctrl_midblock)
+ ctrl_midblocks.append(ctrl_midblock)
+ self.ctrl_midblocks = nn.ModuleList(ctrl_midblocks)
+
# After the midblock application, information is added from control to base
# Addition requires change in number of channels
- self.ctrl_to_base = []
- for i in range(len(ctrl_channels_list)):
- ctrl_channels = ctrl_channels_list[i]
- self.ctrl_to_base.append(make_zero_conv(ctrl_channels, base_channels))
+ self.ctrl_to_base = nn.ModuleList(
+ [make_zero_conv(ctrl_channels_list[i], base_channels) for i in range(len(ctrl_channels_list))]
+ )
self.gradient_checkpointing = False
@@ -2787,6 +2796,7 @@
encoder_hidden_states: Tensor,
hidden_states_ctrl: List[Tensor],
conditioning_scale: List[float],
+ ctrl_temb_list: Optional[List[Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
attention_mask: Optional[Tensor] = None,
encoder_attention_mask: Optional[Tensor] = None,
@@ -2810,15 +2820,15 @@
if apply_control:
for i in range(len(h_ctrl_list)):
- b2c = self.base_to_ctrl[i].to(device="cuda", dtype=h_base.dtype)
- h_ctrl_list[i] = torch.cat([h_ctrl_list[i], b2c(h_base)], dim=1)
+ h_ctrl_list[i] = torch.cat([h_ctrl_list[i], self.base_to_ctrl[i](h_base)], dim=1)
h_base = self.base_midblock(h_base, **joint_args)
if apply_control:
for i in range(len(h_ctrl_list)):
- mid = self.ctrl_midblocks[i].to(device="cuda", dtype=h_ctrl_list[i].dtype)
- h_ctrl_list[i] = mid(h_ctrl_list[i], **joint_args) # apply ctrl mid block
- c2b = self.ctrl_to_base[i].to(device="cuda", dtype=h_ctrl_list[i].dtype)
- h_base = h_base + c2b(h_ctrl_list[i]) * conditioning_scale[i] # add ctrl -> base
+ ctrl_joint_args = dict(joint_args)
+ if ctrl_temb_list is not None:
+ ctrl_joint_args["temb"] = ctrl_temb_list[i]
+ h_ctrl_list[i] = self.ctrl_midblocks[i](h_ctrl_list[i], **ctrl_joint_args) # apply ctrl mid block
+ h_base = h_base + self.ctrl_to_base[i](h_ctrl_list[i]) * conditioning_scale[i] # add ctrl -> base
return h_base, h_ctrl_list
diff --git a/src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs.py
--- a/src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs.py
+++ b/src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs.py
@@ -633,8 +633,12 @@
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_list.append(image)
- repeat_by = batch_size if len(image_list) == 1 else num_images_per_prompt
for index in range(len(image_list)):
+ image_batch_size = image_list[index].shape[0]
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ repeat_by = num_images_per_prompt
image_list[index] = image_list[index].repeat_interleave(repeat_by, dim=0)
image_list[index] = image_list[index].to(device=device, dtype=dtype)
if do_classifier_free_guidance:
diff --git a/src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
--- a/src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
+++ b/src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
@@ -680,15 +680,13 @@
for image in images:
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_list.append(image)
- image_batch_size = len(image_list)
- if image_batch_size == 1:
- repeat_by = batch_size
- else:
- # image batch size is the same as prompt batch size
- repeat_by = num_images_per_prompt
-
for index in range(len(image_list)):
+ image_batch_size = image_list[index].shape[0]
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ repeat_by = num_images_per_prompt
image_list[index] = image_list[index].repeat_interleave(repeat_by, dim=0)
image_list[index] = image_list[index].to(device=device, dtype=dtype)
if do_classifier_free_guidance:
@@ -1104,7 +1102,7 @@
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
- latent_model_input = latent_model_input.to(device="cuda", dtype=self.unet.dtype)
+ latent_model_input = latent_model_input.to(dtype=self.unet.dtype)
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -213,9 +213,7 @@
if module is None or isinstance(module, (tuple, list)) and module[0] is None:
register_dict = {name: (None, None)}
elif isinstance(module, list) and module[0] is not None:
- for mod in module:
- library, class_name = _fetch_class_library_tuple(mod)
- register_dict = {name: (library, class_name)}
+ register_dict = {name: [_fetch_class_library_tuple(mod) for mod in module]}
else:
library, class_name = _fetch_class_library_tuple(module)
register_dict = {name: (library, class_name)}You can send follow-ups to the cloud agent here.
Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.
| for index in range(len(self.ctrl_time_embeddings)): | ||
| base_temb = self.base_time_embedding(t_emb, timestep_cond) | ||
| temb_module = self.ctrl_time_embeddings[index] | ||
| temb_module.to(t_emb.device, t_emb.dtype) |
There was a problem hiding this comment.
None time embedding .to crash
High Severity
In UNetMultiControlNetXSModel.forward, temb_module.to(...) runs before checking whether temb_module is None. Adapters default to learn_time_embedding=False, so entries in ctrl_time_embeddings are often None, which raises AttributeError on every forward pass.
Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.
| for i in range(len(h_ctrl_list)): | ||
| mid = self.ctrl_midblocks[i].to(device="cuda", dtype=h_ctrl_list[i].dtype) | ||
| h_ctrl_list[i] = mid(h_ctrl_list[i], **joint_args) # apply ctrl mid block | ||
| c2b = self.ctrl_to_base[i].to(device="cuda", dtype=h_ctrl_list[i].dtype) |
There was a problem hiding this comment.
Hardcoded CUDA in mid block
High Severity
MultiControlNetXSCrossAttnMidBlock2D.forward moves base_to_ctrl, ctrl_midblocks, and ctrl_to_base with device="cuda" each step. Runs fail on CPU/MPS and break when latents live on another device, unlike the single-adapter mid block.
Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.
| latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | ||
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | ||
|
|
||
| latent_model_input = latent_model_input.to(device="cuda", dtype=self.unet.dtype) |
There was a problem hiding this comment.
Latents forced to CUDA
High Severity
The SDXL ControlNet-XS denoising loop casts latent_model_input to device="cuda" before the UNet call. Non-CUDA runs and CPU-offload setups error or misplace tensors relative to the rest of the pipeline.
Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.
|
|
||
| repeat_by = batch_size if len(image_list) == 1 else num_images_per_prompt | ||
| for index in range(len(image_list)): | ||
| image_list[index] = image_list[index].repeat_interleave(repeat_by, dim=0) |
There was a problem hiding this comment.
Wrong multi image batch repeat
High Severity
prepare_image_list picks repeat_by from the number of control images (SD) or equates list length with batch size (SDXL), instead of each tensor’s shape[0] like prepare_image. Multi-adapter runs get the wrong batch dimension versus latents and prompts.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.
| if isinstance(self.base_attentions, nn.ModuleList): # attentions can be a list of Nones | ||
| base_parts.append(self.base_attentions) | ||
| if self.base_downsamplers is not None: | ||
| base_parts.append(self.base_downsamplers) |
There was a problem hiding this comment.
Wrong downsampler attribute name
Medium Severity
MultiControlNetXSCrossAttnDownBlock2D.freeze_base_params references self.base_downsamplers, but the block defines self.base_downsampler. Calling freeze_unet_params on a multi-adapter model raises AttributeError on down blocks with downsampling.
Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.
| if controlnet.get("ctrl_learn_time_embedding"): | ||
| self.ctrl_time_embeddings.append(TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim)) | ||
| else: | ||
| self.ctrl_time_embeddings.append(None) |
There was a problem hiding this comment.
Control modules not registered
Medium Severity
Learned control time embeddings and multi mid-block layers (ctrl_time_embeddings, base_to_ctrl, ctrl_midblocks, ctrl_to_base) live in plain Python lists, not nn.ModuleList. They are omitted from parameters(), .to(), and state_dict() unless manually moved.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.
| elif isinstance(module, list) and module[0] is not None: | ||
| for mod in module: | ||
| library, class_name = _fetch_class_library_tuple(mod) | ||
| register_dict = {name: (library, class_name)} |
There was a problem hiding this comment.
Pipeline list config overwrite
Medium Severity
When register_modules receives a non-empty list (e.g. several ControlNetXSAdapter instances), the loop overwrites register_dict each iteration and only the last module’s (library, class_name) is stored in model_index.json.
Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.
| h_base, h_ctrls, residual_hb, residual_hc = down( | ||
| hidden_states_base=h_base, | ||
| hidden_states_ctrl=h_ctrls, | ||
| temb=temb, |
There was a problem hiding this comment.
Unused per-adapter time embeddings
Medium Severity
UNetMultiControlNetXSModel builds time_embedding_list per adapter but passes only the loop’s final temb into down, mid, and up blocks. Adapters with different time_embedding_mix or learned embeddings do not get distinct timestep conditioning in those blocks.
Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.



Summary
UNetMultiControlNetXSModelfor fusing multiple ControlNet-XS adapters into SD/SDXL UNet pathscontrolnet=[adapter, ...]inStableDiffusionControlNetXSPipelineandStableDiffusionXLControlNetXSPipelineResolves huggingface#8434 (eval run 2026-06-21-r2).
Test plan
python utils/check_copies.pytests/models/unets/test_models_unet_controlnetxs.py(torch absent locally)Made with Cursor