in parler_tts/modeling_parler_tts.py [0:0]
def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
bsz, seq_len, _ = input_ids.size()
# Create the position ids from the input token ids.
position_ids = torch.arange(seq_len, device=input_ids.device) + past_key_values_length
# expand embeddings if needed
if seq_len > self.weights.size(0):
self.make_weights(seq_len + self.offset, self.embedding_dim)
return self.weights.index_select(0, position_ids.view(-1)).detach()