megatron_patch/model/mixtral_bak/model.py (162 lines of code) (raw):

# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Literal, Optional from torch import Tensor from megatron.core import InferenceParams, parallel_state, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.common.language_module.language_module import LanguageModule from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer.enums import AttnMaskType, ModelType from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint class GPTModel(LanguageModule): """GPT Transformer language model. Args: config (TransformerConfig): Transformer config transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers vocab_size (int): Vocabulary size max_sequence_length (int): maximum size of sequence. This is used for positional embedding pre_process (bool, optional): Include embedding layer (used with pipeline parallelism). Defaults to True. post_process (bool, optional): Include an output layer (used with pipeline parallelism). Defaults to True. fp16_lm_cross_entropy (bool, optional): Defaults to False. parallel_output (bool, optional): Do not gather the outputs, keep them split across tensor parallel ranks. Defaults to True. share_embeddings_and_output_weights (bool, optional): When True, input embeddings and output logit weights are shared. Defaults to False. position_embedding_type (Literal[learned_absolute,rope], optional): Position embedding type.. Defaults to 'learned_absolute'. rotary_percent (float, optional): Percent of rotary dimension to use for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 1.0. rotary_base (int, optional): Base period for rotary position embeddings. Ignored unless position_embedding_type is 'rope'. Defaults to 10000. seq_len_interpolation_factor (Optional[float], optional): scale of linearly interpolating RoPE for longer sequences. The value must be a float larger than 1.0. Defaults to None. """ def __init__( self, config: TransformerConfig, transformer_layer_spec: ModuleSpec, vocab_size: int, max_sequence_length: int, pre_process: bool = True, post_process: bool = True, fp16_lm_cross_entropy: bool = False, parallel_output: bool = True, share_embeddings_and_output_weights: bool = False, position_embedding_type: Literal['learned_absolute', 'rope'] = 'learned_absolute', rotary_percent: float = 1.0, rotary_base: int = 10000, seq_len_interpolation_factor: Optional[float] = None, ) -> None: super().__init__(config=config) self.transformer_layer_spec: ModuleSpec = transformer_layer_spec self.vocab_size = vocab_size self.max_sequence_length = max_sequence_length self.pre_process = pre_process self.post_process = post_process self.fp16_lm_cross_entropy = fp16_lm_cross_entropy self.parallel_output = parallel_output self.share_embeddings_and_output_weights = share_embeddings_and_output_weights self.position_embedding_type = position_embedding_type # megatron core pipelining currently depends on model type # TODO: remove this dependency ? self.model_type = ModelType.encoder_or_decoder if self.pre_process: self.embedding = LanguageModelEmbedding( config=self.config, vocab_size=self.vocab_size, max_sequence_length=self.max_sequence_length, position_embedding_type=position_embedding_type, ) if self.position_embedding_type == 'rope': self.rotary_pos_emb = RotaryEmbedding( kv_channels=self.config.kv_channels, rotary_percent=rotary_percent, seq_len_interpolation_factor=seq_len_interpolation_factor, rotary_base=rotary_base, ) # Transformer. self.decoder = TransformerBlock( config=self.config, spec=transformer_layer_spec, pre_process=self.pre_process, post_process=self.post_process, ) # Output if post_process: self.output_layer = tensor_parallel.ColumnParallelLinear( config.hidden_size, self.vocab_size, config=config, init_method=config.init_method, bias=False, skip_bias_add=False, gather_output=not self.parallel_output, skip_weight_param_allocation=self.pre_process and self.share_embeddings_and_output_weights, ) if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process): self.initialize_last_stage_with_word_embeddings() def set_input_tensor(self, input_tensor: Tensor) -> None: """Sets input tensor to the model. See megatron.model.transformer.set_input_tensor() Args: input_tensor (Tensor): Sets the input tensor for the model. """ # This is usually handled in schedules.py but some inference code still # gives us non-lists or None if not isinstance(input_tensor, list): input_tensor = [input_tensor] assert len(input_tensor) == 1, 'input_tensor should only be length 1 for gpt/bert' self.decoder.set_input_tensor(input_tensor[0]) def forward( self, input_ids: Tensor, position_ids: Tensor, attention_mask: Tensor, decoder_input: Tensor = None, labels: Tensor = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, extra_block_kwargs: dict = None, ) -> Tensor: """Forward function of the GPT Model This function passes the input tensors through the embedding layer, and then the decoeder and finally into the post processing layer (optional). It either returns the Loss values if labels are given or the final hidden units Args: input_ids (Tensor): Batch of input token IDs. position_ids (Tensor): Batch of positional encodings corresponding to `input_ids`. attention_mask (Tensor): Mask to avoid attention on padding token indices in `input_ids`. decoder_input (Tensor, optional): Optional pre-calculated embeddings passed directly to the decoder. labels (Tensor, optional): Target output token IDs for supervised training. inference_params (InferenceParams, optional): Parameters for controlling inference behavior. packed_seq_params (PackedSeqParams, optional): Parameters for processing packed sequences. extra_block_kwargs (dict, optional): Additional keyword arguments to pass to the transformer blocks. Returns: Tensor: If `labels` is not provided, the method returns the logits tensor of shape [batch_size, seq_length, vocab_size]. If `labels` is provided, it returns the loss value as a tensor. """ # If decoder_input is provided (not None), then input_ids and position_ids are ignored. # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. # Decoder embedding. if decoder_input is not None: pass elif self.pre_process: decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) else: # intermediate stage of pipeline # decoder will get hidden_states from encoder.input_tensor decoder_input = None # Rotary positional embeddings (embedding is None for PP intermediate devices) rotary_pos_emb = None if self.position_embedding_type == 'rope': rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( inference_params, self.decoder, decoder_input, self.config ) rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) # Run decoder. hidden_states = self.decoder( hidden_states=decoder_input, attention_mask=attention_mask, inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, packed_seq_params=packed_seq_params, **(extra_block_kwargs or {}), ) if not self.post_process: return hidden_states # logits and loss output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() logits, _ = self.output_layer(hidden_states, weight=output_weight) if labels is None: # [s b h] => [b s h] return logits.transpose(0, 1).contiguous() loss = self.compute_language_model_loss(labels, logits) return loss def sharded_state_dict(self, prefix: str = '', sharded_offsets: tuple = ()) -> ShardedStateDict: """ Returns a sharded state dictionary for distributed training, where model parameters are partitioned across different devices or processes. Args: prefix (str, optional): A prefix to prepend to each key in the state dictionary. This can be used to distinguish parameters of different model components when combining state dictionaries. sharded_offsets (tuple, optional): Offsets for sharding the state dictionary. It typically contains the start index for each shard and the overall size of the model parameter. It should be empty for this model as offsets are not expected. Returns: ShardedStateDict: A dictionary containing model parameters with keys prepended by `prefix`. The parameters are sharded according to the tensor parallelism configuration. """ assert not sharded_offsets, "Unexpected sharded offsets" sharded_state_dict = {} if self.pre_process: embedding_prefix = f'{prefix}embedding.' embedding_sharded_state_dict = self.embedding.sharded_state_dict( prefix=embedding_prefix ) sharded_state_dict.update(embedding_sharded_state_dict) decoder_prefix = f'{prefix}decoder.' decoder_sharded_state_dict = self.decoder.sharded_state_dict(prefix=decoder_prefix) sharded_state_dict.update(decoder_sharded_state_dict) if self.post_process: output_layer_prefix = f'{prefix}output_layer.' output_layer_key = f'{output_layer_prefix}weight' if self.share_embeddings_and_output_weights: if not self.pre_process: # when sharing embeddings with last stage, we need to use the weights from the first stage # on pipeline first rank, word embeddings are saved to {prefix}embedding.word_embeddings.weight tensor = self.shared_embedding_or_output_weight() first_stage_word_emb_key = f'{prefix}embedding.word_embeddings.weight' last_stage_word_emb_replica_id = ( 1, # copy of first stage embedding 0, parallel_state.get_data_parallel_rank(with_context_parallel=True), ) sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint( tensor=tensor, key=first_stage_word_emb_key, replica_id=last_stage_word_emb_replica_id, allow_shape_mismatch=True, ) sharded_state_dict[output_layer_key] = sharded_output_layer_tensor else: output_layer_state_dict = self.output_layer.state_dict( prefix=output_layer_prefix, keep_vars=True ) output_layer_tensor = output_layer_state_dict[output_layer_key] # independent output layer sharded_output_layer_tensor = make_tp_sharded_tensor_for_checkpoint( tensor=output_layer_tensor, key=output_layer_key, allow_shape_mismatch=True, ) sharded_state_dict[output_layer_key] = sharded_output_layer_tensor return sharded_state_dict