fix(hooks): context parallel padding for indivisible sequences (#12568)#3
fix(hooks): context parallel padding for indivisible sequences (#12568)#3srlynch1 wants to merge 1 commit into
Conversation
Pad tensors before EquipartitionSharder.shard when sequence length is not divisible by world_size; mask tensors pad with 0; gather hook trims padding. Fixes huggingface#12568 for QwenImage and other Ring/Unified CP models. Co-authored-by: Cursor <cursoragent@cursor.com>
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.
Bugbot Autofix prepared a fix for the issue found in the latest run.
- ✅ Fixed: Pad lengths collide per dim
- Pad lengths are now keyed by input tensor name instead of shard dimension, and the gather hook trims using the hidden_states entry so colliding split_dim values no longer corrupt proj_out output.
Or push these changes by commenting:
@cursor push 28b0ab095c
Preview (28b0ab095c)
diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py
--- a/src/diffusers/hooks/context_parallel.py
+++ b/src/diffusers/hooks/context_parallel.py
@@ -37,7 +37,7 @@
logger = get_logger(__name__) # pylint: disable=invalid-name
-def _get_cp_pad_lengths(parallel_config: ContextParallelConfig) -> dict[int, int]:
+def _get_cp_pad_lengths(parallel_config: ContextParallelConfig) -> dict[str, int]:
pad_lengths = getattr(parallel_config, "_cp_pad_lengths", None)
if pad_lengths is None:
pad_lengths = {}
@@ -244,7 +244,8 @@
if world_size > 1 and seq_len % world_size != 0:
pad_value = 0 if "mask" in name.lower() else 0.0
x = _pad_tensor_for_context_parallel(x, dim, world_size, pad_value=pad_value)
- _get_cp_pad_lengths(self.parallel_config)[dim] = seq_len
+ if name:
+ _get_cp_pad_lengths(self.parallel_config)[name] = seq_len
return EquipartitionSharder.shard(x, dim, mesh)
@@ -282,8 +283,9 @@
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)
- if pad_lengths and cpm.gather_dim in pad_lengths:
- original_len = pad_lengths.pop(cpm.gather_dim)
+ unpad_key = getattr(cpm, "unpad_key", "hidden_states")
+ if pad_lengths and unpad_key and unpad_key in pad_lengths:
+ original_len = pad_lengths.pop(unpad_key)
x = x.narrow(cpm.gather_dim, 0, original_len)
output[i] = x
diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py
--- a/tests/hooks/test_hooks.py
+++ b/tests/hooks/test_hooks.py
@@ -420,7 +420,7 @@
out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states")
assert out.shape[1] == 9
- assert self.parallel_config._cp_pad_lengths[1] == 7
+ assert self.parallel_config._cp_pad_lengths["hidden_states"] == 7
def test_prepare_cp_input_pads_attention_mask_with_zeros(self):
mask = torch.ones(1, 7, dtype=torch.long)
@@ -443,7 +443,7 @@
assert not hasattr(self.parallel_config, "_cp_pad_lengths")
def test_gather_hook_trims_padded_output(self):
- self.parallel_config._cp_pad_lengths = {1: 7}
+ self.parallel_config._cp_pad_lengths = {"hidden_states": 7}
gather_hook = ContextParallelGatherHook(
metadata=[ContextParallelOutput(gather_dim=1, expected_dims=3)],
parallel_config=self.parallel_config,
@@ -454,5 +454,5 @@
out = gather_hook.post_forward(self.module, x)
assert out.shape[1] == 7
- assert 1 not in getattr(self.parallel_config, "_cp_pad_lengths", {})
+ assert "hidden_states" not in getattr(self.parallel_config, "_cp_pad_lengths", {})You can send follow-ups to the cloud agent here.
Reviewed by Cursor Bugbot for commit 1be07ba. Configure here.
| if world_size > 1 and seq_len % world_size != 0: | ||
| pad_value = 0 if "mask" in name.lower() else 0.0 | ||
| x = _pad_tensor_for_context_parallel(x, dim, world_size, pad_value=pad_value) | ||
| _get_cp_pad_lengths(self.parallel_config)[dim] = seq_len |
There was a problem hiding this comment.
Pad lengths collide per dim
High Severity
_cp_pad_lengths stores one pre-pad length per shard dimension, but several context-parallel inputs (e.g. image hidden_states, text encoder_hidden_states, and encoder_hidden_states_mask in QwenImage) share split_dim=1 with different sequence lengths. Later inputs overwrite earlier entries, so ContextParallelGatherHook may trim proj_out to the wrong length, corrupting output or causing narrow to fail.
Additional Locations (1)
Reviewed by Cursor Bugbot for commit 1be07ba. Configure here.



Summary
ContextParallelSplitHookbefore equipartition sharding when sequence length is not divisible by world_sizeContextParallelGatherHooktrims padding after gather viaparallel_config._cp_pad_lengthsFixes huggingface#12568 (QwenImage context parallel AssertionError).
Test plan
ruff checkon changed filespython utils/check_copies.pypytest tests/hooks/test_hooks.py::ContextParallelHooksTests -q(host env missing torch)Made with Cursor