in pytorch_translate/evals.py [0:0]
def get_training_stats(trainer):
stats = OrderedDict()
if trainer.get_meter("train_loss") is not None:
avg = trainer.get_meter("train_loss").avg
if avg is not None:
stats["loss"] = f"{avg:.3f}"
if trainer.get_meter("train_nll_loss").count > 0:
nll_loss = trainer.get_meter("train_nll_loss").avg
stats["nll_loss"] = f"{nll_loss:.3f}"
else:
nll_loss = trainer.get_meter("train_nll_loss").avg
stats["ppl"] = get_perplexity(nll_loss) if nll_loss is not None else -1.0
if trainer.get_meter("wps") is not None:
stats["wps"] = (
round(utils.item(trainer.get_meter("wps").avg))
if trainer.get_meter("wps").avg
else None
)
if trainer.get_meter("ups") is not None:
stats["ups"] = (
f"{trainer.get_meter('ups').avg:.1f}"
if trainer.get_meter("ups").avg
else None
)
if trainer.get_meter("wpb") is not None:
stats["wpb"] = (
round(utils.item(trainer.get_meter("wpb").avg))
if trainer.get_meter("wpb").avg
else None
)
if trainer.get_meter("bsz") is not None:
stats["bsz"] = (
round(utils.item(trainer.get_meter("bsz").avg))
if trainer.get_meter("bsz").avg
else None
)
stats["num_updates"] = trainer.get_num_updates()
stats["lr"] = trainer.get_lr()
if trainer.get_meter("gnorm") is not None:
stats["gnorm"] = (
f"{trainer.get_meter('gnorm').avg:.3f}"
if trainer.get_meter("gnorm").avg
else None
)
if trainer.get_meter("clip") is not None:
stats["clip"] = (
f"{trainer.get_meter('clip').avg:.0%}"
if trainer.get_meter("clip").avg
else None
)
if trainer.get_meter("oom") is not None:
stats["oom"] = (
trainer.get_meter("oom").avg if trainer.get_meter("oom").avg else None
)
if trainer.get_meter("loss_scale") is not None:
stats["loss_scale"] = (
f"{trainer.get_meter('loss_scale').avg:.3f}"
if trainer.get_meter("loss_scale").avg
else None
)
if trainer.get_meter("wall") is not None:
stats["wall"] = (
round(utils.item(trainer.get_meter("wall").elapsed_time))
if trainer.get_meter("wall").elapsed_time
else None
)
if trainer.get_meter("train_wall") is not None:
stats["train_wall"] = (
round(utils.item(trainer.get_meter("train_wall").sum))
if trainer.get_meter("train_wall").sum
else None
)
return stats