def _get_cache()

in parler_tts/modeling_parler_tts.py [0:0]


    def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
        """
        Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
        new `generate` call requires a larger cache.

        Returns the resulting cache object.
        """
        cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
        requires_cross_attention_cache = (
            self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
        )

        if hasattr(self, "_cache"):
            cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache

        if cache_implementation == "sliding_window":
            max_cache_len = min(self.config.sliding_window, max_cache_len)

        need_new_cache = (
            not hasattr(self, "_cache")
            or (not isinstance(cache_to_check, cache_cls))
            or cache_to_check.max_batch_size != max_batch_size
            or cache_to_check.max_cache_len < max_cache_len
        )

        if requires_cross_attention_cache and hasattr(self, "_cache"):
            need_new_cache = (
                need_new_cache
                or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1]
            )

        if need_new_cache:
            if hasattr(self.config, "_pre_quantization_dtype"):
                cache_dtype = self.config._pre_quantization_dtype
            else:
                cache_dtype = self.dtype
            cache_kwargs = {
                "config": self.config.decoder,
                "max_batch_size": max_batch_size,
                "max_cache_len": max_cache_len,
                "device": self.device,
                "dtype": cache_dtype,
            }
            self._cache = cache_cls(**cache_kwargs)
            if requires_cross_attention_cache:
                encoder_kwargs = cache_kwargs.copy()
                encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
                config_cross_attention_cache = copy.deepcopy(self.config.decoder)
                config_cross_attention_cache.update(
                    {"num_key_value_heads": self.config.decoder.num_cross_attention_key_value_heads}
                )
                encoder_kwargs["config"] = config_cross_attention_cache
                self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs))
        else:
            self._cache.reset()
        return self._cache