in passage_retrieval.py [0:0]
def main(opt):
src.util.init_logger(is_main=True)
tokenizer = transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')
data = src.data.load_data(opt.data)
model_class = src.model.Retriever
model = model_class.from_pretrained(opt.model_path)
model.cuda()
model.eval()
if not opt.no_fp16:
model = model.half()
index = src.index.Indexer(model.config.indexing_dimension, opt.n_subquantizers, opt.n_bits)
# index all passages
input_paths = glob.glob(args.passages_embeddings)
input_paths = sorted(input_paths)
embeddings_dir = Path(input_paths[0]).parent
index_path = embeddings_dir / 'index.faiss'
if args.save_or_load_index and index_path.exists():
src.index.deserialize_from(embeddings_dir)
else:
logger.info(f'Indexing passages from files {input_paths}')
start_time_indexing = time.time()
index_encoded_data(index, input_paths, opt.indexing_batch_size)
logger.info(f'Indexing time: {time.time()-start_time_indexing:.1f} s.')
if args.save_or_load_index:
src.index.serialize(embeddings_dir)
questions_embedding = embed_questions(opt, data, model, tokenizer)
# get top k results
start_time_retrieval = time.time()
top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs)
logger.info(f'Search time: {time.time()-start_time_retrieval:.1f} s.')
passages = src.util.load_passages(args.passages)
passages = {x[0]:(x[1], x[2]) for x in passages}
add_passages(data, passages, top_ids_and_scores)
hasanswer = validate(data, args.validation_workers)
add_hasanswer(data, hasanswer)
output_path = Path(args.output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(args.output_path, 'w') as fout:
json.dump(data, fout, indent=4)
logger.info(f'Saved results to {args.output_path}')