Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions invokeai/backend/model_manager/load/model_loaders/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,43 @@ def _load_model(
if is_bfl_format:
sd = self._convert_flux2_vae_bfl_to_diffusers(sd)

# FLUX.2 VAE configuration (32 latent channels)
# Based on the official FLUX.2 VAE architecture
# Use default config - AutoencoderKLFlux2 has built-in defaults
# FLUX.2 VAE configuration (32 latent channels).
# The standard FLUX.2 VAE uses block_out_channels=(128,256,512,512) for both
# encoder and decoder. The "small decoder" variant from
# black-forest-labs/FLUX.2-small-decoder keeps the full encoder but uses a
# narrower decoder with channels (96,192,384,384). AutoencoderKLFlux2 only
# exposes a single block_out_channels, so we build the model with the
# encoder's channels and, if the decoder differs, replace just the decoder
# submodule with a matching one before loading the state dict.
encoder_block_out_channels = (128, 256, 512, 512)
decoder_block_out_channels = encoder_block_out_channels
if "encoder.conv_in.weight" in sd and "encoder.conv_norm_out.weight" in sd:
enc_last = int(sd["encoder.conv_norm_out.weight"].shape[0])
enc_first = int(sd["encoder.conv_in.weight"].shape[0])
encoder_block_out_channels = (enc_first, enc_first * 2, enc_last, enc_last)
if "decoder.conv_in.weight" in sd and "decoder.conv_norm_out.weight" in sd:
dec_last = int(sd["decoder.conv_in.weight"].shape[0])
dec_first = int(sd["decoder.conv_norm_out.weight"].shape[0])
decoder_block_out_channels = (dec_first, dec_first * 2, dec_last, dec_last)

with SilenceWarnings():
with accelerate.init_empty_weights():
model = AutoencoderKLFlux2()
model = AutoencoderKLFlux2(block_out_channels=encoder_block_out_channels)
if decoder_block_out_channels != encoder_block_out_channels:
# Rebuild the decoder with the smaller channel widths.
from diffusers.models.autoencoders.vae import Decoder

cfg = model.config
model.decoder = Decoder(
in_channels=cfg.latent_channels,
out_channels=cfg.out_channels,
up_block_types=cfg.up_block_types,
block_out_channels=decoder_block_out_channels,
layers_per_block=cfg.layers_per_block,
norm_num_groups=cfg.norm_num_groups,
act_fn=cfg.act_fn,
mid_block_add_attention=cfg.mid_block_add_attention,
)

# Convert to bfloat16 and load
for k in sd.keys():
Expand Down
Loading