def change_attention_class()

in optimum/graphcore/models/t5/modeling_t5.py [0:0]


    def change_attention_class(self, restore=False, **kwargs):
        """Changes the attention layers to either use the original T5Attention forward
        or IPUT5Attention forward.

        Args:
            restore (bool, optional): whether to restore the attention layers to their original version or not. Defaults to False.
        """
        use_cache = kwargs.get("use_cache", False)
        use_cross_cache = kwargs.get("use_cross_cache", False)
        batch_size = kwargs.get("batch_size", 1)
        max_length = kwargs.get("max_length", 128)
        encoder_max_length = kwargs.get("encoder_max_length", 1500)
        num_beams = kwargs.get("num_beams", 1)

        for layer in self.encoder.block:
            if restore:
                layer.layer[0].SelfAttention = layer.layer[0].SelfAttention.to_model(T5Attention)
                continue

            layer.layer[0].SelfAttention = IPUT5Attention.from_model(
                layer.layer[0].SelfAttention,
                use_cache=False,
            )

        for layer in self.decoder.block:
            if restore:
                layer.layer[0].SelfAttention = layer.layer[0].SelfAttention.to_model(T5Attention)
                layer.layer[1].EncDecAttention = layer.layer[1].EncDecAttention.to_model(T5Attention)
                continue

            layer.layer[0].SelfAttention = IPUT5Attention.from_model(
                layer.layer[0].SelfAttention,
                use_cache=use_cache,
                use_cross_cache=False,
                batch_size=batch_size,
                max_length=max_length,
                num_beams=num_beams,
                num_heads=layer.layer[0].SelfAttention.n_heads,
                head_dim=layer.layer[0].SelfAttention.key_value_proj_dim,
                dtype=layer.layer[0].SelfAttention.k.weight.dtype,
            )

            layer.layer[1].EncDecAttention = IPUT5Attention.from_model(
                layer.layer[1].EncDecAttention,
                use_cache=False,
                use_cross_cache=use_cross_cache,
                batch_size=batch_size,
                encoder_max_length=encoder_max_length,
                num_beams=num_beams,
                num_heads=layer.layer[1].EncDecAttention.n_heads,
                head_dim=layer.layer[1].EncDecAttention.key_value_proj_dim,
                dtype=layer.layer[1].EncDecAttention.k.weight.dtype,
            )