in kilt/readers/t5/evaluate_kilt_task.py [0:0]
def run_generate():
parser = argparse.ArgumentParser()
parser.add_argument(
"model_size",
type=str,
help="T5 model size, either 't5-small', 't5-base', 't5-large', 't5-3b', 't5-11b'. Defaults to 't5-base'.",
default="t5-base",
)
parser.add_argument(
"input_path",
type=str,
help="like nqa/test_articles_questions.txt",
)
parser.add_argument(
"output_path",
type=str,
help="where to save summaries",
)
parser.add_argument(
"reference_path", type=str, help="like nqa/test_reference_answers.txt"
)
parser.add_argument(
"score_path",
type=str,
help="where to save the rouge score",
)
parser.add_argument(
"--batch_size",
type=int,
default=8,
required=False,
help="batch size: how many to summarize at a time",
)
parser.add_argument(
"--no_cuda",
default=False,
type=bool,
help="Whether to force the execution on CPU.",
)
args = parser.parse_args()
args.device = torch.device(
"cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
)
source_lns = [x.rstrip() for x in open(args.input_path).readlines()]
sq2sq = Seq2seqTransformer(args)
checkpoints = list(
sorted(
glob.glob(
os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True
)
)
)
model = sq2sq.load_from_checkpoint(checkpoints[-1]).model
tokenizer = sq2sq.tokenizer
generate_answers(
source_lns, args.output_path, model, tokenizer, args.batch_size, args.device
)
output_lns = [x.rstrip() for x in open(args.output_path).readlines()]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()]
calculate_rouge(output_lns, reference_lns, args.score_path)