@@ -1921,10 +1921,14 @@ def __init__(
19211921
19221922 self .ctrl_time_embeddings = []
19231923 self .ctrl_time_embedding_mix = []
1924- for controlnet in controlnets :
1924+ for index , controlnet in enumerate ( controlnets ) :
19251925 self .ctrl_time_embedding_mix .append (controlnet .get ("time_embedding_mix" ))
19261926 if controlnet .get ("ctrl_learn_time_embedding" ):
1927- self .ctrl_time_embeddings .append (TimestepEmbedding (in_channels = time_embed_input_dim , time_embed_dim = time_embed_dim ))
1927+ ctrl_time_embedding = TimestepEmbedding (
1928+ in_channels = time_embed_input_dim , time_embed_dim = time_embed_dim
1929+ )
1930+ self .ctrl_time_embeddings .append (ctrl_time_embedding )
1931+ self .add_module (f"ctrl_time_embedding_{ index } " , ctrl_time_embedding )
19281932 else :
19291933 self .ctrl_time_embeddings .append (None )
19301934
@@ -2194,19 +2198,19 @@ def forward(
21942198 # there might be better ways to encapsulate this.
21952199 t_emb = t_emb .to (dtype = sample .dtype )
21962200
2197- time_embedding_list = []
2201+ base_temb = self .base_time_embedding (t_emb , timestep_cond )
2202+ base_temb = base_temb + aug_emb if aug_emb is not None else base_temb
2203+
2204+ ctrl_temb_list = []
21982205 for index in range (len (self .ctrl_time_embeddings )):
2199- base_temb = self .base_time_embedding (t_emb , timestep_cond )
22002206 temb_module = self .ctrl_time_embeddings [index ]
2201- temb_module .to (t_emb .device , t_emb .dtype )
22022207 if temb_module is not None and apply_control :
22032208 ctrl_temb = temb_module (t_emb , timestep_cond )
2204- interpolation_param = self .ctrl_time_embedding_mix [index ]** 0.3
2205- temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param )
2209+ interpolation_param = self .ctrl_time_embedding_mix [index ] ** 0.3
2210+ ctrl_temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param )
22062211 else :
2207- temb = base_temb
2208- temb = temb + aug_emb if aug_emb is not None else temb
2209- time_embedding_list .append (temb )
2212+ ctrl_temb = base_temb
2213+ ctrl_temb_list .append (ctrl_temb )
22102214
22112215
22122216 # text embeddings
@@ -2220,7 +2224,8 @@ def forward(
22202224 h_base , h_ctrls , residual_hb , residual_hc = down (
22212225 hidden_states_base = h_base ,
22222226 hidden_states_ctrl = h_ctrls ,
2223- temb = temb ,
2227+ temb = base_temb ,
2228+ ctrl_temb_list = ctrl_temb_list ,
22242229 encoder_hidden_states = cemb ,
22252230 conditioning_scale = conditioning_scale ,
22262231 cross_attention_kwargs = cross_attention_kwargs ,
@@ -2236,7 +2241,8 @@ def forward(
22362241 h_base , h_ctrls = self .mid_block (
22372242 hidden_states_base = h_base ,
22382243 hidden_states_ctrl = h_ctrls ,
2239- temb = temb ,
2244+ temb = base_temb ,
2245+ ctrl_temb_list = ctrl_temb_list ,
22402246 encoder_hidden_states = cemb ,
22412247 conditioning_scale = conditioning_scale ,
22422248 cross_attention_kwargs = cross_attention_kwargs ,
@@ -2255,7 +2261,7 @@ def forward(
22552261 hidden_states = h_base ,
22562262 res_hidden_states_tuple_base = skips_hb ,
22572263 res_hidden_states_tuple_ctrl = skips_hc ,
2258- temb = temb ,
2264+ temb = base_temb ,
22592265 encoder_hidden_states = cemb ,
22602266 conditioning_scale = conditioning_scale ,
22612267 cross_attention_kwargs = cross_attention_kwargs ,
@@ -2527,8 +2533,8 @@ def freeze_base_params(self) -> None:
25272533 base_parts = [self .base_resnets ]
25282534 if isinstance (self .base_attentions , nn .ModuleList ): # attentions can be a list of Nones
25292535 base_parts .append (self .base_attentions )
2530- if self .base_downsamplers is not None :
2531- base_parts .append (self .base_downsamplers )
2536+ if self .base_downsampler is not None :
2537+ base_parts .append (self .base_downsampler )
25322538 for part in base_parts :
25332539 for param in part .parameters ():
25342540 param .requires_grad = False
@@ -2539,6 +2545,7 @@ def forward(
25392545 temb : Tensor ,
25402546 encoder_hidden_states : Optional [Tensor ] = None , # text embedding
25412547 hidden_states_ctrl : Optional [List [Tensor ]] = None , # ctrl embedding
2548+ ctrl_temb_list : Optional [List [Tensor ]] = None ,
25422549 conditioning_scale : Optional [List [float ]] = None ,
25432550 attention_mask : Optional [Tensor ] = None ,
25442551 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -2596,10 +2603,11 @@ def forward(
25962603 if apply_control :
25972604 for index in range (len (h_ctrl_list )):
25982605 c_res , c_attn = ctrl_blocks_list [index ][layer_index ]
2606+ ctrl_temb = ctrl_temb_list [index ] if ctrl_temb_list is not None else temb
25992607 if torch .is_grad_enabled () and self .gradient_checkpointing :
2600- h_ctrl_list [index ] = self ._gradient_checkpointing_func (c_res , h_ctrl_list [index ], temb )
2608+ h_ctrl_list [index ] = self ._gradient_checkpointing_func (c_res , h_ctrl_list [index ], ctrl_temb )
26012609 else :
2602- h_ctrl_list [index ] = c_res (h_ctrl_list [index ], temb )
2610+ h_ctrl_list [index ] = c_res (h_ctrl_list [index ], ctrl_temb )
26032611 if c_attn is not None :
26042612 h_ctrl_list [index ] = c_attn (
26052613 h_ctrl_list [index ],
@@ -2665,9 +2673,9 @@ def __init__(
26652673
26662674 # Before the midblock application, information is concatted from base to control.
26672675 # Concat doesn't require change in number of channels
2668- self .base_to_ctrl = []
2669- for i in range (len (ctrl_channels_list )):
2670- self . base_to_ctrl . append ( make_zero_conv ( base_channels , base_channels ) )
2676+ self .base_to_ctrl = nn . ModuleList (
2677+ [ make_zero_conv ( base_channels , base_channels ) for _ in range (len (ctrl_channels_list ))]
2678+ )
26712679
26722680 self .base_midblock = UNetMidBlock2DCrossAttn (
26732681 transformer_layers_per_block = transformer_layers_per_block ,
@@ -2680,7 +2688,7 @@ def __init__(
26802688 upcast_attention = upcast_attention ,
26812689 )
26822690
2683- self . ctrl_midblocks = []
2691+ ctrl_midblocks = []
26842692 for i in range (len (ctrl_channels_list )):
26852693 ctrl_channels = ctrl_channels_list [i ]
26862694 ctrl_max_norm_num_groups = ctrl_max_norm_num_groups_list [i ]
@@ -2701,14 +2709,15 @@ def __init__(
27012709 upcast_attention = upcast_attention ,
27022710 )
27032711
2704- self .ctrl_midblocks .append (ctrl_midblock )
2712+ ctrl_midblocks .append (ctrl_midblock )
2713+
2714+ self .ctrl_midblocks = nn .ModuleList (ctrl_midblocks )
27052715
27062716 # After the midblock application, information is added from control to base
27072717 # Addition requires change in number of channels
2708- self .ctrl_to_base = []
2709- for i in range (len (ctrl_channels_list )):
2710- ctrl_channels = ctrl_channels_list [i ]
2711- self .ctrl_to_base .append (make_zero_conv (ctrl_channels , base_channels ))
2718+ self .ctrl_to_base = nn .ModuleList (
2719+ [make_zero_conv (ctrl_channels_list [i ], base_channels ) for i in range (len (ctrl_channels_list ))]
2720+ )
27122721
27132722 self .gradient_checkpointing = False
27142723
@@ -2787,6 +2796,7 @@ def forward(
27872796 encoder_hidden_states : Tensor ,
27882797 hidden_states_ctrl : List [Tensor ],
27892798 conditioning_scale : List [float ],
2799+ ctrl_temb_list : Optional [List [Tensor ]] = None ,
27902800 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
27912801 attention_mask : Optional [Tensor ] = None ,
27922802 encoder_attention_mask : Optional [Tensor ] = None ,
@@ -2810,15 +2820,15 @@ def forward(
28102820
28112821 if apply_control :
28122822 for i in range (len (h_ctrl_list )):
2813- b2c = self .base_to_ctrl [i ].to (device = "cuda" , dtype = h_base .dtype )
2814- h_ctrl_list [i ] = torch .cat ([h_ctrl_list [i ], b2c (h_base )], dim = 1 )
2823+ h_ctrl_list [i ] = torch .cat ([h_ctrl_list [i ], self .base_to_ctrl [i ](h_base )], dim = 1 )
28152824 h_base = self .base_midblock (h_base , ** joint_args )
28162825 if apply_control :
28172826 for i in range (len (h_ctrl_list )):
2818- mid = self .ctrl_midblocks [i ].to (device = "cuda" , dtype = h_ctrl_list [i ].dtype )
2819- h_ctrl_list [i ] = mid (h_ctrl_list [i ], ** joint_args ) # apply ctrl mid block
2820- c2b = self .ctrl_to_base [i ].to (device = "cuda" , dtype = h_ctrl_list [i ].dtype )
2821- h_base = h_base + c2b (h_ctrl_list [i ]) * conditioning_scale [i ] # add ctrl -> base
2827+ ctrl_joint_args = dict (joint_args )
2828+ if ctrl_temb_list is not None :
2829+ ctrl_joint_args ["temb" ] = ctrl_temb_list [i ]
2830+ h_ctrl_list [i ] = self .ctrl_midblocks [i ](h_ctrl_list [i ], ** ctrl_joint_args ) # apply ctrl mid block
2831+ h_base = h_base + self .ctrl_to_base [i ](h_ctrl_list [i ]) * conditioning_scale [i ] # add ctrl -> base
28222832
28232833 return h_base , h_ctrl_list
28242834
0 commit comments