def get_batch_generation()

in lama/modules/elmo_connector.py [0:0]


    def get_batch_generation(self, sentences_list, logger= None,
                             try_cuda=True):
        
        if not sentences_list:
            return None
        if try_cuda:
            self.try_cuda()

        tokenized_text_list = []
        for sentences in sentences_list:
            tokenized_text_list.append(get_text(sentences))

        if logger is not None:
            logger.debug("\n{}\n".format(tokenized_text_list))

        # look for masked indices
        masked_indices_list = []
        for tokenized_text in tokenized_text_list:
            masked_indices = []
            for i in range(len(tokenized_text)):
                token = tokenized_text[i]
                if (token == MASK):
                    masked_indices.append(i+1) # to align with the next shift
                    tokenized_text[i] = ELMO_UNK # replace MASK with <unk>
            masked_indices_list.append(masked_indices)

        character_ids = batch_to_ids(tokenized_text_list)
        batch_size = character_ids.shape[0]

        with torch.no_grad():
            
            bilm_input = character_ids.to(self._model_device)
            bilm_output = None
            for _ in range(self.warm_up_cycles):
                '''After loading the pre-trained model, the first few batches will be negatively 
                impacted until the biLM can reset its internal states. 
                You may want to run a few batches through the model to warm up the states before making 
                predictions (although we have not worried about this issue in practice).'''
                bilm_output = self.elmo_lstm(bilm_input)
            
            elmo_activations = bilm_output['activations'][-1].cpu() # last layer

            forward_sequence_output,backward_sequence_output = torch.split(elmo_activations, int(self.hidden_size), dim=-1)

            logits_forward = self.output_layer(forward_sequence_output)
            logits_backward = self.output_layer(backward_sequence_output)

            log_softmax = torch.nn.LogSoftmax(dim=-1)
            log_probs_forward = log_softmax(logits_forward)
            log_probs_backward = log_softmax(logits_backward)

        pad = torch.zeros([batch_size, 1, len(self.vocab)], dtype=torch.float)

        log_probs_forward_splitted = torch.split(log_probs_forward, 1, dim=1)
        log_probs_backward_splitted = torch.split(log_probs_backward, 1, dim=1)

        log_probs_forward = torch.cat(list([pad])+list(log_probs_forward_splitted[:-1]), dim=1) # shift forward +1 log_probs_forward
        log_probs_backward = torch.cat(list(log_probs_backward_splitted[1:])+list([pad]), dim=1) # shift backward -1 log_probs_backward

        avg_log_probs = (log_probs_forward + log_probs_backward) / 2

        num_tokens = avg_log_probs.shape[1]

        token_ids_list = []
        for tokenized_text in tokenized_text_list:
            token_ids = self.__get_tokend_ids(" ".join(tokenized_text).strip())
            while len(token_ids) < num_tokens:
                token_ids = np.append(token_ids, self.inverse_vocab[ELMO_END_SENTENCE])
            token_ids_list.append(token_ids)

        return avg_log_probs, token_ids_list, masked_indices_list