in fairnr_cli/train.py [0:0]
def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch."""
# Initialize data iterator
itr = epoch_itr.next_epoch_itr(
fix_batches_to_gpus=args.fix_batches_to_gpus,
shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
)
update_freq = (
args.update_freq[epoch_itr.epoch - 1]
if epoch_itr.epoch <= len(args.update_freq)
else args.update_freq[-1]
)
itr = iterators.GroupedIterator(itr, update_freq)
progress = progress_bar.progress_bar(
itr,
log_format=args.log_format,
log_interval=args.log_interval,
epoch=epoch_itr.epoch,
tensorboard_logdir=(
args.tensorboard_logdir if distributed_utils.is_master(args) else None
),
default_log_format=('tqdm' if not args.no_progress_bar else 'simple'),
)
# task specific setup per epoch
task.begin_epoch(epoch_itr.epoch, trainer.get_model())
valid_subsets = args.valid_subset.split(',')
max_update = args.max_update or math.inf
should_end_training = False
for samples in progress:
with metrics.aggregate('train_inner'):
try:
log_output = trainer.train_step(samples)
except ResetTrainerException:
trainer._wrapped_criterion = None
trainer._wrapped_model = None
trainer._optimizer = None
logger.info("reset the trainer at {}".format(trainer.get_num_updates()))
log_output = trainer.train_step(samples)
if log_output is None: # OOM, overflow, ...
continue
# log mid-epoch stats
num_updates = trainer.get_num_updates()
if num_updates % args.log_interval == 0:
stats = get_training_stats(metrics.get_smoothed_values('train_inner'))
progress.log(stats, tag='train_inner', step=num_updates)
# reset mid-epoch stats after each log interval
# the end-of-epoch stats will still be preserved
metrics.reset_meters('train_inner')
valid_losses = validate_and_save(args, trainer, task, epoch_itr, valid_subsets)
if should_stop_early(args, valid_losses[0]) or num_updates >= max_update:
should_end_training = True
break
# log end-of-epoch stats
stats = get_training_stats(metrics.get_smoothed_values('train'))
progress.print(stats, tag='train', step=num_updates)
# reset epoch-level meters
metrics.reset_meters('train')
return should_end_training