in parler_tts/modeling_parler_tts.py [0:0]
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
"""
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache.
Returns the resulting cache object.
"""
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
if hasattr(self, "_cache"):
cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
if cache_implementation == "sliding_window":
max_cache_len = min(self.config.sliding_window, max_cache_len)
need_new_cache = (
not hasattr(self, "_cache")
or (not isinstance(cache_to_check, cache_cls))
or cache_to_check.max_batch_size != max_batch_size
or cache_to_check.max_cache_len < max_cache_len
)
if requires_cross_attention_cache and hasattr(self, "_cache"):
need_new_cache = (
need_new_cache
or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1]
)
if need_new_cache:
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
cache_kwargs = {
"config": self.config.decoder,
"max_batch_size": max_batch_size,
"max_cache_len": max_cache_len,
"device": self.device,
"dtype": cache_dtype,
}
self._cache = cache_cls(**cache_kwargs)
if requires_cross_attention_cache:
encoder_kwargs = cache_kwargs.copy()
encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
config_cross_attention_cache = copy.deepcopy(self.config.decoder)
config_cross_attention_cache.update(
{"num_key_value_heads": self.config.decoder.num_cross_attention_key_value_heads}
)
encoder_kwargs["config"] = config_cross_attention_cache
self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs))
else:
self._cache.reset()
return self._cache