in bring-your-own-container/fairseq_translation/fairseq/train_driver.py [0:0]
def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch."""
# Update parameters every N batches
if epoch_itr.epoch <= len(args.update_freq):
update_freq = args.update_freq[epoch_itr.epoch - 1]
else:
update_freq = args.update_freq[-1]
# Initialize data iterator
itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=args.fix_batches_to_gpus)
itr = iterators.GroupedIterator(itr, update_freq)
progress = progress_bar.build_progress_bar(
args,
itr,
epoch_itr.epoch,
no_progress_bar="simple",
)
extra_meters = collections.defaultdict(lambda: AverageMeter())
first_valid = args.valid_subset.split(",")[0]
max_update = args.max_update or math.inf
for i, samples in enumerate(progress, start=epoch_itr.iterations_in_epoch):
log_output = trainer.train_step(samples)
if log_output is None:
continue
# log mid-epoch stats
stats = get_training_stats(trainer)
for k, v in log_output.items():
if k in ["loss", "nll_loss", "ntokens", "nsentences", "sample_size"]:
continue # these are already logged above
if "loss" in k:
extra_meters[k].update(v, log_output["sample_size"])
else:
extra_meters[k].update(v)
stats[k] = extra_meters[k].avg
progress.log(stats)
# ignore the first mini-batch in words-per-second calculation
if i == 0:
trainer.get_meter("wps").reset()
num_updates = trainer.get_num_updates()
if (
args.save_interval_updates > 0
and num_updates % args.save_interval_updates == 0
and num_updates > 0
):
valid_losses = validate(args, trainer, task, epoch_itr, [first_valid])
save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
if num_updates >= max_update:
break
# log end-of-epoch stats
stats = get_training_stats(trainer)
for k, meter in extra_meters.items():
stats[k] = meter.avg
progress.print(stats)
# reset training meters
for k in [
"train_loss",
"train_nll_loss",
"wps",
"ups",
"wpb",
"bsz",
"gnorm",
"clip",
]:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()