def model_fn()

in bring-your-own-container/fairseq_translation/fairseq/sagemaker_translate.py [0:0]


def model_fn(model_dir):

    model_name = "checkpoint_best.pt"
    model_path = os.path.join(model_dir, model_name)

    logger.info("Loading the model")
    with open(model_path, "rb") as f:
        model_info = torch.load(f, map_location=torch.device("cpu"))

    # Will be overidden by the model_info['args'] - need to keep for pre-trained models
    parser = options.get_generation_parser(interactive=True)
    # get args for FairSeq by converting the hyperparameters as if they were command-line arguments
    argv_copy = copy.deepcopy(sys.argv)
    # remove the modifications we did in the command-line arguments
    sys.argv[1:] = ["--path", model_path, model_dir]
    args = options.parse_args_and_arch(parser)
    # restore previous command-line args
    sys.argv = argv_copy

    saved_args = model_info["args"]
    for key, value in vars(saved_args).items():
        setattr(args, key, value)

    args.data = [model_dir]
    print(args)

    # Setup task, e.g., translation
    task = tasks.setup_task(args)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info("Current device: {}".format(device))

    model_paths = [os.path.join(model_dir, model_name)]
    models, model_args = utils.load_ensemble_for_inference(
        model_paths, task, model_arg_overrides={}
    )

    # Set dictionaries
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()

    # Initialize generator
    translator = SequenceGenerator(
        models,
        tgt_dict,
        beam_size=args.beam,
        minlen=args.min_len,
        stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized),
        len_penalty=args.lenpen,
        unk_penalty=args.unkpen,
        sampling=args.sampling,
        sampling_topk=args.sampling_topk,
        sampling_temperature=args.sampling_temperature,
        diverse_beam_groups=args.diverse_beam_groups,
        diverse_beam_strength=args.diverse_beam_strength,
    )

    if device.type == "cuda":
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    # align_dict = utils.load_align_dict(args.replace_unk)
    align_dict = utils.load_align_dict(None)

    max_positions = utils.resolve_max_positions(
        task.max_positions(), *[model.max_positions() for model in models]
    )

    return dict(
        translator=translator,
        task=task,
        max_positions=max_positions,
        align_dict=align_dict,
        tgt_dict=tgt_dict,
        args=args,
        device=device,
    )