in generate_passage_embeddings.py [0:0]
def main(opt):
logger = src.util.init_logger(is_main=True)
tokenizer = transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')
model_class = src.model.Retriever
#model, _, _, _, _, _ = src.util.load(model_class, opt.model_path, opt)
model = model_class.from_pretrained(opt.model_path)
model.eval()
model = model.to(opt.device)
if not opt.no_fp16:
model = model.half()
passages = src.util.load_passages(args.passages)
shard_size = len(passages) // args.num_shards
start_idx = args.shard_id * shard_size
end_idx = start_idx + shard_size
if args.shard_id == args.num_shards-1:
end_idx = len(passages)
passages = passages[start_idx:end_idx]
logger.info(f'Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}')
allids, allembeddings = embed_passages(opt, passages, model, tokenizer)
output_path = Path(args.output_path)
save_file = output_path.parent / (output_path.name + f'_{args.shard_id:02d}')
output_path.parent.mkdir(parents=True, exist_ok=True)
logger.info(f'Saving {len(allids)} passage embeddings to {save_file}')
with open(save_file, mode='wb') as f:
pickle.dump((allids, allembeddings), f)
logger.info(f'Total passages processed {len(allids)}. Written to {save_file}.')