Skip to content

fix(hooks): context parallel padding for indivisible sequences (#12568)#3

Open
srlynch1 wants to merge 1 commit into
mainfrom
e2e/diffusers-12568-r2
Open

fix(hooks): context parallel padding for indivisible sequences (#12568)#3
srlynch1 wants to merge 1 commit into
mainfrom
e2e/diffusers-12568-r2

Conversation

@srlynch1

Copy link
Copy Markdown
Owner

Summary

  • Pad tensors in ContextParallelSplitHook before equipartition sharding when sequence length is not divisible by world_size
  • Mask tensors pad with 0 so padded tokens do not affect attention
  • ContextParallelGatherHook trims padding after gather via parallel_config._cp_pad_lengths

Fixes huggingface#12568 (QwenImage context parallel AssertionError).

Test plan

  • ruff check on changed files
  • python utils/check_copies.py
  • pytest tests/hooks/test_hooks.py::ContextParallelHooksTests -q (host env missing torch)

Made with Cursor

Pad tensors before EquipartitionSharder.shard when sequence length is not
divisible by world_size; mask tensors pad with 0; gather hook trims padding.
Fixes huggingface#12568 for QwenImage and other Ring/Unified CP models.

Co-authored-by: Cursor <cursoragent@cursor.com>

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.

Fix All in Cursor

Bugbot Autofix prepared a fix for the issue found in the latest run.

  • ✅ Fixed: Pad lengths collide per dim
    • Pad lengths are now keyed by input tensor name instead of shard dimension, and the gather hook trims using the hidden_states entry so colliding split_dim values no longer corrupt proj_out output.

Create PR

Or push these changes by commenting:

@cursor push 28b0ab095c
Preview (28b0ab095c)
diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py
--- a/src/diffusers/hooks/context_parallel.py
+++ b/src/diffusers/hooks/context_parallel.py
@@ -37,7 +37,7 @@
 
 logger = get_logger(__name__)  # pylint: disable=invalid-name
 
-def _get_cp_pad_lengths(parallel_config: ContextParallelConfig) -> dict[int, int]:
+def _get_cp_pad_lengths(parallel_config: ContextParallelConfig) -> dict[str, int]:
     pad_lengths = getattr(parallel_config, "_cp_pad_lengths", None)
     if pad_lengths is None:
         pad_lengths = {}
@@ -244,7 +244,8 @@
         if world_size > 1 and seq_len % world_size != 0:
             pad_value = 0 if "mask" in name.lower() else 0.0
             x = _pad_tensor_for_context_parallel(x, dim, world_size, pad_value=pad_value)
-            _get_cp_pad_lengths(self.parallel_config)[dim] = seq_len
+            if name:
+                _get_cp_pad_lengths(self.parallel_config)[name] = seq_len
 
         return EquipartitionSharder.shard(x, dim, mesh)
 
@@ -282,8 +283,9 @@
                     output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
                 )
 
-            if pad_lengths and cpm.gather_dim in pad_lengths:
-                original_len = pad_lengths.pop(cpm.gather_dim)
+            unpad_key = getattr(cpm, "unpad_key", "hidden_states")
+            if pad_lengths and unpad_key and unpad_key in pad_lengths:
+                original_len = pad_lengths.pop(unpad_key)
                 x = x.narrow(cpm.gather_dim, 0, original_len)
 
             output[i] = x

diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py
--- a/tests/hooks/test_hooks.py
+++ b/tests/hooks/test_hooks.py
@@ -420,7 +420,7 @@
             out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states")
 
         assert out.shape[1] == 9
-        assert self.parallel_config._cp_pad_lengths[1] == 7
+        assert self.parallel_config._cp_pad_lengths["hidden_states"] == 7
 
     def test_prepare_cp_input_pads_attention_mask_with_zeros(self):
         mask = torch.ones(1, 7, dtype=torch.long)
@@ -443,7 +443,7 @@
         assert not hasattr(self.parallel_config, "_cp_pad_lengths")
 
     def test_gather_hook_trims_padded_output(self):
-        self.parallel_config._cp_pad_lengths = {1: 7}
+        self.parallel_config._cp_pad_lengths = {"hidden_states": 7}
         gather_hook = ContextParallelGatherHook(
             metadata=[ContextParallelOutput(gather_dim=1, expected_dims=3)],
             parallel_config=self.parallel_config,
@@ -454,5 +454,5 @@
             out = gather_hook.post_forward(self.module, x)
 
         assert out.shape[1] == 7
-        assert 1 not in getattr(self.parallel_config, "_cp_pad_lengths", {})
+        assert "hidden_states" not in getattr(self.parallel_config, "_cp_pad_lengths", {})

You can send follow-ups to the cloud agent here.

Reviewed by Cursor Bugbot for commit 1be07ba. Configure here.

if world_size > 1 and seq_len % world_size != 0:
pad_value = 0 if "mask" in name.lower() else 0.0
x = _pad_tensor_for_context_parallel(x, dim, world_size, pad_value=pad_value)
_get_cp_pad_lengths(self.parallel_config)[dim] = seq_len

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pad lengths collide per dim

High Severity

_cp_pad_lengths stores one pre-pad length per shard dimension, but several context-parallel inputs (e.g. image hidden_states, text encoder_hidden_states, and encoder_hidden_states_mask in QwenImage) share split_dim=1 with different sequence lengths. Later inputs overwrite earlier entries, so ContextParallelGatherHook may trim proj_out to the wrong length, corrupting output or causing narrow to fail.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 1be07ba. Configure here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Context parallel bug when using QwenImage

1 participant