in src/beir_utils.py [0:0]
def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs):
if dist.is_initialized():
idx = np.array_split(range(len(corpus)), dist.get_world_size())[dist.get_rank()]
else:
idx = range(len(corpus))
corpus = [corpus[i] for i in idx]
corpus = [
c['title'] + ' ' + c['text'] if len(c['title']) > 0 else c['text'] for c in corpus
]
allemb = []
nbatch = (len(corpus)-1) // batch_size + 1
with torch.no_grad():
for k in range(nbatch):
start_idx = k * batch_size
end_idx = min((k+1) * batch_size, len(corpus))
cencode = self.tokenizer.batch_encode_plus(
corpus[start_idx:end_idx],
max_length=self.maxlength,
padding=True,
truncation=True,
add_special_tokens=self.add_special_tokens,
return_tensors="pt",
)
ids, mask = cencode['input_ids'], cencode['attention_mask']
ids, mask = ids.cuda(), mask.cuda()
emb = self.doc_encoder(ids, mask, normalize=self.norm_doc)
allemb.append(emb)
allemb = torch.cat(allemb, dim=0)
if dist.is_initialized():
allemb = dist_utils.varsize_gather_nograd(allemb)
allemb = allemb.cpu().numpy()
return allemb