modules/SwissArmyTransformer/sat/model/official/t5_model.py (222 lines of code) (raw):

import math import torch import torch.nn.functional as F from sat.model.mixins import BaseMixin from sat.model.encoder_decoder_model import EncoderDecoderModel from sat.model.base_model import non_conflict from sat.mpu import get_model_parallel_world_size from sat.model.transformer import standard_attention, SelfAttention, CrossAttention, MLP from sat.mpu.mappings import copy_to_model_parallel_region from sat.mpu.utils import divide, split_tensor_along_last_dim, unscaled_init_method from sat.mpu.layers import ColumnParallelLinear, VocabParallelEmbedding class T5PositionEmbeddingMixin(BaseMixin): def position_embedding_forward(self, position_ids, **kw_args): return None class T5LayerNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-6): """ Construct a layernorm module in the T5 style No bias and no subtraction of mean. """ super().__init__() self.weight = torch.nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): # layer norm should always be calculated in float32 variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) # convert into float16 or bfloat16 if necessary if self.weight.dtype == torch.float16: hidden_states = hidden_states.to(torch.float16) elif self.weight.dtype == torch.bfloat16: hidden_states = hidden_states.to(torch.bfloat16) return self.weight * hidden_states class T5AttentionMixin(BaseMixin): def __init__(self, relative_attention_num_buckets, num_attention_heads, is_decoder=False): super().__init__() self.relative_attention_num_buckets = relative_attention_num_buckets world_size = get_model_parallel_world_size() self.num_attention_heads_per_partition = divide(num_attention_heads, world_size) self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_attention_heads_per_partition) self.is_decoder = is_decoder @staticmethod def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): """ Adapted from Mesh Tensorflow: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 Translate relative position to a bucket number for relative attention. The relative position is defined as memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should allow for more graceful generalization to longer sequences than the model has been trained on Args: relative_position: an int32 Tensor bidirectional: a boolean - whether the attention is bidirectional num_buckets: an integer max_distance: an integer Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ relative_buckets = 0 if bidirectional: num_buckets //= 2 relative_buckets += (relative_position > 0).to(torch.long) * num_buckets relative_position = torch.abs(relative_position) else: relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions max_exact = num_buckets // 2 is_small = relative_position < max_exact # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance relative_postion_if_large = max_exact + ( torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).to(torch.long) relative_postion_if_large = torch.min( relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) ) relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) return relative_buckets def compute_bias(self, query_length, key_length): """Compute binned relative position bias""" context_position = torch.arange(query_length, dtype=torch.long)[:, None] memory_position = torch.arange(key_length, dtype=torch.long)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=(not self.is_decoder), num_buckets=self.relative_attention_num_buckets, ) relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) # shape (query_length, key_length, num_heads) values = self.relative_attention_bias(relative_position_bucket) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values @non_conflict def attention_fn(self, q, k, v, mask, dropout_fn, position_bias=None, old_impl=standard_attention, cross_attention=False, **kw_args): log_attention_weights = None if not cross_attention: if position_bias is None: seq_length = q.size(2) key_length = k.size(2) position_bias = self.compute_bias(key_length, key_length) position_bias = position_bias[:, :, -seq_length:, :] kw_args['output_cross_layer']['position_bias'] = position_bias log_attention_weights = position_bias return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, position_bias=position_bias, log_attention_weights=log_attention_weights, scaling_attention_score=False, **kw_args) class T5DecoderFinalMixin(BaseMixin): def __init__(self, vocab_size, hidden_size, tie_word_embeddings=True): super().__init__() self.hidden_size = hidden_size self.tie_word_embeddings = tie_word_embeddings if not tie_word_embeddings: self.lm_head = VocabParallelEmbedding( vocab_size, hidden_size, init_method=unscaled_init_method(0.02)) def final_forward(self, logits, **kwargs): logits_parallel = copy_to_model_parallel_region(logits) if self.tie_word_embeddings: logits_parallel = logits_parallel * (self.hidden_size ** -0.5) logits_parallel = F.linear(logits_parallel, self.transformer.word_embeddings.weight) else: logits_parallel = F.linear(logits_parallel, self.lm_head.weight) return logits_parallel def t5_gelu(x): """ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 """ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) class T5GatedGeluMLPMixin(BaseMixin): def __init__(self, num_layers, hidden_size, inner_hidden_size=None, bias=True, init_method_std=0.02): super().__init__() self.hidden_size = hidden_size if inner_hidden_size is None: inner_hidden_size = 4 * hidden_size self.inner_hidden_size = inner_hidden_size self.init_method_std = init_method_std self.gated_h_to_4h_list = torch.nn.ModuleList([ ColumnParallelLinear( self.hidden_size, self.inner_hidden_size, gather_output=False, init_method=self._init_weights, bias=bias, module=self, name="gated_h_to_4h" ) for layer_id in range(num_layers)]) def _init_weights(self, weight, **kwargs): torch.nn.init.normal_(weight, mean=0, std=self.init_method_std * (self.hidden_size ** -0.5)) def mlp_forward(self, hidden_states, layer_id=None, **kw_args): mlp_module = self.transformer.layers[layer_id].mlp hidden_gelu = t5_gelu(mlp_module.dense_h_to_4h(hidden_states)) hidden_linear = self.gated_h_to_4h_list[layer_id](hidden_states) hidden_states = hidden_gelu * hidden_linear output = mlp_module.dense_4h_to_h(hidden_states) if self.training: output = mlp_module.dropout(output) return output class T5Model(EncoderDecoderModel): def __init__(self, args, **kwargs): self.init_method_std = args.init_method_std super().__init__(args, tie_word_embeddings=True, **kwargs, use_bias=False, layernorm=T5LayerNorm, activation_func=torch.nn.functional.relu, init_method=self._init_weights) self.encoder.add_mixin( "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, args.num_attention_heads) ) self.encoder.add_mixin( "t5-position", T5PositionEmbeddingMixin() ) del self.encoder.transformer.position_embeddings num_attention_heads = args.dec_num_attention_heads if args.dec_num_attention_heads is not None else args.num_attention_heads self.decoder.add_mixin( "t5-attention", T5AttentionMixin(args.relative_attention_num_buckets, num_attention_heads, is_decoder=True) ) self.decoder.add_mixin( "t5-position", T5PositionEmbeddingMixin() ) self.decoder.add_mixin( "t5-final", T5DecoderFinalMixin(args.vocab_size, args.hidden_size, tie_word_embeddings=not args.no_share_embeddings) ) del self.decoder.transformer.position_embeddings if args.gated_gelu_mlp: self.encoder.add_mixin( "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std, inner_hidden_size=args.inner_hidden_size, bias=False) ) self.decoder.add_mixin( "gated-mlp", T5GatedGeluMLPMixin(args.num_layers, args.hidden_size, init_method_std=self.init_method_std, inner_hidden_size=args.inner_hidden_size, bias=False) ) def _init_weights(self, weight, module, name): init_method_std = self.init_method_std if isinstance(module, MLP): if name == "dense_h_to_4h": torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5)) elif name == "dense_4h_to_h": torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5)) else: raise NotImplementedError(name) elif isinstance(module, SelfAttention): if name == "query_key_value": torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5)) torch.nn.init.normal_(weight[:module.inner_hidden_size], mean=0, std=init_method_std * ( (module.hidden_size * module.hidden_size_per_attention_head) ** -0.5)) elif name == "dense": torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5)) else: raise NotImplementedError(name) elif isinstance(module, CrossAttention): if name == "query": torch.nn.init.normal_(weight, mean=0, std=init_method_std * ( (module.hidden_size * module.hidden_size_per_attention_head) ** -0.5)) elif name == "key_value": torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.hidden_size ** -0.5)) elif name == "dense": torch.nn.init.normal_(weight, mean=0, std=init_method_std * (module.inner_hidden_size ** -0.5)) else: raise NotImplementedError(name) else: raise NotImplementedError(module) @classmethod def add_model_specific_args(cls, parser): super().add_model_specific_args(parser) parser.add_argument("--relative-attention-num-buckets", type=int, default=None) parser.add_argument("--init-method-std", type=float, default=0.02) parser.add_argument("--gated-gelu-mlp", action='store_true') parser.add_argument("--no-share-embeddings", action='store_true') def encode(self, input_ids, attention_mask=None, **kw_args): return super().encode(input_ids, None, attention_mask, **kw_args) def decode(self, input_ids, attention_mask=None, encoder_outputs=None, cross_attention_mask=None, **kw_args): return super().decode(input_ids, None, attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args) def forward(self, enc_input_ids, dec_input_ids, *, enc_attention_mask=None, dec_attention_mask=None, cross_attention_mask=None, **kw_args): batch_size, seq_length = enc_input_ids.size()[:2] if enc_attention_mask is None: enc_attention_mask = torch.ones(1, 1, 1, seq_length, dtype=self.encoder.transformer.word_embeddings.weight.dtype, device=enc_input_ids.device) if cross_attention_mask is None: cross_attention_mask = enc_attention_mask encoder_outputs = self.encode(enc_input_ids, enc_attention_mask, **kw_args) decoder_outputs, *mems = self.decode(dec_input_ids, dec_attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask, **kw_args) return (encoder_outputs, decoder_outputs, *mems)