Skip to content

Commit 28b0ab0

Browse files
committed
Fix context parallel pad length collision across shared split dims
Store pre-pad sequence lengths keyed by input name instead of shard dimension so tensors like hidden_states and encoder_hidden_states_mask that share split_dim=1 no longer overwrite each other. The gather hook looks up the pad length for hidden_states when trimming gathered output.
1 parent 1be07ba commit 28b0ab0

2 files changed

Lines changed: 9 additions & 7 deletions

File tree

src/diffusers/hooks/context_parallel.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
logger = get_logger(__name__) # pylint: disable=invalid-name
3939

40-
def _get_cp_pad_lengths(parallel_config: ContextParallelConfig) -> dict[int, int]:
40+
def _get_cp_pad_lengths(parallel_config: ContextParallelConfig) -> dict[str, int]:
4141
pad_lengths = getattr(parallel_config, "_cp_pad_lengths", None)
4242
if pad_lengths is None:
4343
pad_lengths = {}
@@ -244,7 +244,8 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput, nam
244244
if world_size > 1 and seq_len % world_size != 0:
245245
pad_value = 0 if "mask" in name.lower() else 0.0
246246
x = _pad_tensor_for_context_parallel(x, dim, world_size, pad_value=pad_value)
247-
_get_cp_pad_lengths(self.parallel_config)[dim] = seq_len
247+
if name:
248+
_get_cp_pad_lengths(self.parallel_config)[name] = seq_len
248249

249250
return EquipartitionSharder.shard(x, dim, mesh)
250251

@@ -282,8 +283,9 @@ def post_forward(self, module, output):
282283
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
283284
)
284285

285-
if pad_lengths and cpm.gather_dim in pad_lengths:
286-
original_len = pad_lengths.pop(cpm.gather_dim)
286+
unpad_key = getattr(cpm, "unpad_key", "hidden_states")
287+
if pad_lengths and unpad_key and unpad_key in pad_lengths:
288+
original_len = pad_lengths.pop(unpad_key)
287289
x = x.narrow(cpm.gather_dim, 0, original_len)
288290

289291
output[i] = x

tests/hooks/test_hooks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def test_prepare_cp_input_pads_hidden_states(self):
420420
out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states")
421421

422422
assert out.shape[1] == 9
423-
assert self.parallel_config._cp_pad_lengths[1] == 7
423+
assert self.parallel_config._cp_pad_lengths["hidden_states"] == 7
424424

425425
def test_prepare_cp_input_pads_attention_mask_with_zeros(self):
426426
mask = torch.ones(1, 7, dtype=torch.long)
@@ -443,7 +443,7 @@ def test_prepare_cp_input_no_pad_when_divisible(self):
443443
assert not hasattr(self.parallel_config, "_cp_pad_lengths")
444444

445445
def test_gather_hook_trims_padded_output(self):
446-
self.parallel_config._cp_pad_lengths = {1: 7}
446+
self.parallel_config._cp_pad_lengths = {"hidden_states": 7}
447447
gather_hook = ContextParallelGatherHook(
448448
metadata=[ContextParallelOutput(gather_dim=1, expected_dims=3)],
449449
parallel_config=self.parallel_config,
@@ -454,5 +454,5 @@ def test_gather_hook_trims_padded_output(self):
454454
out = gather_hook.post_forward(self.module, x)
455455

456456
assert out.shape[1] == 7
457-
assert 1 not in getattr(self.parallel_config, "_cp_pad_lengths", {})
457+
assert "hidden_states" not in getattr(self.parallel_config, "_cp_pad_lengths", {})
458458

0 commit comments

Comments
 (0)