def query_qa_dense_index_nn()

in longform-qa/lfqa_utils.py [0:0]


def query_qa_dense_index_nn(passage, qa_embedder, tokenizer, wiki_passages, wiki_index, n_results=10, min_length=20):
    a_rep = embed_passages_for_retrieval([passage], tokenizer, qa_embedder)
    D, I = wiki_index.search(a_rep, 2 * n_results)
    res_passages = [wiki_passages[int(i)] for i in I[0]]
    support_doc = "<P> " + " <P> ".join([p["passage_text"] for p in res_passages])
    res_list = [dict([(k, p[k]) for k in wiki_passages.column_names]) for p in res_passages]
    res_list = [res for res in res_list if len(res["passage_text"].split()) > min_length][:n_results]
    for r, sc, i in zip(res_list, D[0], I[0]):
        r["passage_id"] = int(i)
        r["score"] = float(sc)
    return support_doc, res_list