in custom/evaluate_utils.py [0:0]
def load(args, task=None, itr=None, generator=None, log=False):
"""Returns task, model, generator, and dataset iterator for the given `args`."""
assert args.path is not None, '--path required for generation!'
import random
random.seed(42)
torch.manual_seed(42)
utils.import_user_module(args)
if log:
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset splits
if task is None:
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
# Load ensemble
if log:
print('| loading model(s) from {}'.format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(':'),
arg_overrides=eval(args.model_overrides),
task=task,
)
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
if use_cuda:
model.cuda()
model = models[0]
if itr is None:
# Load dataset (possibly sharded)
itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=args.tokens_per_sample,
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
num_shards=args.num_shards,
shard_id=args.shard_id,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
# Get model step
step = torch.load(args.path)['optimizer_history'][-1]['num_updates']
if generator is None:
# Initialize generator
generator = task.build_generator(args)
return task, model, generator, itr, step