modules/SwissArmyTransformer/sat/mpu/layers.py (331 lines of code) (raw):
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import math
import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size
from .mappings import copy_to_model_parallel_region
from .mappings import gather_from_model_parallel_region
from .mappings import reduce_from_model_parallel_region
from .mappings import scatter_to_model_parallel_region
from .utils import divide, unscaled_init_method
from .utils import VocabUtility
def _initialize_affine_weight(weight, output_size, input_size,
per_partition_size, partition_dim, init_method,
stride=1, return_master_weight=False, module=None, name=None, self=None):
"""Initialize affine weight for model parallel.
Build the master weight on all processes and scatter
the relevant chunk."""
# If we only use 1 process for model parallelism, bypass scatter.
world_size = get_model_parallel_world_size()
if world_size == 1:
init_method(weight, module=module, name=name)
if return_master_weight:
return weight
return None
# Initialize master weight
master_weight = torch.empty(output_size, input_size,
dtype=weight.dtype,
requires_grad=False,
device=weight.device)
init_method(master_weight, module=module, name=name)
weight_list = self.partition(full_weight=master_weight)
rank = get_model_parallel_rank()
with torch.no_grad():
weight.copy_(weight_list[rank])
del weight_list
if return_master_weight:
return master_weight
return None
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(self, num_embeddings, embedding_dim, params_dtype=torch.float, init_method=unscaled_init_method(0.02), skip_init=False, device=torch.device('cpu')):
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
# Set the detauls for compatibility.
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_model_parallel_rank(),
get_model_parallel_world_size())
self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index
# Allocate weights.
self.weight = Parameter(torch.empty(self.num_embeddings_per_partition,
self.embedding_dim, dtype=params_dtype,
device=device))
self.weight.model_parallel = True
self.weight.tensor_model_parallel = True
# And initialize.
if not skip_init:
_initialize_affine_weight(
self.weight, self.num_embeddings, self.embedding_dim,
self.num_embeddings_per_partition, 0, init_method, self=self)
def forward(self, input_):
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight,
self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq,
self.sparse)
# Mask the output embedding.
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel)
return output
def repartition(self):
assert self.num_embeddings_per_partition == self.num_embeddings
self.vocab_start_index, self.vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_model_parallel_rank(),
get_model_parallel_world_size())
self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index
self.original_weight = self.weight
self.weight = torch.nn.Parameter(torch.clone(
self.weight[self.vocab_start_index:self.vocab_end_index],
).detach())
del self.original_weight
def partition(self, new_model_parallel_size=None, full_weight=None):
assert self.num_embeddings_per_partition == self.num_embeddings or full_weight is not None
flag = 1
if full_weight is None:
full_weight = self.weight
flag = 2
if new_model_parallel_size is None:
new_model_parallel_size = get_model_parallel_world_size()
new_weights = []
for rank in range(new_model_parallel_size):
vocab_start_index, vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, rank,
new_model_parallel_size)
weight = torch.clone(
full_weight[vocab_start_index:vocab_end_index],
).detach()
new_weights.append(weight)
if flag == 1:
return new_weights
else:
return new_weights, []
def merge(self, new_weights, new_biases):
self.weight.data.copy_(torch.cat(new_weights))
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias
gather_output: If true, call all-gether on output and make Y avaiable
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers. only used in initialization.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
"""
def __init__(self, input_size, output_size, bias=True, gather_output=True,
init_method=unscaled_init_method(0.02), stride=1,
keep_master_weight_for_test=False, params_dtype=torch.float, module=None, name=None, skip_init=False, device=torch.device('cpu')):
super(ColumnParallelLinear, self).__init__()
# Keep input parameters
self.stride = stride
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
world_size = get_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
self.weight = Parameter(torch.empty(self.output_size_per_partition,
self.input_size, dtype=params_dtype,
device=device))
self.weight.model_parallel = True
self.weight.tensor_model_parallel = True
if bias:
self.bias = Parameter(torch.empty(self.output_size_per_partition,dtype=params_dtype, device=device))
self.bias.model_parallel = True
self.bias.tensor_model_parallel = True
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
# Initialize weight.
if not skip_init:
self.master_weight = _initialize_affine_weight(
self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method,
stride=self.stride, return_master_weight=keep_master_weight_for_test, module=module, name=name, self=self)
def forward(self, input_):
# Set up backprop all-reduce, and don't change the input.
input_parallel = copy_to_model_parallel_region(input_)
# Matrix multiply.
# input_parallel: [seq_len, input_size]
# weight: [output_size // mp_size, input_size]
# bias: [output_size // mp_size]
output_parallel = F.linear(input_parallel, self.weight, self.bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel)
else:
output = output_parallel
return output
def repartition(self):
assert self.output_size_per_partition == self.output_size
self.output_size_per_partition = divide(self.output_size, get_model_parallel_world_size())
mp_rank = get_model_parallel_rank()
mp_size = get_model_parallel_world_size()
self.original_weight = self.weight
# weight is arranged as [stride0...stride1...stride2] * [input_size], extract non-contiguous parts
strides = [1]*self.stride if isinstance(self.stride, int) else self.stride # int means equal number of qkv, or ratios
assert self.weight.shape[0] % sum(strides) == 0, 'cannot divide weight evenly'
factor = self.weight.shape[0] // sum(strides)
# decompose weight according to strides
strided_weights, _acm = [], 0
for i in range(len(strides)):
strided_weights.append(self.weight[_acm:_acm+factor*strides[i], :].detach())
_acm += factor*strides[i]
new_weight = torch.cat([
strided_weight[
(strided_weight.shape[0]//mp_size)*mp_rank:
(strided_weight.shape[0]//mp_size)*(mp_rank+1)
]
for strided_weight in strided_weights
], dim=0).contiguous().view(self.output_size_per_partition, self.input_size)
self.weight = torch.nn.Parameter(new_weight)
del self.original_weight
if self.bias is not None and self.bias.numel() != 0:
self.original_bias = self.bias
# decompose bias according to strides
strided_biases, _acm = [], 0
for i in range(len(strides)):
strided_biases.append(self.bias[_acm:_acm+factor*strides[i]].detach())
_acm += factor*strides[i]
new_bias = torch.cat([
strided_bias[
(strided_bias.shape[0]//mp_size)*mp_rank:
(strided_bias.shape[0]//mp_size)*(mp_rank+1)
]
for strided_bias in strided_biases
], dim=0).contiguous().view(self.output_size_per_partition)
self.bias = torch.nn.Parameter(new_bias)
del self.original_bias
def partition(self, new_model_parallel_size=None, full_weight=None):
assert self.output_size_per_partition == self.output_size or full_weight is not None
flag = 1
if full_weight is None:
full_weight = self.weight
flag = 2
if new_model_parallel_size is None:
new_model_parallel_size = get_model_parallel_world_size()
output_size_per_partition = divide(self.output_size, new_model_parallel_size)
new_weights = []
new_biases = []
mp_size = new_model_parallel_size
# weight is arranged as [stride0...stride1...stride2] * [input_size], extract non-contiguous parts
strides = [1]*self.stride if isinstance(self.stride, int) else self.stride # int means equal number of qkv, or ratios
assert full_weight.shape[0] % sum(strides) == 0, 'cannot divide weight evenly'
factor = full_weight.shape[0] // sum(strides)
# decompose weight according to strides
strided_weights, _acm = [], 0
for i in range(len(strides)):
strided_weights.append(full_weight[_acm:_acm+factor*strides[i], :].detach())
_acm += factor*strides[i]
if flag == 2 and self.bias is not None and self.bias.numel() != 0:
# decompose bias according to strides
strided_biases, _acm = [], 0
for i in range(len(strides)):
strided_biases.append(self.bias[_acm:_acm+factor*strides[i]].detach())
_acm += factor*strides[i]
for rank in range(new_model_parallel_size):
mp_rank = rank
new_weight = torch.cat([
strided_weight[
(strided_weight.shape[0]//mp_size)*mp_rank:
(strided_weight.shape[0]//mp_size)*(mp_rank+1)
]
for strided_weight in strided_weights
], dim=0).contiguous().view(output_size_per_partition, self.input_size)
new_weights.append(torch.clone(new_weight).detach())
if flag == 2 and self.bias is not None and self.bias.numel() != 0:
new_bias = torch.cat([
strided_bias[
(strided_bias.shape[0]//mp_size)*mp_rank:
(strided_bias.shape[0]//mp_size)*(mp_rank+1)
]
for strided_bias in strided_biases
], dim=0).contiguous().view(output_size_per_partition)
new_biases.append(torch.clone(new_bias).detach())
if flag == 1:
return new_weights
else:
return new_weights, new_biases
def merge(self, new_weights, new_biases):
strides = [1]*self.stride if isinstance(self.stride, int) else self.stride # int means equal number of qkv, or ratios
assert self.weight.shape[0] % sum(strides) == 0, 'cannot divide weight evenly'
all_weights = []
_acm = 0
for stride in strides:
for weight in new_weights:
factor = weight.shape[0] // sum(strides)
all_weights.append(weight[_acm:_acm+factor*stride])
_acm += factor*stride
self.weight.data.copy_(torch.cat(all_weights))
if self.bias is not None and self.bias.numel() != 0:
all_biases = []
_acm = 0
for stride in strides:
for bias in new_biases:
factor = bias.shape[0] // sum(strides)
all_biases.append(bias[_acm:_acm+factor*stride])
_acm += factor*stride
self.bias.data.copy_(torch.cat(all_biases))
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
"""
def __init__(self, input_size, output_size, bias=True,
input_is_parallel=False,
init_method=unscaled_init_method(0.02), stride=1,
keep_master_weight_for_test=False, params_dtype=torch.float, module=None, name=None, skip_init=False, device=torch.device('cpu'), final_bias=True):
super(RowParallelLinear, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension.
world_size = get_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
self.final_bias = final_bias
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
self.weight = Parameter(torch.empty(self.output_size,
self.input_size_per_partition, dtype=params_dtype, device=device))
self.weight.model_parallel = True
self.weight.tensor_model_parallel = True
if bias:
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype, device=device))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
# Initialize weight.
if not skip_init:
self.master_weight = _initialize_affine_weight(
self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test, module=module, name=name, self=self)
def forward(self, input_):
# Split the input vector along the last dimension.
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_model_parallel_region(input_)
# Matrix multiply.
# input_parallel: [seq_len, input_size // mp_size]
# weight: [output_size, input_size // mp_size]
if self.final_bias or self.bias is None:
output_parallel = F.linear(input_parallel, self.weight)
else:
output_parallel = F.linear(input_parallel, self.weight, self.bias / get_model_parallel_world_size())
# All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel)
if self.final_bias and self.bias is not None:
output = output_ + self.bias
else:
output = output_
return output
def repartition(self):
assert self.input_size_per_partition == self.input_size
self.input_size_per_partition = divide(self.input_size, get_model_parallel_world_size())
mp_rank = get_model_parallel_rank()
self.original_weight = self.weight
self.weight = torch.nn.Parameter(torch.clone(
self.weight[:, mp_rank*self.input_size_per_partition
:(mp_rank+1)*self.input_size_per_partition],
).detach())
del self.original_weight
def partition(self, new_model_parallel_size=None, full_weight=None):
assert self.input_size_per_partition == self.input_size or full_weight is not None
flag = 1
if full_weight is None:
full_weight = self.weight
flag = 2
if new_model_parallel_size is None:
new_model_parallel_size = get_model_parallel_world_size()
input_size_per_partition = divide(self.input_size, new_model_parallel_size)
new_weights = []
new_biases = []
for rank in range(new_model_parallel_size):
mp_rank = rank
weight = torch.clone(
full_weight[:, mp_rank*input_size_per_partition
:(mp_rank+1)*input_size_per_partition],
).detach()
new_weights.append(weight)
if flag == 2 and self.bias is not None and self.bias.numel() != 0:
new_biases.append(torch.clone(self.bias.data).detach())
if flag == 1:
return new_weights
else:
return new_weights, new_biases
def merge(self, new_weights, new_biases):
self.weight.data.copy_(torch.cat(new_weights, 1))
if self.bias is not None and self.bias.numel() != 0:
self.bias.data.copy_(new_biases[0])