megatron_patch/model/bloom/layers.py (87 lines of code) (raw):

# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn.functional as F import torch.nn.init as init from torch.nn.parameter import Parameter from megatron import get_args from megatron.core import mpu from megatron.core.parallel_state import get_tensor_model_parallel_rank from megatron.core.parallel_state import get_tensor_model_parallel_world_size from megatron.core.tensor_parallel.layers import _initialize_affine_weight_gpu from megatron.core.tensor_parallel.layers import _initialize_affine_weight_cpu from megatron.core.tensor_parallel.mappings import \ reduce_from_tensor_model_parallel_region from megatron.core.tensor_parallel.utils import VocabUtility class VocabParallelEmbedding(torch.nn.Module): """Embedding parallelized in the vocabulary dimension. This is mainly adapted from torch.nn.Embedding and all the default values are kept. Arguments: num_embeddings: vocabulary size. embedding_dim: size of hidden state. Keyword Arguments: init_method: method to initialize weights. params_dtype use_cpu_initialization perform_initialization """ def __init__(self, num_embeddings: int, embedding_dim: int, *, init_method=init.xavier_normal_, params_dtype: torch.dtype = torch.float32, use_cpu_initialization: bool = False, perform_initialization: bool = True): super(VocabParallelEmbedding, self).__init__() # Keep the input dimensions. self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim # Set the detauls for compatibility. self.padding_idx = None self.max_norm = None self.norm_type = 2. self.scale_grad_by_freq = False self.sparse = False self._weight = None self.tensor_model_parallel_size = get_tensor_model_parallel_world_size( ) # Divide the weight matrix along the vocaburaly dimension. self.vocab_start_index, self.vocab_end_index = \ VocabUtility.vocab_range_from_global_vocab_size( self.num_embeddings, get_tensor_model_parallel_rank(), self.tensor_model_parallel_size) self.num_embeddings_per_partition = self.vocab_end_index - \ self.vocab_start_index args = get_args() if mpu.is_pipeline_first_stage() and args.embed_layernorm: from megatron.model.fused_layer_norm\ import MixedFusedLayerNorm as LayerNorm self.norm = LayerNorm(embedding_dim) # Allocate weights and initialize. if use_cpu_initialization: self.weight = Parameter( torch.empty(self.num_embeddings_per_partition, self.embedding_dim, dtype=params_dtype)) if perform_initialization: _initialize_affine_weight_cpu( self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method, params_dtype=params_dtype) else: self.weight = Parameter( torch.empty(self.num_embeddings_per_partition, self.embedding_dim, device=torch.cuda.current_device(), dtype=params_dtype)) if perform_initialization: _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1) def forward(self, input_): if self.tensor_model_parallel_size > 1: # Build the mask. input_mask = (input_ < self.vocab_start_index) | \ (input_ >= self.vocab_end_index) # Mask the input. masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 else: masked_input = input_ # Get the embeddings. output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) # Mask the output embedding. if self.tensor_model_parallel_size > 1: output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. output = reduce_from_tensor_model_parallel_region(output_parallel) if hasattr(self, 'norm'): output = self.norm(output) return output