in training/trainer.py [0:0]
def train_epoch(self, train_loader):
# Init stat meters
batch_time_meter = AverageMeter("Batch Time", self.device, ":.2f")
data_time_meter = AverageMeter("Data Time", self.device, ":.2f")
mem_meter = MemMeter("Mem (GB)", self.device, ":.2f")
data_times = []
phase = Phase.TRAIN
iters_per_epoch = len(train_loader)
loss_names = []
for batch_key in self.loss.keys():
loss_names.append(f"Losses/{phase}_{batch_key}_loss")
loss_mts = OrderedDict(
[(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names]
)
extra_loss_mts = {}
progress = ProgressMeter(
iters_per_epoch,
[
batch_time_meter,
data_time_meter,
mem_meter,
self.time_elapsed_meter,
*loss_mts.values(),
],
self._get_meters([phase]),
prefix="Train Epoch: [{}]".format(self.epoch),
)
# Model training loop
self.model.train()
end = time.time()
for data_iter, batch in enumerate(train_loader):
# measure data loading time
data_time_meter.update(time.time() - end)
data_times.append(data_time_meter.val)
batch = batch.to(
self.device, non_blocking=True
) # move tensors in a tensorclass
try:
self._run_step(batch, phase, loss_mts, extra_loss_mts)
# compute gradient and do optim step
exact_epoch = self.epoch + float(data_iter) / iters_per_epoch
self.where = float(exact_epoch) / self.max_epochs
assert self.where <= 1 + self.EPSILON
if self.where < 1.0:
self.optim.step_schedulers(
self.where, step=int(exact_epoch * iters_per_epoch)
)
else:
logging.warning(
f"Skipping scheduler update since the training is at the end, i.e, {self.where} of [0,1]."
)
# Log schedulers
if data_iter % self.logging_conf.log_scalar_frequency == 0:
for j, param_group in enumerate(self.optim.optimizer.param_groups):
for option in self.optim.schedulers[j]:
optim_prefix = (
"" + f"{j}_"
if len(self.optim.optimizer.param_groups) > 1
else ""
)
self.logger.log(
os.path.join("Optim", f"{optim_prefix}", option),
param_group[option],
self.steps[phase],
)
# Clipping gradients and detecting diverging gradients
if self.gradient_clipper is not None:
self.scaler.unscale_(self.optim.optimizer)
self.gradient_clipper(model=self.model)
if self.gradient_logger is not None:
self.gradient_logger(
self.model, rank=self.distributed_rank, where=self.where
)
# Optimizer step: the scaler will make sure gradients are not
# applied if the gradients are infinite
self.scaler.step(self.optim.optimizer)
self.scaler.update()
# measure elapsed time
batch_time_meter.update(time.time() - end)
end = time.time()
self.time_elapsed_meter.update(
time.time() - self.start_time + self.ckpt_time_elapsed
)
mem_meter.update(reset_peak_usage=True)
if data_iter % self.logging_conf.log_freq == 0:
progress.display(data_iter)
if data_iter % self.logging_conf.log_scalar_frequency == 0:
# Log progress meters.
for progress_meter in progress.meters:
self.logger.log(
os.path.join("Step_Stats", phase, progress_meter.name),
progress_meter.val,
self.steps[phase],
)
# Catching NaN/Inf errors in the loss
except FloatingPointError as e:
raise e
self.est_epoch_time[Phase.TRAIN] = batch_time_meter.avg * iters_per_epoch
self._log_timers(Phase.TRAIN)
self._log_sync_data_times(Phase.TRAIN, data_times)
out_dict = self._log_meters_and_save_best_ckpts([Phase.TRAIN])
for k, v in loss_mts.items():
out_dict[k] = v.avg
for k, v in extra_loss_mts.items():
out_dict[k] = v.avg
out_dict.update(self._get_trainer_state(phase))
logging.info(f"Losses and meters: {out_dict}")
self._reset_meters([phase])
return out_dict