diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 81b36e113df4..5b5c0acb785a 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -500,6 +500,8 @@ "Ideogram4ModularPipeline", "LTXAutoBlocks", "LTXModularPipeline", + "PixArtAlphaAutoBlocks", + "PixArtAlphaModularPipeline", "QwenImageAutoBlocks", "QwenImageEditAutoBlocks", "QwenImageEditModularPipeline", @@ -1347,6 +1349,8 @@ Ideogram4ModularPipeline, LTXAutoBlocks, LTXModularPipeline, + PixArtAlphaAutoBlocks, + PixArtAlphaModularPipeline, QwenImageAutoBlocks, QwenImageEditAutoBlocks, QwenImageEditModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 4b36994aef07..356f8c054dbb 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -113,6 +113,7 @@ "ZImageAutoBlocks", "ZImageModularPipeline", ] + _import_structure["pixart_alpha"] = ["PixArtAlphaAutoBlocks", "PixArtAlphaModularPipeline"] _import_structure["components_manager"] = ["ComponentsManager"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -162,6 +163,7 @@ SequentialPipelineBlocks, ) from .modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, InsertableDict, OutputParam + from .pixart_alpha import PixArtAlphaAutoBlocks, PixArtAlphaModularPipeline from .qwenimage import ( QwenImageAutoBlocks, QwenImageEditAutoBlocks, diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index a121553b7588..957d06cde27c 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -131,6 +131,7 @@ def _helios_pyramid_map_fn(config_dict=None): ("qwenimage-edit", _create_default_map_fn("QwenImageEditModularPipeline")), ("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")), ("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")), + ("pixart-alpha", _create_default_map_fn("PixArtAlphaModularPipeline")), ("anima", _create_default_map_fn("AnimaModularPipeline")), ("z-image", _create_default_map_fn("ZImageModularPipeline")), ("helios", _create_default_map_fn("HeliosModularPipeline")), diff --git a/src/diffusers/modular_pipelines/pixart_alpha/__init__.py b/src/diffusers/modular_pipelines/pixart_alpha/__init__.py new file mode 100644 index 000000000000..23ed865b48e8 --- /dev/null +++ b/src/diffusers/modular_pipelines/pixart_alpha/__init__.py @@ -0,0 +1,47 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_pixart_alpha"] = ["PixArtAlphaAutoBlocks"] + _import_structure["modular_pipeline"] = ["PixArtAlphaModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_pixart_alpha import PixArtAlphaAutoBlocks + from .modular_pipeline import PixArtAlphaModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/pixart_alpha/before_denoise.py b/src/diffusers/modular_pipelines/pixart_alpha/before_denoise.py new file mode 100644 index 000000000000..648198943aaa --- /dev/null +++ b/src/diffusers/modular_pipelines/pixart_alpha/before_denoise.py @@ -0,0 +1,512 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import torch + +from ...models import PixArtTransformer2DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import PixArtAlphaModularPipeline + + +logger = logging.get_logger(__name__) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.modular_pipelines.qwenimage.inputs.repeat_tensor_to_batch_size +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_images_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times + - If batch size equals batch_size: repeat each element num_images_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_images_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_images_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_images_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + + +# text input step (per-prompt expansion) + + +# auto_docstring +class PixArtAlphaTextInputStep(ModularPipelineBlocks): + """ + Input step that expands the text embeddings to the final batch size (batch_size * num_images_per_prompt). + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + + Outputs: + batch_size (`int`): + The number of prompts. + dtype (`dtype`): + The dtype of the text embeddings, used for the latents. + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "pixart-alpha" + + @property + def description(self) -> str: + return ( + "Input step that expands the text embeddings to the final batch size (batch_size * num_images_per_prompt)." + ) + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_images_per_prompt"), + InputParam.template("prompt_embeds"), + InputParam.template("prompt_embeds_mask"), + InputParam.template("negative_prompt_embeds"), + InputParam.template("negative_prompt_embeds_mask"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam("batch_size", type_hint=int, description="The number of prompts."), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="The dtype of the text embeddings, used for the latents.", + ), + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), + ] + + @torch.no_grad() + def __call__(self, components: PixArtAlphaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + batch_size = block_state.batch_size + num_images_per_prompt = block_state.num_images_per_prompt + + block_state.prompt_embeds = repeat_tensor_to_batch_size( + "prompt_embeds", block_state.prompt_embeds, batch_size, num_images_per_prompt + ) + block_state.prompt_embeds_mask = repeat_tensor_to_batch_size( + "prompt_embeds_mask", block_state.prompt_embeds_mask, batch_size, num_images_per_prompt + ) + + if block_state.negative_prompt_embeds is not None: + block_state.negative_prompt_embeds = repeat_tensor_to_batch_size( + "negative_prompt_embeds", block_state.negative_prompt_embeds, batch_size, num_images_per_prompt + ) + block_state.negative_prompt_embeds_mask = repeat_tensor_to_batch_size( + "negative_prompt_embeds_mask", + block_state.negative_prompt_embeds_mask, + batch_size, + num_images_per_prompt, + ) + + self.set_block_state(state, block_state) + return components, state + + +# set timesteps step + + +# auto_docstring +class PixArtAlphaSetTimestepsStep(ModularPipelineBlocks): + """ + Step that sets the scheduler's timesteps for the denoising process. + + Components: + scheduler (`DPMSolverMultistepScheduler`) + + Inputs: + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + timesteps (`Tensor`, *optional*): + Timesteps for the denoising process. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + + Outputs: + timesteps (`Tensor`): + The timestep schedule for the denoising loop. + num_inference_steps (`int`): + The number of denoising steps. + """ + + model_name = "pixart-alpha" + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for the denoising process." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", DPMSolverMultistepScheduler), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("num_inference_steps"), + InputParam.template("timesteps"), + InputParam.template("sigmas"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "timesteps", type_hint=torch.Tensor, description="The timestep schedule for the denoising loop." + ), + OutputParam("num_inference_steps", type_hint=int, description="The number of denoising steps."), + ] + + @torch.no_grad() + def __call__(self, components: PixArtAlphaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( + components.scheduler, + block_state.num_inference_steps, + components._execution_device, + block_state.timesteps, + block_state.sigmas, + ) + components.scheduler.set_begin_index(0) + + self.set_block_state(state, block_state) + return components, state + + +# prepare latents step + + +# auto_docstring +class PixArtAlphaPrepareLatentsStep(ModularPipelineBlocks): + """ + Step that prepares the initial random noise latents for the denoising process. + + Components: + scheduler (`DPMSolverMultistepScheduler`) + + Inputs: + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + latents (`Tensor`): + The initial noisy latents for the denoising loop. + """ + + model_name = "pixart-alpha" + + @property + def description(self) -> str: + return "Step that prepares the initial random noise latents for the denoising process." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("scheduler", DPMSolverMultistepScheduler), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("num_images_per_prompt"), + InputParam.template("batch_size", required=True), + InputParam.template("generator"), + InputParam.template("latents"), + InputParam.template("dtype"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="The initial noisy latents for the denoising loop.", + ), + ] + + @staticmethod + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % vae_scale_factor != 0: + raise ValueError(f"`height` must be divisible by {vae_scale_factor} but is {height}.") + if width is not None and width % vae_scale_factor != 0: + raise ValueError(f"`width` must be divisible by {vae_scale_factor} but is {width}.") + + @torch.no_grad() + def __call__(self, components: PixArtAlphaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs(block_state.height, block_state.width, components.vae_scale_factor) + + device = components._execution_device + height = block_state.height or components.default_height + width = block_state.width or components.default_width + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + shape = ( + batch_size, + components.num_channels_latents, + int(height) // components.vae_scale_factor, + int(width) // components.vae_scale_factor, + ) + + if block_state.latents is None: + block_state.latents = randn_tensor( + shape, generator=block_state.generator, device=device, dtype=block_state.dtype + ) + else: + block_state.latents = block_state.latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + block_state.latents = block_state.latents * components.scheduler.init_noise_sigma + + self.set_block_state(state, block_state) + return components, state + + +# prepare micro-conditions step (resolution / aspect ratio) + + +# auto_docstring +class PixArtAlphaPrepareMicroConditionsStep(ModularPipelineBlocks): + """ + Step that prepares the `resolution` and `aspect_ratio` micro-conditions consumed by the transformer. + + Components: + transformer (`PixArtTransformer2DModel`) + + Inputs: + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + batch_size (`int`): + Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can + be generated in input step. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + dtype (`dtype`, *optional*, defaults to torch.float32): + The dtype of the model inputs, can be generated in input step. + + Outputs: + resolution (`Tensor`): + The resolution micro-condition, or None when unused by the checkpoint. + aspect_ratio (`Tensor`): + The aspect-ratio micro-condition, or None when unused by the checkpoint. + """ + + model_name = "pixart-alpha" + + @property + def description(self) -> str: + return "Step that prepares the `resolution` and `aspect_ratio` micro-conditions consumed by the transformer." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("transformer", PixArtTransformer2DModel), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("height"), + InputParam.template("width"), + InputParam.template("batch_size", required=True), + InputParam.template("num_images_per_prompt"), + InputParam.template("dtype"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "resolution", + type_hint=torch.Tensor, + description="The resolution micro-condition, or None when unused by the checkpoint.", + ), + OutputParam( + "aspect_ratio", + type_hint=torch.Tensor, + description="The aspect-ratio micro-condition, or None when unused by the checkpoint.", + ), + ] + + @torch.no_grad() + def __call__(self, components: PixArtAlphaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.resolution = None + block_state.aspect_ratio = None + + if components.transformer.config.sample_size == 128: + device = components._execution_device + height = block_state.height or components.default_height + width = block_state.width or components.default_width + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + resolution = torch.tensor([height, width]).repeat(batch_size, 1) + aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1) + block_state.resolution = resolution.to(dtype=block_state.dtype, device=device) + block_state.aspect_ratio = aspect_ratio.to(dtype=block_state.dtype, device=device) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/pixart_alpha/decoders.py b/src/diffusers/modular_pipelines/pixart_alpha/decoders.py new file mode 100644 index 000000000000..e997f0685ab1 --- /dev/null +++ b/src/diffusers/modular_pipelines/pixart_alpha/decoders.py @@ -0,0 +1,158 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import PixArtAlphaModularPipeline + + +logger = logging.get_logger(__name__) + + +# decode step + + +# auto_docstring +class PixArtAlphaDecodeStep(ModularPipelineBlocks): + """ + Step that decodes the denoised latents into an image tensor with the VAE. + + Components: + vae (`AutoencoderKL`) + + Inputs: + latents (`Tensor`): + The denoised latents to decode, can be generated in the denoise step. + + Outputs: + images (`list`): + Generated images. (tensor output of the vae decoder.) + """ + + model_name = "pixart-alpha" + + @property + def description(self) -> str: + return "Step that decodes the denoised latents into an image tensor with the VAE." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("vae", AutoencoderKL), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents to decode, can be generated in the denoise step.", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam.template("images", note="tensor output of the vae decoder.")] + + @torch.no_grad() + def __call__(self, components: PixArtAlphaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + latents = block_state.latents / components.vae.config.scaling_factor + block_state.images = components.vae.decode(latents, return_dict=False)[0] + + self.set_block_state(state, block_state) + return components, state + + +# postprocess the decoded images + + +# auto_docstring +class PixArtAlphaProcessImagesOutputStep(ModularPipelineBlocks): + """ + Step that postprocesses the decoded image tensor into the requested output format. + + Components: + image_processor (`VaeImageProcessor`) + + Inputs: + images (`Tensor`): + The image tensor from the decode step. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "pixart-alpha" + + @property + def description(self) -> str: + return "Step that postprocesses the decoded image tensor into the requested output format." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "images", + required=True, + type_hint=torch.Tensor, + description="The image tensor from the decode step.", + ), + InputParam.template("output_type"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam.template("images")] + + @staticmethod + def check_inputs(output_type): + if output_type not in ["pil", "np", "pt"]: + raise ValueError(f"Invalid output_type: {output_type}") + + @torch.no_grad() + def __call__(self, components: PixArtAlphaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + self.check_inputs(block_state.output_type) + + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/pixart_alpha/denoise.py b/src/diffusers/modular_pipelines/pixart_alpha/denoise.py new file mode 100644 index 000000000000..2ebd7ebffbe6 --- /dev/null +++ b/src/diffusers/modular_pipelines/pixart_alpha/denoise.py @@ -0,0 +1,246 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...models import PixArtTransformer2DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import PixArtAlphaModularPipeline + + +logger = logging.get_logger(__name__) + + +class PixArtAlphaLoopDenoiser(ModularPipelineBlocks): + model_name = "pixart-alpha" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.5}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", PixArtTransformer2DModel), + ComponentSpec("scheduler", DPMSolverMultistepScheduler), + ] + + @property + def description(self) -> str: + return "Step within the denoising loop that predicts the noise with classifier-free guidance." + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to denoise.", + ), + InputParam( + "num_inference_steps", + required=True, + type_hint=int, + description="The number of denoising steps.", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="The prompt embeddings used to guide the denoising.", + ), + InputParam( + "prompt_embeds_mask", + required=True, + type_hint=torch.Tensor, + description="The attention mask for the prompt embeddings.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="The negative prompt embeddings used for classifier-free guidance.", + ), + InputParam( + "negative_prompt_embeds_mask", + type_hint=torch.Tensor, + description="The attention mask for the negative prompt embeddings.", + ), + InputParam( + "resolution", + type_hint=torch.Tensor, + description="The resolution micro-condition, or None when unused by the checkpoint.", + ), + InputParam( + "aspect_ratio", + type_hint=torch.Tensor, + description="The aspect-ratio micro-condition, or None when unused by the checkpoint.", + ), + ] + + @torch.no_grad() + def __call__( + self, + components: PixArtAlphaModularPipeline, + block_state: BlockState, + i: int, + t: torch.Tensor, + ) -> PipelineState: + do_cfg = block_state.negative_prompt_embeds is not None + + guider_inputs = { + "hidden_states": (block_state.latents, block_state.latents) if do_cfg else block_state.latents, + "encoder_hidden_states": (block_state.prompt_embeds, block_state.negative_prompt_embeds) + if do_cfg + else block_state.prompt_embeds, + "encoder_attention_mask": (block_state.prompt_embeds_mask, block_state.negative_prompt_embeds_mask) + if do_cfg + else block_state.prompt_embeds_mask, + } + + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + added_cond_kwargs = {"resolution": block_state.resolution, "aspect_ratio": block_state.aspect_ratio} + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + + latent_model_input = components.scheduler.scale_model_input(guider_state_batch.hidden_states, t) + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = components.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=guider_state_batch.encoder_hidden_states, + encoder_attention_mask=guider_state_batch.encoder_attention_mask, + timestep=timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # PixArt transformers with a learned sigma predict both the noise and the variance; keep only the noise. + if components.transformer.config.out_channels // 2 == components.transformer.config.in_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + + guider_state_batch.noise_pred = noise_pred + components.guider.cleanup_models(components.transformer) + + guider_output = components.guider(guider_state) + block_state.noise_pred = guider_output.pred + + return components, block_state + + +class PixArtAlphaLoopAfterDenoiser(ModularPipelineBlocks): + model_name = "pixart-alpha" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", DPMSolverMultistepScheduler)] + + @property + def description(self) -> str: + return "Step within the denoising loop that updates the latents with the scheduler." + + @property + def inputs(self) -> list[InputParam]: + return [InputParam.template("generator")] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="The denoised latents.", + ) + ] + + @torch.no_grad() + def __call__( + self, + components: PixArtAlphaModularPipeline, + block_state: BlockState, + i: int, + t: torch.Tensor, + ): + block_state.latents = components.scheduler.step( + block_state.noise_pred, t, block_state.latents, generator=block_state.generator, return_dict=False + )[0] + + return components, block_state + + +class PixArtAlphaDenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "pixart-alpha" + + @property + def description(self) -> str: + return "Pipeline block that iteratively denoises the latents over the scheduler's timesteps." + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.5}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", PixArtTransformer2DModel), + ComponentSpec("scheduler", DPMSolverMultistepScheduler), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return [ + InputParam("timesteps", required=True, type_hint=torch.Tensor), + InputParam("num_inference_steps", required=True, type_hint=int), + ] + + @torch.no_grad() + def __call__(self, components: PixArtAlphaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, + 0, + ) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): + progress_bar.update() + + self.set_block_state(state, block_state) + return components, state + + +class PixArtAlphaDenoiseStep(PixArtAlphaDenoiseLoopWrapper): + block_classes = [PixArtAlphaLoopDenoiser, PixArtAlphaLoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] diff --git a/src/diffusers/modular_pipelines/pixart_alpha/encoders.py b/src/diffusers/modular_pipelines/pixart_alpha/encoders.py new file mode 100644 index 000000000000..faf4c4d0757e --- /dev/null +++ b/src/diffusers/modular_pipelines/pixart_alpha/encoders.py @@ -0,0 +1,335 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import re +import urllib.parse as ul + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance +from ...utils import BACKENDS_MAPPING, is_bs4_available, is_ftfy_available, logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import PixArtAlphaModularPipeline + + +logger = logging.get_logger(__name__) + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +def _get_pixart_prompt_embeds(text_encoder, tokenizer, prompt, max_sequence_length, device): + """Tokenize an already-preprocessed prompt and encode it with the T5 text encoder. + + Returns the per-prompt embeddings and attention mask without any `num_images_per_prompt` expansion — that expansion + is the responsibility of the input step in the core denoise sequence. + """ + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0] + prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + +# text encoder step + + +# auto_docstring +class PixArtAlphaTextEncoderStep(ModularPipelineBlocks): + """ + Text Encoder step that encodes the prompt into T5 hidden states to guide the image generation. + + Components: + text_encoder (`T5EncoderModel`) tokenizer (`T5Tokenizer`) guider (`ClassifierFreeGuidance`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 120): + Maximum sequence length for prompt encoding. + clean_caption (`bool`, *optional*, defaults to True): + Whether to clean the caption before encoding (requires the `bs4` and `ftfy` packages). + + Outputs: + prompt_embeds (`Tensor`): + The prompt embeddings. + prompt_embeds_mask (`Tensor`): + The encoder attention mask. + negative_prompt_embeds (`Tensor`): + The negative prompt embeddings. + negative_prompt_embeds_mask (`Tensor`): + The negative prompt embeddings mask. + """ + + model_name = "pixart-alpha" + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + @property + def description(self) -> str: + return "Text Encoder step that encodes the prompt into T5 hidden states to guide the image generation." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ + ComponentSpec("text_encoder", T5EncoderModel), + ComponentSpec("tokenizer", T5Tokenizer), + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 4.5}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam.template("prompt"), + InputParam.template("negative_prompt"), + InputParam.template("max_sequence_length", default=120), + InputParam( + "clean_caption", + default=True, + type_hint=bool, + description="Whether to clean the caption before encoding (requires the `bs4` and `ftfy` packages).", + ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam.template("prompt_embeds"), + OutputParam.template("prompt_embeds_mask"), + OutputParam.template("negative_prompt_embeds"), + OutputParam.template("negative_prompt_embeds_mask"), + ] + + @staticmethod + def check_inputs(prompt, negative_prompt): + if not isinstance(prompt, str) and not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if ( + negative_prompt is not None + and not isinstance(negative_prompt, str) + and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + @torch.no_grad() + def __call__(self, components: PixArtAlphaModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + self.check_inputs(block_state.prompt, block_state.negative_prompt) + + prompt = self._text_preprocessing(block_state.prompt, clean_caption=block_state.clean_caption) + block_state.prompt_embeds, block_state.prompt_embeds_mask = _get_pixart_prompt_embeds( + components.text_encoder, components.tokenizer, prompt, block_state.max_sequence_length, device + ) + + block_state.negative_prompt_embeds = None + block_state.negative_prompt_embeds_mask = None + if components.requires_unconditional_embeds: + negative_prompt = block_state.negative_prompt or "" + negative_prompt = self._text_preprocessing(negative_prompt, clean_caption=block_state.clean_caption) + block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = _get_pixart_prompt_embeds( + components.text_encoder, components.tokenizer, negative_prompt, block_state.max_sequence_length, device + ) + + self.set_block_state(state, block_state) + return components, state diff --git a/src/diffusers/modular_pipelines/pixart_alpha/modular_blocks_pixart_alpha.py b/src/diffusers/modular_pipelines/pixart_alpha/modular_blocks_pixart_alpha.py new file mode 100644 index 000000000000..33587b9ca438 --- /dev/null +++ b/src/diffusers/modular_pipelines/pixart_alpha/modular_blocks_pixart_alpha.py @@ -0,0 +1,149 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + PixArtAlphaPrepareLatentsStep, + PixArtAlphaPrepareMicroConditionsStep, + PixArtAlphaSetTimestepsStep, + PixArtAlphaTextInputStep, +) +from .decoders import PixArtAlphaDecodeStep, PixArtAlphaProcessImagesOutputStep +from .denoise import PixArtAlphaDenoiseStep +from .encoders import PixArtAlphaTextEncoderStep + + +logger = logging.get_logger(__name__) + + +# auto_docstring +class PixArtAlphaCoreDenoiseStep(SequentialPipelineBlocks): + """ + Components: + scheduler (`DPMSolverMultistepScheduler`) transformer (`PixArtTransformer2DModel`) guider + (`ClassifierFreeGuidance`) + + Inputs: + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + text embeddings used to guide the image generation. Can be generated from text_encoder step. + prompt_embeds_mask (`Tensor`): + mask for the text embeddings. Can be generated from text_encoder step. + negative_prompt_embeds (`Tensor`, *optional*): + negative text embeddings used to guide the image generation. Can be generated from text_encoder step. + negative_prompt_embeds_mask (`Tensor`, *optional*): + mask for the negative text embeddings. Can be generated from text_encoder step. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + timesteps (`Tensor`, *optional*): + Timesteps for the denoising process. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + + model_name = "pixart-alpha" + block_classes = [ + PixArtAlphaTextInputStep(), + PixArtAlphaSetTimestepsStep(), + PixArtAlphaPrepareLatentsStep(), + PixArtAlphaPrepareMicroConditionsStep(), + PixArtAlphaDenoiseStep(), + ] + block_names = ["text_inputs", "set_timesteps", "prepare_latents", "prepare_micro_conditions", "denoise"] + + @property + def outputs(self): + return [OutputParam.template("latents")] + + +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", PixArtAlphaTextEncoderStep()), + ("denoise", PixArtAlphaCoreDenoiseStep()), + ("decode", PixArtAlphaDecodeStep()), + ("postprocess", PixArtAlphaProcessImagesOutputStep()), + ] +) + + +# auto_docstring +class PixArtAlphaAutoBlocks(SequentialPipelineBlocks): + """ + Supported workflows: + - `text2image`: requires `prompt` + + Components: + text_encoder (`T5EncoderModel`) tokenizer (`T5Tokenizer`) guider (`ClassifierFreeGuidance`) scheduler + (`DPMSolverMultistepScheduler`) transformer (`PixArtTransformer2DModel`) vae (`AutoencoderKL`) + image_processor (`VaeImageProcessor`) + + Inputs: + prompt (`str`): + The prompt or prompts to guide image generation. + negative_prompt (`str`, *optional*): + The prompt or prompts not to guide the image generation. + max_sequence_length (`int`, *optional*, defaults to 120): + Maximum sequence length for prompt encoding. + clean_caption (`bool`, *optional*, defaults to True): + Whether to clean the caption before encoding (requires the `bs4` and `ftfy` packages). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. + timesteps (`Tensor`, *optional*): + Timesteps for the denoising process. + sigmas (`list`, *optional*): + Custom sigmas for the denoising process. + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + generator (`Generator`, *optional*): + Torch generator for deterministic generation. + latents (`Tensor`, *optional*): + Pre-generated noisy latents for image generation. + output_type (`str`, *optional*, defaults to pil): + Output format: 'pil', 'np', 'pt'. + + Outputs: + images (`list`): + Generated images. + """ + + model_name = "pixart-alpha" + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + + _workflow_map = { + "text2image": {"prompt": True}, + } + + @property + def outputs(self): + return [OutputParam.template("images")] diff --git a/src/diffusers/modular_pipelines/pixart_alpha/modular_pipeline.py b/src/diffusers/modular_pipelines/pixart_alpha/modular_pipeline.py new file mode 100644 index 000000000000..1b90b8070a77 --- /dev/null +++ b/src/diffusers/modular_pipelines/pixart_alpha/modular_pipeline.py @@ -0,0 +1,65 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...utils import logging +from ..modular_pipeline import ModularPipeline + + +logger = logging.get_logger(__name__) + + +class PixArtAlphaModularPipeline(ModularPipeline): + """ + A ModularPipeline for PixArt-Alpha. + + > [!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "PixArtAlphaAutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + if getattr(self, "transformer", None) is not None: + return self.transformer.config.sample_size + return 128 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if getattr(self, "vae", None) is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + if getattr(self, "transformer", None) is not None: + return self.transformer.config.in_channels + return 4 + + @property + def requires_unconditional_embeds(self): + requires_unconditional_embeds = False + + if hasattr(self, "guider") and self.guider is not None: + requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1 + + return requires_unconditional_embeds diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index ed4e13a57eb1..9615b2dd862b 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -392,6 +392,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class PixArtAlphaAutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class PixArtAlphaModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class QwenImageAutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/modular_pipelines/pixart_alpha/__init__.py b/tests/modular_pipelines/pixart_alpha/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/pixart_alpha/test_modular_pipeline_pixart_alpha.py b/tests/modular_pipelines/pixart_alpha/test_modular_pipeline_pixart_alpha.py new file mode 100644 index 000000000000..3419300dbad6 --- /dev/null +++ b/tests/modular_pipelines/pixart_alpha/test_modular_pipeline_pixart_alpha.py @@ -0,0 +1,65 @@ +# coding=utf-8 +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from diffusers.modular_pipelines.pixart_alpha import ( + PixArtAlphaAutoBlocks, + PixArtAlphaModularPipeline, +) + +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +PIXART_ALPHA_TEXT2IMAGE_WORKFLOWS = { + "text2image": [ + ("text_encoder", "PixArtAlphaTextEncoderStep"), + ("denoise.text_inputs", "PixArtAlphaTextInputStep"), + ("denoise.set_timesteps", "PixArtAlphaSetTimestepsStep"), + ("denoise.prepare_latents", "PixArtAlphaPrepareLatentsStep"), + ("denoise.prepare_micro_conditions", "PixArtAlphaPrepareMicroConditionsStep"), + ("denoise.denoise", "PixArtAlphaDenoiseStep"), + ("decode", "PixArtAlphaDecodeStep"), + ("postprocess", "PixArtAlphaProcessImagesOutputStep"), + ] +} + + +class TestPixArtAlphaModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = PixArtAlphaModularPipeline + pipeline_blocks_class = PixArtAlphaAutoBlocks + pretrained_model_name_or_path = "idealclx/tiny-pixart-alpha-modular" + + params = frozenset(["prompt", "height", "width"]) + batch_params = frozenset(["prompt"]) + expected_workflow_blocks = PIXART_ALPHA_TEXT2IMAGE_WORKFLOWS + + def test_pipeline_call_signature(self): + # Override to prevent signature check failure for guider configurations + # (guidance_scale) which are intentionally omitted from pipeline inputs. + pass + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + return { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "height": 64, + "width": 64, + "max_sequence_length": 48, + "output_type": "pt", + } + + def test_float16_inference(self): + super().test_float16_inference(9e-2)