in dpr_scale/data_prep/prep_conv_datasets.py [0:0]
def prep_dpr_dstc7(infile, outfile):
skipped = 0
with open(infile) as fin, open(outfile, "w") as fout:
json_obj = ujson.load(fin)
for line in tqdm(json_obj):
if "options-for-correct-answers" in line:
question = get_question(line["messages-so-far"])
pos_ctxs, pos_ctx_ids = get_pos_ctxs(
line["options-for-correct-answers"]
)
neg_ctxs = get_neg_ctxs(line["options-for-next"], pos_ctx_ids)
out_json = ujson.dumps(
{
"question": question,
"answers": [],
"positive_ctxs": pos_ctxs,
"hard_negative_ctxs": neg_ctxs,
}
)
fout.write(f"{out_json}\n")
else:
skipped += 1
print(f"{infile}: {skipped}")