in optimum/exporters/executorch/integrations.py [0:0]
def __init__(self, model, max_static_cache_length, batch_size):
super().__init__()
# Get the decoder component
self.decoder = model.get_decoder()
if isinstance(model, WhisperForConditionalGeneration):
self.proj_out = model.proj_out
else:
self.proj_out = model.lm_head
self.config = model.config
# Initialize static cache
self.static_cache = StaticCache(
config=self.config,
max_batch_size=batch_size,
max_cache_len=max_static_cache_length,
device="cpu",
dtype=torch.float32,
)
# Register cache buffers to make them exportable
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)