def run()

in engine/training_engine.py [0:0]


    def run(self, train_sampler=None):
        if train_sampler is None and self.is_master_node:
            logger.error("Train sampler cannot be None")

        copy_at_epoch = getattr(self.opts, "ema.copy_at_epoch", -1)
        train_start_time = time.time()
        save_dir = getattr(self.opts, "common.exp_loc", "results")

        cfg_file = getattr(self.opts, "common.config_file", None)
        if cfg_file is not None and self.is_master_node:
            dst_cfg_file = "{}/config.yaml".format(save_dir)
            shutil.copy(src=cfg_file, dst=dst_cfg_file)
            logger.info('Configuration file is stored here: {}'.format(logger.color_text(dst_cfg_file)))

        keep_k_best_ckpts = getattr(self.opts, "common.k_best_checkpoints", 5)
        ema_best_metric = self.best_metric
        is_ema_best = False
        try:
            max_epochs = getattr(self.opts, "scheduler.max_epochs", DEFAULT_EPOCHS)
            for epoch in range(self.start_epoch, max_epochs):
                # Note that we are using our owm implementations of data samplers
                # and we have defined this function for both distributed and non-distributed cases
                train_sampler.set_epoch(epoch)
                train_sampler.update_scales(epoch=epoch, is_master_node=self.is_master_node)

                train_loss, train_ckpt_metric = self.train_epoch(epoch)

                val_loss, val_ckpt_metric = self.val_epoch(epoch=epoch, model=self.model)

                if epoch == copy_at_epoch and self.model_ema is not None:
                    if self.is_master_node:
                        logger.log('Copying EMA weights')
                    # copy model_src weights to model_tgt
                    self.model = copy_weights(model_tgt=self.model, model_src=self.model_ema)
                    if self.is_master_node:
                        logger.log('EMA weights copied')
                        logger.log('Running validation after Copying EMA model weights')
                    self.val_epoch(epoch=epoch, model=self.model)

                gc.collect()

                max_checkpoint_metric = getattr(self.opts, "stats.checkpoint_metric_max", False)
                if max_checkpoint_metric:
                    is_best = val_ckpt_metric >= self.best_metric
                    self.best_metric = max(val_ckpt_metric, self.best_metric)
                else:
                    is_best = val_ckpt_metric <= self.best_metric
                    self.best_metric = min(val_ckpt_metric, self.best_metric)

                val_ema_loss = None
                val_ema_ckpt_metric = None
                if self.model_ema is not None:
                    val_ema_loss, val_ema_ckpt_metric = self.val_epoch(
                        epoch=epoch,
                        model=self.model_ema.ema_model,
                        extra_str=" (EMA)"
                    )
                    if max_checkpoint_metric:
                        is_ema_best = val_ema_ckpt_metric >= ema_best_metric
                        ema_best_metric = max(val_ema_ckpt_metric, ema_best_metric)
                    else:
                        is_ema_best = val_ema_ckpt_metric <= ema_best_metric
                        ema_best_metric = min(val_ema_ckpt_metric, ema_best_metric)

                if self.is_master_node:
                    save_checkpoint(
                        iterations=self.train_iterations,
                        epoch=epoch,
                        model=self.model,
                        optimizer=self.optimizer,
                        best_metric=self.best_metric,
                        is_best=is_best,
                        save_dir=save_dir,
                        model_ema=self.model_ema,
                        is_ema_best=is_ema_best,
                        ema_best_metric=ema_best_metric,
                        gradient_scalar=self.gradient_scalar,
                        max_ckpt_metric=max_checkpoint_metric,
                        k_best_checkpoints=keep_k_best_ckpts
                    )
                    logger.info('Checkpoints saved at: {}'.format(save_dir), print_line=True)

                if self.tb_log_writter is not None and self.is_master_node:
                    lr_list = self.scheduler.retrieve_lr(self.optimizer)
                    for g_id, lr_val in enumerate(lr_list):
                        self.tb_log_writter.add_scalar('LR/Group-{}'.format(g_id), round(lr_val, 6), epoch)
                    self.tb_log_writter.add_scalar('Train/Loss', round(train_loss, 2), epoch)
                    self.tb_log_writter.add_scalar('Val/Loss', round(val_loss, 2), epoch)
                    self.tb_log_writter.add_scalar('Common/Best Metric', round(self.best_metric, 2), epoch)
                    if val_ema_loss is not None:
                        self.tb_log_writter.add_scalar('Val_EMA/Loss', round(val_ema_loss, 2), epoch)

                    # If val checkpoint metric is different from loss, add that too
                    if self.ckpt_metric != 'loss':
                        self.tb_log_writter.add_scalar('Train/{}'.format(self.ckpt_metric.title()),
                                                       round(train_ckpt_metric, 2), epoch)
                        self.tb_log_writter.add_scalar('Val/{}'.format(self.ckpt_metric.title()),
                                                       round(val_ckpt_metric, 2), epoch)
                        if val_ema_ckpt_metric is not None:
                            self.tb_log_writter.add_scalar('Val_EMA/{}'.format(self.ckpt_metric.title()),
                                                           round(val_ema_ckpt_metric, 2), epoch)

                if self.max_iterations_reached and self.is_master_node:
                    logger.info('Max. iterations for training reached')
                    break

        except KeyboardInterrupt:
            if self.is_master_node:
                logger.log('Keyboard interruption. Exiting from early training')
        except Exception as e:
            if self.is_master_node:
                if 'out of memory' in str(e):
                    logger.log('OOM exception occured')
                    n_gpus = getattr(self.opts, "dev.num_gpus", 1)
                    for dev_id in range(n_gpus):
                        mem_summary = torch.cuda.memory_summary(device=torch.device('cuda:{}'.format(dev_id)),
                                                                abbreviated=True)
                        logger.log('Memory summary for device id: {}'.format(dev_id))
                        print(mem_summary)
                else:
                    logger.log('Exception occurred that interrupted the training. {}'.format(str(e)))
                    print(e)
                    raise e
        finally:
            use_distributed = getattr(self.opts, "ddp.use_distributed", False)
            if use_distributed:
                torch.distributed.destroy_process_group()

            torch.cuda.empty_cache()

            if self.is_master_node and self.tb_log_writter is not None:
                self.tb_log_writter.close()

            if self.is_master_node:
                train_end_time = time.time()
                hours, rem = divmod(train_end_time - train_start_time, 3600)
                minutes, seconds = divmod(rem, 60)
                train_time_str = "{:0>2}:{:0>2}:{:05.2f}".format(int(hours), int(minutes), seconds)
                logger.log('Training took {}'.format(train_time_str))
            try:
                exit(0)
            except Exception as e:
                pass
            finally:
                pass