in mdl.py [0:0]
def validate(args, trainer, task, epoch_itr, subsets, allowed_batches):
"""Evaluate the model on the validation set(s) and return the losses."""
assert len(subsets) == 1
valid_epoch_itr = task.get_batch_iterator(
dataset=task.dataset(subsets[0]),
max_tokens=args.max_tokens_valid,
max_sentences=args.max_sentences_valid,
max_positions=utils.resolve_max_positions(
task.max_positions(),
trainer.get_model().max_positions(),
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_shards=args.distributed_world_size,
shard_id=args.distributed_rank,
num_workers=args.num_workers,
)
valid_epoch_itr.frozen_batches = allowed_batches
itr = valid_epoch_itr.next_epoch_itr(shuffle=False)
progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch,
prefix='next block validation',
no_progress_bar='simple'
)
# reset validation loss meters
for k in ['valid_loss', 'valid_nll_loss', 'loss']:
meter = trainer.get_meter(k)
if meter is not None:
meter.reset()
for sample in progress:
log_output = trainer.valid_step(sample)
# log validation stats
valid_loss = trainer.get_meter('valid_loss').avg
return valid_loss