in optimum/executorch/attentions/custom_kv_cache.py [0:0]
def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype):
"""
Helper function to recursively replace KV caches in the module.
Args:
module: The module to modify
config: The model configuration
Returns:
The modified module
"""
# Check if module has static_cache (TorchExportableModuleWithStaticCache)
if hasattr(module, "static_cache"):
assert isinstance(module.static_cache, StaticCache), f"Expected StaticCache, got {type(module.static_cache)}"
# TODO: Add replace_cache to exported module
# in transformer's executorch.py
if getattr(module, "replace_cache", None) is not None:
static_cache = ETCustomStaticCache(
config=config,
max_batch_size=generation_config.cache_config.batch_size,
max_cache_len=generation_config.cache_config.max_cache_len,
device=generation_config.cache_config.device,
dtype=cache_dtype,
)
module.replace_cache(static_cache)
else:
module.static_cache = ETCustomStaticCache(
config=config,
max_batch_size=generation_config.cache_config.batch_size,
max_cache_len=generation_config.cache_config.max_cache_len,
device=generation_config.cache_config.device,
dtype=cache_dtype,
)
# Dont know why we need to this even though
# CustomKVCache registers the attributes
for i in range(len(module.static_cache.kv_cache)):
setattr(module, f"key_cache_{i}", module.static_cache.kv_cache[i].k_cache)
setattr(module, f"value_cache_{i}", module.static_cache.kv_cache[i].v_cache)
# Check if module has cache (TorchExportableModuleWithHybridCache)
elif hasattr(module, "cache"):
assert isinstance(module.cache, HybridCache), f"Expected HybridCache, got {type(module.cache)}"
# Replace with ETCustomHybridCache
if getattr(module, "replace_cache", None) is not None:
hybrid_cache = ETCustomHybridCache(
config=config,
max_batch_size=generation_config.cache_config.batch_size,
max_cache_len=generation_config.cache_config.max_cache_len,
device=generation_config.cache_config.device,
dtype=cache_dtype,
)
module.replace_cache(hybrid_cache)
else:
module.cache = ETCustomHybridCache(
config=config,
max_batch_size=generation_config.cache_config.batch_size,
max_cache_len=generation_config.cache_config.max_cache_len,
device=generation_config.cache_config.device,
dtype=cache_dtype,
)
# Register cache attributes for each layer
for i in range(len(module.cache.kv_cache)):
setattr(module, f"key_cache_{i}", module.cache.kv_cache[i].k_cache)
setattr(module, f"value_cache_{i}", module.cache.kv_cache[i].v_cache)
if module.cache.is_sliding[i]:
# Register cache_positions as buffer for sliding window layers
# This prevents it from being traced as a constant
module.register_buffer(
f"cache_positions_{i}",
module.cache.kv_cache[i].cache_positions_manager.cache_positions,
persistent=False,
)
else:
raise ValueError(
"Module must have either 'static_cache' (TorchExportableModuleWithStaticCache) "
"or 'cache' (TorchExportableModuleWithHybridCache) attribute"
)
return module