modules/SwissArmyTransformer/sat/model/official/cait_model.py (139 lines of code) (raw):
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from sat.model.base_model import BaseMixin, BaseModel, non_conflict
from sat.model.official.vit_model import ViTModel, ClsMixin
from sat.model.mixins import BaseMixin
from sat import mpu
class AttnMixin(BaseMixin):
def __init__(self, num_heads, num_layers):
super().__init__()
self.num_layers = num_layers
self.proj_l = nn.ModuleList([nn.Linear(num_heads, num_heads) for i in range(num_layers)])
self.proj_w = nn.ModuleList([nn.Linear(num_heads, num_heads) for i in range(num_layers)])
def attention_fn(self, query_layer, key_layer, value_layer, attention_mask,
attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs):
# adapted from https://github.com/THUDM/sat/blob/main/sat/mpu/transformer.py#L47
if scaling_attention_score:
query_layer = query_layer * (query_layer.shape[-1]**-0.5) # / math.sqrt(query_layer.shape[-1])
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if log_attention_weights is not None:
attention_scores += log_attention_weights
attention_scores = self.proj_l[kwargs['layer_id']](attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
if not (attention_mask.shape[-2] == 1 and (attention_mask > 0).all()):
# if auto-regressive, skip
attention_scores = torch.mul(attention_scores, attention_mask) - \
10000.0 * (1.0 - attention_mask)
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = self.proj_w[kwargs['layer_id']](attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
if attention_dropout is not None:
if mpu.get_cuda_rng_tracker is not None:
with mpu.get_cuda_rng_tracker().fork():
attention_probs = attention_dropout(attention_probs)
else:
attention_probs = attention_dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
return context_layer
def reinit(self, parent_model=None):
# init with identity matrix so that pretrained weights with standard_attn can be reused
for i in range(self.num_layers):
nn.init.eye_(self.proj_l[i].weight)
nn.init.eye_(self.proj_w[i].weight)
class EncForward(BaseMixin):
def __init__(self, dim, num_layers, init_values=1e-4):
super().__init__()
self.gamma_1 = nn.ParameterList([nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) for i in range(num_layers)])
self.gamma_2 = nn.ParameterList([nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) for i in range(num_layers)])
def layer_forward(self, hidden_states, mask, *args, **kw_args):
layer = self.transformer.layers[kw_args['layer_id']]
# Layer norm at the begining of the transformer layer.
layernorm_output1 = layer.input_layernorm(hidden_states)
# Self attention.
attention_output = layer.attention(layernorm_output1, mask, **kw_args)
# Residual connection.
layernorm_input = hidden_states + self.gamma_1[kw_args['layer_id']] * attention_output
# Layer norm post the self attention.
layernorm_output = layer.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = layer.mlp(layernorm_output, **kw_args)
# Second residual connection.
output = layernorm_input + self.gamma_2[kw_args['layer_id']] * mlp_output
return output
from sat.model.transformer import standard_attention
from sat.mpu.utils import split_tensor_along_last_dim
class DecForward(BaseMixin):
def __init__(self, dim, num_layers, init_values=1e-4):
super().__init__()
self.gamma_1 = nn.ParameterList([nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) for i in range(num_layers)])
self.gamma_2 = nn.ParameterList([nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) for i in range(num_layers)])
def position_embedding_forward(self, position_ids, **kwargs):
return 0
def layer_forward(self, hidden_states, mask, *args, **kw_args):
'''
hidden_states: [batch, seq_len, hidden_size]
mask: [(1, 1), seq_len, seq_len]
'''
layer = self.transformer.layers[kw_args['layer_id']]
encoder_outputs = kw_args['encoder_outputs']
assert encoder_outputs is not None
# Layer norm at the begining of the transformer layer.
u = torch.cat([hidden_states, encoder_outputs], 1)
layernorm_output1 = layer.input_layernorm(u)
assert 'cross_attention_mask' in kw_args
# Cross attention
attention_output = layer.cross_attention(layernorm_output1, **kw_args)
# Residual connection.
layernorm_input = hidden_states + self.gamma_1[kw_args['layer_id']] * attention_output
# Layer norm post the cross attention
layernorm_output = layer.post_cross_attention_layernorm(layernorm_input)
# MLP.
mlp_output = layer.mlp(layernorm_output, **kw_args)
# Second residual connection.
output = layernorm_input + self.gamma_2[kw_args['layer_id']] * mlp_output
return output
def cross_attention_forward(self, hidden_states, cross_attention_mask, encoder_outputs, **kw_args):
# adapted from https://github.com/THUDM/sat/blob/d8c9d1e0a9bb2af1e1d26a68b35f16d84aafcc2f/sat/mpu/transformer.py#L216
# if you want to use a customized attention_fn, just inherit the attention mixin for this mixin and use self.attention_fn instead of self.hooks['attention_fn']
layer = self.transformer.layers[kw_args['layer_id']].cross_attention
attention_fn = standard_attention
mixed_query_layer = layer.query(hidden_states[:, :1])
mixed_x_layer = layer.key_value(hidden_states)
(mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 2)
dropout_fn = layer.attention_dropout if layer.training else None
# Reshape and transpose [b, np, s, hn]
query_layer = layer._transpose_for_scores(mixed_query_layer)
key_layer = layer._transpose_for_scores(mixed_key_layer)
value_layer = layer._transpose_for_scores(mixed_value_layer)
context_layer = attention_fn(query_layer, key_layer, value_layer, cross_attention_mask, dropout_fn,
cross_attention=True, **kw_args)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (layer.hidden_size_per_partition,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
# Output. [b, s, h]
output = layer.dense(context_layer)
if layer.training:
output = layer.output_dropout(output)
return output
class CaiTEncoder(ViTModel):
def __init__(self, args, transformer=None, layernorm_epsilon=1e-6, use_final_layernorm=False):
super().__init__(args, transformer=transformer, layernorm_epsilon=layernorm_epsilon, use_final_layernorm=use_final_layernorm)
self.del_mixin('cls')
self.add_mixin('attn', AttnMixin(args.num_attention_heads, args.num_layers))
self.add_mixin('enc_forward', EncForward(args.hidden_size, args.num_layers, init_values=args.init_scale))
@classmethod
def add_model_specific_args(cls, parser):
group = parser.add_argument_group('CaiT-enc', 'CaiT encoder Configurations')
group.add_argument('--init-scale', type=float, default=1e-4)
return super().add_model_specific_args(parser)
class CaiTDecoder(BaseModel):
def __init__(self, args, transformer=None, layernorm_epsilon=1e-6):
super().__init__(args, is_decoder=True, transformer=transformer, layernorm_epsilon=layernorm_epsilon)
self.add_mixin('cls', ClsMixin(args.hidden_size, args.num_classes))
self.add_mixin('dec_forward', DecForward(args.hidden_size, args.num_layers, init_values=args.init_scale))
@classmethod
def add_model_specific_args(cls, parser):
return super().add_model_specific_args(parser)
from sat.model import EncoderDecoderModel
import argparse
class CaiT(EncoderDecoderModel):
def __init__(self, args, transformer=None, layernorm_epsilon=1e-6):
encoder = CaiTEncoder(args, transformer=transformer, layernorm_epsilon=layernorm_epsilon)
dec_args = argparse.Namespace(**vars(args))
# dec_args.enc_hidden_size = dec_args.hidden_size # used for cross attn
override_attrs = ['num_layers', 'hidden_size', 'num_attention_heads', 'layernorm_order',
'max_sequence_length', 'inner_hidden_size', 'hidden_size_per_attention_head']
for name in override_attrs:
dec_attr = getattr(dec_args, 'dec_' + name, None)
if dec_attr is not None: # else use encoder-config
setattr(dec_args, name, dec_attr)
decoder = CaiTDecoder(dec_args, transformer=transformer, layernorm_epsilon=layernorm_epsilon)
super().__init__(args, encoder=encoder, decoder=decoder)
def forward(self, input_ids, enc_position_ids, dec_position_ids, *, enc_attention_mask=None, dec_attention_mask=None, cross_attention_mask=None, **kw_args):
# Please use self.decoder for auto-regressive generation.
if enc_attention_mask is None:
enc_attention_mask = torch.ones(1, 1, dtype=self.encoder.transformer.word_embeddings.weight.dtype, device=input_ids.device)
if cross_attention_mask is None:
cross_attention_mask = enc_attention_mask
encoder_outputs = self.encode(input_ids, enc_position_ids, enc_attention_mask, **kw_args)
decoder_outputs, *mems = self.decode(input_ids, dec_position_ids, dec_attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args)
return (encoder_outputs, decoder_outputs, *mems)