def load()

in custom/evaluate_utils.py [0:0]


def load(args, task=None, itr=None, generator=None, log=False):
    """Returns task, model, generator, and dataset iterator for the given `args`."""
    assert args.path is not None, '--path required for generation!'
    import random
    random.seed(42)
    torch.manual_seed(42)
    utils.import_user_module(args)
    if log:
        print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
    if task is None:
        task = tasks.setup_task(args)
        task.load_dataset(args.gen_subset)

    # Load ensemble
    if log:
        print('| loading model(s) from {}'.format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(':'),
        arg_overrides=eval(args.model_overrides),
        task=task,
    )

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()
    model = models[0]

    if itr is None:
        # Load dataset (possibly sharded)
        itr = task.get_batch_iterator(
            dataset=task.dataset(args.gen_subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=args.tokens_per_sample,
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            num_shards=args.num_shards,
            shard_id=args.shard_id,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)

    # Get model step
    step = torch.load(args.path)['optimizer_history'][-1]['num_updates']

    if generator is None:
        # Initialize generator
        generator = task.build_generator(args)
    return task, model, generator, itr, step