def get_batch_generation()

in lama/modules/roberta_connector.py [0:0]


    def get_batch_generation(self, sentences_list, logger=None, try_cuda=True):
        if not sentences_list:
            return None
        if try_cuda:
            self.try_cuda()

        tensor_list = []
        masked_indices_list = []
        max_len = 0
        output_tokens_list = []
        for masked_inputs_list in sentences_list:

            tokens_list = []

            for idx, masked_input in enumerate(masked_inputs_list):

                # 2. sobstitute [MASK] with <mask>
                masked_input = masked_input.replace(MASK, ROBERTA_MASK)

                text_spans = masked_input.split(ROBERTA_MASK)
                text_spans_bpe = (
                    (" {0} ".format(ROBERTA_MASK))
                    .join(
                        [
                            self.bpe.encode(text_span.rstrip())
                            for text_span in text_spans
                        ]
                    )
                    .strip()
                )

                prefix = ""
                if idx == 0:
                    prefix = ROBERTA_START_SENTENCE

                tokens_list.append(
                    self.task.source_dictionary.encode_line(
                        str(prefix + " " + text_spans_bpe).strip(), append_eos=True
                    )
                )

            tokens = torch.cat(tokens_list)[: self.max_sentence_length]
            output_tokens_list.append(tokens.long().cpu().numpy())

            if len(tokens) > max_len:
                max_len = len(tokens)
            tensor_list.append(tokens)
            masked_index = (tokens == self.task.mask_idx).nonzero().numpy()
            for x in masked_index:
                masked_indices_list.append([x[0]])

        pad_id = self.task.source_dictionary.pad()
        tokens_list = []
        for tokens in tensor_list:
            pad_lenght = max_len - len(tokens)
            if pad_lenght > 0:
                pad_tensor = torch.full([pad_lenght], pad_id, dtype=torch.int)
                tokens = torch.cat((tokens, pad_tensor))
            tokens_list.append(tokens)

        batch_tokens = torch.stack(tokens_list)

        with torch.no_grad():
            # with utils.eval(self.model.model):
            self.model.eval()
            self.model.model.eval()
            log_probs, extra = self.model.model(
                batch_tokens.long().to(device=self._model_device),
                features_only=False,
                return_all_hiddens=False,
            )

        return log_probs.cpu(), output_tokens_list, masked_indices_list