def change_bart_attention_class()

in optimum/graphcore/models/bart/modeling_bart.py [0:0]


    def change_bart_attention_class(self, restore: bool, **kwargs):
        """Changes the attention layers to either use the original BartAttention forward or
        BartAttentionWithoutException forward.

        Args:
            restore: whether to restore the attention layers to their original version or not.
        """
        use_cache = kwargs.get("use_cache", False)
        batch_size = kwargs.get("batch_size", 1)
        max_length = kwargs.get("max_length", 128)
        num_beams = kwargs.get("num_beams", 1)

        for encoder_layer in self.encoder.layers:
            if restore:
                encoder_layer.self_attn = encoder_layer.self_attn.to_model(BartAttention)
                continue

            encoder_layer.self_attn = IPUBartAttention.from_model(
                encoder_layer.self_attn,
                use_cache=False,
            )

        for decoder_layer in self.decoder.layers:
            if restore:
                decoder_layer.self_attn = decoder_layer.self_attn.to_model(BartAttention)
                decoder_layer.encoder_attn = decoder_layer.encoder_attn.to_model(BartAttention)
                continue

            decoder_layer.self_attn = IPUBartAttention.from_model(
                decoder_layer.self_attn,
                use_cache=use_cache,
                batch_size=batch_size,
                max_length=max_length,
                num_beams=num_beams,
                dtype=decoder_layer.self_attn.k_proj.weight.dtype,
            )
            decoder_layer.encoder_attn = IPUBartAttention.from_model(
                decoder_layer.encoder_attn,
                use_cache=False,
            )