def main()

in scripts/interactive_file.py [0:0]


def main(args):
    utils.import_user_module(args)

    if args.buffer_size < 1:
        args.buffer_size = 1
    if args.max_tokens is None and args.max_sentences is None:
        args.max_sentences = 1

    assert (
        not args.sampling or args.nbest == args.beam
    ), "--sampling requires --nbest to be equal to --beam"
    assert (
        not args.max_sentences or args.max_sentences <= args.buffer_size
    ), "--max-sentences/--batch-size cannot be larger than --buffer-size"

    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Setup task, e.g., translation
    task = tasks.setup_task(args)

    # Load ensemble
    print("| loading model(s) from {}".format(args.path))
    models, _model_args = checkpoint_utils.load_model_ensemble(
        args.path.split(":"), arg_overrides=eval(args.model_overrides), task=task,
    )

    # Set dictionaries
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
            need_attn=args.print_alignment,
        )
        if args.fp16:
            model.half()
        if use_cuda:
            model.cuda()

    # Initialize generator
    generator = task.build_generator(args)

    # Handle tokenization and BPE
    tokenizer = encoders.build_tokenizer(args)
    bpe = encoders.build_bpe(args)

    def encode_fn(x):
        if tokenizer is not None:
            x = tokenizer.encode(x)
        if bpe is not None:
            x = bpe.encode(x)
        return x

    def decode_fn(x):
        if bpe is not None:
            x = bpe.decode(x)
        if tokenizer is not None:
            x = tokenizer.decode(x)
        return x

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    max_positions = utils.resolve_max_positions(
        task.max_positions(), *[model.max_positions() for model in models]
    )

    if args.buffer_size > 1:
        print("| Sentence buffer size:", args.buffer_size)
    # print('| Type the input sentence and press return:')
    start_id = 0
    num_lines = get_num_lines(args.input)
    print("loading {} with {} lines".format(args.input, num_lines))
    pb = tqdm(total=num_lines // args.buffer_size)
    with open(args.output_file, "w") as fp:
        for inputs in buffered_read(args.input, args.buffer_size):
            results = []
            for batch in make_batches(inputs, args, task, max_positions, encode_fn):
                src_tokens = batch.src_tokens
                src_lengths = batch.src_lengths
                if use_cuda:
                    src_tokens = src_tokens.cuda()
                    src_lengths = src_lengths.cuda()

                sample = {
                    "net_input": {
                        "src_tokens": src_tokens,
                        "src_lengths": src_lengths,
                    },
                }
                translations = task.inference_step(generator, models, sample)
                for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
                    src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
                    results.append((start_id + id, src_tokens_i, hypos))

            # sort output to match input order
            for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
                if src_dict is not None:
                    src_str = src_dict.string(src_tokens, args.remove_bpe)
                    # print('S-{}\t{}'.format(id, src_str))

                # Process top predictions
                for hypo in hypos[: min(len(hypos), args.nbest)]:
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo["tokens"].int().cpu(),
                        src_str=src_str,
                        alignment=hypo["alignment"],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )
                    hypo_str = decode_fn(hypo_str)
                    fp.write(hypo_str + "\n")
                    # print('H-{}\t{}\t{}'.format(id, hypo['score'], hypo_str))
                    # print('P-{}\t{}'.format(
                    #     id,
                    #     ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
                    # ))
                    # if args.print_alignment:
                    #     alignment_str = " ".join(["{}-{}".format(src, tgt) for src, tgt in alignment])
                    #     print('A-{}\t{}'.format(
                    #         id,
                    #         alignment_str
                    #     ))

            # update running id counter
            start_id += len(inputs)
            pb.update(1)
    pb.close()