def decomposed_errors()

in mdr/retrieval/decomposed_analysis.py [0:0]


def decomposed_errors():
    top1_pred = [json.loads(l) for l in open("/private/home/xwhan/data/hotpot/dense_val_b1_top1.json").readlines()]
    analysis_folder = "/private/home/xwhan/data/hotpot/analysis"

    start_errors, bridge_errors, failed = [], [], []
    correct = []
    for item in top1_pred:
        pred_titles = [_[0] for _ in item["candidate_chains"][0]]
        gold_titles = [_[0] for _ in item["sp"]]
        if set(pred_titles) == set(gold_titles):
            if item["type"] == "bridge":
                correct.append(item)
            continue
        if item["type"] == "bridge":
            start_title = None
            for t in gold_titles:
                if t != item["bridge"]:
                    start_title = t
            assert start_title is not None
            if item["bridge"] in pred_titles and start_title not in pred_titles:
                start_errors.append(item)
            elif item["bridge"] not in pred_titles and start_title in pred_titles:
                bridge_errors.append(item)
            else:
                failed.append(item)

    with open(analysis_folder + "/correct.json", "w") as g:
        for _ in correct:
            _["predicted"] = _.pop("candidate_chains")[0]
            g.write(json.dumps(_) + "\n")

    with open(analysis_folder + "/start_errors.json", "w") as g:
        for _ in start_errors:
            _["predicted"] = _.pop("candidate_chains")[0]
            g.write(json.dumps(_) + "\n")

    with open(analysis_folder + "/bridge_errors.json", "w") as g:
        for _ in bridge_errors:
            _["predicted"] = _.pop("candidate_chains")[0]
            g.write(json.dumps(_) + "\n")

    with open(analysis_folder + "/total_errors.json", "w") as g:
        for _ in failed:
            _["predicted"] = _.pop("candidate_chains")[0]
            g.write(json.dumps(_) + "\n")


    print(len(correct))
    print(len(start_errors))
    print(len(bridge_errors))
    print(len(failed))