modules/SwissArmyTransformer/sat/model/official/mixtral_model.py (81 lines of code) (raw):
from sat.model import BaseMixin, BaseModel
import torch
import torch.nn as nn
from sat.transformer_defaults import attention_fn_default
from sat.mpu.utils import split_tensor_along_last_dim
import torch.nn.functional as F
from sat.mpu import ColumnParallelLinear
from sat.model.position_embedding.triton_rotary_embeddings import FastRotaryEmbedding
class RotaryMixin(BaseMixin):
def __init__(self, hidden_size, num_heads):
super().__init__()
self.rotary_emb = FastRotaryEmbedding(hidden_size // num_heads, base=1000000)
def attention_forward(self, hidden_states, mask, **kw_args):
origin = self
self = self.transformer.layers[kw_args['layer_id']].attention
attention_fn = attention_fn_default
if 'attention_fn' in self.hooks:
attention_fn = self.hooks['attention_fn']
mixed_raw_layer = self.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, self.stride)
dropout_fn = self.attention_dropout if self.training else None
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
query_layer, key_layer = origin.rotary_emb(query_layer, key_layer, kw_args['position_ids'], max_seqlen=kw_args['position_ids'].max()+1, layer_id=kw_args['layer_id'])
context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.dense(context_layer)
if self.training:
output = self.output_dropout(output)
return output
class MixtralMlpMixin(BaseMixin):
def __init__(self, num_layers, in_features, num_experts, num_experts_per_tok):
super().__init__()
self.top_k = num_experts_per_tok
self.gates = nn.ModuleList([nn.Linear(in_features, num_experts, bias=False) for i in range(num_layers)])
def routing_forward(self, hidden_states, **kw_args):
# Adapted from mixtral-8x7b https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gates[kw_args['layer_id']](hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
return routing_weights, selected_experts
class LMMixin(BaseMixin):
def __init__(self, vocab_size, hidden_size):
super().__init__()
self.lm_head = ColumnParallelLinear(
hidden_size,
vocab_size,
gather_output=True,
# init_method=init_method,
bias=False,
# params_dtype=params_dtype,
module=self,
name="lm_head",
# skip_init=skip_init,
# device=device
)
def final_forward(self, logits, **kwargs):
return self.lm_head(logits)
from sat.ops.layernorm import RMSNorm
class MixtralModel(BaseModel):
def __init__(self, args, transformer=None, layernorm=RMSNorm, activation_func=nn.functional.silu, **kwargs):
super().__init__(args, transformer=transformer, layernorm=layernorm, activation_func=activation_func, init_method_std=0.01, **kwargs)
del self.transformer.position_embeddings
if 'inner_hidden_size' not in args:
args.inner_hidden_size = None
self.add_mixin("rotary", RotaryMixin(args.hidden_size, args.num_attention_heads))
self.add_mixin("lm", LMMixin(args.vocab_size, args.hidden_size))
self.add_mixin("mlp", MixtralMlpMixin(args.num_layers, args.hidden_size, args.num_experts, args.num_experts_per_tok))
def position_embedding_forward(self, *args, **kwargs):
return None
@classmethod
def add_model_specific_args(cls, parser):
group = parser.add_argument_group('Mixtral-8x7b', 'Mixtral-8x7b Configurations')
group.add_argument('--bos-token-id', type=int, default=1)
group.add_argument('--eos-token-id', type=int, default=2)
group.add_argument('--num-experts-per-tok', type=int, default=2)
return parser