in optimum/graphcore/models/whisper/modeling_whisper.py [0:0]
def change_attention_class(self, restore=False, **kwargs):
"""Change the attention layers to support a KV cache.
Args:
restore: whether to restore the attention layers to their original version or not.
"""
batch_size = kwargs.get("batch_size", 1)
num_beams = kwargs.get("num_beams", 1)
use_cache = kwargs.get("use_cache", False)
max_length = kwargs.get("max_length", 448)
use_cross_cache = kwargs.get("use_cross_cache", False)
encoder_max_length = kwargs.get("encoder_max_length", 1500)
batch_serialization_factor = kwargs.get("batch_serialization_factor", 1)
sequence_serialization_factor = kwargs.get("sequence_serialization_factor", 1)
for encoder_layer in self.model.encoder.layers:
if restore:
encoder_layer.self_attn = encoder_layer.self_attn.to_model(WhisperAttention)
continue
encoder_layer.self_attn = IPUWhisperAttention.from_model(
encoder_layer.self_attn,
use_cache=False,
batch_serialization_factor=batch_serialization_factor,
sequence_serialization_factor=sequence_serialization_factor,
)
for decoder_layer in self.model.decoder.layers:
if restore:
decoder_layer.self_attn = decoder_layer.self_attn.to_model(WhisperAttention)
decoder_layer.encoder_attn = decoder_layer.encoder_attn.to_model(WhisperAttention)
continue
decoder_layer.self_attn = IPUWhisperAttention.from_model(
decoder_layer.self_attn,
use_cache=use_cache,
use_cross_cache=False,
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
dtype=decoder_layer.self_attn.k_proj.weight.dtype,
)
decoder_layer.encoder_attn = IPUWhisperAttention.from_model(
decoder_layer.encoder_attn,
use_cache=False,
use_cross_cache=use_cross_cache,
batch_size=batch_size,
encoder_max_length=encoder_max_length,
num_beams=num_beams,
dtype=decoder_layer.encoder_attn.k_proj.weight.dtype,
)