in generate_passage_embeddings.py [0:0]
def embed_passages(opt, passages, model, tokenizer):
batch_size = opt.per_gpu_batch_size * opt.world_size
collator = src.data.TextCollator(tokenizer, model.config.passage_maxlength)
dataset = src.data.TextDataset(passages, title_prefix='title:', passage_prefix='context:')
dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=False, num_workers=10, collate_fn=collator)
total = 0
allids, allembeddings = [], []
with torch.no_grad():
for k, (ids, text_ids, text_mask) in enumerate(dataloader):
embeddings = model.embed_text(
text_ids=text_ids.cuda(),
text_mask=text_mask.cuda(),
apply_mask=model.config.apply_passage_mask,
extract_cls=model.config.extract_cls,
)
embeddings = embeddings.cpu()
total += len(ids)
allids.append(ids)
allembeddings.append(embeddings)
if k % 100 == 0:
logger.info('Encoded passages %d', total)
allembeddings = torch.cat(allembeddings, dim=0).numpy()
allids = [x for idlist in allids for x in idlist]
return allids, allembeddings