def _replace_with_et_custom_kv_cache()

in optimum/executorch/attentions/custom_kv_cache.py [0:0]


def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype):
    """
    Helper function to recursively replace KV caches in the module.

    Args:
        module: The module to modify
        config: The model configuration

    Returns:
        The modified module
    """
    # Check if module has static_cache (TorchExportableModuleWithStaticCache)
    if hasattr(module, "static_cache"):
        assert isinstance(module.static_cache, StaticCache), f"Expected StaticCache, got {type(module.static_cache)}"

        # TODO: Add replace_cache to exported module
        # in transformer's executorch.py
        if getattr(module, "replace_cache", None) is not None:
            static_cache = ETCustomStaticCache(
                config=config,
                max_batch_size=generation_config.cache_config.batch_size,
                max_cache_len=generation_config.cache_config.max_cache_len,
                device=generation_config.cache_config.device,
                dtype=cache_dtype,
            )
            module.replace_cache(static_cache)
        else:
            module.static_cache = ETCustomStaticCache(
                config=config,
                max_batch_size=generation_config.cache_config.batch_size,
                max_cache_len=generation_config.cache_config.max_cache_len,
                device=generation_config.cache_config.device,
                dtype=cache_dtype,
            )
            # Dont know why we need to this even though
            # CustomKVCache registers the attributes
            for i in range(len(module.static_cache.kv_cache)):
                setattr(module, f"key_cache_{i}", module.static_cache.kv_cache[i].k_cache)
                setattr(module, f"value_cache_{i}", module.static_cache.kv_cache[i].v_cache)

    # Check if module has cache (TorchExportableModuleWithHybridCache)
    elif hasattr(module, "cache"):
        assert isinstance(module.cache, HybridCache), f"Expected HybridCache, got {type(module.cache)}"

        # Replace with ETCustomHybridCache
        if getattr(module, "replace_cache", None) is not None:
            hybrid_cache = ETCustomHybridCache(
                config=config,
                max_batch_size=generation_config.cache_config.batch_size,
                max_cache_len=generation_config.cache_config.max_cache_len,
                device=generation_config.cache_config.device,
                dtype=cache_dtype,
            )
            module.replace_cache(hybrid_cache)
        else:
            module.cache = ETCustomHybridCache(
                config=config,
                max_batch_size=generation_config.cache_config.batch_size,
                max_cache_len=generation_config.cache_config.max_cache_len,
                device=generation_config.cache_config.device,
                dtype=cache_dtype,
            )
            # Register cache attributes for each layer
            for i in range(len(module.cache.kv_cache)):
                setattr(module, f"key_cache_{i}", module.cache.kv_cache[i].k_cache)
                setattr(module, f"value_cache_{i}", module.cache.kv_cache[i].v_cache)
                if module.cache.is_sliding[i]:
                    # Register cache_positions as buffer for sliding window layers
                    # This prevents it from being traced as a constant
                    module.register_buffer(
                        f"cache_positions_{i}",
                        module.cache.kv_cache[i].cache_positions_manager.cache_positions,
                        persistent=False,
                    )
    else:
        raise ValueError(
            "Module must have either 'static_cache' (TorchExportableModuleWithStaticCache) "
            "or 'cache' (TorchExportableModuleWithHybridCache) attribute"
        )

    return module