From 1be07ba931c9aac23eba9072b3aafc15e339f1fc Mon Sep 17 00:00:00 2001 From: Simon Lynch Date: Sun, 21 Jun 2026 21:45:40 +1000 Subject: [PATCH] fix(hooks): pad indivisible sequences for context parallel sharding Pad tensors before EquipartitionSharder.shard when sequence length is not divisible by world_size; mask tensors pad with 0; gather hook trims padding. Fixes #12568 for QwenImage and other Ring/Unified CP models. Co-authored-by: Cursor --- src/diffusers/hooks/context_parallel.py | 64 +++++++++++++++---- tests/hooks/test_hooks.py | 84 +++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 12 deletions(-) diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index cfc812509a01..18dcda6a9f2b 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -37,6 +37,29 @@ logger = get_logger(__name__) # pylint: disable=invalid-name +def _get_cp_pad_lengths(parallel_config: ContextParallelConfig) -> dict[int, int]: + pad_lengths = getattr(parallel_config, "_cp_pad_lengths", None) + if pad_lengths is None: + pad_lengths = {} + parallel_config._cp_pad_lengths = pad_lengths + return pad_lengths + + +def _pad_tensor_for_context_parallel( + x: torch.Tensor, dim: int, world_size: int, pad_value: float | int = 0.0 +) -> torch.Tensor: + seq_len = x.size(dim) + if world_size <= 1 or seq_len % world_size == 0: + return x + + pad_len = world_size - (seq_len % world_size) + pad_width = [0] * (2 * x.dim()) + pad_idx = x.dim() - 1 - dim + pad_width[2 * pad_idx + 1] = pad_len + return torch.nn.functional.pad(x, tuple(pad_width), mode="constant", value=pad_value) + + + _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}" _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}" @@ -156,7 +179,7 @@ def pre_forward(self, module, *args, **kwargs): # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard # the output instead of input for a particular layer by setting split_output=True if isinstance(input_val, torch.Tensor): - input_val = self._prepare_cp_input(input_val, cpm) + input_val = self._prepare_cp_input(input_val, cpm, name) elif isinstance(input_val, (list, tuple)): if len(input_val) != len(cpm): raise ValueError( @@ -165,7 +188,7 @@ def pre_forward(self, module, *args, **kwargs): sharded_input_val = [] for i, x in enumerate(input_val): if torch.is_tensor(x) and not cpm[i].split_output: - x = self._prepare_cp_input(x, cpm[i]) + x = self._prepare_cp_input(x, cpm[i], name) sharded_input_val.append(x) input_val = sharded_input_val else: @@ -198,23 +221,32 @@ def post_forward(self, module, output): if index >= len(output): raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.") current_output = output[index] - current_output = self._prepare_cp_input(current_output, cpm) + current_output = self._prepare_cp_input(current_output, cpm, "") output[index] = current_output return output[0] if is_tensor else tuple(output) - def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor: + def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput, name: str = "") -> torch.Tensor: if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: logger.warning_once( f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied." ) return x - else: - if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything: - return PartitionAnythingSharder.shard_anything( - x, cp_input.split_dim, self.parallel_config._flattened_mesh - ) - return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) + + mesh = self.parallel_config._flattened_mesh + dim = cp_input.split_dim + + if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything: + return PartitionAnythingSharder.shard_anything(x, dim, mesh) + + world_size = mesh.size() + seq_len = x.size(dim) + if world_size > 1 and seq_len % world_size != 0: + pad_value = 0 if "mask" in name.lower() else 0.0 + x = _pad_tensor_for_context_parallel(x, dim, world_size, pad_value=pad_value) + _get_cp_pad_lengths(self.parallel_config)[dim] = seq_len + + return EquipartitionSharder.shard(x, dim, mesh) class ContextParallelGatherHook(ModelHook): @@ -236,18 +268,26 @@ def post_forward(self, module, output): if len(output) != len(self.metadata): raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.") + pad_lengths = getattr(self.parallel_config, "_cp_pad_lengths", None) + for i, cpm in enumerate(self.metadata): if cpm is None: continue if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything: - output[i] = PartitionAnythingSharder.unshard_anything( + x = PartitionAnythingSharder.unshard_anything( output[i], cpm.gather_dim, self.parallel_config._flattened_mesh ) else: - output[i] = EquipartitionSharder.unshard( + x = EquipartitionSharder.unshard( output[i], cpm.gather_dim, self.parallel_config._flattened_mesh ) + if pad_lengths and cpm.gather_dim in pad_lengths: + original_len = pad_lengths.pop(cpm.gather_dim) + x = x.narrow(cpm.gather_dim, 0, original_len) + + output[i] = x + return output[0] if is_tensor else tuple(output) diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py index 26418adfddee..d05068324945 100644 --- a/tests/hooks/test_hooks.py +++ b/tests/hooks/test_hooks.py @@ -13,11 +13,18 @@ # limitations under the License. import gc +from unittest.mock import patch import pytest import torch from diffusers.hooks import HookRegistry, ModelHook +from diffusers.hooks.context_parallel import ( + ContextParallelGatherHook, + ContextParallelSplitHook, + EquipartitionSharder, +) +from diffusers.models._modeling_parallel import ContextParallelOutput from diffusers.training_utils import free_memory from diffusers.utils.logging import get_logger @@ -372,3 +379,80 @@ def test_invocation_order_stateful_last(self): .replace("\n", "") ) assert output == expected_invocation_order_log + +class _DummyMesh: + def __init__(self, size: int): + self._size = size + + def size(self): + return self._size + + def get_group(self): + return None + + +class _DummyParallelConfig: + def __init__(self, mesh_size: int): + self._flattened_mesh = _DummyMesh(mesh_size) + self.ulysses_anything = False + self.ring_anything = False + + +class _DummyCPInput: + def __init__(self, split_dim: int, expected_dims: int | None = None, split_output: bool = False): + self.split_dim = split_dim + self.expected_dims = expected_dims + self.split_output = split_output + + +class ContextParallelHooksTests: + def setup_method(self): + self.parallel_config = _DummyParallelConfig(mesh_size=3) + self.hook = ContextParallelSplitHook(metadata={}, parallel_config=self.parallel_config) + self.module = DummyModel(in_features=1, hidden_features=1, out_features=1, num_layers=1) + self.hook.initialize_hook(self.module) + + def test_prepare_cp_input_pads_hidden_states(self): + x = torch.randn(1, 7, 16) + cp_input = _DummyCPInput(split_dim=1, expected_dims=3) + + with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t): + out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states") + + assert out.shape[1] == 9 + assert self.parallel_config._cp_pad_lengths[1] == 7 + + def test_prepare_cp_input_pads_attention_mask_with_zeros(self): + mask = torch.ones(1, 7, dtype=torch.long) + cp_input = _DummyCPInput(split_dim=1, expected_dims=2) + + with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t): + out_mask = self.hook._prepare_cp_input(mask, cp_input, name="encoder_hidden_states_mask") + + assert out_mask.shape[1] == 9 + assert torch.equal(out_mask[:, -2:], torch.zeros(1, 2, dtype=torch.long)) + + def test_prepare_cp_input_no_pad_when_divisible(self): + x = torch.randn(1, 6, 16) + cp_input = _DummyCPInput(split_dim=1, expected_dims=3) + + with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t): + out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states") + + assert out.shape[1] == 6 + assert not hasattr(self.parallel_config, "_cp_pad_lengths") + + def test_gather_hook_trims_padded_output(self): + self.parallel_config._cp_pad_lengths = {1: 7} + gather_hook = ContextParallelGatherHook( + metadata=[ContextParallelOutput(gather_dim=1, expected_dims=3)], + parallel_config=self.parallel_config, + ) + x = torch.randn(1, 9, 16) + + with patch.object(EquipartitionSharder, "unshard", side_effect=lambda t, dim, mesh: t): + out = gather_hook.post_forward(self.module, x) + + assert out.shape[1] == 7 + assert 1 not in getattr(self.parallel_config, "_cp_pad_lengths", {}) +