in mdr/qa/qa_trainer.py [0:0]
def _train(self) -> Optional[float]:
job_env = submitit.JobEnvironment()
batch_step = 0 # forward batch count
best_metric = 0
train_loss_meter = AverageMeter()
print(f"Start training", flush=True)
# Start from the loaded epoch
start_epoch = self._state.epoch
global_step = self._state.global_step
for epoch in range(start_epoch, self._train_cfg.num_train_epochs):
print(f"Start epoch {epoch}", flush=True)
self._state.model.train()
self._state.epoch = epoch
for batch in self._train_loader:
batch_step += 1
batch_inputs = move_to_cuda(batch["net_inputs"])
loss = self._state.model(batch_inputs)
if torch.cuda.device_count() > 1:
loss = loss.mean()
if self._train_cfg.gradient_accumulation_steps > 1:
loss = loss / self._train_cfg.gradient_accumulation_steps
if self._train_cfg.fp16:
with amp.scale_loss(loss, self._state.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
train_loss_meter.update(loss.item())
if (batch_step + 1) % self._train_cfg.gradient_accumulation_steps == 0:
if self._train_cfg.fp16:
torch.nn.utils.clip_grad_norm_(
amp.master_params(self._state.optimizer), self._train_cfg.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(
self._state.model.parameters(), self._train_cfg.max_grad_norm)
self._state.optimizer.step()
self._state.lr_scheduler.step()
self._state.model.zero_grad()
global_step += 1
self._state.global_step = global_step
self.tb_logger.add_scalar('batch_train_loss',
loss.item(), global_step)
self.tb_logger.add_scalar('smoothed_train_loss',
train_loss_meter.avg, global_step)
if job_env.global_rank == 0:
if self._train_cfg.eval_period != -1 and global_step % self._train_cfg.eval_period == 0:
metrics = self._eval()
for k, v in metrics.items():
self.tb_logger.add_scalar(k, v*100, global_step)
score = metrics[self._train_cfg.final_metric]
if best_metric < score:
print("Saving model with best %s %.2f -> em %.2f" % (self._train_cfg.final_metric, best_metric*100, score*100), flush=True)
torch.save(self._state.model.state_dict(), os.path.join(self._train_cfg.output_dir, f"checkpoint_best.pt"))
best_metric = score
# Checkpoint only on the master
if job_env.global_rank == 0:
self.checkpoint(rm_init=False)
metrics = self._eval()
for k, v in metrics.items():
self.tb_logger.add_scalar(k, v*100, global_step)
score = metrics[self._train_cfg.final_metric]
if best_metric < score:
print("Saving model with best %s %.2f -> em %.2f" % (self._train_cfg.final_metric, best_metric*100, score*100), flush=True)
torch.save(self._state.model.state_dict(), os.path.join(self._train_cfg.output_dir, f"checkpoint_best.pt"))
best_metric = score
self.log({
"best_score": best_metric,
"curr_score": score,
"smoothed_loss": train_loss_meter.avg,
"epoch": epoch
})
return best_metric