in optimum/graphcore/models/bart/modeling_bart.py [0:0]
def change_bart_attention_class(self, restore: bool, **kwargs):
"""Changes the attention layers to either use the original BartAttention forward or
BartAttentionWithoutException forward.
Args:
restore: whether to restore the attention layers to their original version or not.
"""
use_cache = kwargs.get("use_cache", False)
batch_size = kwargs.get("batch_size", 1)
max_length = kwargs.get("max_length", 128)
num_beams = kwargs.get("num_beams", 1)
for encoder_layer in self.encoder.layers:
if restore:
encoder_layer.self_attn = encoder_layer.self_attn.to_model(BartAttention)
continue
encoder_layer.self_attn = IPUBartAttention.from_model(
encoder_layer.self_attn,
use_cache=False,
)
for decoder_layer in self.decoder.layers:
if restore:
decoder_layer.self_attn = decoder_layer.self_attn.to_model(BartAttention)
decoder_layer.encoder_attn = decoder_layer.encoder_attn.to_model(BartAttention)
continue
decoder_layer.self_attn = IPUBartAttention.from_model(
decoder_layer.self_attn,
use_cache=use_cache,
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
dtype=decoder_layer.self_attn.k_proj.weight.dtype,
)
decoder_layer.encoder_attn = IPUBartAttention.from_model(
decoder_layer.encoder_attn,
use_cache=False,
)