in bring-your-own-container/fairseq_translation/fairseq/sagemaker_translate.py [0:0]
def model_fn(model_dir):
model_name = "checkpoint_best.pt"
model_path = os.path.join(model_dir, model_name)
logger.info("Loading the model")
with open(model_path, "rb") as f:
model_info = torch.load(f, map_location=torch.device("cpu"))
# Will be overidden by the model_info['args'] - need to keep for pre-trained models
parser = options.get_generation_parser(interactive=True)
# get args for FairSeq by converting the hyperparameters as if they were command-line arguments
argv_copy = copy.deepcopy(sys.argv)
# remove the modifications we did in the command-line arguments
sys.argv[1:] = ["--path", model_path, model_dir]
args = options.parse_args_and_arch(parser)
# restore previous command-line args
sys.argv = argv_copy
saved_args = model_info["args"]
for key, value in vars(saved_args).items():
setattr(args, key, value)
args.data = [model_dir]
print(args)
# Setup task, e.g., translation
task = tasks.setup_task(args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info("Current device: {}".format(device))
model_paths = [os.path.join(model_dir, model_name)]
models, model_args = utils.load_ensemble_for_inference(
model_paths, task, model_arg_overrides={}
)
# Set dictionaries
tgt_dict = task.target_dictionary
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
need_attn=args.print_alignment,
)
if args.fp16:
model.half()
# Initialize generator
translator = SequenceGenerator(
models,
tgt_dict,
beam_size=args.beam,
minlen=args.min_len,
stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized),
len_penalty=args.lenpen,
unk_penalty=args.unkpen,
sampling=args.sampling,
sampling_topk=args.sampling_topk,
sampling_temperature=args.sampling_temperature,
diverse_beam_groups=args.diverse_beam_groups,
diverse_beam_strength=args.diverse_beam_strength,
)
if device.type == "cuda":
translator.cuda()
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
# align_dict = utils.load_align_dict(args.replace_unk)
align_dict = utils.load_align_dict(None)
max_positions = utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models]
)
return dict(
translator=translator,
task=task,
max_positions=max_positions,
align_dict=align_dict,
tgt_dict=tgt_dict,
args=args,
device=device,
)