diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py index 9c65999e3a17..8a701e088806 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -910,6 +910,9 @@ def __call__( controlnet_cond, controlnet_cond_mask = self.prepare_sparse_control_conditioning( conditioning_frames, num_frames, controlnet_frame_indices, device, controlnet.dtype ) + if self.do_classifier_free_guidance and not guess_mode: + controlnet_cond = torch.cat([controlnet_cond] * 2) + controlnet_cond_mask = torch.cat([controlnet_cond_mask] * 2) # 6. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -976,6 +979,13 @@ def __call__( return_dict=False, ) + if guess_mode and self.do_classifier_free_guidance: + # Copied from pipeline_controlnet: apply controlnet residuals only to the conditional batch. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat( + [torch.zeros_like(mid_block_res_sample), mid_block_res_sample] + ) + # predict the noise residual noise_pred = self.unet( latent_model_input, diff --git a/tests/pipelines/animatediff/test_animatediff_sparsectrl.py b/tests/pipelines/animatediff/test_animatediff_sparsectrl.py index 6b9f672cc4a1..f0a8efb62e93 100644 --- a/tests/pipelines/animatediff/test_animatediff_sparsectrl.py +++ b/tests/pipelines/animatediff/test_animatediff_sparsectrl.py @@ -492,3 +492,44 @@ def test_encode_prompt_works_in_isolation(self): "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0, } return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict) + + def test_conditioning_image_affects_output_with_unrelated_prompt(self): + """Regression for #9508: conditioning RGB should influence output even when prompt is unrelated. + + Regression for #9508 after CFG conditioning batch alignment and guess_mode residual parity. + """ + components = self.get_dummy_components() + pipe: AnimateDiffSparseControlNetPipeline = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + video_size = (32, 32) + red_frame = Image.new("RGB", video_size, color=(255, 0, 0)) + black_frame = Image.new("RGB", video_size, color=(0, 0, 0)) + + base_inputs = self.get_dummy_inputs(torch_device, num_frames=2) + base_inputs.update( + { + "prompt": "a sky with clouds", + "num_inference_steps": 4, + "controlnet_frame_indices": [0], + "output_type": "pt", + } + ) + + red_inputs = {**base_inputs, "conditioning_frames": [red_frame, red_frame]} + black_inputs = {**base_inputs, "conditioning_frames": [black_frame, black_frame]} + + red_output = pipe(**red_inputs).frames[0] + black_output = pipe(**black_inputs).frames[0] + + red_mean = to_np(red_output).mean() + black_mean = to_np(black_output).mean() + diff = abs(red_mean - black_mean) + + self.assertGreater( + diff, + 1e-2, + f"Conditioning image should affect output with unrelated prompt (diff={diff:.4f}, " + f"red_mean={red_mean:.4f}, black_mean={black_mean:.4f})", + )