def _validate_kv_cache()

in optimum/graphcore/generation/utils.py [0:0]


    def _validate_kv_cache(self, use_cache, num_beams=1, max_length=128):
        first_call = not hasattr(self, "_poptorch_decoder")

        if use_cache and self.__class__ not in MODELS_SUPPORTING_KV_CACHE:
            if first_call:
                logger.warn(
                    f"{self.__class__} does not support KV caching, but `use_cache=True`. "
                    "Overriding to `use_cache=False`. If your believe your pipelined model "
                    "supports static KV caching, please decorate it using `supports_kv_cache`."
                )
            use_cache = False

        if not use_cache or not first_call:
            return use_cache

        model_has_kv_cache_initialized = any(getattr(m, "kv_cache_initialized", False) for m in self.modules())
        if use_cache and not model_has_kv_cache_initialized:
            raise ValueError(
                f"{self.__class__.__name__} supports KV caching and `use_cache=True`, but no KV caches have been initialized. "
                f"Please pass `use_cache=True` to the `parallelize` method of {self.__class__.__name__}."
            )
        self.kv_cache_enabled = use_cache and model_has_kv_cache_initialized

        if not self.kv_cache_enabled:
            return use_cache

        module_with_cache = next(m for m in self.modules() if getattr(m, "kv_cache_initialized", False))
        cache_shape = module_with_cache._k_cache.shape
        cache_num_beams = module_with_cache._num_beams
        cache_max_length = cache_shape[2]

        generic_kwarg_msg = (
            "KV caches are created with `kwargs` that are directly provided to `parallelize`, or where such "
            "kwargs are missing, we optionally retrieve values from the `model.generation_config`. "
            "On the other hand, `model.generate()` will determine generation kwargs in the priority of "
            "`kwargs` > `kwargs['generation_config']` > `model.generation_config`. "
            "Mismatches between the two flows can be reconciled by ensuring that the kwargs provided to `parallelize` "
            "match the `kwargs` and / or `kwargs['generation_config']` passed to `model.generate()`."
        )
        if cache_num_beams != num_beams:
            raise ValueError(
                f"KV caches were created with num_beams={cache_num_beams}, but `model.generate()` is being called "
                f"with {num_beams=}."
                f"\n{generic_kwarg_msg}"
            )
        if cache_max_length != max_length:
            raise ValueError(
                f"KV caches were created with max_length={cache_max_length}, but `model.generate()` is being called "
                f"with {max_length=}."
                f"\n{generic_kwarg_msg}"
            )

        return use_cache