def validate()

in mdl.py [0:0]


def validate(args, trainer, task, epoch_itr, subsets, allowed_batches):
    """Evaluate the model on the validation set(s) and return the losses."""

    assert len(subsets) == 1

    valid_epoch_itr = task.get_batch_iterator(
        dataset=task.dataset(subsets[0]),
        max_tokens=args.max_tokens_valid,
        max_sentences=args.max_sentences_valid,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            trainer.get_model().max_positions(),
        ),
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        seed=args.seed,
        num_shards=args.distributed_world_size,
        shard_id=args.distributed_rank,
        num_workers=args.num_workers,
    )

    valid_epoch_itr.frozen_batches = allowed_batches

    itr = valid_epoch_itr.next_epoch_itr(shuffle=False)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch,
        prefix='next block validation',
        no_progress_bar='simple'
    )

    # reset validation loss meters
    for k in ['valid_loss', 'valid_nll_loss', 'loss']:
        meter = trainer.get_meter(k)
        if meter is not None:
            meter.reset()

    for sample in progress:
        log_output = trainer.valid_step(sample)

    # log validation stats
    valid_loss = trainer.get_meter('valid_loss').avg

    return valid_loss