-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_net.py
More file actions
executable file
·244 lines (213 loc) · 8.92 KB
/
train_net.py
File metadata and controls
executable file
·244 lines (213 loc) · 8.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# TRY MY BEST TO SEED
import random
import numpy as np
import torch
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
# torch.autograd.set_detect_anomaly(True)
# torch.use_deterministic_algorithms(True)
# torch.backends.cudnn.deterministic = True
import copy
import itertools
import logging
import os
from collections import OrderedDict
from typing import Any, Dict, List, Set
import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.data import build_detection_train_loader, build_detection_test_loader
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
from detectron2.engine import hooks
from detectron2.evaluation import (
DatasetEvaluators,
verify_results,
)
from detectron2.projects.deeplab import build_lr_scheduler
from detectron2.solver.build import maybe_add_gradient_clipping
from evaluator.evaluator import TrajectoryEvaluator
from dataloader.dataset_mappers.dataset_mapper import DatasetMapper
from model import setup
from common.envs import *
class Trainer(DefaultTrainer):
def __init__(self, cfg):
super().__init__(cfg)
self.eval_period = cfg.TEST.EVAL_PERIOD
self.training_evaluators = []
for e in self.training_evaluators:
e.reset()
def run_step(self):
self._trainer.iter = self.iter
self._trainer.run_step()
if self._trainer.iter % self.eval_period == 0:
for e in self.training_evaluators:
e.evaluate()
for e in self.training_evaluators:
e.reset()
for e in self.training_evaluators:
e.process(self._trainer.model.training_input, self._trainer.model.training_inference)
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
evaluator_list = []
if os.getenv('DEPLOY_SEQS') is not None:
evaluator_list.append(TrajectoryEvaluator(dataset_name, distributed=True, output_dir=output_folder))
return DatasetEvaluators(evaluator_list)
@classmethod
def build_train_loader(cls, cfg):
mapper = DatasetMapper(cfg, is_train=True)
return build_detection_train_loader(cfg, mapper=mapper)
@classmethod
def build_test_loader(cls, cfg, dataset_name):
mapper = DatasetMapper(cfg, is_train=False)
# mapper = TrajectoryPredictionDatasetMapper(cfg, is_train=False)
return build_detection_test_loader(cfg, dataset_name, mapper=mapper)
@classmethod
def build_lr_scheduler(cls, cfg, optimizer):
"""
It now calls :func:`detectron2.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""
# return torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9999839709739701) # ** 3200 = 0.95
return build_lr_scheduler(cfg, optimizer)
@classmethod
def build_optimizer(cls, cfg, model):
# return torch.optim.Adam(model.parameters(), lr=cfg.SOLVER.BASE_LR*cfg.SOLVER.BACKBONE_MULTIPLIER)
weight_decay_norm = cfg.SOLVER.WEIGHT_DECAY_NORM
weight_decay_embed = cfg.SOLVER.WEIGHT_DECAY_EMBED
defaults = {}
defaults["lr"] = cfg.SOLVER.BASE_LR
defaults["weight_decay"] = cfg.SOLVER.WEIGHT_DECAY
norm_module_types = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
)
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
hyperparams = copy.copy(defaults)
if "backbone" in module_name:
hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER
if "sidecar_bone" in module_name:
hyperparams["lr"] = 0.00001
if "interfere_bone" in module_name:
hyperparams["lr"] = 0.00001
if (
"relative_position_bias_table" in module_param_name
or "absolute_pos_embed" in module_param_name
):
print(module_param_name)
hyperparams["weight_decay"] = 0.0
if isinstance(module, norm_module_types):
hyperparams["weight_decay"] = weight_decay_norm
if isinstance(module, torch.nn.Embedding):
hyperparams["weight_decay"] = weight_decay_embed
params.append({"params": [value], **hyperparams})
def maybe_add_full_model_gradient_clipping(optim):
# detectron2 doesn't have full model gradient clipping now
clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
enable = (
cfg.SOLVER.CLIP_GRADIENTS.ENABLED
and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
and clip_norm_val > 0.0
)
class FullModelGradientClippingOptimizer(optim):
def step(self, closure=None):
all_params = itertools.chain(*[x["params"] for x in self.param_groups])
torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
super().step(closure=closure)
return FullModelGradientClippingOptimizer if enable else optim
optimizer_type = cfg.SOLVER.OPTIMIZER
if optimizer_type == "SGD":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM
)
elif optimizer_type == "ADAMW":
optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
params, cfg.SOLVER.BASE_LR
)
else:
raise NotImplementedError(f"no optimizer type {optimizer_type}")
if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
optimizer = maybe_add_gradient_clipping(cfg, optimizer)
return optimizer
@classmethod
def test(cls, cfg, model, evaluators=None):
if not cfg.SOLVER.AMP.ENABLED:
results = super(Trainer, cls).test(cfg, model, evaluators)
else:
with torch.cuda.amp.autocast(enabled=True):
results = super(Trainer, cls).test(cfg, model, evaluators)
return results
@classmethod
def test_with_TTA(cls, cfg, model):
logger = logging.getLogger("detectron2.trainer")
# In the end of training, run an evaluation with TTA.
logger.info("Running inference with test-time augmentation ...")
model = ModelWithTTA(cfg, model)
evaluators = [
cls.build_evaluator(
cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
)
for name in cfg.DATASETS.TEST
]
if not cfg.SOLVER.AMP.ENABLED:
res = cls.test(cfg, model, evaluators)
else:
with torch.cuda.amp.autocast(enabled=True):
res = cls.test(cfg, model, evaluators)
res = OrderedDict({k + "_TTA": v for k, v in res.items()})
return res
def build_hooks(self):
ret = super(Trainer, self).build_hooks()
cfg = self.cfg.clone()
if comm.is_main_process():
ret.append(hooks.BestCheckpointer(cfg.TEST.EVAL_PERIOD, self.checkpointer, "metrics"))
return ret
def main(args):
cfg = setup(args)
if args.eval_only:
model = Trainer.build_model(cfg)
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
res = Trainer.test(cfg, model)
if cfg.TEST.AUG.ENABLED:
res.update(Trainer.test_with_TTA(cfg, model))
if comm.is_main_process():
verify_results(cfg, res)
return res
trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)