Skip to content

Add multi ControlNetXS adapter support (eval r2 #8434)#7

Open
srlynch1 wants to merge 1 commit into
mainfrom
e2e/2026-06-21-r2-diffusers-8434
Open

Add multi ControlNetXS adapter support (eval r2 #8434)#7
srlynch1 wants to merge 1 commit into
mainfrom
e2e/2026-06-21-r2-diffusers-8434

Conversation

@srlynch1

Copy link
Copy Markdown
Owner

Summary

  • Add UNetMultiControlNetXSModel for fusing multiple ControlNet-XS adapters into SD/SDXL UNet paths
  • Allow controlnet=[adapter, ...] in StableDiffusionControlNetXSPipeline and StableDiffusionXLControlNetXSPipeline
  • Preserve single-adapter API; document bidirectional fusion behavior when composing adapters

Resolves huggingface#8434 (eval run 2026-06-21-r2).

Test plan

  • ruff on touched files
  • python utils/check_copies.py
  • pytest tests/models/unets/test_models_unet_controlnetxs.py (torch absent locally)

Made with Cursor

Resolves huggingface#8434 by introducing UNetMultiControlNetXSModel and allowing
controlnet=[adapter, ...] in ControlNet-XS pipelines with per-adapter
conditioning scales and combined bidirectional fusion.

Co-authored-by: Cursor <cursoragent@cursor.com>

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 8 potential issues.

Fix All in Cursor

Bugbot Autofix prepared fixes for all 8 issues found in the latest run.

  • ✅ Fixed: None time embedding .to crash
    • Removed the unconditional .to() call on optional time-embedding modules and only invoke learned embeddings when the module is not None.
  • ✅ Fixed: Hardcoded CUDA in mid block
    • Removed hardcoded device="cuda" casts in the multi mid block forward so fusion and control modules run on the activation device.
  • ✅ Fixed: Latents forced to CUDA
    • Dropped the forced CUDA cast in the SDXL denoising loop and only cast latents to the UNet dtype, leaving them on the pipeline device.
  • ✅ Fixed: Wrong multi image batch repeat
    • Updated prepare_image_list in both SD and SDXL pipelines to derive repeat_by from each image tensor's batch dimension, matching prepare_image.
  • ✅ Fixed: Wrong downsampler attribute name
    • Changed freeze_base_params to reference self.base_downsampler, which is the attribute actually defined on multi down blocks.
  • ✅ Fixed: Control modules not registered
    • Registered learned time embeddings via add_module and converted multi mid-block base_to_ctrl, ctrl_midblocks, and ctrl_to_base to nn.ModuleList.
  • ✅ Fixed: Pipeline list config overwrite
    • Updated register_modules to accumulate a list of (library, class_name) tuples for list-valued modules instead of overwriting with the last entry.
  • ✅ Fixed: Unused per-adapter time embeddings
    • Computed per-adapter ctrl_temb_list and passed it through down and mid blocks so each control branch receives its own timestep conditioning.

Create PR

Or push these changes by commenting:

@cursor push 4983dcda0a
Preview (4983dcda0a)
diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py
--- a/src/diffusers/models/controlnets/controlnet_xs.py
+++ b/src/diffusers/models/controlnets/controlnet_xs.py
@@ -1921,10 +1921,14 @@
 
         self.ctrl_time_embeddings = []
         self.ctrl_time_embedding_mix = []
-        for controlnet in controlnets:
+        for index, controlnet in enumerate(controlnets):
             self.ctrl_time_embedding_mix.append(controlnet.get("time_embedding_mix"))
             if controlnet.get("ctrl_learn_time_embedding"):
-                self.ctrl_time_embeddings.append(TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim))
+                ctrl_time_embedding = TimestepEmbedding(
+                    in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim
+                )
+                self.ctrl_time_embeddings.append(ctrl_time_embedding)
+                self.add_module(f"ctrl_time_embedding_{index}", ctrl_time_embedding)
             else:
                 self.ctrl_time_embeddings.append(None)
 
@@ -2194,19 +2198,19 @@
         # there might be better ways to encapsulate this.
         t_emb = t_emb.to(dtype=sample.dtype)
 
-        time_embedding_list = []
+        base_temb = self.base_time_embedding(t_emb, timestep_cond)
+        base_temb = base_temb + aug_emb if aug_emb is not None else base_temb
+
+        ctrl_temb_list = []
         for index in range(len(self.ctrl_time_embeddings)):
-            base_temb = self.base_time_embedding(t_emb, timestep_cond)
             temb_module = self.ctrl_time_embeddings[index]
-            temb_module.to(t_emb.device, t_emb.dtype)
             if temb_module is not None and apply_control:
                 ctrl_temb = temb_module(t_emb, timestep_cond)
-                interpolation_param = self.ctrl_time_embedding_mix[index]**0.3
-                temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param)
+                interpolation_param = self.ctrl_time_embedding_mix[index] ** 0.3
+                ctrl_temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param)
             else:
-                temb = base_temb
-            temb = temb + aug_emb if aug_emb is not None else temb
-            time_embedding_list.append(temb)
+                ctrl_temb = base_temb
+            ctrl_temb_list.append(ctrl_temb)
 
 
         # text embeddings
@@ -2220,7 +2224,8 @@
             h_base, h_ctrls, residual_hb, residual_hc = down(
                 hidden_states_base=h_base,
                 hidden_states_ctrl=h_ctrls,
-                temb=temb,
+                temb=base_temb,
+                ctrl_temb_list=ctrl_temb_list,
                 encoder_hidden_states=cemb,
                 conditioning_scale=conditioning_scale,
                 cross_attention_kwargs=cross_attention_kwargs,
@@ -2236,7 +2241,8 @@
         h_base, h_ctrls = self.mid_block(
             hidden_states_base=h_base,
             hidden_states_ctrl=h_ctrls,
-            temb=temb,
+            temb=base_temb,
+            ctrl_temb_list=ctrl_temb_list,
             encoder_hidden_states=cemb,
             conditioning_scale=conditioning_scale,
             cross_attention_kwargs=cross_attention_kwargs,
@@ -2255,7 +2261,7 @@
                 hidden_states=h_base,
                 res_hidden_states_tuple_base=skips_hb,
                 res_hidden_states_tuple_ctrl=skips_hc,
-                temb=temb,
+                temb=base_temb,
                 encoder_hidden_states=cemb,
                 conditioning_scale=conditioning_scale,
                 cross_attention_kwargs=cross_attention_kwargs,
@@ -2527,8 +2533,8 @@
         base_parts = [self.base_resnets]
         if isinstance(self.base_attentions, nn.ModuleList):  # attentions can be a list of Nones
             base_parts.append(self.base_attentions)
-        if self.base_downsamplers is not None:
-            base_parts.append(self.base_downsamplers)
+        if self.base_downsampler is not None:
+            base_parts.append(self.base_downsampler)
         for part in base_parts:
             for param in part.parameters():
                 param.requires_grad = False
@@ -2539,6 +2545,7 @@
         temb: Tensor,
         encoder_hidden_states: Optional[Tensor] = None,  # text embedding
         hidden_states_ctrl: Optional[List[Tensor]] = None, # ctrl embedding
+        ctrl_temb_list: Optional[List[Tensor]] = None,
         conditioning_scale: Optional[List[float]] = None,
         attention_mask: Optional[Tensor] = None,
         cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -2596,10 +2603,11 @@
             if apply_control:
                 for index in range(len(h_ctrl_list)):
                     c_res, c_attn = ctrl_blocks_list[index][layer_index]
+                    ctrl_temb = ctrl_temb_list[index] if ctrl_temb_list is not None else temb
                     if torch.is_grad_enabled() and self.gradient_checkpointing:
-                        h_ctrl_list[index] = self._gradient_checkpointing_func(c_res, h_ctrl_list[index], temb)
+                        h_ctrl_list[index] = self._gradient_checkpointing_func(c_res, h_ctrl_list[index], ctrl_temb)
                     else:
-                        h_ctrl_list[index] = c_res(h_ctrl_list[index], temb)
+                        h_ctrl_list[index] = c_res(h_ctrl_list[index], ctrl_temb)
                     if c_attn is not None:
                         h_ctrl_list[index]  = c_attn(
                             h_ctrl_list[index],
@@ -2665,9 +2673,9 @@
 
         # Before the midblock application, information is concatted from base to control.
         # Concat doesn't require change in number of channels
-        self.base_to_ctrl = []
-        for i in range(len(ctrl_channels_list)):
-            self.base_to_ctrl.append(make_zero_conv(base_channels, base_channels))
+        self.base_to_ctrl = nn.ModuleList(
+            [make_zero_conv(base_channels, base_channels) for _ in range(len(ctrl_channels_list))]
+        )
 
         self.base_midblock = UNetMidBlock2DCrossAttn(
             transformer_layers_per_block=transformer_layers_per_block,
@@ -2680,7 +2688,7 @@
             upcast_attention=upcast_attention,
         )
 
-        self.ctrl_midblocks = []
+        ctrl_midblocks = []
         for i in range(len(ctrl_channels_list)):
             ctrl_channels = ctrl_channels_list[i]
             ctrl_max_norm_num_groups = ctrl_max_norm_num_groups_list[i]
@@ -2701,14 +2709,15 @@
                 upcast_attention=upcast_attention,
             )
 
-            self.ctrl_midblocks.append(ctrl_midblock)
+            ctrl_midblocks.append(ctrl_midblock)
 
+        self.ctrl_midblocks = nn.ModuleList(ctrl_midblocks)
+
         # After the midblock application, information is added from control to base
         # Addition requires change in number of channels
-        self.ctrl_to_base = []
-        for i in range(len(ctrl_channels_list)):
-            ctrl_channels = ctrl_channels_list[i]
-            self.ctrl_to_base.append(make_zero_conv(ctrl_channels, base_channels))
+        self.ctrl_to_base = nn.ModuleList(
+            [make_zero_conv(ctrl_channels_list[i], base_channels) for i in range(len(ctrl_channels_list))]
+        )
 
         self.gradient_checkpointing = False
 
@@ -2787,6 +2796,7 @@
         encoder_hidden_states: Tensor,
         hidden_states_ctrl: List[Tensor],
         conditioning_scale: List[float],
+        ctrl_temb_list: Optional[List[Tensor]] = None,
         cross_attention_kwargs: Optional[Dict[str, Any]] = None,
         attention_mask: Optional[Tensor] = None,
         encoder_attention_mask: Optional[Tensor] = None,
@@ -2810,15 +2820,15 @@
 
         if apply_control:
             for i in range(len(h_ctrl_list)):
-                b2c = self.base_to_ctrl[i].to(device="cuda", dtype=h_base.dtype)
-                h_ctrl_list[i] = torch.cat([h_ctrl_list[i], b2c(h_base)], dim=1)
+                h_ctrl_list[i] = torch.cat([h_ctrl_list[i], self.base_to_ctrl[i](h_base)], dim=1)
         h_base = self.base_midblock(h_base, **joint_args)
         if apply_control:
             for i in range(len(h_ctrl_list)):
-                mid = self.ctrl_midblocks[i].to(device="cuda", dtype=h_ctrl_list[i].dtype)
-                h_ctrl_list[i] = mid(h_ctrl_list[i], **joint_args)  # apply ctrl mid block
-                c2b = self.ctrl_to_base[i].to(device="cuda", dtype=h_ctrl_list[i].dtype)
-                h_base = h_base + c2b(h_ctrl_list[i]) * conditioning_scale[i]  # add ctrl -> base
+                ctrl_joint_args = dict(joint_args)
+                if ctrl_temb_list is not None:
+                    ctrl_joint_args["temb"] = ctrl_temb_list[i]
+                h_ctrl_list[i] = self.ctrl_midblocks[i](h_ctrl_list[i], **ctrl_joint_args)  # apply ctrl mid block
+                h_base = h_base + self.ctrl_to_base[i](h_ctrl_list[i]) * conditioning_scale[i]  # add ctrl -> base
 
         return h_base, h_ctrl_list
 

diff --git a/src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs.py
--- a/src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs.py
+++ b/src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs.py
@@ -633,8 +633,12 @@
             image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
             image_list.append(image)
 
-        repeat_by = batch_size if len(image_list) == 1 else num_images_per_prompt
         for index in range(len(image_list)):
+            image_batch_size = image_list[index].shape[0]
+            if image_batch_size == 1:
+                repeat_by = batch_size
+            else:
+                repeat_by = num_images_per_prompt
             image_list[index] = image_list[index].repeat_interleave(repeat_by, dim=0)
             image_list[index] = image_list[index].to(device=device, dtype=dtype)
             if do_classifier_free_guidance:

diff --git a/src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
--- a/src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
+++ b/src/diffusers/pipelines/deprecated/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
@@ -680,15 +680,13 @@
         for image in images:
           image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
           image_list.append(image)
-        image_batch_size = len(image_list)
 
-        if image_batch_size == 1:
-            repeat_by = batch_size
-        else:
-            # image batch size is the same as prompt batch size
-            repeat_by = num_images_per_prompt
-
         for index in range(len(image_list)):
+          image_batch_size = image_list[index].shape[0]
+          if image_batch_size == 1:
+              repeat_by = batch_size
+          else:
+              repeat_by = num_images_per_prompt
           image_list[index] = image_list[index].repeat_interleave(repeat_by, dim=0)
           image_list[index] = image_list[index].to(device=device, dtype=dtype)
           if do_classifier_free_guidance:
@@ -1104,7 +1102,7 @@
                 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
                 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
-                latent_model_input = latent_model_input.to(device="cuda", dtype=self.unet.dtype)
+                latent_model_input = latent_model_input.to(dtype=self.unet.dtype)
 
                 added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
 

diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -213,9 +213,7 @@
             if module is None or isinstance(module, (tuple, list)) and module[0] is None:
                 register_dict = {name: (None, None)}
             elif isinstance(module, list) and module[0] is not None:
-                for mod in module:
-                  library, class_name = _fetch_class_library_tuple(mod)
-                  register_dict = {name: (library, class_name)}
+                register_dict = {name: [_fetch_class_library_tuple(mod) for mod in module]}
             else:
                 library, class_name = _fetch_class_library_tuple(module)
                 register_dict = {name: (library, class_name)}

You can send follow-ups to the cloud agent here.

Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.

for index in range(len(self.ctrl_time_embeddings)):
base_temb = self.base_time_embedding(t_emb, timestep_cond)
temb_module = self.ctrl_time_embeddings[index]
temb_module.to(t_emb.device, t_emb.dtype)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None time embedding .to crash

High Severity

In UNetMultiControlNetXSModel.forward, temb_module.to(...) runs before checking whether temb_module is None. Adapters default to learn_time_embedding=False, so entries in ctrl_time_embeddings are often None, which raises AttributeError on every forward pass.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.

for i in range(len(h_ctrl_list)):
mid = self.ctrl_midblocks[i].to(device="cuda", dtype=h_ctrl_list[i].dtype)
h_ctrl_list[i] = mid(h_ctrl_list[i], **joint_args) # apply ctrl mid block
c2b = self.ctrl_to_base[i].to(device="cuda", dtype=h_ctrl_list[i].dtype)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoded CUDA in mid block

High Severity

MultiControlNetXSCrossAttnMidBlock2D.forward moves base_to_ctrl, ctrl_midblocks, and ctrl_to_base with device="cuda" each step. Runs fail on CPU/MPS and break when latents live on another device, unlike the single-adapter mid block.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.

latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

latent_model_input = latent_model_input.to(device="cuda", dtype=self.unet.dtype)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Latents forced to CUDA

High Severity

The SDXL ControlNet-XS denoising loop casts latent_model_input to device="cuda" before the UNet call. Non-CUDA runs and CPU-offload setups error or misplace tensors relative to the rest of the pipeline.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.


repeat_by = batch_size if len(image_list) == 1 else num_images_per_prompt
for index in range(len(image_list)):
image_list[index] = image_list[index].repeat_interleave(repeat_by, dim=0)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong multi image batch repeat

High Severity

prepare_image_list picks repeat_by from the number of control images (SD) or equates list length with batch size (SDXL), instead of each tensor’s shape[0] like prepare_image. Multi-adapter runs get the wrong batch dimension versus latents and prompts.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.

if isinstance(self.base_attentions, nn.ModuleList): # attentions can be a list of Nones
base_parts.append(self.base_attentions)
if self.base_downsamplers is not None:
base_parts.append(self.base_downsamplers)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong downsampler attribute name

Medium Severity

MultiControlNetXSCrossAttnDownBlock2D.freeze_base_params references self.base_downsamplers, but the block defines self.base_downsampler. Calling freeze_unet_params on a multi-adapter model raises AttributeError on down blocks with downsampling.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.

if controlnet.get("ctrl_learn_time_embedding"):
self.ctrl_time_embeddings.append(TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim))
else:
self.ctrl_time_embeddings.append(None)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Control modules not registered

Medium Severity

Learned control time embeddings and multi mid-block layers (ctrl_time_embeddings, base_to_ctrl, ctrl_midblocks, ctrl_to_base) live in plain Python lists, not nn.ModuleList. They are omitted from parameters(), .to(), and state_dict() unless manually moved.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.

elif isinstance(module, list) and module[0] is not None:
for mod in module:
library, class_name = _fetch_class_library_tuple(mod)
register_dict = {name: (library, class_name)}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pipeline list config overwrite

Medium Severity

When register_modules receives a non-empty list (e.g. several ControlNetXSAdapter instances), the loop overwrites register_dict each iteration and only the last module’s (library, class_name) is stored in model_index.json.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.

h_base, h_ctrls, residual_hb, residual_hc = down(
hidden_states_base=h_base,
hidden_states_ctrl=h_ctrls,
temb=temb,

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused per-adapter time embeddings

Medium Severity

UNetMultiControlNetXSModel builds time_embedding_list per adapter but passes only the loop’s final temb into down, mid, and up blocks. Adapters with different time_embedding_mix or learned embeddings do not get distinct timestep conditioning in those blocks.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 6c2a47f. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support multiple control nets in the StableDiffusionControlNetXSPipeline/StableDiffusionXLControlNetXSPipeline

1 participant