def get_hidden_states()

in utils_nlp/models/bert/sequence_encoding.py [0:0]


    def get_hidden_states(self, text, batch_size=32):
        """Extract the hidden states from the pretrained model

        Args:
            text: List of documents to extract features from.
            batch_size: Batch size, defaults to 32.

        Returns:
            pd.DataFrame with columns:
                text_index (int), token (str), layer_index (int), values (list[float]).
        """
        device, num_gpus = get_device(self.num_gpus)
        self.model = move_model_to_device(self.model, device)
        self.model = parallelize_model(self.model, device, self.num_gpus)

        self.model.eval()

        tokens = self.tokenizer.tokenize(text)

        (
            tokens,
            input_ids,
            input_mask,
            input_type_ids,
        ) = self.tokenizer.preprocess_encoder_tokens(tokens, max_len=self.max_len)

        input_ids = torch.tensor(input_ids, dtype=torch.long, device=device)
        input_mask = torch.tensor(input_mask, dtype=torch.long, device=device)
        input_type_ids = torch.arange(
            input_ids.size(0), dtype=torch.long, device=device
        )

        eval_data = TensorDataset(input_ids, input_mask, input_type_ids)
        eval_dataloader = DataLoader(
            eval_data, sampler=SequentialSampler(eval_data), batch_size=batch_size
        )

        hidden_states = {"text_index": [], "token": [], "layer_index": [], "values": []}
        for (
            input_ids_tensor,
            input_mask_tensor,
            example_indices_tensor,
        ) in eval_dataloader:
            with torch.no_grad():
                all_encoder_layers, _ = self.model(
                    input_ids_tensor,
                    token_type_ids=None,
                    attention_mask=input_mask_tensor,
                )
                self.embedding_dim = all_encoder_layers[0].size()[-1]

            for b, example_index in enumerate(example_indices_tensor):
                for (i, token) in enumerate(tokens[example_index.item()]):
                    for (j, layer_index) in enumerate(self.layer_index):
                        layer_output = (
                            all_encoder_layers[int(layer_index)].detach().cpu().numpy()
                        )
                        layer_output = layer_output[b]
                        hidden_states["text_index"].append(example_index.item())
                        hidden_states["token"].append(token)
                        hidden_states["layer_index"].append(layer_index)
                        hidden_states["values"].append(
                            [round(x.item(), 6) for x in layer_output[i]]
                        )

            # empty cache
            del [input_ids_tensor, input_mask_tensor, example_indices_tensor]
            torch.cuda.empty_cache()

        # empty cache
        del [input_ids, input_mask, input_type_ids]
        torch.cuda.empty_cache()

        return pd.DataFrame.from_dict(hidden_states)