in src/beir_utils.py [0:0]
def encode_queries(self, queries: List[str], batch_size: int, **kwargs) -> np.ndarray:
if dist.is_initialized():
idx = np.array_split(range(len(queries)), dist.get_world_size())[dist.get_rank()]
else:
idx = range(len(queries))
queries = [queries[i] for i in idx]
allemb = []
nbatch = (len(queries)-1) // batch_size + 1
with torch.no_grad():
for k in range(nbatch):
start_idx = k * batch_size
end_idx = min((k+1) * batch_size, len(queries))
qencode = self.tokenizer.batch_encode_plus(
queries[start_idx:end_idx],
max_length=self.maxlength,
padding=True,
truncation=True,
add_special_tokens=self.add_special_tokens,
return_tensors="pt",
)
ids, mask = qencode['input_ids'], qencode['attention_mask']
ids, mask = ids.cuda(), mask.cuda()
emb = self.query_encoder(ids, mask, normalize=self.norm_query)
allemb.append(emb)
allemb = torch.cat(allemb, dim=0)
if dist.is_initialized():
allemb = dist_utils.varsize_gather_nograd(allemb)
allemb = allemb.cpu().numpy()
return allemb