in training/trainer.py [0:0]
def val_epoch(self, val_loader, phase):
batch_time = AverageMeter("Batch Time", self.device, ":.2f")
data_time = AverageMeter("Data Time", self.device, ":.2f")
mem = MemMeter("Mem (GB)", self.device, ":.2f")
iters_per_epoch = len(val_loader)
curr_phases = [phase]
curr_models = [self.model]
loss_names = []
for p in curr_phases:
for key in self.loss.keys():
loss_names.append(f"Losses/{p}_{key}_loss")
loss_mts = OrderedDict(
[(name, AverageMeter(name, self.device, ":.2e")) for name in loss_names]
)
extra_loss_mts = {}
for model in curr_models:
model.eval()
if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_start"):
unwrap_ddp_if_wrapped(model).on_validation_epoch_start()
progress = ProgressMeter(
iters_per_epoch,
[batch_time, data_time, mem, self.time_elapsed_meter, *loss_mts.values()],
self._get_meters(curr_phases),
prefix="Val Epoch: [{}]".format(self.epoch),
)
end = time.time()
for data_iter, batch in enumerate(val_loader):
# measure data loading time
data_time.update(time.time() - end)
batch = batch.to(self.device, non_blocking=True)
# compute output
with torch.no_grad():
with torch.cuda.amp.autocast(
enabled=(self.optim_conf.amp.enabled if self.optim_conf else False),
dtype=(
get_amp_type(self.optim_conf.amp.amp_dtype)
if self.optim_conf
else None
),
):
for phase, model in zip(curr_phases, curr_models):
loss_dict, batch_size, extra_losses = self._step(
batch,
model,
phase,
)
assert len(loss_dict) == 1
loss_key, loss = loss_dict.popitem()
loss_mts[loss_key].update(loss.item(), batch_size)
for k, v in extra_losses.items():
if k not in extra_loss_mts:
extra_loss_mts[k] = AverageMeter(k, self.device, ":.2e")
extra_loss_mts[k].update(v.item(), batch_size)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
self.time_elapsed_meter.update(
time.time() - self.start_time + self.ckpt_time_elapsed
)
if torch.cuda.is_available():
mem.update(reset_peak_usage=True)
if data_iter % self.logging_conf.log_freq == 0:
progress.display(data_iter)
if data_iter % self.logging_conf.log_scalar_frequency == 0:
# Log progress meters.
for progress_meter in progress.meters:
self.logger.log(
os.path.join("Step_Stats", phase, progress_meter.name),
progress_meter.val,
self.steps[Phase.VAL],
)
if data_iter % 10 == 0:
dist.barrier()
self.est_epoch_time[phase] = batch_time.avg * iters_per_epoch
self._log_timers(phase)
for model in curr_models:
if hasattr(unwrap_ddp_if_wrapped(model), "on_validation_epoch_end"):
unwrap_ddp_if_wrapped(model).on_validation_epoch_end()
out_dict = self._log_meters_and_save_best_ckpts(curr_phases)
for k, v in loss_mts.items():
out_dict[k] = v.avg
for k, v in extra_loss_mts.items():
out_dict[k] = v.avg
for phase in curr_phases:
out_dict.update(self._get_trainer_state(phase))
self._reset_meters(curr_phases)
logging.info(f"Meters: {out_dict}")
return out_dict