in engine/training_engine.py [0:0]
def run(self, train_sampler=None):
if train_sampler is None and self.is_master_node:
logger.error("Train sampler cannot be None")
copy_at_epoch = getattr(self.opts, "ema.copy_at_epoch", -1)
train_start_time = time.time()
save_dir = getattr(self.opts, "common.exp_loc", "results")
cfg_file = getattr(self.opts, "common.config_file", None)
if cfg_file is not None and self.is_master_node:
dst_cfg_file = "{}/config.yaml".format(save_dir)
shutil.copy(src=cfg_file, dst=dst_cfg_file)
logger.info('Configuration file is stored here: {}'.format(logger.color_text(dst_cfg_file)))
keep_k_best_ckpts = getattr(self.opts, "common.k_best_checkpoints", 5)
ema_best_metric = self.best_metric
is_ema_best = False
try:
max_epochs = getattr(self.opts, "scheduler.max_epochs", DEFAULT_EPOCHS)
for epoch in range(self.start_epoch, max_epochs):
# Note that we are using our owm implementations of data samplers
# and we have defined this function for both distributed and non-distributed cases
train_sampler.set_epoch(epoch)
train_sampler.update_scales(epoch=epoch, is_master_node=self.is_master_node)
train_loss, train_ckpt_metric = self.train_epoch(epoch)
val_loss, val_ckpt_metric = self.val_epoch(epoch=epoch, model=self.model)
if epoch == copy_at_epoch and self.model_ema is not None:
if self.is_master_node:
logger.log('Copying EMA weights')
# copy model_src weights to model_tgt
self.model = copy_weights(model_tgt=self.model, model_src=self.model_ema)
if self.is_master_node:
logger.log('EMA weights copied')
logger.log('Running validation after Copying EMA model weights')
self.val_epoch(epoch=epoch, model=self.model)
gc.collect()
max_checkpoint_metric = getattr(self.opts, "stats.checkpoint_metric_max", False)
if max_checkpoint_metric:
is_best = val_ckpt_metric >= self.best_metric
self.best_metric = max(val_ckpt_metric, self.best_metric)
else:
is_best = val_ckpt_metric <= self.best_metric
self.best_metric = min(val_ckpt_metric, self.best_metric)
val_ema_loss = None
val_ema_ckpt_metric = None
if self.model_ema is not None:
val_ema_loss, val_ema_ckpt_metric = self.val_epoch(
epoch=epoch,
model=self.model_ema.ema_model,
extra_str=" (EMA)"
)
if max_checkpoint_metric:
is_ema_best = val_ema_ckpt_metric >= ema_best_metric
ema_best_metric = max(val_ema_ckpt_metric, ema_best_metric)
else:
is_ema_best = val_ema_ckpt_metric <= ema_best_metric
ema_best_metric = min(val_ema_ckpt_metric, ema_best_metric)
if self.is_master_node:
save_checkpoint(
iterations=self.train_iterations,
epoch=epoch,
model=self.model,
optimizer=self.optimizer,
best_metric=self.best_metric,
is_best=is_best,
save_dir=save_dir,
model_ema=self.model_ema,
is_ema_best=is_ema_best,
ema_best_metric=ema_best_metric,
gradient_scalar=self.gradient_scalar,
max_ckpt_metric=max_checkpoint_metric,
k_best_checkpoints=keep_k_best_ckpts
)
logger.info('Checkpoints saved at: {}'.format(save_dir), print_line=True)
if self.tb_log_writter is not None and self.is_master_node:
lr_list = self.scheduler.retrieve_lr(self.optimizer)
for g_id, lr_val in enumerate(lr_list):
self.tb_log_writter.add_scalar('LR/Group-{}'.format(g_id), round(lr_val, 6), epoch)
self.tb_log_writter.add_scalar('Train/Loss', round(train_loss, 2), epoch)
self.tb_log_writter.add_scalar('Val/Loss', round(val_loss, 2), epoch)
self.tb_log_writter.add_scalar('Common/Best Metric', round(self.best_metric, 2), epoch)
if val_ema_loss is not None:
self.tb_log_writter.add_scalar('Val_EMA/Loss', round(val_ema_loss, 2), epoch)
# If val checkpoint metric is different from loss, add that too
if self.ckpt_metric != 'loss':
self.tb_log_writter.add_scalar('Train/{}'.format(self.ckpt_metric.title()),
round(train_ckpt_metric, 2), epoch)
self.tb_log_writter.add_scalar('Val/{}'.format(self.ckpt_metric.title()),
round(val_ckpt_metric, 2), epoch)
if val_ema_ckpt_metric is not None:
self.tb_log_writter.add_scalar('Val_EMA/{}'.format(self.ckpt_metric.title()),
round(val_ema_ckpt_metric, 2), epoch)
if self.max_iterations_reached and self.is_master_node:
logger.info('Max. iterations for training reached')
break
except KeyboardInterrupt:
if self.is_master_node:
logger.log('Keyboard interruption. Exiting from early training')
except Exception as e:
if self.is_master_node:
if 'out of memory' in str(e):
logger.log('OOM exception occured')
n_gpus = getattr(self.opts, "dev.num_gpus", 1)
for dev_id in range(n_gpus):
mem_summary = torch.cuda.memory_summary(device=torch.device('cuda:{}'.format(dev_id)),
abbreviated=True)
logger.log('Memory summary for device id: {}'.format(dev_id))
print(mem_summary)
else:
logger.log('Exception occurred that interrupted the training. {}'.format(str(e)))
print(e)
raise e
finally:
use_distributed = getattr(self.opts, "ddp.use_distributed", False)
if use_distributed:
torch.distributed.destroy_process_group()
torch.cuda.empty_cache()
if self.is_master_node and self.tb_log_writter is not None:
self.tb_log_writter.close()
if self.is_master_node:
train_end_time = time.time()
hours, rem = divmod(train_end_time - train_start_time, 3600)
minutes, seconds = divmod(rem, 60)
train_time_str = "{:0>2}:{:0>2}:{:05.2f}".format(int(hours), int(minutes), seconds)
logger.log('Training took {}'.format(train_time_str))
try:
exit(0)
except Exception as e:
pass
finally:
pass