in d2go/runner/default_runner.py [0:0]
def do_train(self, cfg, model, resume):
# Note that flops at the beginning of training is often inaccurate,
# if a model has input-dependent logic
add_flop_printing_hook(model, cfg.OUTPUT_DIR)
optimizer = self.build_optimizer(cfg, model)
scheduler = self.build_lr_scheduler(cfg, optimizer)
checkpointer = self.build_checkpointer(
cfg,
model,
save_dir=cfg.OUTPUT_DIR,
optimizer=optimizer,
scheduler=scheduler,
)
checkpoint = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume)
start_iter = (
checkpoint.get("iteration", -1)
if resume and checkpointer.has_checkpoint()
else -1
)
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration (or iter zero if there's no checkpoint).
start_iter += 1
max_iter = cfg.SOLVER.MAX_ITER
periodic_checkpointer = PeriodicCheckpointer(
checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
)
data_loader = self.build_detection_train_loader(cfg)
def _get_model_with_abnormal_checker(model):
if not cfg.ABNORMAL_CHECKER.ENABLED:
return model
tbx_writer = self.get_tbx_writer(cfg)
writers = abnormal_checker.get_writers(cfg, tbx_writer)
checker = abnormal_checker.AbnormalLossChecker(start_iter, writers)
ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker)
return ret
trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
_get_model_with_abnormal_checker(model), data_loader, optimizer
)
trainer_hooks = [
hooks.IterationTimer(),
model_ema.EMAHook(cfg, model) if cfg.MODEL_EMA.ENABLED else None,
self._create_data_loader_hook(cfg),
self._create_after_step_hook(
cfg, model, optimizer, scheduler, periodic_checkpointer
),
hooks.EvalHook(
cfg.TEST.EVAL_PERIOD,
lambda: self.do_test(cfg, model, train_iter=trainer.iter),
),
kmeans_anchors.compute_kmeans_anchors_hook(self, cfg),
self._create_qat_hook(cfg) if cfg.QUANTIZATION.QAT.ENABLED else None,
]
if comm.is_main_process():
tbx_writer = self.get_tbx_writer(cfg)
writers = [
CommonMetricPrinter(max_iter),
JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
tbx_writer,
]
trainer_hooks.append(hooks.PeriodicWriter(writers))
update_hooks_from_registry(trainer_hooks)
trainer.register_hooks(trainer_hooks)
trainer.train(start_iter, max_iter)
if hasattr(self, "original_cfg"):
table = get_cfg_diff_table(cfg, self.original_cfg)
logger.info(
"GeneralizeRCNN Runner ignoring training config change: \n" + table
)
trained_cfg = self.original_cfg.clone()
else:
trained_cfg = cfg.clone()
with temp_defrost(trained_cfg):
trained_cfg.MODEL.WEIGHTS = checkpointer.get_checkpoint_file()
return {"model_final": trained_cfg}