in utils_nlp/models/bert/sequence_encoding.py [0:0]
def get_hidden_states(self, text, batch_size=32):
"""Extract the hidden states from the pretrained model
Args:
text: List of documents to extract features from.
batch_size: Batch size, defaults to 32.
Returns:
pd.DataFrame with columns:
text_index (int), token (str), layer_index (int), values (list[float]).
"""
device, num_gpus = get_device(self.num_gpus)
self.model = move_model_to_device(self.model, device)
self.model = parallelize_model(self.model, device, self.num_gpus)
self.model.eval()
tokens = self.tokenizer.tokenize(text)
(
tokens,
input_ids,
input_mask,
input_type_ids,
) = self.tokenizer.preprocess_encoder_tokens(tokens, max_len=self.max_len)
input_ids = torch.tensor(input_ids, dtype=torch.long, device=device)
input_mask = torch.tensor(input_mask, dtype=torch.long, device=device)
input_type_ids = torch.arange(
input_ids.size(0), dtype=torch.long, device=device
)
eval_data = TensorDataset(input_ids, input_mask, input_type_ids)
eval_dataloader = DataLoader(
eval_data, sampler=SequentialSampler(eval_data), batch_size=batch_size
)
hidden_states = {"text_index": [], "token": [], "layer_index": [], "values": []}
for (
input_ids_tensor,
input_mask_tensor,
example_indices_tensor,
) in eval_dataloader:
with torch.no_grad():
all_encoder_layers, _ = self.model(
input_ids_tensor,
token_type_ids=None,
attention_mask=input_mask_tensor,
)
self.embedding_dim = all_encoder_layers[0].size()[-1]
for b, example_index in enumerate(example_indices_tensor):
for (i, token) in enumerate(tokens[example_index.item()]):
for (j, layer_index) in enumerate(self.layer_index):
layer_output = (
all_encoder_layers[int(layer_index)].detach().cpu().numpy()
)
layer_output = layer_output[b]
hidden_states["text_index"].append(example_index.item())
hidden_states["token"].append(token)
hidden_states["layer_index"].append(layer_index)
hidden_states["values"].append(
[round(x.item(), 6) for x in layer_output[i]]
)
# empty cache
del [input_ids_tensor, input_mask_tensor, example_indices_tensor]
torch.cuda.empty_cache()
# empty cache
del [input_ids, input_mask, input_type_ids]
torch.cuda.empty_cache()
return pd.DataFrame.from_dict(hidden_states)