in src/rime/models/zero_shot/item_knn.py [0:0]
def _compute_embeddings(self, titles):
""" find embedding of a batch of sequences """
with _to_cuda(self.model) as model:
embeddings = []
for batch in tqdm(np.split(titles, range(0, len(titles), self.batch_size)[1:])):
inputs = self.tokenizer(batch.tolist(), padding=True, return_tensors='pt')
if hasattr(self.model, 'bert'):
for key in inputs.keys(): # 'input_ids', 'attention_mask', 'token_type_ids'
inputs[key] = inputs[key][:, :512]
offset = 1 # [cls] seq [sep]
hidden_states = self.model.bert(**inputs.to(model.device))[0]
else:
offset = 0
hidden_states = self.model.transformer(**inputs.to(model.device))[0]
if self.pooling == 'mean':
segments = [slice(offset, n - offset) for n in
inputs['attention_mask'].sum(1).tolist()]
elif self.pooling == 'cls':
segments = [slice(0, 1) for _ in inputs['attention_mask']]
else:
raise NotImplementedError
hidden_states = torch.vstack([ # mean-pooling on causal lm states
x[slc].mean(0, keepdims=True) for x, slc in zip(hidden_states, segments)])
embeddings.append(hidden_states.double().cpu().numpy())
return np.vstack(embeddings)