diff --git a/src/diffusers/pipelines/ltx2/connectors.py b/src/diffusers/pipelines/ltx2/connectors.py index 8a00a0c6b452..0b0b333209aa 100644 --- a/src/diffusers/pipelines/ltx2/connectors.py +++ b/src/diffusers/pipelines/ltx2/connectors.py @@ -302,13 +302,17 @@ def forward( if binary_attn_mask.ndim == 4: binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L] - # Replace padding positions with learned registers using vectorized masking - mask = binary_attn_mask.unsqueeze(-1) # [B, L, 1] + # Move the valid tokens to the front in their original order and fill the tail + # with registers indexed by absolute position, matching the original LTX + # implementation (`_replace_padded_with_learnable_registers`). A stable argsort + # of the inverted mask gathers valid tokens first while preserving their order. + order = torch.argsort(1 - binary_attn_mask, dim=1, stable=True) # [B, L] + front_aligned = torch.gather(hidden_states, 1, order.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1])) + num_valid = binary_attn_mask.sum(dim=1, keepdim=True) # [B, 1] + positions = torch.arange(seq_len, device=hidden_states.device).unsqueeze(0) # [1, L] + front_mask = (positions < num_valid).unsqueeze(-1) # [B, L, 1] registers_expanded = registers.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, D] - hidden_states = mask * hidden_states + (1 - mask) * registers_expanded - - # Flip sequence: embeddings move to front, registers to back (from left padding layout) - hidden_states = torch.flip(hidden_states, dims=[1]) + hidden_states = torch.where(front_mask, front_aligned, registers_expanded.to(hidden_states.dtype)) # Overwrite attention_mask with an all-zeros mask if using registers. attention_mask = torch.zeros_like(attention_mask) diff --git a/tests/pipelines/ltx2/test_ltx2_connectors.py b/tests/pipelines/ltx2/test_ltx2_connectors.py new file mode 100644 index 000000000000..f8209ea75e3f --- /dev/null +++ b/tests/pipelines/ltx2/test_ltx2_connectors.py @@ -0,0 +1,99 @@ +# Copyright 2026 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers.pipelines.ltx2.connectors import LTX2ConnectorTransformer1d + +from ...testing_utils import enable_full_determinism + + +enable_full_determinism() + + +class LTX2ConnectorRegisterLayoutTests(unittest.TestCase): + """The connector must lay out its sequence exactly like the original LTX + implementation (``ltx_core`` ``_replace_padded_with_learnable_registers``, + also matched by ComfyUI): the valid tokens move to the front *in their + original order*, and the tail is filled with the tiled learnable registers + indexed by *absolute position*. The connector blocks apply RoPE, so any + deviation (e.g. reversed token order) produces embeddings the DiT was + never trained on. + """ + + num_registers = 4 + seq_len = 12 + num_heads = 2 + head_dim = 4 + + def get_connector(self): + # num_layers=0 keeps the forward to layout + final RMSNorm, so the + # register layout can be checked exactly. + return LTX2ConnectorTransformer1d( + num_attention_heads=self.num_heads, + attention_head_dim=self.head_dim, + num_layers=0, + num_learnable_registers=self.num_registers, + ).eval() + + def get_inputs(self, valid_lengths): + dim = self.num_heads * self.head_dim + batch_size = len(valid_lengths) + hidden_states = torch.randn(batch_size, self.seq_len, dim) + # Left padding, like the Gemma tokenization in the LTX2 pipelines. + binary_mask = torch.zeros(batch_size, self.seq_len, dtype=torch.int64) + for i, n in enumerate(valid_lengths): + binary_mask[i, self.seq_len - n :] = 1 + additive_mask = (binary_mask - 1).to(hidden_states.dtype) + additive_mask = additive_mask.reshape(batch_size, 1, 1, self.seq_len) + additive_mask = additive_mask * torch.finfo(hidden_states.dtype).max + return hidden_states, binary_mask, additive_mask + + def reference_layout(self, connector, hidden_states, binary_mask): + # Reference semantics: front-align valid tokens (order preserved), + # fill the tail with the register tile by absolute position. + batch_size, seq_len, _ = hidden_states.shape + registers = connector.learnable_registers.detach() + tiled = registers.repeat(seq_len // self.num_registers, 1) + expected = torch.empty_like(hidden_states) + for i in range(batch_size): + valid = hidden_states[i, binary_mask[i].bool()] + expected[i, : valid.shape[0]] = valid + expected[i, valid.shape[0] :] = tiled[valid.shape[0] :] + # The forward ends with a non-affine RMSNorm. + return expected * torch.rsqrt(expected.pow(2).mean(-1, keepdim=True) + 1e-6) + + def check_layout(self, valid_lengths): + connector = self.get_connector() + hidden_states, binary_mask, additive_mask = self.get_inputs(valid_lengths) + with torch.no_grad(): + output, _ = connector(hidden_states, additive_mask) + expected = self.reference_layout(connector, hidden_states, binary_mask) + self.assertTrue(torch.allclose(output, expected, atol=1e-5)) + + def test_register_layout_left_padded(self): + self.check_layout([5]) + + def test_register_layout_mixed_lengths_batch(self): + # The pipelines concatenate negative and positive prompts of different + # lengths into one batch; the layout must be computed per row. + self.check_layout([5, 2]) + + def test_register_layout_fully_valid(self): + self.check_layout([self.seq_len]) + + def test_register_layout_single_token(self): + self.check_layout([1])