22from pathlib import Path
33from itertools import combinations
44from typing import Dict , Any
5- from omegaconf import DictConfig
5+ from omegaconf import DictConfig , OmegaConf
66from accelerate import Accelerator
77
88from evaluator .build import EVALUATOR_REGISTRY , BaseEvaluator
99from . import eval_utils
10+ import yaml
1011
1112@EVALUATOR_REGISTRY .register ()
1213class RetrievalEval (BaseEvaluator ):
@@ -31,6 +32,16 @@ def __init__(self, cfg: DictConfig, accelerator: Accelerator, **kwargs: Any) ->
3132 self .eval_dict [src_modality + '_' + ref_modality + '_err_top5' ] = []
3233
3334 self .eval_dict ['target_metric' ] = []
35+ self .freeze_object_enc = self .cfg .task .get (self .cfg .task .name ).freeze_object_enc
36+
37+ #HOW TO LOAD GROUNDING EVAL ELEGANTLY I DONT LIKE THIS? DO WE EVEN NEED INSTANCE EVAL?
38+ if not self .freeze_object_enc :
39+ # self.grounding_config_path = self.cfg.task.get(self.cfg.task.name).scene_level_grounding_eval.config_path
40+ # with open(self.grounding_config_path, 'r') as f:
41+ # grounding_config = yaml.safe_load(f)
42+
43+ # grounding_eval_cfg = OmegaConf.create(grounding_config)
44+ self .grounding_eval = EVALUATOR_REGISTRY .get ('GroundingEval' )(cfg , accelerator , ** kwargs )
3445
3546 def batch_metrics (self , data_dict : Dict [str , Any ]) -> Dict [str , float ]:
3647 """Calculate retrieval metrics for a batch of embeddings."""
@@ -60,4 +71,13 @@ def batch_metrics(self, data_dict: Dict[str, Any]) -> Dict[str, float]:
6071 metrics ['err_top1' ] = float (sum (all_top1_metric )) / len (all_top1_metric )
6172 metrics ['err_top5' ] = float (sum (all_top5_metric )) / len (all_top5_metric )
6273
74+ if not self .freeze_object_enc :
75+ instance_data_dict = {}
76+ instance_data_dict ['embeddings' ] = data_dict ['object_modality_embeddings' ]
77+ instance_data_dict ['masks' ] = data_dict ['masks' ]
78+ instance_metrics = self .grounding_eval .batch_metrics (instance_data_dict )
79+ metrics ['instance_target_metric' ] = instance_metrics ['target_metric' ]
80+ metrics ['instance_err_top1' ] = instance_metrics ['err_top1' ]
81+ metrics ['instance_err_top3' ] = instance_metrics ['err_top3' ]
82+
6383 return metrics
0 commit comments