in lama/modules/elmo_connector.py [0:0]
def get_batch_generation(self, sentences_list, logger= None,
try_cuda=True):
if not sentences_list:
return None
if try_cuda:
self.try_cuda()
tokenized_text_list = []
for sentences in sentences_list:
tokenized_text_list.append(get_text(sentences))
if logger is not None:
logger.debug("\n{}\n".format(tokenized_text_list))
# look for masked indices
masked_indices_list = []
for tokenized_text in tokenized_text_list:
masked_indices = []
for i in range(len(tokenized_text)):
token = tokenized_text[i]
if (token == MASK):
masked_indices.append(i+1) # to align with the next shift
tokenized_text[i] = ELMO_UNK # replace MASK with <unk>
masked_indices_list.append(masked_indices)
character_ids = batch_to_ids(tokenized_text_list)
batch_size = character_ids.shape[0]
with torch.no_grad():
bilm_input = character_ids.to(self._model_device)
bilm_output = None
for _ in range(self.warm_up_cycles):
'''After loading the pre-trained model, the first few batches will be negatively
impacted until the biLM can reset its internal states.
You may want to run a few batches through the model to warm up the states before making
predictions (although we have not worried about this issue in practice).'''
bilm_output = self.elmo_lstm(bilm_input)
elmo_activations = bilm_output['activations'][-1].cpu() # last layer
forward_sequence_output,backward_sequence_output = torch.split(elmo_activations, int(self.hidden_size), dim=-1)
logits_forward = self.output_layer(forward_sequence_output)
logits_backward = self.output_layer(backward_sequence_output)
log_softmax = torch.nn.LogSoftmax(dim=-1)
log_probs_forward = log_softmax(logits_forward)
log_probs_backward = log_softmax(logits_backward)
pad = torch.zeros([batch_size, 1, len(self.vocab)], dtype=torch.float)
log_probs_forward_splitted = torch.split(log_probs_forward, 1, dim=1)
log_probs_backward_splitted = torch.split(log_probs_backward, 1, dim=1)
log_probs_forward = torch.cat(list([pad])+list(log_probs_forward_splitted[:-1]), dim=1) # shift forward +1 log_probs_forward
log_probs_backward = torch.cat(list(log_probs_backward_splitted[1:])+list([pad]), dim=1) # shift backward -1 log_probs_backward
avg_log_probs = (log_probs_forward + log_probs_backward) / 2
num_tokens = avg_log_probs.shape[1]
token_ids_list = []
for tokenized_text in tokenized_text_list:
token_ids = self.__get_tokend_ids(" ".join(tokenized_text).strip())
while len(token_ids) < num_tokens:
token_ids = np.append(token_ids, self.inverse_vocab[ELMO_END_SENTENCE])
token_ids_list.append(token_ids)
return avg_log_probs, token_ids_list, masked_indices_list