def main()

in generate_passage_embeddings.py [0:0]


def main(opt):
    logger = src.util.init_logger(is_main=True)
    tokenizer = transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')
    model_class = src.model.Retriever
    #model, _, _, _, _, _ = src.util.load(model_class, opt.model_path, opt)
    model = model_class.from_pretrained(opt.model_path)
    
    model.eval()
    model = model.to(opt.device)
    if not opt.no_fp16:
        model = model.half()

    passages = src.util.load_passages(args.passages)

    shard_size = len(passages) // args.num_shards
    start_idx = args.shard_id * shard_size
    end_idx = start_idx + shard_size
    if args.shard_id == args.num_shards-1:
        end_idx = len(passages)

    passages = passages[start_idx:end_idx]
    logger.info(f'Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}')

    allids, allembeddings = embed_passages(opt, passages, model, tokenizer)

    output_path = Path(args.output_path)
    save_file = output_path.parent / (output_path.name + f'_{args.shard_id:02d}')
    output_path.parent.mkdir(parents=True, exist_ok=True) 
    logger.info(f'Saving {len(allids)} passage embeddings to {save_file}')
    with open(save_file, mode='wb') as f:
        pickle.dump((allids, allembeddings), f)

    logger.info(f'Total passages processed {len(allids)}. Written to {save_file}.')