Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 52 additions & 12 deletions src/diffusers/hooks/context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,29 @@

logger = get_logger(__name__) # pylint: disable=invalid-name

def _get_cp_pad_lengths(parallel_config: ContextParallelConfig) -> dict[int, int]:
pad_lengths = getattr(parallel_config, "_cp_pad_lengths", None)
if pad_lengths is None:
pad_lengths = {}
parallel_config._cp_pad_lengths = pad_lengths
return pad_lengths


def _pad_tensor_for_context_parallel(
x: torch.Tensor, dim: int, world_size: int, pad_value: float | int = 0.0
) -> torch.Tensor:
seq_len = x.size(dim)
if world_size <= 1 or seq_len % world_size == 0:
return x

pad_len = world_size - (seq_len % world_size)
pad_width = [0] * (2 * x.dim())
pad_idx = x.dim() - 1 - dim
pad_width[2 * pad_idx + 1] = pad_len
return torch.nn.functional.pad(x, tuple(pad_width), mode="constant", value=pad_value)



_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"

Expand Down Expand Up @@ -156,7 +179,7 @@ def pre_forward(self, module, *args, **kwargs):
# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
# the output instead of input for a particular layer by setting split_output=True
if isinstance(input_val, torch.Tensor):
input_val = self._prepare_cp_input(input_val, cpm)
input_val = self._prepare_cp_input(input_val, cpm, name)
elif isinstance(input_val, (list, tuple)):
if len(input_val) != len(cpm):
raise ValueError(
Expand All @@ -165,7 +188,7 @@ def pre_forward(self, module, *args, **kwargs):
sharded_input_val = []
for i, x in enumerate(input_val):
if torch.is_tensor(x) and not cpm[i].split_output:
x = self._prepare_cp_input(x, cpm[i])
x = self._prepare_cp_input(x, cpm[i], name)
sharded_input_val.append(x)
input_val = sharded_input_val
else:
Expand Down Expand Up @@ -198,23 +221,32 @@ def post_forward(self, module, output):
if index >= len(output):
raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
current_output = output[index]
current_output = self._prepare_cp_input(current_output, cpm)
current_output = self._prepare_cp_input(current_output, cpm, "")
output[index] = current_output

return output[0] if is_tensor else tuple(output)

def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput, name: str = "") -> torch.Tensor:
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
logger.warning_once(
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied."
)
return x
else:
if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything:
return PartitionAnythingSharder.shard_anything(
x, cp_input.split_dim, self.parallel_config._flattened_mesh
)
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)

mesh = self.parallel_config._flattened_mesh
dim = cp_input.split_dim

if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything:
return PartitionAnythingSharder.shard_anything(x, dim, mesh)

world_size = mesh.size()
seq_len = x.size(dim)
if world_size > 1 and seq_len % world_size != 0:
pad_value = 0 if "mask" in name.lower() else 0.0
x = _pad_tensor_for_context_parallel(x, dim, world_size, pad_value=pad_value)
_get_cp_pad_lengths(self.parallel_config)[dim] = seq_len

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pad lengths collide per dim

High Severity

_cp_pad_lengths stores one pre-pad length per shard dimension, but several context-parallel inputs (e.g. image hidden_states, text encoder_hidden_states, and encoder_hidden_states_mask in QwenImage) share split_dim=1 with different sequence lengths. Later inputs overwrite earlier entries, so ContextParallelGatherHook may trim proj_out to the wrong length, corrupting output or causing narrow to fail.

Additional Locations (1)
Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit 1be07ba. Configure here.


return EquipartitionSharder.shard(x, dim, mesh)


class ContextParallelGatherHook(ModelHook):
Expand All @@ -236,18 +268,26 @@ def post_forward(self, module, output):
if len(output) != len(self.metadata):
raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")

pad_lengths = getattr(self.parallel_config, "_cp_pad_lengths", None)

for i, cpm in enumerate(self.metadata):
if cpm is None:
continue
if self.parallel_config.ulysses_anything or self.parallel_config.ring_anything:
output[i] = PartitionAnythingSharder.unshard_anything(
x = PartitionAnythingSharder.unshard_anything(
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)
else:
output[i] = EquipartitionSharder.unshard(
x = EquipartitionSharder.unshard(
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)

if pad_lengths and cpm.gather_dim in pad_lengths:
original_len = pad_lengths.pop(cpm.gather_dim)
x = x.narrow(cpm.gather_dim, 0, original_len)

output[i] = x

return output[0] if is_tensor else tuple(output)


Expand Down
84 changes: 84 additions & 0 deletions tests/hooks/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,18 @@
# limitations under the License.

import gc
from unittest.mock import patch

import pytest
import torch

from diffusers.hooks import HookRegistry, ModelHook
from diffusers.hooks.context_parallel import (
ContextParallelGatherHook,
ContextParallelSplitHook,
EquipartitionSharder,
)
from diffusers.models._modeling_parallel import ContextParallelOutput
from diffusers.training_utils import free_memory
from diffusers.utils.logging import get_logger

Expand Down Expand Up @@ -372,3 +379,80 @@ def test_invocation_order_stateful_last(self):
.replace("\n", "")
)
assert output == expected_invocation_order_log

class _DummyMesh:
def __init__(self, size: int):
self._size = size

def size(self):
return self._size

def get_group(self):
return None


class _DummyParallelConfig:
def __init__(self, mesh_size: int):
self._flattened_mesh = _DummyMesh(mesh_size)
self.ulysses_anything = False
self.ring_anything = False


class _DummyCPInput:
def __init__(self, split_dim: int, expected_dims: int | None = None, split_output: bool = False):
self.split_dim = split_dim
self.expected_dims = expected_dims
self.split_output = split_output


class ContextParallelHooksTests:
def setup_method(self):
self.parallel_config = _DummyParallelConfig(mesh_size=3)
self.hook = ContextParallelSplitHook(metadata={}, parallel_config=self.parallel_config)
self.module = DummyModel(in_features=1, hidden_features=1, out_features=1, num_layers=1)
self.hook.initialize_hook(self.module)

def test_prepare_cp_input_pads_hidden_states(self):
x = torch.randn(1, 7, 16)
cp_input = _DummyCPInput(split_dim=1, expected_dims=3)

with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t):
out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states")

assert out.shape[1] == 9
assert self.parallel_config._cp_pad_lengths[1] == 7

def test_prepare_cp_input_pads_attention_mask_with_zeros(self):
mask = torch.ones(1, 7, dtype=torch.long)
cp_input = _DummyCPInput(split_dim=1, expected_dims=2)

with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t):
out_mask = self.hook._prepare_cp_input(mask, cp_input, name="encoder_hidden_states_mask")

assert out_mask.shape[1] == 9
assert torch.equal(out_mask[:, -2:], torch.zeros(1, 2, dtype=torch.long))

def test_prepare_cp_input_no_pad_when_divisible(self):
x = torch.randn(1, 6, 16)
cp_input = _DummyCPInput(split_dim=1, expected_dims=3)

with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t):
out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states")

assert out.shape[1] == 6
assert not hasattr(self.parallel_config, "_cp_pad_lengths")

def test_gather_hook_trims_padded_output(self):
self.parallel_config._cp_pad_lengths = {1: 7}
gather_hook = ContextParallelGatherHook(
metadata=[ContextParallelOutput(gather_dim=1, expected_dims=3)],
parallel_config=self.parallel_config,
)
x = torch.randn(1, 9, 16)

with patch.object(EquipartitionSharder, "unshard", side_effect=lambda t, dim, mesh: t):
out = gather_hook.post_forward(self.module, x)

assert out.shape[1] == 7
assert 1 not in getattr(self.parallel_config, "_cp_pad_lengths", {})