in parler_tts/modeling_parler_tts.py [0:0]
def _get_initial_cache_position(self, input_ids, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
if "inputs_embeds" in model_kwargs:
cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
else:
cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
past_length = 0
if model_kwargs.get("past_key_values") is not None:
cache = model_kwargs["past_key_values"]
past_length = 0
if not isinstance(cache, Cache):
past_length = cache[0][0].shape[2]
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
past_length = cache.get_seq_length()
# TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty,
# end-to-end compilation will yield bad results because `cache_position` will be incorrect.
if not is_torchdynamo_compiling():
cache_position = cache_position[past_length:]
model_kwargs["cache_position"] = cache_position
return model_kwargs