@@ -1025,10 +1025,6 @@ class Anima(supported_models_base.BASE):
10251025
10261026 supported_inference_dtypes = [torch .bfloat16 , torch .float16 , torch .float32 ]
10271027
1028- def __init__ (self , unet_config ):
1029- super ().__init__ (unet_config )
1030- self .memory_usage_factor = (unet_config .get ("model_channels" , 2048 ) / 2048 ) * 0.95
1031-
10321028 def get_model (self , state_dict , prefix = "" , device = None ):
10331029 out = model_base .Anima (self , device = device )
10341030 return out
@@ -1038,6 +1034,12 @@ def clip_target(self, state_dict={}):
10381034 detect = comfy .text_encoders .hunyuan_video .llama_detect (state_dict , "{}qwen3_06b.transformer." .format (pref ))
10391035 return supported_models_base .ClipTarget (comfy .text_encoders .anima .AnimaTokenizer , comfy .text_encoders .anima .te (** detect ))
10401036
1037+ def set_inference_dtype (self , dtype , manual_cast_dtype , ** kwargs ):
1038+ self .memory_usage_factor = (self .unet_config .get ("model_channels" , 2048 ) / 2048 ) * 0.95
1039+ if dtype is torch .float16 :
1040+ self .memory_usage_factor *= 1.4
1041+ return super ().set_inference_dtype (dtype , manual_cast_dtype , ** kwargs )
1042+
10411043class CosmosI2VPredict2 (CosmosT2IPredict2 ):
10421044 unet_config = {
10431045 "image_model" : "cosmos_predict2" ,
0 commit comments