optimum/tpu/xla_model_parallel.py (548 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from copy import deepcopy
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple
import torch
import torch.ao.quantization.fx._decomposed
import torch.distributed as dist
import torch.distributed._functional_collectives as fc
import torch.distributed.distributed_c10d as c10d
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn import Linear
from torch.nn.parameter import Parameter
EPS = torch.finfo(torch.float32).eps
USE_CUDA = os.environ.get("USE_CUDA", False)
if not USE_CUDA:
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
TAG = None
RANKSET = None
GROUP_SIZE = None
def set_g_group():
global TAG
global RANKSET
global GROUP_SIZE
assert USE_CUDA, "This hack is only for PyTorch non-XLA CUDA paths, i.e., eager and inductor."
TAG, RANKSET, GROUP_SIZE = fc._expand_group(c10d._get_default_group())
@dataclass
class TensorQConfig:
dtype: torch.dtype = torch.int8
axis: int = -1
quant_min: int = -128
quant_max: int = 127
symmetric_quant: bool = True
def _find_per_channel_min_max(x: torch.Tensor, axis: int):
x_dim = x.size()
new_axis_list = list(range(len(x_dim)))
new_axis_list[axis] = 0
new_axis_list[0] = axis
y = x.permute(new_axis_list)
y = torch.flatten(y, start_dim=1)
return torch.aminmax(y, dim=1)
def _find_qparams(x: torch.Tensor, qconfig: TensorQConfig):
# Only support per-channel symmetric quant to int8 now
axis = qconfig.axis
dtype = qconfig.dtype
symmetric_quant = qconfig.symmetric_quant
quant_min = qconfig.quant_min
quant_max = qconfig.quant_max
assert axis >= 0 and axis < len(x.shape)
assert dtype == torch.int8
min_val, max_val = _find_per_channel_min_max(x, axis)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
scale = torch.ones(min_val_neg.size(), dtype=torch.float32)
if symmetric_quant:
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
eps = torch.zeros_like(scale).fill_(EPS)
scale = torch.max(scale, eps)
return scale, None
else:
assert symmetric_quant
def _quantize_to_dtype(
x: torch.Tensor,
qconfig: TensorQConfig,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor] = None,
):
if zero_point is None:
zero_point = torch.zeros_like(scale)
return torch.ops.quantized_decomposed.quantize_per_channel(
x,
scale,
zero_point,
qconfig.axis,
qconfig.quant_min,
qconfig.quant_max,
qconfig.dtype,
)
def quantize_tensor(x: torch.Tensor, qconfig: TensorQConfig):
scale, zp = _find_qparams(x, qconfig)
x_int = _quantize_to_dtype(x, qconfig, scale, zp)
return x_int, scale, zp
def get_model_parallel_rank():
if USE_CUDA:
return dist.get_rank()
return xm.get_ordinal()
def get_model_parallel_world_size():
if USE_CUDA:
return dist.get_world_size()
return xm.xrt_world_size()
def get_model_parallel_group():
return None
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def forward(ctx, input_, groups, world_size, rank): # type: ignore
ctx.groups, ctx.world_size, ctx.rank = groups, world_size, rank
return input_
@staticmethod
def backward(ctx, grad_output): # type: ignore
groups, world_size, rank = ctx.groups, ctx.world_size, ctx.rank
return my_reduce(grad_output, groups, world_size, rank)
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-redcue the input from the model parallel region."""
@staticmethod
def forward(ctx, input_, groups, world_size, rank): # type: ignore
return my_reduce(input_, groups, world_size, rank)
@staticmethod
def backward(ctx, grad_output): # type: ignore
return grad_output
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def forward(ctx, input_, groups, world_size, rank): # type: ignore
ctx.groups, ctx.world_size, ctx.rank = groups, world_size, rank
return my_split(input_, groups, world_size, rank)
@staticmethod
def backward(ctx, grad_output): # type: ignore
groups, world_size, rank = ctx.groups, ctx.world_size, ctx.rank
return my_gather(grad_output, groups, world_size, rank)
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def forward(ctx, input_, groups, world_size, rank): # type: ignore
ctx.groups, ctx.world_size, ctx.rank = groups, world_size, rank
return my_gather(input_, groups, world_size, rank)
@staticmethod
def backward(ctx, grad_output): # type: ignore
groups, world_size, rank = ctx.groups, ctx.world_size, ctx.rank
return my_split(grad_output, groups, world_size, rank)
# -----------------
# Helper functions.
# -----------------
def copy_to_model_parallel_region(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
return _CopyToModelParallelRegion.apply(input_, groups, world_size, rank)
def reduce_from_model_parallel_region(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
return _ReduceFromModelParallelRegion.apply(input_, groups, world_size, rank)
def scatter_to_model_parallel_region(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
return _ScatterToModelParallelRegion.apply(input_, groups, world_size, rank)
def gather_from_model_parallel_region(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
return _GatherFromModelParallelRegion.apply(input_, groups, world_size, rank)
def ensure_divisibility(numerator: int, denominator: int) -> None:
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
def divide_and_check_no_remainder(numerator: int, denominator: int) -> int:
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(
tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False
) -> Tuple[torch.Tensor, ...]:
"""Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide_and_check_no_remainder(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
# Below copied from fairscale/nn/model_parallel/layers.py
def my_reduce(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
"""All-reduce the the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# All-reduce.
if USE_CUDA:
input_ = torch.ops.c10d_functional.all_reduce(input_, "sum", TAG, RANKSET, GROUP_SIZE)
else:
input_ = xm.all_reduce(xm.REDUCE_SUM, input_, groups=groups)
return input_
def my_split(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
"""Split the tensor along its last dimension and keep the
corresponding slice.
"""
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along last dimension.
input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default.
output = input_list[rank].contiguous()
return output
def my_gather(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
"""Gather tensors and concatinate along the last dimension."""
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
if USE_CUDA:
last_dim = input_.dim() - 1
# Using all_reduce to achieve all_gather as torch.ops.c10d_functional.all_gather_into_tensor
# is buggy in 16 bits.
size = input_.size(last_dim)
padding = [0] * (2 * input_.dim())
ordinal = rank
left, right = ordinal, world_size - 1 - ordinal
idx = input_.dim() - 1 - last_dim
padding[2 * idx] = left * size
padding[2 * idx + 1] = right * size
output = torch.ops.c10d_functional.all_reduce(F.pad(input_, padding), "sum", TAG, RANKSET, GROUP_SIZE)
else:
output = xm.all_gather(input_, dim=-1, groups=groups)
return output
def _initialize_affine_weight(
weight: torch.Tensor,
out_features: int,
in_features: int,
per_partition_size: int,
partition_dim: int,
init_method: Callable[[torch.Tensor], torch.Tensor],
world_size: int,
rank: int,
stride: int = 1,
return_master_weight: bool = False,
) -> Optional[torch.Tensor]:
"""Initialize affine weight for model parallel.
Build the master weight on all processes and scatter
the relevant chunk.
"""
# If we only use 1 process for model parallelism, bypass scatter.
if world_size == 1:
init_method(weight)
if return_master_weight:
return weight
return None
# Initialize master weight
master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False)
init_method(master_weight)
# Split and copy
per_partition_per_stride_size = divide_and_check_no_remainder(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim)
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
torch.cat(my_weight_list, dim=partition_dim, out=weight)
if return_master_weight:
return master_weight
return None
class ParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the embedding dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
keep_master_weight_for_test: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
groups: Optional[List] = None,
quant: bool = False,
) -> None:
super(ParallelEmbedding, self).__init__()
if world_size is None:
self.groups = get_model_parallel_group()
self.world_size = get_model_parallel_world_size()
self.rank = get_model_parallel_rank()
else:
self.groups = groups
self.world_size = world_size
self.rank = rank
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = scale_grad_by_freq
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
self._weight = None
self.quant = quant
# Divide the weight matrix along the embedding dimension.
self.embedding_dim_per_partition = divide_and_check_no_remainder(self.embedding_dim, self.world_size)
# Allocate weights.
if quant:
self.weight = Parameter(
torch.empty(
(self.num_embeddings, self.embedding_dim_per_partition),
dtype=torch.int8,
),
requires_grad=False,
)
self.weight_scaler = Parameter(torch.Tensor(self.num_embeddings))
else:
self.weight = Parameter(torch.Tensor(self.num_embeddings, self.embedding_dim_per_partition))
# And initialize.
_initialize_affine_weight(
self.weight,
self.num_embeddings,
self.embedding_dim,
self.embedding_dim_per_partition,
1,
init_method,
self.world_size,
self.rank,
stride=1,
return_master_weight=False,
)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
input_parallel = copy_to_model_parallel_region(input_, self.groups, self.world_size, self.rank)
# PyTorch eager and inductor do not accept negative values in the input to embedding
# layers. Take the modulus to avoid this error.
if USE_CUDA:
input_parallel = torch.remainder(input_parallel, self.weight.shape[0])
weight = self.weight
if self.quant:
weight = weight * self.weight_scaler.unsqueeze(-1)
output_parallel = F.embedding(
input_parallel,
weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
output = gather_from_model_parallel_region(output_parallel, self.groups, self.world_size, self.rank)
return output
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Arguments:
in_features: first dimension of matrix A.
out_features: second dimension of matrix A.
bias: If true, add bias
gather_output: If true, call all-gether on output and make Y available to
all GPUs, otherwise, every GPU will have its output which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set to
zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be set
to False. It returns the master weights used for initialization.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
gather_output: bool = True,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
stride: int = 1,
keep_master_weight_for_test: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
groups: Optional[List] = None,
quant: bool = False,
) -> None:
super(ColumnParallelLinear, self).__init__()
if world_size is None:
self.groups = get_model_parallel_group()
self.world_size = get_model_parallel_world_size()
self.rank = get_model_parallel_rank()
else:
self.groups = groups
self.world_size = world_size
self.rank = rank
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.quant = quant
# Divide the weight matrix along the last dimension.
self.output_size_per_partition = divide_and_check_no_remainder(out_features, self.world_size)
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
if quant:
self.weight = Parameter(
torch.empty(
(self.output_size_per_partition, self.in_features),
dtype=torch.int8,
),
requires_grad=False,
)
self.weight_scaler = Parameter(torch.Tensor(self.output_size_per_partition))
else:
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features))
if bias:
self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)
# Initialize weight.
self.master_weight = _initialize_affine_weight(
self.weight,
self.out_features,
self.in_features,
self.output_size_per_partition,
0,
init_method,
self.world_size,
self.rank,
stride=stride,
return_master_weight=keep_master_weight_for_test,
)
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(
self.weight.data.transpose(0, 1),
self.groups,
self.world_size,
self.rank,
).transpose_(0, 1)
def set_quantize(self):
assert not self.quant
self.weight = Parameter(
torch.empty((self.output_size_per_partition, self.in_features), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = Parameter(torch.Tensor(self.output_size_per_partition))
self.quant = True
def quantize(self):
assert not self.quant
fp_w = deepcopy(self.weight.data)
orig_dtype = fp_w.dtype
fp_w = fp_w.to(torch.float32)
self.weight = Parameter(
torch.empty((self.output_size_per_partition, self.in_features), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = Parameter(torch.Tensor(self.output_size_per_partition))
qconfig = TensorQConfig(axis=0)
self.weight.data, scale, zero_point = quantize_tensor(fp_w, qconfig)
self.weight_scaler.data = scale.to(orig_dtype)
self.quant = True
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_, self.groups, self.world_size, self.rank)
# Matrix multiply.
if self.quant and USE_CUDA:
# GPUs do not support mixed int8 bf16 computation. Scale int8 weights to bf16 before linear.
scaled_weight = self.weight * self.weight_scaler
output_parallel = F.linear(input_parallel, scaled_weight, self.bias)
elif self.quant:
output_parallel = F.linear(input_parallel, self.weight, self.bias)
output_parallel = output_parallel * self.weight_scaler
else:
output_parallel = F.linear(input_parallel, self.weight, self.bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel, self.groups, self.world_size, self.rank)
else:
output = output_parallel
return output
@classmethod
def create(
cls,
in_features: int,
out_features: int,
bias: bool = True,
gather_output: bool = True,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
stride: int = 1,
keep_master_weight_for_test: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
groups: Optional[List] = None,
quant: bool = False,
):
if world_size == 1 or xr.is_spmd():
# when SPMD is enabled, sharding is done with notation on the Linear.
return Linear(in_features, out_features, bias=bias)
else:
return ColumnParallelLinear(
in_features,
out_features,
bias,
gather_output,
init_method,
stride,
keep_master_weight_for_test,
world_size,
rank,
groups,
quant,
)
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
in_features: first dimension of matrix A.
out_features: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already split
across the GPUs and we do not split again.
init_method: method to initialize weights. Note that bias is always set to
zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be set
to False. It returns the master weights used for initialization.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
input_is_parallel: bool = False,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
stride: int = 1,
keep_master_weight_for_test: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
groups: Optional[List] = None,
quant: bool = False,
):
super(RowParallelLinear, self).__init__()
if world_size is None:
self.groups = get_model_parallel_group()
self.world_size = get_model_parallel_world_size()
self.rank = get_model_parallel_rank()
else:
self.groups = groups
self.world_size = world_size
self.rank = rank
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.input_is_parallel = input_is_parallel
self.quant = quant
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide_and_check_no_remainder(in_features, self.world_size)
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
if quant:
self.weight = Parameter(
torch.empty(
(self.out_features, self.input_size_per_partition),
dtype=torch.int8,
),
requires_grad=False,
)
self.weight_scaler = Parameter(torch.Tensor(self.out_features))
else:
self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition))
if bias:
self.bias = Parameter(torch.Tensor(self.out_features))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)
# Initialize weight.
self.master_weight = _initialize_affine_weight(
self.weight,
self.out_features,
self.in_features,
self.input_size_per_partition,
1,
init_method,
self.world_size,
self.rank,
stride=stride,
return_master_weight=keep_master_weight_for_test,
)
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data, self.groups, self.world_size, self.rank)
def set_quantize(self):
assert not self.quant
self.weight = Parameter(
torch.empty((self.out_features, self.input_size_per_partition), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = Parameter(torch.Tensor(self.out_features))
self.quant = True
def quantize(self):
assert not self.quant
fp_w = deepcopy(self.weight.data)
orig_dtype = fp_w.dtype
fp_w = fp_w.to(torch.float32)
self.weight = Parameter(
torch.empty((self.out_features, self.input_size_per_partition), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = Parameter(torch.Tensor(self.out_features))
qconfig = TensorQConfig(axis=0)
self.weight.data, scale, zero_point = quantize_tensor(fp_w, qconfig)
self.weight_scaler.data = scale.to(orig_dtype)
self.quant = True
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_model_parallel_region(input_, self.groups, self.world_size, self.rank)
# Matrix multiply.
if self.quant and USE_CUDA:
# GPUs do not support mixed int8 bf16 computation. Scale int8 weights to bf16 before linear.
scaled_weight = self.weight * self.weight_scaler
output_parallel = F.linear(input_parallel, scaled_weight, self.bias)
elif self.quant:
output_parallel = F.linear(input_parallel, self.weight, self.bias)
output_parallel = output_parallel * self.weight_scaler
else:
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel, self.groups, self.world_size, self.rank)
if self.bias is not None:
output = output_ + self.bias
else:
output = output_
return output
@classmethod
def create(
cls,
in_features: int,
out_features: int,
bias: bool = True,
input_is_parallel: bool = False,
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_,
stride: int = 1,
keep_master_weight_for_test: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
groups: Optional[List] = None,
quant: bool = False,
):
if world_size == 1 or xr.is_spmd():
# when SPMD is enabled, sharding is done with notation on the Linear.
return Linear(in_features, out_features, bias=bias)
else:
return RowParallelLinear(
in_features,
out_features,
bias,
input_is_parallel,
init_method,
stride,
keep_master_weight_for_test,
world_size,
rank,
groups,
quant,
)