def forward()

in parler_tts/modeling_parler_tts.py [0:0]


    def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
        bsz, seq_len, _ = input_ids.size()
        # Create the position ids from the input token ids.
        position_ids = torch.arange(seq_len, device=input_ids.device) + past_key_values_length
        # expand embeddings if needed
        if seq_len > self.weights.size(0):
            self.make_weights(seq_len + self.offset, self.embedding_dim)
        return self.weights.index_select(0, position_ids.view(-1)).detach()