def do_train()

in d2go/runner/default_runner.py [0:0]


    def do_train(self, cfg, model, resume):
        # Note that flops at the beginning of training is often inaccurate,
        # if a model has input-dependent logic
        add_flop_printing_hook(model, cfg.OUTPUT_DIR)

        optimizer = self.build_optimizer(cfg, model)
        scheduler = self.build_lr_scheduler(cfg, optimizer)

        checkpointer = self.build_checkpointer(
            cfg,
            model,
            save_dir=cfg.OUTPUT_DIR,
            optimizer=optimizer,
            scheduler=scheduler,
        )
        checkpoint = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume)
        start_iter = (
            checkpoint.get("iteration", -1)
            if resume and checkpointer.has_checkpoint()
            else -1
        )
        # The checkpoint stores the training iteration that just finished, thus we start
        # at the next iteration (or iter zero if there's no checkpoint).
        start_iter += 1
        max_iter = cfg.SOLVER.MAX_ITER
        periodic_checkpointer = PeriodicCheckpointer(
            checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
        )

        data_loader = self.build_detection_train_loader(cfg)

        def _get_model_with_abnormal_checker(model):
            if not cfg.ABNORMAL_CHECKER.ENABLED:
                return model

            tbx_writer = self.get_tbx_writer(cfg)
            writers = abnormal_checker.get_writers(cfg, tbx_writer)
            checker = abnormal_checker.AbnormalLossChecker(start_iter, writers)
            ret = abnormal_checker.AbnormalLossCheckerWrapper(model, checker)
            return ret

        trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
            _get_model_with_abnormal_checker(model), data_loader, optimizer
        )
        trainer_hooks = [
            hooks.IterationTimer(),
            model_ema.EMAHook(cfg, model) if cfg.MODEL_EMA.ENABLED else None,
            self._create_data_loader_hook(cfg),
            self._create_after_step_hook(
                cfg, model, optimizer, scheduler, periodic_checkpointer
            ),
            hooks.EvalHook(
                cfg.TEST.EVAL_PERIOD,
                lambda: self.do_test(cfg, model, train_iter=trainer.iter),
            ),
            kmeans_anchors.compute_kmeans_anchors_hook(self, cfg),
            self._create_qat_hook(cfg) if cfg.QUANTIZATION.QAT.ENABLED else None,
        ]

        if comm.is_main_process():
            tbx_writer = self.get_tbx_writer(cfg)
            writers = [
                CommonMetricPrinter(max_iter),
                JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
                tbx_writer,
            ]
            trainer_hooks.append(hooks.PeriodicWriter(writers))
        update_hooks_from_registry(trainer_hooks)
        trainer.register_hooks(trainer_hooks)
        trainer.train(start_iter, max_iter)

        if hasattr(self, "original_cfg"):
            table = get_cfg_diff_table(cfg, self.original_cfg)
            logger.info(
                "GeneralizeRCNN Runner ignoring training config change: \n" + table
            )
            trained_cfg = self.original_cfg.clone()
        else:
            trained_cfg = cfg.clone()
        with temp_defrost(trained_cfg):
            trained_cfg.MODEL.WEIGHTS = checkpointer.get_checkpoint_file()
        return {"model_final": trained_cfg}