Skip to content

Commit 039955c

Browse files
Some fixes to previous pr. (Comfy-Org#12339)
1 parent 6a26328 commit 039955c

2 files changed

Lines changed: 7 additions & 5 deletions

File tree

comfy/ldm/cosmos/predict2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,6 @@ def _forward(
894894
**block_kwargs,
895895
)
896896

897-
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
897+
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
898898
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
899899
return x_B_C_Tt_Hp_Wp

comfy/supported_models.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,10 +1025,6 @@ class Anima(supported_models_base.BASE):
10251025

10261026
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
10271027

1028-
def __init__(self, unet_config):
1029-
super().__init__(unet_config)
1030-
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
1031-
10321028
def get_model(self, state_dict, prefix="", device=None):
10331029
out = model_base.Anima(self, device=device)
10341030
return out
@@ -1038,6 +1034,12 @@ def clip_target(self, state_dict={}):
10381034
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref))
10391035
return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect))
10401036

1037+
def set_inference_dtype(self, dtype, manual_cast_dtype, **kwargs):
1038+
self.memory_usage_factor = (self.unet_config.get("model_channels", 2048) / 2048) * 0.95
1039+
if dtype is torch.float16:
1040+
self.memory_usage_factor *= 1.4
1041+
return super().set_inference_dtype(dtype, manual_cast_dtype, **kwargs)
1042+
10411043
class CosmosI2VPredict2(CosmosT2IPredict2):
10421044
unet_config = {
10431045
"image_model": "cosmos_predict2",

0 commit comments

Comments
 (0)