in src/hyperpod_nemo_adapter/patches/patch_llama_flash_attn_cp.py [0:0]
def patched_LFA2__init__(self, *args, **kwargs):
super(self.__class__, self).__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
##### SAGEMAKER Add core attention OF TRANSFORMER ENGINE!
llama_config = kwargs["config"]
num_gqa_groups = llama_config.num_key_value_heads
num_attention_heads = llama_config.num_attention_heads
kv_channels = llama_config.hidden_size // num_attention_heads
# Attention.
self.core_attention = te.attention.DotProductAttention(
num_attention_heads,
kv_channels,
num_gqa_groups=num_gqa_groups,
attention_dropout=self.attention_dropout if self.training else 0.0,
qkv_format="sbhd",
tp_size=1,
get_rng_state_tracker=None,
sequence_parallel=False,
tp_group=None,
layer_number=self.layer_idx + 1,
attention_type="self",
)