def patched_LFA2__init__()

in src/hyperpod_nemo_adapter/patches/patch_llama_flash_attn_cp.py [0:0]


def patched_LFA2__init__(self, *args, **kwargs):
    super(self.__class__, self).__init__(*args, **kwargs)

    # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
    # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
    # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
    self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
    ##### SAGEMAKER Add core attention OF TRANSFORMER ENGINE!
    llama_config = kwargs["config"]
    num_gqa_groups = llama_config.num_key_value_heads
    num_attention_heads = llama_config.num_attention_heads
    kv_channels = llama_config.hidden_size // num_attention_heads

    # Attention.
    self.core_attention = te.attention.DotProductAttention(
        num_attention_heads,
        kv_channels,
        num_gqa_groups=num_gqa_groups,
        attention_dropout=self.attention_dropout if self.training else 0.0,
        qkv_format="sbhd",
        tp_size=1,
        get_rng_state_tracker=None,
        sequence_parallel=False,
        tp_group=None,
        layer_number=self.layer_idx + 1,
        attention_type="self",
    )