Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,37 @@ def prepare_attention_mask(

return attention_mask

def prepare_joint_attention_mask(
self, attention_mask: torch.Tensor, target_length: int, dtype: torch.dtype
) -> torch.Tensor:
r"""
Prepare a joint attention mask for MMDiT-style processors that concatenate
`[hidden_states, encoder_hidden_states]` before attention (SD3 order).

The input mask covers text tokens only; image/latent tokens are left-padded with
`1.0` (attend) so padding respects the concat order.
"""
if attention_mask is None:
return attention_mask

current_length: int = attention_mask.shape[-1]
remaining_length: int = target_length - current_length
if current_length != target_length:
if attention_mask.device.type == "mps":
padding_shape = (attention_mask.shape[0], remaining_length)
padding = torch.ones(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat([padding, attention_mask], dim=1)
else:
attention_mask = F.pad(attention_mask, (remaining_length, 0), value=1.0)

if attention_mask.dim() == 3:
attention_mask = attention_mask[:, None, :, :]
elif attention_mask.dim() == 2:
attention_mask = attention_mask[:, None, None, :]

attention_mask = (1.0 - attention_mask.to(dtype)) * torch.finfo(dtype).min
return attention_mask

def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
r"""
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
Expand Down Expand Up @@ -1481,7 +1512,12 @@ def __call__(
key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
if attention_mask is not None:
attention_mask = attn.prepare_joint_attention_mask(attention_mask, key.shape[2], key.dtype)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Joint mask on image self-attention

High Severity

In JointAttnProcessor2_0, prepare_joint_attention_mask runs whenever attention_mask is set, even when encoder_hidden_states is None. SD3.5 dual-attention blocks pass the same joint_attention_kwargs (including the text mask) into the second JointAttnProcessor2_0 self-attention pass, so image-only keys get a wrongly padded joint mask and incorrect SDPA masking.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit f41465a. Configure here.


hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand Down Expand Up @@ -1880,7 +1916,12 @@ def __call__(
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
if attention_mask is not None:
attention_mask = attn.prepare_joint_attention_mask(attention_mask, key.shape[2], key.dtype)

hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

Expand Down
26 changes: 26 additions & 0 deletions tests/models/test_sd3_joint_attention_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# coding=utf-8
import torch
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
from ..testing_utils import enable_full_determinism, torch_device
enable_full_determinism()

class TestSD3JointAttentionMaskProcessor:
def test_joint_attention_mask_makes_padding_invariant(self):
attn = Attention(
query_dim=32, added_kv_proj_dim=32, dim_head=8, heads=4, out_dim=32,
context_pre_only=False, bias=True, processor=JointAttnProcessor2_0(), eps=1e-6,
).eval().to(torch_device)
batch_size, image_seq_len, short_text_len, long_text_len, dim = 1, 64, 20, 40, 32
gen = torch.Generator(device=torch_device).manual_seed(0)
hidden_states = torch.randn(batch_size, image_seq_len, dim, generator=gen, device=torch_device)
encoder_short = torch.randn(batch_size, short_text_len, dim, generator=gen, device=torch_device)
encoder_long = torch.zeros(batch_size, long_text_len, dim, device=torch_device, dtype=encoder_short.dtype)
encoder_long[:, :short_text_len] = encoder_short
mask_short = torch.ones((batch_size, short_text_len), device=torch_device)
mask_long = torch.zeros((batch_size, long_text_len), device=torch_device)
mask_long[:, :short_text_len] = 1.0
with torch.no_grad():
out_s_hs, out_s_enc = attn(hidden_states=hidden_states, encoder_hidden_states=encoder_short, attention_mask=mask_short)
out_l_hs, out_l_enc = attn(hidden_states=hidden_states, encoder_hidden_states=encoder_long, attention_mask=mask_long)
assert (out_s_hs - out_l_hs).abs().max().item() < 1e-5
assert (out_s_enc - out_l_enc[:, :short_text_len]).abs().max().item() < 1e-5
43 changes: 43 additions & 0 deletions tests/models/transformers/test_models_transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,46 @@ class TestSD35TransformerBitsAndBytes(SD35TransformerTesterConfig, BitsAndBytesT

class TestSD35TransformerTorchAo(SD35TransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for SD3.5 Transformer."""


class TestSD3JointAttentionMask(SD3TransformerTesterConfig):
def test_joint_attention_mask_makes_padding_invariant(self):
init_dict = self.get_init_dict()
model = SD3Transformer2DModel(**init_dict).to(torch_device)
model.eval()

inputs = self.get_dummy_inputs(batch_size=1)
hidden_states = inputs["hidden_states"]
pooled_projections = inputs["pooled_projections"]
timestep = inputs["timestep"]

content_length = 80
padded_length = 154
embedding_dim = init_dict["joint_attention_dim"]

content = torch.randn(1, content_length, embedding_dim, generator=self.generator, device=torch_device)
padded = torch.zeros(1, padded_length, embedding_dim, device=torch_device, dtype=content.dtype)
padded[:, :content_length] = content

attention_mask = torch.zeros(1, padded_length, device=torch_device)
attention_mask[:, :content_length] = 1.0

with torch.no_grad():
out_padded = model(
hidden_states=hidden_states,
encoder_hidden_states=padded,
pooled_projections=pooled_projections,
timestep=timestep,
joint_attention_kwargs={"attention_mask": attention_mask},
return_dict=False,
)[0]

out_trimmed = model(
hidden_states=hidden_states,
encoder_hidden_states=content,
pooled_projections=pooled_projections,
timestep=timestep,
return_dict=False,
)[0]

assert torch.allclose(out_padded, out_trimmed, atol=1e-5, rtol=1e-4)