megatron_patch/model/mixtral_bak/transformer/attention.py (322 lines of code) (raw):
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Union
import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.core.transformer.custom_layers.transformer_engine import SplitAlongDim
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.utils import divide
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
@dataclass
class SelfAttentionSubmodules:
linear_qkv: Union[ModuleSpec, type] = None
core_attention: Union[ModuleSpec, type] = None
linear_proj: Union[ModuleSpec, type] = None
@dataclass
class CrossAttentionSubmodules:
linear_q: Union[ModuleSpec, type] = None
linear_kv: Union[ModuleSpec, type] = None
core_attention: Union[ModuleSpec, type] = None
linear_proj: Union[ModuleSpec, type] = None
class Attention(MegatronModule, ABC):
"""Attention layer abstract class.
This layer only contains common modules required for the "self attn" and
"cross attn" specializations.
"""
def __init__(
self,
config: TransformerConfig,
submodules: Union[SelfAttentionSubmodules, CrossAttentionSubmodules],
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
):
super().__init__(config=config)
self.config = config
self.layer_number = layer_number
self.attn_mask_type = attn_mask_type
self.attention_type = attention_type
# For normal attention without groups, num_query_groups == num_attention_heads,
# so these two will be the same
self.query_projection_size = self.config.kv_channels * self.config.num_attention_heads
self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = divide(
self.query_projection_size, self.config.num_attention_heads
)
self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size)
self.num_query_groups_per_partition = divide(self.config.num_query_groups, world_size)
self.core_attention = build_module(
submodules.core_attention,
config=self.config,
layer_number=self.layer_number,
attn_mask_type=self.attn_mask_type,
attention_type=self.attention_type,
)
self.checkpoint_core_attention = self.config.recompute_granularity == 'selective'
# Output.
self.linear_proj = build_module(
submodules.linear_proj,
self.query_projection_size,
self.config.hidden_size,
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=False,
tp_comm_buffer_name='proj',
)
def _checkpointed_attention_forward(
self,
query,
key,
value,
attention_mask,
rotary_pos_emb=None,
attn_mask_type=None,
packed_seq_params=None,
):
"""Forward method with selective activation checkpointing."""
def custom_forward(*inputs):
query = inputs[0]
key = inputs[1]
value = inputs[2]
attention_mask = inputs[3]
attn_mask_type = inputs[5]
attn_mask_type = AttnMaskType(attn_mask_type.item())
output_ = self.core_attention(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)
return output_
if attn_mask_type is None:
attn_mask_type = self.attn_mask_type
attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int)
hidden_states = tensor_parallel.checkpoint(
custom_forward,
False,
query,
key,
value,
attention_mask,
rotary_pos_emb,
attn_mask_type,
)
return hidden_states
def _allocate_memory(self, inference_max_sequence_length, batch_size, dtype):
"""Allocate memory to store kv cache during inference."""
return torch.empty(
inference_max_sequence_length,
batch_size,
self.num_query_groups_per_partition,
self.hidden_size_per_attention_head,
dtype=dtype,
device=torch.cuda.current_device(),
)
def _adjust_key_value_for_inference(self, inference_params, key, value, rotary_pos_emb):
"""
Saves the generated key and value tensors to the end of the buffers in inference_params.
Returns the full size keys and values from the provided inference_params, as well as
adjusted rotary_pos_emb.
Returns a tuple: (key, value, rotary_pos_emb)
"""
attn_mask_type = self.attn_mask_type
if inference_params is None:
return key, value, rotary_pos_emb, attn_mask_type
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step = False
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_length = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, key.dtype
)
inference_value_memory = self._allocate_memory(
inf_max_seq_length, inf_max_batch_size, value.dtype
)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory,
inference_value_memory,
)
is_first_step = True
else:
# Get the pre-allocated buffers for this layer
inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[
self.layer_number
]
attn_mask_type = AttnMaskType.no_mask
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key
inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value
key = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
# adjust the key rotary positional embedding
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
# need to cross check this condition during inference
# if not set_inference_key_value_memory:
if not is_first_step:
# In inference, we compute one token at a time.
# Select the correct positional embedding
# (only the last token in the sequence)
q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end]
else:
# In the first forward pass of inference,
# we use the entire provided prefix.
# q_pos_emb here has the rope embeddings of the entire
# prefix + to-be-generated output so
# we slice to just the prefix.
q_pos_emb = q_pos_emb[:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
return key, value, rotary_pos_emb, attn_mask_type
@abstractmethod
def get_query_key_value_tensors(self, hidden_states, key_value_states):
"""
This method needs to be implemented based on whether the derived class
is "self-attn" or "cross-attn".
"""
def forward(
self,
hidden_states,
attention_mask,
key_value_states=None,
inference_params=None,
rotary_pos_emb=None,
packed_seq_params=None,
):
# hidden_states: [sq, b, h]
# For self attention we just duplicate the rotary_pos_emb if it isn't already
if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = (rotary_pos_emb,) * 2
# =====================
# Query, Key, and Value
# =====================
# Get the query, key and value tensors based on the type of attention -
# self or cross attn.
query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
# ===================================================
# Adjust key, value, and rotary_pos_emb for inference
# ===================================================
key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
inference_params, key, value, rotary_pos_emb
)
if packed_seq_params is not None:
query = query.squeeze(1)
key = key.squeeze(1)
value = value.squeeze(1)
# ================================================
# relative positional embedding (rotary embedding)
# ================================================
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
if packed_seq_params is not None:
cu_seqlens_q = packed_seq_params.cu_seqlens_q
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
else:
cu_seqlens_q = cu_seqlens_kv = None
query = apply_rotary_pos_emb(
query, q_pos_emb, fused=self.config.apply_rope_fusion, cu_seqlens=cu_seqlens_q
)
key = apply_rotary_pos_emb(
key, k_pos_emb, fused=self.config.apply_rope_fusion, cu_seqlens=cu_seqlens_kv
)
# TODO, can apply positional embedding to value_layer so it has
# absolute positional embedding.
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
# ==================================
# core attention computation
# ==================================
if self.checkpoint_core_attention:
core_attn_out = self._checkpointed_attention_forward(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)
else:
core_attn_out = self.core_attention(
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)
if packed_seq_params is not None:
# reshape to same output shape as unpacked case
# (t, np, hn) -> (t, b=1, h=np*hn)
# t is the pack size = sum (sq_i)
# note that batch is a dummy dimension in the packed case
core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.linear_proj(core_attn_out)
return output, bias
class SelfAttention(Attention):
"""Self-attention layer class
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(
self,
config: TransformerConfig,
submodules: SelfAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
attention_type="self",
)
self.linear_qkv = build_module(
submodules.linear_qkv,
self.config.hidden_size,
self.query_projection_size + 2 * self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='qkv',
)
def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
"""
Derives `query`, `key` and `value` tensors from `hidden_states`.
"""
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
mixed_qkv, _ = self.linear_qkv(hidden_states)
# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
new_tensor_shape = mixed_qkv.size()[:-1] + (
self.num_query_groups_per_partition,
(
(self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
* self.hidden_size_per_attention_head
),
)
mixed_qkv = mixed_qkv.view(*new_tensor_shape)
split_arg_list = [
(
self.num_attention_heads_per_partition
// self.num_query_groups_per_partition
* self.hidden_size_per_attention_head
),
self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head,
]
if SplitAlongDim is not None:
# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list,)
else:
# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3,)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
return query, key, value
class CrossAttention(Attention):
"""Cross-attention layer class
Cross-attention layer takes input with size [s, b, h] and context with size
[s, b, h] and returns output of the same size.
"""
def __init__(
self,
config: TransformerConfig,
submodules: CrossAttentionSubmodules,
layer_number: int,
attn_mask_type=AttnMaskType.padding,
):
super().__init__(
config=config,
submodules=submodules,
layer_number=layer_number,
attn_mask_type=attn_mask_type,
attention_type="cross",
)
if self.config.num_query_groups != self.config.num_attention_heads:
raise ValueError(
f"Group query attention is not currently supported in cross attention."
)
assert self.query_projection_size == self.kv_projection_size
self.linear_q = build_module(
submodules.linear_q,
self.config.hidden_size,
self.query_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=False,
is_expert=False,
)
self.linear_kv = build_module(
submodules.linear_kv,
self.config.hidden_size,
2 * self.kv_projection_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=False,
is_expert=False,
)
def get_query_key_value_tensors(self, hidden_states, key_value_states):
"""
Derives `query` tensor from `hidden_states`, and `key`/`value` tensors
from `key_value_states`.
"""
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv, _ = self.linear_kv(key_value_states)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv.size()[:-1] + (
self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head,
)
mixed_kv = mixed_kv.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query, _ = self.linear_q(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query.size()[:-1] + (
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
query = query.view(*new_tensor_shape)
return query, key, value