modules/SwissArmyTransformer/sat/mpu/layers.py (331 lines of code) (raw):

# coding=utf-8 # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Parts of the code here are adapted from PyTorch # repo: https://github.com/pytorch/pytorch import math import torch import torch.nn.functional as F import torch.nn.init as init from torch.nn.parameter import Parameter from .initialize import get_model_parallel_rank from .initialize import get_model_parallel_world_size from .mappings import copy_to_model_parallel_region from .mappings import gather_from_model_parallel_region from .mappings import reduce_from_model_parallel_region from .mappings import scatter_to_model_parallel_region from .utils import divide, unscaled_init_method from .utils import VocabUtility def _initialize_affine_weight(weight, output_size, input_size, per_partition_size, partition_dim, init_method, stride=1, return_master_weight=False, module=None, name=None, self=None): """Initialize affine weight for model parallel. Build the master weight on all processes and scatter the relevant chunk.""" # If we only use 1 process for model parallelism, bypass scatter. world_size = get_model_parallel_world_size() if world_size == 1: init_method(weight, module=module, name=name) if return_master_weight: return weight return None # Initialize master weight master_weight = torch.empty(output_size, input_size, dtype=weight.dtype, requires_grad=False, device=weight.device) init_method(master_weight, module=module, name=name) weight_list = self.partition(full_weight=master_weight) rank = get_model_parallel_rank() with torch.no_grad(): weight.copy_(weight_list[rank]) del weight_list if return_master_weight: return master_weight return None class VocabParallelEmbedding(torch.nn.Module): """Embedding parallelized in the vocabulary dimension. This is mainly adapted from torch.nn.Embedding and all the default values are kept. Arguments: num_embeddings: vocabulary size. embedding_dim: size of hidden state. init_method: method to initialize weights. """ def __init__(self, num_embeddings, embedding_dim, params_dtype=torch.float, init_method=unscaled_init_method(0.02), skip_init=False, device=torch.device('cpu')): super(VocabParallelEmbedding, self).__init__() # Keep the input dimensions. self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim # Set the detauls for compatibility. self.padding_idx = None self.max_norm = None self.norm_type = 2. self.scale_grad_by_freq = False self.sparse = False self._weight = None # Divide the weight matrix along the vocaburaly dimension. self.vocab_start_index, self.vocab_end_index = \ VocabUtility.vocab_range_from_global_vocab_size( self.num_embeddings, get_model_parallel_rank(), get_model_parallel_world_size()) self.num_embeddings_per_partition = self.vocab_end_index - \ self.vocab_start_index # Allocate weights. self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, dtype=params_dtype, device=device)) self.weight.model_parallel = True self.weight.tensor_model_parallel = True # And initialize. if not skip_init: _initialize_affine_weight( self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method, self=self) def forward(self, input_): # Build the mask. input_mask = (input_ < self.vocab_start_index) | \ (input_ >= self.vocab_end_index) # Mask the input. masked_input = input_.clone() - self.vocab_start_index masked_input[input_mask] = 0 # Get the embeddings. output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse) # Mask the output embedding. output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. output = reduce_from_model_parallel_region(output_parallel) return output def repartition(self): assert self.num_embeddings_per_partition == self.num_embeddings self.vocab_start_index, self.vocab_end_index = \ VocabUtility.vocab_range_from_global_vocab_size( self.num_embeddings, get_model_parallel_rank(), get_model_parallel_world_size()) self.num_embeddings_per_partition = self.vocab_end_index - \ self.vocab_start_index self.original_weight = self.weight self.weight = torch.nn.Parameter(torch.clone( self.weight[self.vocab_start_index:self.vocab_end_index], ).detach()) del self.original_weight def partition(self, new_model_parallel_size=None, full_weight=None): assert self.num_embeddings_per_partition == self.num_embeddings or full_weight is not None flag = 1 if full_weight is None: full_weight = self.weight flag = 2 if new_model_parallel_size is None: new_model_parallel_size = get_model_parallel_world_size() new_weights = [] for rank in range(new_model_parallel_size): vocab_start_index, vocab_end_index = \ VocabUtility.vocab_range_from_global_vocab_size( self.num_embeddings, rank, new_model_parallel_size) weight = torch.clone( full_weight[vocab_start_index:vocab_end_index], ).detach() new_weights.append(weight) if flag == 1: return new_weights else: return new_weights, [] def merge(self, new_weights, new_biases): self.weight.data.copy_(torch.cat(new_weights)) class ColumnParallelLinear(torch.nn.Module): """Linear layer with column parallelism. The linear layer is defined as Y = XA + b. A is parallelized along its second dimension as A = [A_1, ..., A_p]. Arguments: input_size: first dimension of matrix A. output_size: second dimension of matrix A. bias: If true, add bias gather_output: If true, call all-gether on output and make Y avaiable to all GPUs, otherwise, every GPU will have its output which is Y_i = XA_i init_method: method to initialize weights. Note that bias is always set to zero. stride: For the strided linear layers. only used in initialization. keep_master_weight_for_test: This was added for testing and should be set to False. It returns the master weights used for initialization. """ def __init__(self, input_size, output_size, bias=True, gather_output=True, init_method=unscaled_init_method(0.02), stride=1, keep_master_weight_for_test=False, params_dtype=torch.float, module=None, name=None, skip_init=False, device=torch.device('cpu')): super(ColumnParallelLinear, self).__init__() # Keep input parameters self.stride = stride self.input_size = input_size self.output_size = output_size self.gather_output = gather_output # Divide the weight matrix along the last dimension. world_size = get_model_parallel_world_size() self.output_size_per_partition = divide(output_size, world_size) # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. self.weight = Parameter(torch.empty(self.output_size_per_partition, self.input_size, dtype=params_dtype, device=device)) self.weight.model_parallel = True self.weight.tensor_model_parallel = True if bias: self.bias = Parameter(torch.empty(self.output_size_per_partition,dtype=params_dtype, device=device)) self.bias.model_parallel = True self.bias.tensor_model_parallel = True # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() else: self.register_parameter('bias', None) # Initialize weight. if not skip_init: self.master_weight = _initialize_affine_weight( self.weight, self.output_size, self.input_size, self.output_size_per_partition, 0, init_method, stride=self.stride, return_master_weight=keep_master_weight_for_test, module=module, name=name, self=self) def forward(self, input_): # Set up backprop all-reduce, and don't change the input. input_parallel = copy_to_model_parallel_region(input_) # Matrix multiply. # input_parallel: [seq_len, input_size] # weight: [output_size // mp_size, input_size] # bias: [output_size // mp_size] output_parallel = F.linear(input_parallel, self.weight, self.bias) if self.gather_output: # All-gather across the partitions. output = gather_from_model_parallel_region(output_parallel) else: output = output_parallel return output def repartition(self): assert self.output_size_per_partition == self.output_size self.output_size_per_partition = divide(self.output_size, get_model_parallel_world_size()) mp_rank = get_model_parallel_rank() mp_size = get_model_parallel_world_size() self.original_weight = self.weight # weight is arranged as [stride0...stride1...stride2] * [input_size], extract non-contiguous parts strides = [1]*self.stride if isinstance(self.stride, int) else self.stride # int means equal number of qkv, or ratios assert self.weight.shape[0] % sum(strides) == 0, 'cannot divide weight evenly' factor = self.weight.shape[0] // sum(strides) # decompose weight according to strides strided_weights, _acm = [], 0 for i in range(len(strides)): strided_weights.append(self.weight[_acm:_acm+factor*strides[i], :].detach()) _acm += factor*strides[i] new_weight = torch.cat([ strided_weight[ (strided_weight.shape[0]//mp_size)*mp_rank: (strided_weight.shape[0]//mp_size)*(mp_rank+1) ] for strided_weight in strided_weights ], dim=0).contiguous().view(self.output_size_per_partition, self.input_size) self.weight = torch.nn.Parameter(new_weight) del self.original_weight if self.bias is not None and self.bias.numel() != 0: self.original_bias = self.bias # decompose bias according to strides strided_biases, _acm = [], 0 for i in range(len(strides)): strided_biases.append(self.bias[_acm:_acm+factor*strides[i]].detach()) _acm += factor*strides[i] new_bias = torch.cat([ strided_bias[ (strided_bias.shape[0]//mp_size)*mp_rank: (strided_bias.shape[0]//mp_size)*(mp_rank+1) ] for strided_bias in strided_biases ], dim=0).contiguous().view(self.output_size_per_partition) self.bias = torch.nn.Parameter(new_bias) del self.original_bias def partition(self, new_model_parallel_size=None, full_weight=None): assert self.output_size_per_partition == self.output_size or full_weight is not None flag = 1 if full_weight is None: full_weight = self.weight flag = 2 if new_model_parallel_size is None: new_model_parallel_size = get_model_parallel_world_size() output_size_per_partition = divide(self.output_size, new_model_parallel_size) new_weights = [] new_biases = [] mp_size = new_model_parallel_size # weight is arranged as [stride0...stride1...stride2] * [input_size], extract non-contiguous parts strides = [1]*self.stride if isinstance(self.stride, int) else self.stride # int means equal number of qkv, or ratios assert full_weight.shape[0] % sum(strides) == 0, 'cannot divide weight evenly' factor = full_weight.shape[0] // sum(strides) # decompose weight according to strides strided_weights, _acm = [], 0 for i in range(len(strides)): strided_weights.append(full_weight[_acm:_acm+factor*strides[i], :].detach()) _acm += factor*strides[i] if flag == 2 and self.bias is not None and self.bias.numel() != 0: # decompose bias according to strides strided_biases, _acm = [], 0 for i in range(len(strides)): strided_biases.append(self.bias[_acm:_acm+factor*strides[i]].detach()) _acm += factor*strides[i] for rank in range(new_model_parallel_size): mp_rank = rank new_weight = torch.cat([ strided_weight[ (strided_weight.shape[0]//mp_size)*mp_rank: (strided_weight.shape[0]//mp_size)*(mp_rank+1) ] for strided_weight in strided_weights ], dim=0).contiguous().view(output_size_per_partition, self.input_size) new_weights.append(torch.clone(new_weight).detach()) if flag == 2 and self.bias is not None and self.bias.numel() != 0: new_bias = torch.cat([ strided_bias[ (strided_bias.shape[0]//mp_size)*mp_rank: (strided_bias.shape[0]//mp_size)*(mp_rank+1) ] for strided_bias in strided_biases ], dim=0).contiguous().view(output_size_per_partition) new_biases.append(torch.clone(new_bias).detach()) if flag == 1: return new_weights else: return new_weights, new_biases def merge(self, new_weights, new_biases): strides = [1]*self.stride if isinstance(self.stride, int) else self.stride # int means equal number of qkv, or ratios assert self.weight.shape[0] % sum(strides) == 0, 'cannot divide weight evenly' all_weights = [] _acm = 0 for stride in strides: for weight in new_weights: factor = weight.shape[0] // sum(strides) all_weights.append(weight[_acm:_acm+factor*stride]) _acm += factor*stride self.weight.data.copy_(torch.cat(all_weights)) if self.bias is not None and self.bias.numel() != 0: all_biases = [] _acm = 0 for stride in strides: for bias in new_biases: factor = bias.shape[0] // sum(strides) all_biases.append(bias[_acm:_acm+factor*stride]) _acm += factor*stride self.bias.data.copy_(torch.cat(all_biases)) class RowParallelLinear(torch.nn.Module): """Linear layer with row parallelism. The linear layer is defined as Y = XA + b. A is parallelized along its first dimension and X along its second dimension as: - - | A_1 | | . | A = | . | X = [X_1, ..., X_p] | . | | A_p | - - Arguments: input_size: first dimension of matrix A. output_size: second dimension of matrix A. bias: If true, add bias. Note that bias is not parallelized. input_is_parallel: If true, we assume that the input is already split across the GPUs and we do not split again. init_method: method to initialize weights. Note that bias is always set to zero. stride: For the strided linear layers. keep_master_weight_for_test: This was added for testing and should be set to False. It returns the master weights used for initialization. """ def __init__(self, input_size, output_size, bias=True, input_is_parallel=False, init_method=unscaled_init_method(0.02), stride=1, keep_master_weight_for_test=False, params_dtype=torch.float, module=None, name=None, skip_init=False, device=torch.device('cpu'), final_bias=True): super(RowParallelLinear, self).__init__() # Keep input parameters self.input_size = input_size self.output_size = output_size self.input_is_parallel = input_is_parallel # Divide the weight matrix along the last dimension. world_size = get_model_parallel_world_size() self.input_size_per_partition = divide(input_size, world_size) self.final_bias = final_bias # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. self.weight = Parameter(torch.empty(self.output_size, self.input_size_per_partition, dtype=params_dtype, device=device)) self.weight.model_parallel = True self.weight.tensor_model_parallel = True if bias: self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype, device=device)) # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() else: self.register_parameter('bias', None) # Initialize weight. if not skip_init: self.master_weight = _initialize_affine_weight( self.weight, self.output_size, self.input_size, self.input_size_per_partition, 1, init_method, stride=stride, return_master_weight=keep_master_weight_for_test, module=module, name=name, self=self) def forward(self, input_): # Split the input vector along the last dimension. if self.input_is_parallel: input_parallel = input_ else: input_parallel = scatter_to_model_parallel_region(input_) # Matrix multiply. # input_parallel: [seq_len, input_size // mp_size] # weight: [output_size, input_size // mp_size] if self.final_bias or self.bias is None: output_parallel = F.linear(input_parallel, self.weight) else: output_parallel = F.linear(input_parallel, self.weight, self.bias / get_model_parallel_world_size()) # All-reduce across all the partitions. output_ = reduce_from_model_parallel_region(output_parallel) if self.final_bias and self.bias is not None: output = output_ + self.bias else: output = output_ return output def repartition(self): assert self.input_size_per_partition == self.input_size self.input_size_per_partition = divide(self.input_size, get_model_parallel_world_size()) mp_rank = get_model_parallel_rank() self.original_weight = self.weight self.weight = torch.nn.Parameter(torch.clone( self.weight[:, mp_rank*self.input_size_per_partition :(mp_rank+1)*self.input_size_per_partition], ).detach()) del self.original_weight def partition(self, new_model_parallel_size=None, full_weight=None): assert self.input_size_per_partition == self.input_size or full_weight is not None flag = 1 if full_weight is None: full_weight = self.weight flag = 2 if new_model_parallel_size is None: new_model_parallel_size = get_model_parallel_world_size() input_size_per_partition = divide(self.input_size, new_model_parallel_size) new_weights = [] new_biases = [] for rank in range(new_model_parallel_size): mp_rank = rank weight = torch.clone( full_weight[:, mp_rank*input_size_per_partition :(mp_rank+1)*input_size_per_partition], ).detach() new_weights.append(weight) if flag == 2 and self.bias is not None and self.bias.numel() != 0: new_biases.append(torch.clone(self.bias.data).detach()) if flag == 1: return new_weights else: return new_weights, new_biases def merge(self, new_weights, new_biases): self.weight.data.copy_(torch.cat(new_weights, 1)) if self.bias is not None and self.bias.numel() != 0: self.bias.data.copy_(new_biases[0])