def main()

in passage_retrieval.py [0:0]


def main(opt):
    src.util.init_logger(is_main=True)
    tokenizer = transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')
    data = src.data.load_data(opt.data)
    model_class = src.model.Retriever
    model = model_class.from_pretrained(opt.model_path)

    model.cuda()
    model.eval()
    if not opt.no_fp16:
        model = model.half()

    index = src.index.Indexer(model.config.indexing_dimension, opt.n_subquantizers, opt.n_bits)

    # index all passages
    input_paths = glob.glob(args.passages_embeddings)
    input_paths = sorted(input_paths)
    embeddings_dir = Path(input_paths[0]).parent
    index_path = embeddings_dir / 'index.faiss'
    if args.save_or_load_index and index_path.exists():
        src.index.deserialize_from(embeddings_dir)
    else:
        logger.info(f'Indexing passages from files {input_paths}')
        start_time_indexing = time.time()
        index_encoded_data(index, input_paths, opt.indexing_batch_size)
        logger.info(f'Indexing time: {time.time()-start_time_indexing:.1f} s.')
        if args.save_or_load_index:
            src.index.serialize(embeddings_dir)

    questions_embedding = embed_questions(opt, data, model, tokenizer)

    # get top k results
    start_time_retrieval = time.time()
    top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs) 
    logger.info(f'Search time: {time.time()-start_time_retrieval:.1f} s.')

    passages = src.util.load_passages(args.passages)
    passages = {x[0]:(x[1], x[2]) for x in passages}

    add_passages(data, passages, top_ids_and_scores)
    hasanswer = validate(data, args.validation_workers)
    add_hasanswer(data, hasanswer)
    output_path = Path(args.output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(args.output_path, 'w') as fout:
        json.dump(data, fout, indent=4)
    logger.info(f'Saved results to {args.output_path}')