in optimum/intel/ipex/modeling_base.py [0:0]
def generate(self, *args, **kwargs):
if self._add_patch and kwargs.get("assistant_model", None):
raise ValueError(
f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
)
# Patch functions to support ipex_paged cache
if self._add_patch:
transformers.generation.utils.NEED_SETUP_CACHE_CLASSES_MAPPING["ipex_paged"] = IPEXPagedCache
self.generation_config.cache_implementation = "ipex_paged"
if is_transformers_version(">=", "4.45.0"):
if "ipex_paged" not in transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS:
transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append("ipex_paged")
if kwargs.get("generation_config", None):
# Change cache implementation temporarily
orig_cache_implementation = kwargs["generation_config"].cache_implementation
kwargs["generation_config"].cache_implementation = "ipex_paged"
if self._add_patch and kwargs.get("assistant_model", None):
transformers.generation.utils._crop_past_key_values = _ipex_crop_past_key_values
elif self._add_patch:
transformers.generation.candidate_generator._crop_past_key_values = _ipex_crop_past_key_values
try:
result = super().generate(*args, **kwargs)
except Exception as e:
transformers.generation.utils._crop_past_key_values = _crop_past_key_values
transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values
raise e
if self._add_patch and kwargs.get("assistant_model", None):
transformers.generation.utils._crop_past_key_values = _crop_past_key_values
transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values
# change back cache_implementation
if self._add_patch and kwargs.get("generation_config", None):
kwargs["generation_config"].cache_implementation = orig_cache_implementation
return result