def analyze_results()

in mdr/retrieval/decomposed_analysis.py [0:0]


def analyze_results():
    decomposed_results = [json.loads(l) for l in open("/private/home/xwhan/data/QDMR/qdmr_decomposed_results.json")]
    e2e_results = [json.loads(l) for l in open("/private/home/xwhan/data/QDMR/qdmr_e2e_results.json")]
    better = 0
    worse = 0
    both = 0
    for res1, res2 in zip(decomposed_results, e2e_results):
        sp_titles = set([_[0] for _ in res1["sp"]])

        res1_top1 = [_[0] for _ in res1["candidate_chains"][0]]
        res2_top1 = [_[0] for _ in res2["candidate_chains"][0]]

        assert res1["_id"] == res2["_id"]

        question = res1["question"]
        q_pairs = res1["q_pairs"]

        if set(res2_top1) == sp_titles and set(res1_top1) != sp_titles:
            # print(sp_titles)
            # import pdb; pdb.set_trace()
            better += 1
        elif set(res2_top1) != sp_titles and set(res1_top1) == sp_titles:
            worse += 1
        elif set(res2_top1) == sp_titles and set(res1_top1) == sp_titles:
            both += 1

    print(both)
    print(better)
    print(worse)
    print(len(decomposed_results))