Skip to content

Commit c5ed5be

Browse files
committed
adding training setup for unified+instance unfrozen; need a code review
1 parent eb582e4 commit c5ed5be

8 files changed

Lines changed: 69 additions & 15 deletions

File tree

configs/train/train_scene_crossover.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ task:
7373
train : [Scannet, Scan3R, MultiScan, ARKitScenes]
7474
val : [Scannet, Scan3R, MultiScan, ARKitScenes]
7575
object_enc_ckpt : /drive/dumps/multimodal-spaces/runs/new_runs/instance_crossover_scannet+scan3r+multiscan+arkitscenes.pth
76-
76+
freeze_object_enc : False
77+
scene_level_grounding_eval:
78+
name: GroundingEval
79+
config_path: /home/sayan/Documents/code/multimodal-reality/CrossOver/configs/train/train_instance_crossover.yaml
80+
7781
trainer: UnifiedTrainer
7882

7983
model:

evaluator/grounding_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self, cfg: DictConfig, accelerator: Accelerator, **kwargs: Any) ->
1515
"""Initialize the grounding evaluator with configuration and accelerator."""
1616
self.task_name = cfg.task.name
1717

18-
if 'scene' in self.task_name.lower():
18+
if 'scene' or 'unified' in self.task_name.lower():
1919
self.eval_func = eval_utils.calculate_topK_err_batch
2020
elif 'object' in self.task_name.lower():
2121
self.eval_func = eval_utils.calculate_topK_err

evaluator/retrieval_eval.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from pathlib import Path
33
from itertools import combinations
44
from typing import Dict, Any
5-
from omegaconf import DictConfig
5+
from omegaconf import DictConfig, OmegaConf
66
from accelerate import Accelerator
77

88
from evaluator.build import EVALUATOR_REGISTRY, BaseEvaluator
99
from . import eval_utils
10+
import yaml
1011

1112
@EVALUATOR_REGISTRY.register()
1213
class RetrievalEval(BaseEvaluator):
@@ -31,6 +32,16 @@ def __init__(self, cfg: DictConfig, accelerator: Accelerator, **kwargs: Any) ->
3132
self.eval_dict[src_modality + '_' + ref_modality + '_err_top5'] = []
3233

3334
self.eval_dict['target_metric'] = []
35+
self.freeze_object_enc = self.cfg.task.get(self.cfg.task.name).freeze_object_enc
36+
37+
#HOW TO LOAD GROUNDING EVAL ELEGANTLY I DONT LIKE THIS? DO WE EVEN NEED INSTANCE EVAL?
38+
if not self.freeze_object_enc:
39+
# self.grounding_config_path = self.cfg.task.get(self.cfg.task.name).scene_level_grounding_eval.config_path
40+
# with open(self.grounding_config_path, 'r') as f:
41+
# grounding_config = yaml.safe_load(f)
42+
43+
# grounding_eval_cfg = OmegaConf.create(grounding_config)
44+
self.grounding_eval = EVALUATOR_REGISTRY.get('GroundingEval')(cfg, accelerator, **kwargs)
3445

3546
def batch_metrics(self, data_dict: Dict[str, Any]) -> Dict[str, float]:
3647
"""Calculate retrieval metrics for a batch of embeddings."""
@@ -60,4 +71,13 @@ def batch_metrics(self, data_dict: Dict[str, Any]) -> Dict[str, float]:
6071
metrics['err_top1'] = float(sum(all_top1_metric)) / len(all_top1_metric)
6172
metrics['err_top5'] = float(sum(all_top5_metric)) / len(all_top5_metric)
6273

74+
if not self.freeze_object_enc:
75+
instance_data_dict = {}
76+
instance_data_dict['embeddings'] = data_dict['object_modality_embeddings']
77+
instance_data_dict['masks'] = data_dict['masks']
78+
instance_metrics = self.grounding_eval.batch_metrics(instance_data_dict)
79+
metrics['instance_target_metric'] = instance_metrics['target_metric']
80+
metrics['instance_err_top1'] = instance_metrics['err_top1']
81+
metrics['instance_err_top3'] = instance_metrics['err_top3']
82+
6383
return metrics

model/build.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,8 @@
33
MODEL_REGISTRY = Registry("model")
44

55
def build_model(cfg):
6-
model = MODEL_REGISTRY.get(cfg.model.name)(cfg.model, cfg.task.get(cfg.task.name).modalities)
6+
if 'unified' in cfg.model.name.lower():
7+
model = MODEL_REGISTRY.get(cfg.model.name)(cfg.model, cfg.task.get(cfg.task.name).modalities, cfg.task.get(cfg.task.name).freeze_object_enc)
8+
else:
9+
model = MODEL_REGISTRY.get(cfg.model.name)(cfg.model, cfg.task.get(cfg.task.name).modalities)
710
return model

model/unified_enc.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010

1111
@MODEL_REGISTRY.register()
1212
class UnifiedEncoder(nn.Module):
13-
def __init__(self, args: DictConfig, modalities: List[str]) -> None:
13+
def __init__(self, args: DictConfig, modalities: List[str], freeze_object_enc: bool) -> None:
1414
super().__init__()
1515

1616
self.modalities = modalities
17+
self.freeze_object_enc = freeze_object_enc
1718
self.out_dim = args.out_dim
1819
self.objectwise_modality_encoder = SceneLevelEncoder(args, self.modalities)
1920

@@ -89,7 +90,10 @@ def get_opt_params(self, lr: float) -> List[torch.nn.Parameter]:
8990

9091
optimizer_grouped_parameters += self.fusion.parameters()
9192

92-
for param in self.objectwise_modality_encoder.parameters():
93-
param.requires_grad = False
93+
if self.freeze_object_enc:
94+
for param in self.objectwise_modality_encoder.parameters():
95+
param.requires_grad = False
96+
else:
97+
optimizer_grouped_parameters += self.objectwise_modality_encoder.parameters()
9498

9599
return optimizer_grouped_parameters

optim/build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def build_optim(cfg, params, total_steps):
77
scheduler = get_scheduler(cfg, optimizer, total_steps)
88

99
if 'retrieval' in cfg.model.loss.lower():
10-
loss = LOSS_REGISTRY.get(cfg.model.loss)()
10+
loss = LOSS_REGISTRY.get(cfg.model.loss)(cfg.task.get(cfg.task.name).freeze_object_enc)
1111
else:
1212
loss = LOSS_REGISTRY.get(cfg.model.loss)(cfg.model.base_modality)
1313

optim/loss/contrastive_loss.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111

1212
@LOSS_REGISTRY.register()
1313
class RetrievalLoss(nn.Module):
14-
def __init__(self):
14+
def __init__(self, freeze_object_encoders: bool = True):
1515
super(RetrievalLoss, self).__init__()
1616
self.logit_scale = nn.Parameter((torch.ones([]) * np.log(1 / 0.07)).exp())
17-
17+
self.freeze_object_enc = freeze_object_encoders
18+
if not self.freeze_object_enc:
19+
self.instance_loss = SceneWiseContrastiveLoss(base_modality='rgb')
20+
1821
def calculate_loss(self, src_embed: torch.tensor, ref_embed: torch.tensor, mask: torch.tensor=None) -> torch.tensor:
1922
logit_scale = torch.clamp(self.logit_scale, max=100)
2023

@@ -51,11 +54,25 @@ def forward(self, data_dict: Dict[str, Any]) -> torch.tensor:
5154
loss = self.calculate_loss(a_embed, b_embed, mask)
5255
loss_dict[f'loss_{modality_type}'] = loss
5356

54-
loss_dict['total_loss'] = sum(loss_dict.values())
55-
56-
assert not torch.any(torch.isnan(loss_dict['total_loss'])), 'Loss Coming NaN!!!'
57+
scene_loss = sum(loss_dict.values())
58+
loss_dict['scene_loss'] = scene_loss
5759

58-
return loss_dict['total_loss'], loss_dict
60+
assert not torch.any(torch.isnan(scene_loss)), 'Loss Coming NaN!!!'
61+
62+
if self.freeze_object_enc:
63+
total_loss = scene_loss
64+
loss_dict['total_loss'] = scene_loss
65+
return total_loss, loss_dict
66+
else:
67+
instance_data_dict ={}
68+
instance_data_dict['embeddings'] = data_dict['object_modality_embeddings']
69+
instance_data_dict['masks'] = data_dict['masks']
70+
instance_loss, instance_loss_dict = self.instance_loss(instance_data_dict)
71+
loss_dict['instance_loss']= instance_loss
72+
# loss_dict.update(instance_loss_dict)
73+
total_loss = scene_loss + instance_loss
74+
loss_dict['total_loss'] = total_loss
75+
return total_loss, loss_dict
5976

6077
class ContrastiveLoss(nn.Module):
6178
def __init__(self, base_modality: str):

trainer/unified_trainer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def __init__(self, cfg: DictConfig) -> None:
1616
super().__init__(cfg)
1717

1818
self.task_config = rgetattr(cfg.task, cfg.task.name)
19+
self.freeze_object_enc = self.task_config.freeze_object_enc
20+
1921

2022
# ckpt = '/drive/dumps/multimodal-spaces/runs/new_runs/scene_crossover_scannet+scan3r_scratch.pth'
2123
# self.logger.info(f"Loading Initial Weights from {ckpt}")
@@ -53,7 +55,11 @@ def train_step(self, epoch: int) -> None:
5355
loss, loss_dict = self.loss(data_dict)
5456
# calculate evaluator
5557
metrics = self.evaluator['train'].batch_metrics(data_dict)
56-
self.backward(loss)
58+
if self.freeze_object_enc:
59+
self.backward(loss)
60+
else:
61+
self.backward(loss_dict['scene_loss'])
62+
self.backward(loss_dict['instance_loss'])
5763

5864
self.global_step += 1
5965

0 commit comments

Comments
 (0)