in mdr/retrieval/decomposed_analysis.py [0:0]
def decomposed_errors():
top1_pred = [json.loads(l) for l in open("/private/home/xwhan/data/hotpot/dense_val_b1_top1.json").readlines()]
analysis_folder = "/private/home/xwhan/data/hotpot/analysis"
start_errors, bridge_errors, failed = [], [], []
correct = []
for item in top1_pred:
pred_titles = [_[0] for _ in item["candidate_chains"][0]]
gold_titles = [_[0] for _ in item["sp"]]
if set(pred_titles) == set(gold_titles):
if item["type"] == "bridge":
correct.append(item)
continue
if item["type"] == "bridge":
start_title = None
for t in gold_titles:
if t != item["bridge"]:
start_title = t
assert start_title is not None
if item["bridge"] in pred_titles and start_title not in pred_titles:
start_errors.append(item)
elif item["bridge"] not in pred_titles and start_title in pred_titles:
bridge_errors.append(item)
else:
failed.append(item)
with open(analysis_folder + "/correct.json", "w") as g:
for _ in correct:
_["predicted"] = _.pop("candidate_chains")[0]
g.write(json.dumps(_) + "\n")
with open(analysis_folder + "/start_errors.json", "w") as g:
for _ in start_errors:
_["predicted"] = _.pop("candidate_chains")[0]
g.write(json.dumps(_) + "\n")
with open(analysis_folder + "/bridge_errors.json", "w") as g:
for _ in bridge_errors:
_["predicted"] = _.pop("candidate_chains")[0]
g.write(json.dumps(_) + "\n")
with open(analysis_folder + "/total_errors.json", "w") as g:
for _ in failed:
_["predicted"] = _.pop("candidate_chains")[0]
g.write(json.dumps(_) + "\n")
print(len(correct))
print(len(start_errors))
print(len(bridge_errors))
print(len(failed))