def _compute_embeddings()

in src/rime/models/zero_shot/item_knn.py [0:0]


    def _compute_embeddings(self, titles):
        """ find embedding of a batch of sequences """
        with _to_cuda(self.model) as model:
            embeddings = []

            for batch in tqdm(np.split(titles, range(0, len(titles), self.batch_size)[1:])):
                inputs = self.tokenizer(batch.tolist(), padding=True, return_tensors='pt')
                if hasattr(self.model, 'bert'):
                    for key in inputs.keys():  # 'input_ids', 'attention_mask', 'token_type_ids'
                        inputs[key] = inputs[key][:, :512]
                    offset = 1  # [cls] seq [sep]
                    hidden_states = self.model.bert(**inputs.to(model.device))[0]
                else:
                    offset = 0
                    hidden_states = self.model.transformer(**inputs.to(model.device))[0]

                if self.pooling == 'mean':
                    segments = [slice(offset, n - offset) for n in
                                inputs['attention_mask'].sum(1).tolist()]
                elif self.pooling == 'cls':
                    segments = [slice(0, 1) for _ in inputs['attention_mask']]
                else:
                    raise NotImplementedError

                hidden_states = torch.vstack([  # mean-pooling on causal lm states
                    x[slc].mean(0, keepdims=True) for x, slc in zip(hidden_states, segments)])

                embeddings.append(hidden_states.double().cpu().numpy())

        return np.vstack(embeddings)