Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/flow_factory/models/flux/flux1.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def inference(


# 6. Decode images
images = self.decode_latents(latents, height, width)
images = self.decode_latents(latents, height, width, output_type='pt')

# 7. Create samples
# Transpose `extra_call_back_res` tensors to have batch dimension first
Expand Down
2 changes: 1 addition & 1 deletion src/flow_factory/models/flux/flux1_kontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def inference(


# 7. Prepare output images
generated_images = self.decode_latents(latents, height, width)
generated_images = self.decode_latents(latents, height, width, output_type='pt')

# 8. Create samples
# Transpose `extra_call_back_res` tensors to have batch dimension first
Expand Down
44 changes: 27 additions & 17 deletions src/flow_factory/models/flux/flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class Flux2Sample(I2ISample):
image_latent_ids : Optional[torch.Tensor] = None


CONDITION_IMAGE_SIZE = 1024 * 1024
CONDITION_IMAGE_SIZE = (1024, 1024)

class Flux2Adapter(BaseAdapter):
def __init__(self, config: Arguments, accelerator : Accelerator):
Expand Down Expand Up @@ -111,7 +111,7 @@ def _get_mistral_3_small_prompt_embeds(
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
system_message: str = SYSTEM_MESSAGE,
hidden_states_layers: List[int] = (10, 20, 30),
hidden_states_layers: Tuple[int, ...] = (10, 20, 30),
):
dtype = self.pipeline.text_encoder.dtype if dtype is None else dtype
device = self.pipeline.text_encoder.device if device is None else device
Expand Down Expand Up @@ -159,7 +159,7 @@ def encode_prompt(
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
text_encoder_out_layers: List[int] = (10, 20, 30),
text_encoder_out_layers: Tuple[int, ...] = (10, 20, 30),
) -> Dict[str, torch.Tensor]:
"""Encode prompt(s) into embeddings using the Flux.2 text encoder."""
device = device or self.pipeline.text_encoder.device
Expand Down Expand Up @@ -267,14 +267,16 @@ def _resize_condition_images(
condition_image_size : Union[int, Tuple[int, int]] = CONDITION_IMAGE_SIZE,
) -> List[torch.Tensor]:
"""Preprocess condition images for Flux.2 model."""
if isinstance(condition_image_size, int):
condition_image_size = (condition_image_size, condition_image_size)

if isinstance(condition_images, Image.Image):
condition_images = [condition_images]

for img in condition_images:
self.pipeline.image_processor.check_image_input(img)

if isinstance(condition_image_size, int):
condition_image_size = (condition_image_size, condition_image_size)
condition_images = self._standardize_image_input(
condition_images,
output_type='pil',
)

max_area = condition_image_size[0] * condition_image_size[1]

Expand Down Expand Up @@ -331,7 +333,7 @@ def _standardize_image_input(
self,
images: Union[ImageSingle, ImageBatch],
output_type: Literal['pil', 'pt', 'np'] = 'pil',
):
) -> ImageBatch:
"""
Standardize image input to desired output type.
"""
Expand Down Expand Up @@ -367,18 +369,23 @@ def decode_latents(self, latents: torch.Tensor, latent_ids, output_type: Literal
def preprocess_func(
self,
prompt: List[str],
images: Optional[Union[List[Optional[Image.Image]], List[List[Optional[Image.Image]]]]] = None,
images: Optional[MultiImageBatch] = None,
caption_upsample_temperature: Optional[float] = None,
**kwargs
condition_image_size: Union[int, Tuple[int, int]] = CONDITION_IMAGE_SIZE,
max_sequence_length: int = 512,
text_encoder_out_layers: Tuple[int, ...] = (10, 20, 30),
generator: Optional[torch.Generator] = None,
) -> Dict[str, Union[List[Any], torch.Tensor]]:
"""
Preprocess inputs for Flux.2 model (batched processing).

Args:
prompt: List of text prompts
images: Optional images in various formats
images: Optional images in various formats (MultiImageBatch)
caption_upsample_temperature: Temperature for prompt upsampling
**kwargs: Additional arguments for encoding
max_sequence_length: Max sequence length for text encoder
text_encoder_out_layers: Layers to extract from text encoder
generator: Random generator for encoding (not used, kept for API consistency)

Returns:
Dictionary with all encoded data in list format for consistency
Expand Down Expand Up @@ -410,14 +417,17 @@ def preprocess_func(
# 3: Batch encode prompts
batch = self.encode_prompt(
prompt=final_prompts,
**filter_kwargs(self.encode_prompt, **kwargs)
max_sequence_length=max_sequence_length,
text_encoder_out_layers=text_encoder_out_layers,
)

# 4: Batch encode images if present
if has_images:
image_dict = self.encode_image(
images=images,
**filter_kwargs(self.encode_image, **kwargs)
condition_image_size=condition_image_size,
device=self.device,
generator=generator,
)
# image_dict already returns lists, so directly merge
batch.update(image_dict)
Expand Down Expand Up @@ -575,8 +585,8 @@ def _inference(
extra_call_back_res[key].append(val)

# 5. Decode latents to images
decoded_images = self.decode_latents(latents, latent_ids)
# decoded_condition_images = self.decode_latents(image_latents, image_latent_ids) if image_latents is not None else None
decoded_images = self.decode_latents(latents, latent_ids, output_type='pt')
# decoded_condition_images = self.decode_latents(image_latents, image_latent_ids, output_type='pt') if image_latents is not None else None

# 6. Create samples

Expand Down
4 changes: 2 additions & 2 deletions src/flow_factory/models/flux/flux2_klein.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class Flux2KleinSample(I2ISample):
image_latent_ids : Optional[torch.Tensor] = None


CONDITION_IMAGE_SIZE = 1024 * 1024
CONDITION_IMAGE_SIZE = (1024, 1024)

class Flux2KleinAdapter(BaseAdapter):
def __init__(self, config: Arguments, accelerator : Accelerator):
Expand Down Expand Up @@ -530,7 +530,7 @@ def _inference(
extra_call_back_res[key].append(val)

# 6. Decode latents to images
decoded_images = self.decode_latents(latents, latent_ids)
decoded_images = self.decode_latents(latents, latent_ids, output_type='pt')

# 7. Prepare samples

Expand Down
8 changes: 4 additions & 4 deletions src/flow_factory/models/qwen_image/qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import annotations

import os
from typing import Union, List, Dict, Any, Optional, Tuple, ClassVar
from typing import Union, List, Dict, Any, Optional, Tuple, ClassVar, Literal
from dataclasses import dataclass
import logging
from collections import defaultdict
Expand Down Expand Up @@ -187,7 +187,7 @@ def encode_video(self, video: Union[torch.Tensor, List[torch.Tensor]]):
"""Not needed for Qwen-Image text-to-image models."""
pass

def decode_latents(self, latents: torch.Tensor, height: int, width: int, **kwargs) -> List[Image.Image]:
def decode_latents(self, latents: torch.Tensor, height: int, width: int, output_type: Literal['pil', 'pt', 'np'] = 'pil') -> List[Image.Image]:
"""Decode latents to images using VAE."""

latents = self.pipeline._unpack_latents(latents, height, width, self.pipeline.vae_scale_factor)
Expand All @@ -202,7 +202,7 @@ def decode_latents(self, latents: torch.Tensor, height: int, width: int, **kwarg
)
latents = latents / latents_std + latents_mean
images = self.pipeline.vae.decode(latents, return_dict=False)[0][:, :, 0]
images = self.pipeline.image_processor.postprocess(images, output_type='pil')
images = self.pipeline.image_processor.postprocess(images, output_type=output_type)

return images

Expand Down Expand Up @@ -419,7 +419,7 @@ def inference(
extra_call_back_res[key].append(val)

# 6. Decode latents to images
decoded_images = self.decode_latents(latents, height, width)
decoded_images = self.decode_latents(latents, height, width, output_type='pt')

# 7. Prepare output samples

Expand Down
6 changes: 3 additions & 3 deletions src/flow_factory/models/qwen_image/qwen_image_edit_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def encode_video(self, videos: Union[torch.Tensor, List[torch.Tensor]]):
pass

# ---------------------------------------- Image Decoding ---------------------------------- #
def decode_latents(self, latents: torch.Tensor, height: int, width: int, **kwargs) -> List[Image.Image]:
def decode_latents(self, latents: torch.Tensor, height: int, width: int, output_type: Literal['pil', 'pt', 'np'] = 'pil') -> List[Image.Image]:
"""Decode latents to images using VAE."""

latents = self.pipeline._unpack_latents(latents, height, width, self.pipeline.vae_scale_factor)
Expand All @@ -480,7 +480,7 @@ def decode_latents(self, latents: torch.Tensor, height: int, width: int, **kwarg
)
latents = latents / latents_std + latents_mean
images = self.pipeline.vae.decode(latents, return_dict=False)[0][:, :, 0]
images = self.pipeline.image_processor.postprocess(images, output_type='pil')
images = self.pipeline.image_processor.postprocess(images, output_type=output_type)

return images

Expand Down Expand Up @@ -816,7 +816,7 @@ def _inference(
extra_call_back_res[key].append(val)

# 7. Post-process results
generated_images = self.decode_latents(latents, height, width)
generated_images = self.decode_latents(latents, height, width, output_type='pt')

# Transpose `extra_call_back_res` tensors to have batch dimension first
# (T, B, ...) -> (B, T, ...)
Expand Down
7 changes: 3 additions & 4 deletions src/flow_factory/models/stable_diffusion/sd3_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import annotations

import os
from typing import Union, List, Dict, Any, Optional, Tuple, ClassVar
from typing import Union, List, Dict, Any, Optional, Tuple, ClassVar, Literal
from dataclasses import dataclass
from collections import defaultdict

Expand Down Expand Up @@ -153,8 +153,7 @@ def encode_video(self, videos: Union[torch.Tensor, List[torch.Tensor]]):
def decode_latents(
self,
latents: torch.Tensor,
output_type: str = "pil",
**kwargs
output_type: Literal['pil', 'pt', 'np'] = "pil",
) -> torch.Tensor:
latents = latents.to(self.pipeline.vae.dtype)
latents = (latents / self.pipeline.vae.config.scaling_factor) + self.pipeline.vae.config.shift_factor
Expand Down Expand Up @@ -297,7 +296,7 @@ def inference(
extra_call_back_res[key].append(val)

# 7. Decode latents
images = self.decode_latents(latents=latents)
images = self.decode_latents(latents=latents, output_type='pt')

# 8. Create samples
all_latents = latent_collector.get_result()
Expand Down
2 changes: 1 addition & 1 deletion src/flow_factory/models/wan/wan2_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def inference(
# 7. Decode latents to videos (list of pil images)
if self.pipeline.config.expand_timesteps:
latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
decoded_videos = self.decode_latents(latents, output_type='pil')
decoded_videos = self.decode_latents(latents, output_type='pt')

# 8. Prepare output samples

Expand Down
2 changes: 1 addition & 1 deletion src/flow_factory/models/wan/wan2_t2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def inference(
self.pipeline._current_timestep = None

# 7. Decode latents to videos (list of pil images)
decoded_videos = self.decode_latents(latents, output_type='pil')
decoded_videos = self.decode_latents(latents, output_type='pt')

# 8. Prepare output samples

Expand Down
2 changes: 1 addition & 1 deletion src/flow_factory/models/wan/wan2_v2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def inference(
self._current_timestep = None

# 7. Decode latents to videos (list of pil images)
decoded_videos = self.decode_latents(latents, output_type='pil')
decoded_videos = self.decode_latents(latents, output_type='pt')

# 8. Prepare output samples

Expand Down
8 changes: 4 additions & 4 deletions src/flow_factory/models/z_image/z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import annotations

import os
from typing import Union, List, Dict, Any, Optional, Tuple, ClassVar
from typing import Union, List, Dict, Any, Optional, Tuple, ClassVar, Literal
from dataclasses import dataclass
from PIL import Image
from collections import defaultdict
Expand Down Expand Up @@ -177,13 +177,13 @@ def encode_video(
def decode_latents(
self,
latents: torch.Tensor,
**kwargs
output_type: Literal['pil', 'pt', 'np'] = 'pil',
) -> torch.Tensor:
latents = latents.to(self.pipeline.vae.dtype)
latents = (latents / self.pipeline.vae.config.scaling_factor) + self.pipeline.vae.config.shift_factor

images = self.pipeline.vae.decode(latents, return_dict=False)[0]
images = self.pipeline.image_processor.postprocess(images, output_type="pil")
images = self.pipeline.image_processor.postprocess(images, output_type=output_type)

return images

Expand Down Expand Up @@ -302,7 +302,7 @@ def inference(
extra_call_back_res[key].append(val)

# Decode latents to images
images = self.decode_latents(latents)
images = self.decode_latents(latents, output_type='pt')

# Create samples

Expand Down