modules/SwissArmyTransformer/sat/transformer_defaults.py (251 lines of code) (raw):

# coding=utf-8 # -*- encoding: utf-8 -*- ''' @File : transformer_defaults.py @Time : 2022/06/01 21:44:17 @Author : Ming Ding @Contact : dm18@mails.tsinghua.edu.cn ''' import math import torch import torch.nn.functional as F from sat import mpu from sat.mpu.utils import split_tensor_along_last_dim import contextlib def standard_attention(query_layer, key_layer, value_layer, attention_mask, attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs): # We disable the PB-relax-Attention and only changes the order of computation, because it is enough for most of training. # The implementation in the paper can be done very easily, if you really need it to train very deep transformers. if scaling_attention_score: query_layer = query_layer / math.sqrt(query_layer.shape[-1]) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if log_attention_weights is not None: attention_scores += log_attention_weights if not (attention_mask.shape[-2] == 1 and (attention_mask > 0).all()): # if auto-regressive, skip attention_scores = torch.mul(attention_scores, attention_mask) - \ 10000.0 * (1.0 - attention_mask) attention_probs = F.softmax(attention_scores, dim=-1) if attention_dropout is not None: if mpu.get_cuda_rng_tracker is not None: with mpu.get_cuda_rng_tracker().fork(): attention_probs = attention_dropout(attention_probs) else: attention_probs = attention_dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) return context_layer def attention_fn_default(query_layer, key_layer, value_layer, attention_mask, attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs): # expand head dim to query dim, if necessary # only useful for multi-query attention batch_size, num_query_heads = query_layer.shape[:2] # [b, np, s, hn] num_kv_heads = key_layer.shape[1] # [b, np, s, hn] key_layer = key_layer.unsqueeze(2).expand(-1, -1, num_query_heads//num_kv_heads, -1, -1).contiguous().view(batch_size, num_query_heads, *key_layer.shape[2:]) value_layer = value_layer.unsqueeze(2).expand(-1, -1, num_query_heads//num_kv_heads, -1, -1).contiguous().view(batch_size, num_query_heads, *value_layer.shape[2:]) is_low_triangle = (attention_mask == torch.ones_like(attention_mask, dtype=torch.float).tril()).all() is_full = (attention_mask is None) or (attention_mask > 0).all() if int(torch.__version__.split('.')[0]) >= 2 and scaling_attention_score and (is_full or is_low_triangle): # Pytorch 2.0 attention uses very much memory if attention_mask is float, and has NaN bug if attention_mask is None. dropout_p = 0. if attention_dropout is None or not attention_dropout.training else attention_dropout.p if dropout_p > 0 and mpu.get_cuda_rng_tracker is not None: context = mpu.get_cuda_rng_tracker().fork() else: context = contextlib.nullcontext() with context: attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, key_layer, value_layer, attn_mask=None, dropout_p=dropout_p, is_causal=not is_full ) return attn_output else: return standard_attention( query_layer, key_layer, value_layer, attention_mask, attention_dropout=attention_dropout, log_attention_weights=log_attention_weights, scaling_attention_score=scaling_attention_score, **kwargs ) def attention_forward_default(self, hidden_states, mask, **kw_args): self = self.transformer.layers[kw_args['layer_id']].attention attention_fn = attention_fn_default if 'attention_fn' in self.hooks: attention_fn = self.hooks['attention_fn'] mixed_raw_layer = self.query_key_value(hidden_states) (mixed_query_layer, mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, self.stride) dropout_fn = self.attention_dropout if self.training else None query_layer = self._transpose_for_scores(mixed_query_layer) key_layer = self._transpose_for_scores(mixed_key_layer) value_layer = self._transpose_for_scores(mixed_value_layer) # rotary position embedding if self.transformer.is_rotary_emb: query_layer, key_layer = self.transformer.position_embeddings( query_layer, key_layer, kw_args['position_ids'],max_seqlen=kw_args['position_ids'].max()+1, layer_id=kw_args['layer_id'] ) context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) output = self.dense(context_layer) if self.training: output = self.output_dropout(output) return output def cross_attention_forward_default(self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args): self = self.transformer.layers[kw_args['layer_id']].cross_attention attention_fn = attention_fn_default if 'attention_fn' in self.hooks: attention_fn = self.hooks['attention_fn'] mixed_query_layer = self.query(hidden_states) query_layer = self._transpose_for_scores(mixed_query_layer) dropout_fn = self.attention_dropout if self.training else None if isinstance(encoder_outputs, torch.Tensor): mixed_x_layer = self.key_value(encoder_outputs) (mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2) # Reshape and transpose [b, np, s, hn] key_layer = self._transpose_for_scores(mixed_key_layer) value_layer = self._transpose_for_scores(mixed_value_layer) mem_cross = (key_layer, value_layer) else: key_layer, value_layer = encoder_outputs[kw_args['layer_id']] mem_cross = (key_layer, value_layer) context_layer = attention_fn(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn, cross_attention=True, mem_cross=mem_cross, **kw_args) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) # [b, s, hp] context_layer = context_layer.view(*new_context_layer_shape) # Output. [b, s, h] output = self.dense(context_layer) if self.training: output = self.output_dropout(output) return output def routing_forward_default(self, hidden_states, **kw_args): num_experts = self.transformer.num_experts # This is just an example that select 2 experts randomly. batch_size, sequence_length, hidden_dim = hidden_states.shape # router_logits: (batch * sequence_length, n_experts) router_logits = torch.randn((batch_size*sequence_length, num_experts), device=hidden_states.device, dtype=hidden_states.dtype) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) routing_weights, selected_experts = torch.topk(routing_weights, 2, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) return routing_weights, selected_experts from functools import partial def mlp_forward_default(self, hidden_states, expert_id=-1, **kw_args): if self.transformer.num_experts == 1 or expert_id > -1: self = self.transformer.layers[kw_args['layer_id']].mlp suffix = f"_{expert_id}" if expert_id > 0 else "" if self.is_gated_mlp: intermediate_parallel = getattr(self, "dense_h_to_4h"+suffix)(hidden_states) gated_intermediate_parallel = getattr(self, "dense_h_to_4h_gate"+suffix)(hidden_states) intermediate_parallel = self.activation_func(gated_intermediate_parallel) * intermediate_parallel output = getattr(self, "dense_4h_to_h"+suffix)(intermediate_parallel) else: intermediate_parallel = getattr(self, "dense_h_to_4h"+suffix)(hidden_states) intermediate_parallel = self.activation_func(intermediate_parallel) output = getattr(self, "dense_4h_to_h"+suffix)(intermediate_parallel) return output else: mlp_forward = self.hooks.get('mlp_forward', partial(mlp_forward_default, self)) routing_forward = self.hooks.get('routing_forward', partial(routing_forward_default, self)) self = self.transformer.layers[kw_args['layer_id']].mlp fwd_weight, fwd_idx = routing_forward(hidden_states, **kw_args) # Adapted from mixtral-8x7b https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(fwd_idx, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): idx, top_x = torch.where(expert_mask[expert_idx]) if top_x.shape[0] == 0: continue # in torch it is faster to index using lists than torch tensors top_x_list = top_x.tolist() idx_list = idx.tolist() # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[top_x_list] # I don't know why using hidden_states[None, top_x_list].reshape(-1, hidden_dim) current_hidden_states = mlp_forward(current_state, expert_id=expert_idx, **kw_args) * fwd_weight[top_x_list, idx_list, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) output = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return output def word_embedding_forward_default(self, input_ids, output_cross_layer, **kw_args): return self.transformer.word_embeddings(input_ids) def position_embedding_forward_default(self, position_ids, output_cross_layer, **kw_args): if not self.transformer.is_rotary_emb: return self.transformer.position_embeddings(position_ids) return None from sat.mpu import gather_from_model_parallel_region def final_forward_default(self, logits, **kw_args): logits_parallel = F.linear(logits, self.transformer.word_embeddings.weight) if not kw_args['parallel_output']: logits_parallel = gather_from_model_parallel_region(logits_parallel) return logits_parallel def layer_forward_default(self, hidden_states, mask, *args, **kw_args): ''' hidden_states: [batch, seq_len, hidden_size] mask: [(1, 1), seq_len, seq_len] ''' self = self.transformer.layers[kw_args['layer_id']] # Layer norm at the begining of the transformer layer. attention_input = self.input_layernorm(hidden_states) # Self attention. attention_output = self.attention(attention_input, mask, **kw_args) # Third LayerNorm if self.layernorm_order == 'sandwich': attention_output = self.third_layernorm(attention_output) # DropPath for attention if self.training and self.drop_path > 0.: # drop_path percentage 0, others 1/(1-p) random_tensor = (1-self.drop_path + torch.rand((attention_output.shape[0],), dtype=attention_output.dtype, device=attention_output.device)).floor_() / (1-self.drop_path) attention_output = random_tensor.view(-1, 1, 1) * attention_output # Residual connection. if self.layernorm_order == 'post': hidden_states = attention_input + attention_output mlp_input = self.post_attention_layernorm(hidden_states) else: hidden_states = hidden_states + attention_output if self.is_decoder: encoder_outputs = kw_args['encoder_outputs'] if encoder_outputs is not None: assert 'cross_attention_mask' in kw_args # Cross attention if self.layernorm_order == 'post': attention_output = self.cross_attention(mlp_input, **kw_args) # Residual connection. hidden_states = mlp_input + attention_output # Layer norm post the cross attention mlp_input = self.post_cross_attention_layernorm(hidden_states) else: cross_input = self.post_cross_attention_layernorm(hidden_states) attention_output = self.cross_attention(cross_input, **kw_args) hidden_states = hidden_states + attention_output if self.layernorm_order != 'post': mlp_input = self.post_attention_layernorm(hidden_states) # MLP. mlp_output = self.mlp(mlp_input, **kw_args) # Fourth LayerNorm if self.layernorm_order == 'sandwich': mlp_output = self.fourth_layernorm(mlp_output) # DropPath for mlp if self.training and self.drop_path > 0.: random_tensor = (1-self.drop_path + torch.rand((mlp_output.shape[0],), dtype=mlp_output.dtype, device=mlp_output.device)).floor_() / (1-self.drop_path) mlp_output = random_tensor.view(-1, 1, 1) * mlp_output # Second residual connection. if self.layernorm_order == 'post': output = mlp_input + mlp_output else: output = hidden_states + mlp_output return output HOOKS_DEFAULT = { 'attention_fn': attention_fn_default, 'attention_forward': attention_forward_default, 'cross_attention_forward': cross_attention_forward_default, 'routing_forward': routing_forward_default, 'mlp_forward': mlp_forward_default, 'word_embedding_forward': word_embedding_forward_default, 'position_embedding_forward': position_embedding_forward_default, 'final_forward': final_forward_default, 'layer_forward': layer_forward_default } ARGS_DEFAULT = { 'embedding_dropout_prob': ('hidden_dropout', 0), 'attention_dropout_prob': ('attention_dropout', 0), 'output_dropout_prob': ('hidden_dropout', 0), 'inner_hidden_size': ('inner_hidden_size', None), 'hidden_size_per_attention_head': ('hidden_size_per_attention_head', None), 'cross_hidden_size_per_attention_head': ('cross_hidden_size_per_attention_head', None), 'checkpoint_activations': ('checkpoint_activations', False), 'checkpoint_num_layers': ('checkpoint_num_layers', 1), 'checkpoint_skip_layers': ('checkpoint_skip_layers', 0), 'is_decoder': ('is_decoder', False), 'cross_attn_hidden_size': ('cross_attn_hidden_size', None), 'use_final_layernorm': ('use_final_layernorm', True), 'layernorm_epsilon': ('layernorm_epsilon', 1e-5), 'use_bias': ('use_bias', True), 'use_qkv_bias': ('use_qkv_bias', False), 'num_multi_query_heads': ('num_multi_query_heads', 0), 'cross_num_multi_query_heads': ('cross_num_multi_query_heads', 0), 'drop_path': ('drop_path', 0.), 'row_parallel_linear_final_bias': ('row_parallel_linear_final_bias', True), 'is_gated_mlp': ('is_gated_mlp', False), 'is_rotary_emb': ('is_rotary_emb', False), 'parallel_output': ('parallel_output', False), 'num_experts': ('num_experts', 1), } from sat.ops.layernorm import LayerNorm, RMSNorm NO_WD_MODULES = [LayerNorm, torch.nn.LayerNorm, RMSNorm]