Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,9 @@ def __call__(
controlnet_cond, controlnet_cond_mask = self.prepare_sparse_control_conditioning(
conditioning_frames, num_frames, controlnet_frame_indices, device, controlnet.dtype
)
if self.do_classifier_free_guidance and not guess_mode:
controlnet_cond = torch.cat([controlnet_cond] * 2)
controlnet_cond_mask = torch.cat([controlnet_cond_mask] * 2)

# 6. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand Down Expand Up @@ -976,6 +979,13 @@ def __call__(
return_dict=False,
)

if guess_mode and self.do_classifier_free_guidance:
# Copied from pipeline_controlnet: apply controlnet residuals only to the conditional batch.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
mid_block_res_sample = torch.cat(
[torch.zeros_like(mid_block_res_sample), mid_block_res_sample]
)

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
Expand Down
41 changes: 41 additions & 0 deletions tests/pipelines/animatediff/test_animatediff_sparsectrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,3 +492,44 @@ def test_encode_prompt_works_in_isolation(self):
"do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
}
return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)

def test_conditioning_image_affects_output_with_unrelated_prompt(self):
"""Regression for #9508: conditioning RGB should influence output even when prompt is unrelated.

Regression for #9508 after CFG conditioning batch alignment and guess_mode residual parity.
"""
components = self.get_dummy_components()
pipe: AnimateDiffSparseControlNetPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)

video_size = (32, 32)
red_frame = Image.new("RGB", video_size, color=(255, 0, 0))
black_frame = Image.new("RGB", video_size, color=(0, 0, 0))

base_inputs = self.get_dummy_inputs(torch_device, num_frames=2)
base_inputs.update(
{
"prompt": "a sky with clouds",
"num_inference_steps": 4,
"controlnet_frame_indices": [0],
"output_type": "pt",
}
)

red_inputs = {**base_inputs, "conditioning_frames": [red_frame, red_frame]}
black_inputs = {**base_inputs, "conditioning_frames": [black_frame, black_frame]}

red_output = pipe(**red_inputs).frames[0]
black_output = pipe(**black_inputs).frames[0]

red_mean = to_np(red_output).mean()
black_mean = to_np(black_output).mean()
diff = abs(red_mean - black_mean)

self.assertGreater(
diff,
1e-2,
f"Conditioning image should affect output with unrelated prompt (diff={diff:.4f}, "
f"red_mean={red_mean:.4f}, black_mean={black_mean:.4f})",
)