def run_generate()

in kilt/readers/t5/evaluate_kilt_task.py [0:0]


def run_generate():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "model_size",
        type=str,
        help="T5 model size, either 't5-small', 't5-base', 't5-large', 't5-3b', 't5-11b'. Defaults to 't5-base'.",
        default="t5-base",
    )
    parser.add_argument(
        "input_path",
        type=str,
        help="like nqa/test_articles_questions.txt",
    )
    parser.add_argument(
        "output_path",
        type=str,
        help="where to save summaries",
    )
    parser.add_argument(
        "reference_path", type=str, help="like nqa/test_reference_answers.txt"
    )
    parser.add_argument(
        "score_path",
        type=str,
        help="where to save the rouge score",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=8,
        required=False,
        help="batch size: how many to summarize at a time",
    )
    parser.add_argument(
        "--no_cuda",
        default=False,
        type=bool,
        help="Whether to force the execution on CPU.",
    )

    args = parser.parse_args()
    args.device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
    )

    source_lns = [x.rstrip() for x in open(args.input_path).readlines()]
    sq2sq = Seq2seqTransformer(args)
    checkpoints = list(
        sorted(
            glob.glob(
                os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True
            )
        )
    )

    model = sq2sq.load_from_checkpoint(checkpoints[-1]).model
    tokenizer = sq2sq.tokenizer
    generate_answers(
        source_lns, args.output_path, model, tokenizer, args.batch_size, args.device
    )
    output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
    reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]

    calculate_rouge(output_lns, reference_lns, args.score_path)