megatron_patch/model/falcon/language_model.py (491 lines of code) (raw):

# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from megatron import get_args from megatron.core import mpu, tensor_parallel from megatron.model.enums import AttnMaskType from megatron.model.enums import LayerType from megatron.model.module import MegatronModule from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal from megatron.model.utils import scaled_init_method_normal from .transformer import ParallelTransformer def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): """LM logits using word embedding weights.""" args = get_args() # Parallel logits. if args.async_tensor_model_parallel_allreduce or\ args.sequence_parallel: input_parallel = input_ model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \ model_parallel and not args.sequence_parallel else: input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region( input_) async_grad_allreduce = False # Matrix multiply. logits_parallel = \ tensor_parallel.linear_with_grad_accumulation_and_async_allreduce( input=input_parallel, weight=word_embeddings_weight, bias=bias, gradient_accumulation_fusion=args.gradient_accumulation_fusion, async_grad_allreduce=async_grad_allreduce, sequence_parallel_enabled=args.sequence_parallel) # Gather if needed. if parallel_output: return logits_parallel return tensor_parallel.gather_from_tensor_model_parallel_region( logits_parallel) def get_language_model(num_tokentypes, add_pooler, encoder_attn_mask_type, init_method=None, scaled_init_method=None, add_encoder=True, add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, pre_process=True, post_process=True): """Build language model and return along with the key to save.""" args = get_args() if init_method is None: init_method = init_method_normal(args.init_method_std) if scaled_init_method is None: scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) # Language model. language_model = TransformerLanguageModel( init_method, scaled_init_method, encoder_attn_mask_type, num_tokentypes=num_tokentypes, add_encoder=add_encoder, add_decoder=add_decoder, decoder_attn_mask_type=decoder_attn_mask_type, add_pooler=add_pooler, pre_process=pre_process, post_process=post_process) # key used for checkpoints. language_model_key = 'language_model' return language_model, language_model_key class Pooler(MegatronModule): """Pooler layer. Pool hidden states of a specific token (for example start of the sequence) and add a linear transformation followed by a tanh. Arguments: hidden_size: hidden size init_method: weight initialization method for the linear layer. bias is set to zero. """ def __init__(self, hidden_size, init_method): super(Pooler, self).__init__() args = get_args() self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.sequence_parallel = args.sequence_parallel def forward(self, hidden_states, sequence_index=0): # hidden_states: [s, b, h] # sequence_index: index of the token to pool. # gather data along sequence dimensions # same pooler is run on all tensor parallel nodes if self.sequence_parallel: tpg = tensor_parallel.gather_from_sequence_parallel_region hidden_states = tpg(hidden_states, tensor_parallel_output_grad=False) pooled = hidden_states[sequence_index, :, :] pooled = self.dense(pooled) pooled = torch.tanh(pooled) return pooled class Embedding(MegatronModule): """Language model embeddings. Arguments: hidden_size: hidden size vocab_size: vocabulary size max_sequence_length: maximum size of sequence. This is used for positional embedding embedding_dropout_prob: dropout probability for embeddings init_method: weight initialization method num_tokentypes: size of the token-type embeddings. 0 value will ignore this embedding """ def __init__(self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, init_method, num_tokentypes=0): super(Embedding, self).__init__() self.hidden_size = hidden_size self.init_method = init_method self.num_tokentypes = num_tokentypes args = get_args() # Word embeddings (parallel). self.word_embeddings = tensor_parallel.VocabParallelEmbedding( vocab_size, self.hidden_size, init_method=self.init_method, params_dtype=args.params_dtype, use_cpu_initialization=args.use_cpu_initialization, perform_initialization=args.perform_initialization) self._word_embeddings_key = 'word_embeddings' self.position_embedding_type = args.position_embedding_type if self.position_embedding_type == 'absolute': # Position embedding (serial). self.position_embeddings = torch.nn.Embedding( max_sequence_length, self.hidden_size) self._position_embeddings_key = 'position_embeddings' # Initialize the position embeddings. if args.perform_initialization: self.init_method(self.position_embeddings.weight) else: self.position_embeddings = None # Token type embedding. # Add this as an optional field that can be added through # method call so we can load a pretrain model without # token types and add them as needed. self._tokentype_embeddings_key = 'tokentype_embeddings' if self.num_tokentypes > 0: self.tokentype_embeddings = torch.nn.Embedding( self.num_tokentypes, self.hidden_size) # Initialize the token-type embeddings. if args.perform_initialization: self.init_method(self.tokentype_embeddings.weight) else: self.tokentype_embeddings = None self.fp32_residual_connection = args.fp32_residual_connection self.sequence_parallel = args.sequence_parallel # Embeddings dropout self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) def zero_parameters(self): """Zero out all parameters in embedding.""" self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True self.position_embeddings.weight.data.fill_(0) self.position_embeddings.weight.shared = True if self.num_tokentypes > 0: self.tokentype_embeddings.weight.data.fill_(0) self.tokentype_embeddings.weight.shared = True def forward(self, input_ids, position_ids, tokentype_ids=None): # Embeddings. words_embeddings = self.word_embeddings(input_ids) embeddings = words_embeddings if self.position_embedding_type == 'absolute': assert self.position_embeddings is not None embeddings = embeddings + self.position_embeddings(position_ids) else: assert self.position_embeddings is None if tokentype_ids is not None: assert self.tokentype_embeddings is not None embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) else: assert self.tokentype_embeddings is None # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. embeddings = embeddings.transpose(0, 1).contiguous() # If the input flag for fp32 residual # connection is set, convert for float. if self.fp32_residual_connection: embeddings = embeddings.float() # Dropout. if self.sequence_parallel: embeddings = tensor_parallel.scatter_to_sequence_parallel_region( embeddings) with tensor_parallel.get_cuda_rng_tracker().fork(): embeddings = self.embedding_dropout(embeddings) else: embeddings = self.embedding_dropout(embeddings) return embeddings def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): """For easy load.""" state_dict_ = {} state_dict_[self._word_embeddings_key] \ = self.word_embeddings.state_dict(prefix=prefix, keep_vars=keep_vars) if self.position_embedding_type == 'absolute': state_dict_[self._position_embeddings_key] \ = self.position_embeddings.state_dict(prefix=prefix, keep_vars=keep_vars) if self.num_tokentypes > 0: state_dict_[self._tokentype_embeddings_key] \ = self.tokentype_embeddings.state_dict(prefix=prefix, keep_vars=keep_vars) return state_dict_ def load_state_dict(self, state_dict, strict=True): """Customized load.""" # Word embedding. if self._word_embeddings_key in state_dict: state_dict_ = state_dict[self._word_embeddings_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if 'word_embeddings' in key: state_dict_[key.split('word_embeddings.')[1]] \ = state_dict[key] self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. if self.position_embedding_type == 'absolute': if self._position_embeddings_key in state_dict: state_dict_ = state_dict[self._position_embeddings_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if 'position_embeddings' in key: state_dict_[key.split('position_embeddings.')[1]] \ = state_dict[key] self.position_embeddings.load_state_dict(state_dict_, strict=strict) # Tokentype embedding. if self.num_tokentypes > 0: state_dict_ = {} if self._tokentype_embeddings_key in state_dict: state_dict_ = state_dict[self._tokentype_embeddings_key] else: # for backward compatibility. for key in state_dict.keys(): if 'tokentype_embeddings' in key: state_dict_[key.split('tokentype_embeddings.')[1]] \ = state_dict[key] if len(state_dict_.keys()) > 0: self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict) else: print('***WARNING*** expected tokentype embeddings in the ' 'checkpoint but could not find it') class TransformerLanguageModel(MegatronModule): """Transformer language model. Arguments: transformer_hparams: transformer hyperparameters vocab_size: vocabulary size max_sequence_length: maximum size of sequence. This is used for positional embedding embedding_dropout_prob: dropout probability for embeddings num_tokentypes: size of the token-type embeddings. 0 value will ignore this embedding """ def __init__(self, init_method, output_layer_init_method, encoder_attn_mask_type, num_tokentypes=0, add_encoder=True, add_decoder=False, decoder_attn_mask_type=AttnMaskType.causal, add_pooler=False, pre_process=True, post_process=True): args = get_args() # TODO: passing share_word_embeddings=False # will not work correctly for T5 and embeddings # will not be synced. Fix later for T5. if args.untie_embeddings_and_output_weights: assert not add_decoder super(TransformerLanguageModel, self).__init__( share_word_embeddings=not args.untie_embeddings_and_output_weights) self.pre_process = pre_process self.post_process = post_process self.hidden_size = args.hidden_size self.num_tokentypes = num_tokentypes self.init_method = init_method self.add_encoder = add_encoder self.encoder_attn_mask_type = encoder_attn_mask_type self.add_decoder = add_decoder self.decoder_attn_mask_type = decoder_attn_mask_type self.add_pooler = add_pooler self.encoder_hidden_state = None self.untie_embeddings_and_output_weights =\ args.untie_embeddings_and_output_weights self.seq_length = args.seq_length # Embeddings. if self.pre_process: self.embedding = Embedding(self.hidden_size, args.padded_vocab_size, args.max_position_embeddings, args.hidden_dropout, self.init_method, self.num_tokentypes) self._embedding_key = 'embedding' # Transformer. # Encoder (usually set to True, False if part of an encoder-decoder # architecture and in encoder-only stage). if self.add_encoder: self.encoder = ParallelTransformer( self.init_method, output_layer_init_method, self_attn_mask_type=self.encoder_attn_mask_type, pre_process=self.pre_process, post_process=self.post_process) self._encoder_key = 'encoder' else: self.encoder = None # Decoder (usually set to False, True if part of an encoder-decoder # architecture and in decoder-only stage). if self.add_decoder: self.decoder = ParallelTransformer( self.init_method, output_layer_init_method, layer_type=LayerType.decoder, self_attn_mask_type=self.decoder_attn_mask_type, pre_process=self.pre_process, post_process=self.post_process) self._decoder_key = 'decoder' else: self.decoder = None if self.post_process: # Pooler. if self.add_pooler: self.pooler = Pooler(self.hidden_size, self.init_method) self._pooler_key = 'pooler' if self.untie_embeddings_and_output_weights: tpc = tensor_parallel.ColumnParallelLinear self.output_layer = tpc(args.hidden_size, args.padded_vocab_size, bias=False, init_method=self.init_method) self._output_layer_key = 'output_layer' def set_input_tensor(self, input_tensor): """ See megatron.model.transformer.set_input_tensor()""" # This is usually handled in schedules.py but some inference code still # gives us non-lists or None if not isinstance(input_tensor, list): input_tensor = [input_tensor] if self.add_encoder and self.add_decoder: assert len(input_tensor) == 1, \ 'input_tensor should only' \ ' be length 1 for stage with both encoder and decoder' self.encoder.set_input_tensor(input_tensor[0]) elif self.add_encoder: assert len(input_tensor) == 1, \ 'input_tensor should only' \ ' be length 1 for stage with only encoder' self.encoder.set_input_tensor(input_tensor[0]) elif self.add_decoder: if len(input_tensor) == 2: self.decoder.set_input_tensor(input_tensor[0]) self.encoder_hidden_state = input_tensor[1] elif len(input_tensor) == 1: self.decoder.set_input_tensor(None) self.encoder_hidden_state = input_tensor[0] else: raise Exception('input_tensor must have either length 1 or 2') else: raise Exception( 'Stage must have at least either encoder or decoder') # Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask(self, input_ids_shape, dtype, device, past_key_values_length=0): """ Make causal mask used for bi-directional self-attention. """ bsz, tgt_len = input_ids_shape mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) mask = mask.to(dtype) if past_key_values_length > 0: mask = torch.cat([ torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask ], dim=-1) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) # Copied from transformers.models.bart.modeling_bart._expand_mask def _expand_mask(self, mask, dtype, tgt_len=None): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ if len(mask.size()) == 2: bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) elif len(mask.size()) == 4: mask[mask == 0] = True expanded_mask = mask.to(dtype) inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) def _prepare_decoder_attention_mask(self, attention_mask, input_shape, dtype, device, past_key_values_length=0): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None if input_shape[-1] > 1: combined_attention_mask = self._make_causal_mask( input_shape, dtype, device=device, past_key_values_length=past_key_values_length, ) if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] expanded_attn_mask = self._expand_mask( attention_mask, dtype, tgt_len=input_shape[-1]).to(device) combined_attention_mask = (expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask) return combined_attention_mask def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, enc_dec_attn_mask=None, tokentype_ids=None, inference_params=None, pooling_sequence_index=0, enc_hidden_states=None, output_enc_hidden=False): args = get_args() # Encoder embedding. if self.pre_process: encoder_input = self.embedding(enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids) else: encoder_input = None if inference_params is None: batch_size = enc_input_ids.shape[0] enc_attn_mask = self._prepare_decoder_attention_mask( enc_attn_mask, (batch_size, self.seq_length), args.params_dtype, enc_input_ids.device) else: batch_size = enc_input_ids.shape[0] enc_attn_mask = self._prepare_decoder_attention_mask( enc_attn_mask, (batch_size, enc_attn_mask.size()[-2]), args.params_dtype, enc_input_ids.device) # Run encoder. if enc_hidden_states is None: if self.encoder is not None: encoder_output = self.encoder( encoder_input, enc_position_ids, enc_attn_mask, inference_params=inference_params) else: encoder_output = self.encoder_hidden_state else: encoder_output = enc_hidden_states.to(encoder_input.dtype) if self.post_process: if self.add_pooler: pooled_output = self.pooler(encoder_output, pooling_sequence_index) # output_enc_hidden refers to when we just need the encoder's # output. For example, it is helpful to compute # similarity between two sequences by average pooling if not self.add_decoder or output_enc_hidden: if self.add_pooler and self.post_process: return encoder_output, pooled_output else: return encoder_output # Decoder embedding. if self.pre_process: decoder_input = self.embedding(dec_input_ids, dec_position_ids) else: decoder_input = None # Run decoder. decoder_output = self.decoder(decoder_input, dec_attn_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, inference_params=inference_params) if self.add_pooler and self.post_process: return decoder_output, encoder_output, pooled_output else: return decoder_output, encoder_output def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): """For easy load.""" state_dict_ = {} if self.pre_process: state_dict_[self._embedding_key] \ = self.embedding.state_dict_for_save_checkpoint( prefix=prefix, keep_vars=keep_vars) if self.add_encoder: state_dict_[self._encoder_key] \ = self.encoder.state_dict_for_save_checkpoint( prefix=prefix, keep_vars=keep_vars) if self.post_process: if self.add_pooler: state_dict_[self._pooler_key] =\ self.pooler.\ state_dict_for_save_checkpoint(prefix=prefix, keep_vars=keep_vars) if self.untie_embeddings_and_output_weights: state_dict_[self._output_layer_key] =\ self.output_layer.state_dict(prefix=prefix, keep_vars=keep_vars) if self.add_decoder: state_dict_[self._decoder_key] \ = self.decoder.state_dict_for_save_checkpoint( prefix=prefix, keep_vars=keep_vars) return state_dict_ def load_state_dict(self, state_dict, strict=True): """Customized load.""" # Embedding. if self.pre_process: if self._embedding_key in state_dict: state_dict_ = state_dict[self._embedding_key] else: # for backward compatibility. state_dict_ = {} for key in state_dict.keys(): if '_embeddings' in key: state_dict_[key] = state_dict[key] self.embedding.load_state_dict(state_dict_, strict=strict) # Encoder. if self.add_encoder: if self._encoder_key in state_dict: state_dict_ = state_dict[self._encoder_key] # For backward compatibility. elif 'transformer' in state_dict: state_dict_ = state_dict['transformer'] else: # For backward compatibility. state_dict_ = {} for key in state_dict.keys(): if 'transformer.' in key: state_dict_[key.split('transformer.') [1]] = state_dict[key] # For backward compatibility. state_dict_self_attention = {} for key in state_dict_.keys(): if '.attention.' in key: state_dict_self_attention[key.replace( '.attention.', '.self_attention.')] = state_dict_[key] else: state_dict_self_attention[key] = state_dict_[key] state_dict_ = state_dict_self_attention self.encoder.load_state_dict(state_dict_, strict=False) # Pooler. if self.post_process: if self.add_pooler: assert 'pooler' in state_dict, \ 'could not find data for pooler in the checkpoint' self.pooler.load_state_dict(state_dict[self._pooler_key], strict=strict) if self.untie_embeddings_and_output_weights: assert 'output_layer' in state_dict, \ 'could not find data for output_layer in the checkpoint' self.output_layer.load_state_dict( state_dict[self._output_layer_key], strict=strict) # Decoder. if self.add_decoder: assert 'decoder' in state_dict, \ 'could not find data for pooler in the checkpoint' self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict)