def batch_query_qa_dense_index_nn()

in longform-qa/lfqa_utils.py [0:0]


def batch_query_qa_dense_index_nn(passages, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10):
    a_reps = embed_passages_for_retrieval(passages, tokenizer, qa_embedder)
    D, I = wiki_index.search(a_reps, n_results)
    res_passages_lst = [[wiki_passages[int(i)] for i in i_lst] for i_lst in I]
    support_doc_lst = [
        "<P> " + " <P> ".join([p["passage_text"] for p in res_passages]) for res_passages in res_passages_lst
    ]
    all_res_lists = []
    for (res_passages, dl, il) in zip(res_passages_lst, D, I):
        res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
        for r, sc, i in zip(res_list, dl, il):
            r["passage_id"] = int(i)
            r["score"] = float(sc)
        all_res_lists += [res_list[:]]
    return support_doc_lst, all_res_lists