in optimum/graphcore/generation/utils.py [0:0]
def _validate_kv_cache(self, use_cache, num_beams=1, max_length=128):
first_call = not hasattr(self, "_poptorch_decoder")
if use_cache and self.__class__ not in MODELS_SUPPORTING_KV_CACHE:
if first_call:
logger.warn(
f"{self.__class__} does not support KV caching, but `use_cache=True`. "
"Overriding to `use_cache=False`. If your believe your pipelined model "
"supports static KV caching, please decorate it using `supports_kv_cache`."
)
use_cache = False
if not use_cache or not first_call:
return use_cache
model_has_kv_cache_initialized = any(getattr(m, "kv_cache_initialized", False) for m in self.modules())
if use_cache and not model_has_kv_cache_initialized:
raise ValueError(
f"{self.__class__.__name__} supports KV caching and `use_cache=True`, but no KV caches have been initialized. "
f"Please pass `use_cache=True` to the `parallelize` method of {self.__class__.__name__}."
)
self.kv_cache_enabled = use_cache and model_has_kv_cache_initialized
if not self.kv_cache_enabled:
return use_cache
module_with_cache = next(m for m in self.modules() if getattr(m, "kv_cache_initialized", False))
cache_shape = module_with_cache._k_cache.shape
cache_num_beams = module_with_cache._num_beams
cache_max_length = cache_shape[2]
generic_kwarg_msg = (
"KV caches are created with `kwargs` that are directly provided to `parallelize`, or where such "
"kwargs are missing, we optionally retrieve values from the `model.generation_config`. "
"On the other hand, `model.generate()` will determine generation kwargs in the priority of "
"`kwargs` > `kwargs['generation_config']` > `model.generation_config`. "
"Mismatches between the two flows can be reconciled by ensuring that the kwargs provided to `parallelize` "
"match the `kwargs` and / or `kwargs['generation_config']` passed to `model.generate()`."
)
if cache_num_beams != num_beams:
raise ValueError(
f"KV caches were created with num_beams={cache_num_beams}, but `model.generate()` is being called "
f"with {num_beams=}."
f"\n{generic_kwarg_msg}"
)
if cache_max_length != max_length:
raise ValueError(
f"KV caches were created with max_length={cache_max_length}, but `model.generate()` is being called "
f"with {max_length=}."
f"\n{generic_kwarg_msg}"
)
return use_cache