optimum/graphcore/modeling_utils.py (479 lines of code) (raw):
# Copyright 2021 The HuggingFace Team. All rights reserved.
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
from typing import List, Optional, Tuple, Union
import poptorch
import torch
import torch.nn.functional as F
from peft import PeftModel, PeftType, get_peft_model
from torch import nn
from transformers import PreTrainedModel
from optimum.utils import logging
from .ipu_configuration import IncompatibleIPUConfigError, IPUConfig
logger = logging.get_logger(__name__)
_PRETRAINED_TO_PIPELINED_REGISTRY = {}
def register(transformers_cls=None):
def wrapper(cls):
orig_cls = transformers_cls
if orig_cls is None:
found = False
for base_cls in cls.__bases__:
if base_cls != PipelineMixin:
orig_cls = base_cls
found = True
break
if not found:
raise ValueError(f"Was not able to find original transformers class for {cls}")
_PRETRAINED_TO_PIPELINED_REGISTRY[orig_cls] = cls
return cls
return wrapper
def to_pipelined(model: nn.Module, ipu_config: IPUConfig, force: bool = False):
model_cls = model.get_base_model().__class__ if isinstance(model, PeftModel) else model.__class__
pipelined_cls = _PRETRAINED_TO_PIPELINED_REGISTRY.get(model_cls, None)
if pipelined_cls is not None and isinstance(model, PeftModel):
return pipelined_cls.from_peft(model, ipu_config)
elif pipelined_cls is not None:
return pipelined_cls.from_transformers(model, ipu_config)
# If the user defined his/her own model and already subclassed from PipelineMixin. I.e., the model is already pipelined.
elif isinstance(model, PipelineMixin):
clone = copy.deepcopy(model)
clone.ipu_config = copy.deepcopy(ipu_config)
return clone
else:
if force:
logger.warning(
f"No pipelined version exists for {model_cls.__name__}, creating it dynamically so it might not work as expected."
)
pipelined_cls = type(f"Pipelined{model_cls.__name__}", (model_cls, PipelineMixin), {})
return pipelined_cls.from_model(model)
else:
raise KeyError(f"{model_cls.__name__} pipelined version not found in registry.")
class PipelineMixin:
@classmethod
def from_transformers(cls, model: PreTrainedModel, ipu_config: IPUConfig):
"""
Creates a pipelined version of model from a [`~transformers.PreTrainedModel`] instance.
Args:
model ([`~transformers.PreTrainedModel`]):
The model to convert to a pipelined model.
ipu_config ([`IPUConfig`]):
The `IPUConfig` instance of the pipelined model.
Returns:
The pipelined version of the model.
"""
config = copy.deepcopy(model.config)
generation_config = copy.deepcopy(model.generation_config)
pipelined_model = cls(config)
pipelined_model.load_state_dict(model.state_dict())
pipelined_model.ipu_config = copy.deepcopy(ipu_config)
pipelined_model.generation_config = generation_config
pipelined_model.training = model.training
return pipelined_model
@classmethod
def from_pretrained_transformers(cls, model_name_or_path: str, ipu_config: IPUConfig, *model_args, **kwargs):
"""
Creates a pipelined version of a model by using the `from_pretrained` function.
Args:
model_name_or_path (`str`):
The model name or path.
ipu_config ([`IPUConfig`]):
The `IPUConfig` of the pipelined model.
model_args (`Tuple[Any]`):
The positional arguments to use when instantiating the model.
kwargs (`Dict[str, Any]`):
The keyword arguments to use when instantiating the model.
Returns:
The pipelined model.
"""
pipelined_model = cls.from_pretrained(model_name_or_path, *model_args, **kwargs)
pipelined_model.ipu_config = copy.deepcopy(ipu_config)
return pipelined_model
@classmethod
def from_peft(cls, model: PeftModel, ipu_config: IPUConfig):
"""
Creates a pipelined version of model from a [`~peft.PeftModel`] instance.
Currently, only `peft.PeftType.LORA` is supported.
Args:
model ([`~peft.PeftModel`]):
The model to convert to a pipelined model.
ipu_config ([`IPUConfig`]):
The `IPUConfig` instance of the pipelined model.
Returns:
An instance of `peft.PeftModel` wrapping a pipelined version of the base model.
"""
# Technically speaking, instead of returning an instance of a `PipelineMixin`, such as Pipelined<Model>For<Task>,
# we return an instance of a `peft.PeftModel` which wraps such a pipelined model and defers attribute access.
if len(model.peft_config) > 1 or model.active_adapter != "default":
raise ValueError("Currently only `PeftModel` instances with the `'default'` adapter are supported.")
if model.peft_type != PeftType.LORA:
raise ValueError(f"Currently only LoRA is supported, received {model.peft_type}.")
pretrained = model.get_base_model()
config = copy.deepcopy(pretrained.config)
generation_config = copy.deepcopy(pretrained.generation_config)
peft_config = model.active_peft_config
pipelined_model = cls(config)
pipelined_model.ipu_config = copy.deepcopy(ipu_config)
pipelined_model.generation_config = generation_config
peft_pipelined_model = get_peft_model(pipelined_model, peft_config)
peft_pipelined_model.load_state_dict(model.state_dict())
peft_pipelined_model.training = model.training
return peft_pipelined_model
@classmethod
def from_model(cls, model: nn.Module):
clone = copy.deepcopy(model)
clone.__class__ = cls
# Just needed so that .parallelize() does not throw an error
clone.ipu_config = IPUConfig()
return clone
def _has_ipu_config_check(self):
_ipu_config = getattr(self, "_ipu_config", None)
if _ipu_config is None:
raise AttributeError("No IPUConfig was found. Please set the ipu_config attribute")
@property
def ipu_config(self):
"""Checks that the model has an [`IPUConfig`] attached, and returns it."""
self._has_ipu_config_check()
return self._ipu_config
@ipu_config.setter
def ipu_config(self, value: IPUConfig):
if not isinstance(value, IPUConfig):
raise TypeError(f"ipu_config must be an instance of IPUConfig, but {type(value)} was provided")
self._ipu_config = value
def parallelize(self):
"""Transforms the model to run in an IPU pipeline."""
self._hooks = []
self._has_ipu_config_check()
return self
def deparallelize(self):
"""
Undoes the changes to the model done by `parallelize`.
You should call this function before calling `save_pretrained` so that the `model.state_dict` dictionary is fully compatible with the
original model.
"""
# Remove hooks
if hasattr(self, "_hooks"):
for h in self._hooks:
h.remove()
# Remove poptorch Blocks
for m in self.modules():
if m is not self:
poptorch.removeBlocks(m)
return self
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
"""
Gets the number of (optionally, trainable or non-embeddings) parameters in the module.
Args:
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
If `True`, only returns the number of trainable parameters.
exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`):
If `True`, only returns the number of non-embeddings parameters.
Returns:
:obj:`int`: The number of parameters.
"""
# TODO: actually overwrite this to handle SerializedEmbedding.
if exclude_embeddings:
embedding_param_names = [
f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
]
non_embedding_parameters = [
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
]
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
def _expand_layers_per_ipu_wildcard(
ipu_config: IPUConfig, target_number_of_layers: Optional[Union[int, List]] = None
) -> List[int]:
"""
Expands any wildcard values in `layers_per_ipu` of the IPU configuration.
For example, if we have:
```
layers_per_ipu = [-1, -1]
target_number_of_layers = 9
```
this function will expand the wildcard values to `layers_per_ipu = [4, 5]`
Args:
ipu_config ([`IPUConfig`]):
The `IPUConfig` instance of the model.
target_number_of_layers (:obj:`int` or `List[int]`, `optional`):
The total number of target layers.
Returns:
:obj:`List[int]`: The `layers_per_ipu` with wildcards replaced by the number of layers per IPU.
"""
layers_per_ipu = copy.deepcopy(ipu_config._layers_per_ipu)
layers_per_ipu_mode_str = ipu_config._get_managed_attr_mode_name("layers_per_ipu")
ipus_per_replica = ipu_config._ipus_per_replica
ipus_per_replica_mode_str = ipu_config._get_managed_attr_mode_name("ipus_per_replica")
# Check inputs are valid
if not all(isinstance(n, int) and n >= -1 for n in layers_per_ipu):
raise IncompatibleIPUConfigError(
f"Invalid values in {layers_per_ipu_mode_str}. {layers_per_ipu_mode_str}={layers_per_ipu}"
)
if ipus_per_replica < 1:
raise IncompatibleIPUConfigError(
f"Invalid value for {ipus_per_replica_mode_str}. {ipus_per_replica_mode_str}={ipus_per_replica}"
)
if target_number_of_layers is not None:
if not isinstance(target_number_of_layers, int):
target_number_of_layers = len(target_number_of_layers)
# if ipus_per_replica is 1, then put everything on IPU0, ignoring layers_per_ipu
if ipus_per_replica == 1:
return [target_number_of_layers]
elif ipus_per_replica > 1:
# default/wildcards - split layers evenly over all ipus
if layers_per_ipu in ([-1], [-1] * ipus_per_replica):
quotient, remainder = divmod(target_number_of_layers, ipus_per_replica)
layers_per_ipu = [quotient] * ipus_per_replica
if remainder > 0:
# add any remainder layers to last wildcard IPU
layers_per_ipu[-1] += remainder
# combination of wildcards and integers
elif -1 in layers_per_ipu and len(layers_per_ipu) == ipus_per_replica:
wildcard_idxs = [idx for idx, v in enumerate(layers_per_ipu) if v == -1]
num_wildcard_ipus = len(wildcard_idxs)
# wildcard_layers = target_num_layers - num_non_wildcard_layers
num_wildcard_layers = target_number_of_layers - sum([l for l in layers_per_ipu if l != -1])
quotient, remainder = divmod(num_wildcard_layers, num_wildcard_ipus)
for idx in wildcard_idxs:
layers_per_ipu[idx] = quotient
if remainder > 0:
# add any remainder layers to last wildcard IPU
layers_per_ipu[wildcard_idxs[-1]] += remainder
elif len(layers_per_ipu) != ipus_per_replica:
raise IncompatibleIPUConfigError(
f"{layers_per_ipu_mode_str} has a non-default value set, but its length does not match {ipus_per_replica_mode_str}"
f"{layers_per_ipu_mode_str}={layers_per_ipu}, {ipus_per_replica_mode_str}={ipus_per_replica}. "
)
# no wildcards used
elif sum(layers_per_ipu) != target_number_of_layers:
raise IncompatibleIPUConfigError(
f"{layers_per_ipu_mode_str} does not define the correct number of layers for the current model."
" The current IPU Config specifies IPU assignments for "
f"{sum(layers_per_ipu)} layers but there are {target_number_of_layers} layers "
f"in the model. {layers_per_ipu_mode_str}={layers_per_ipu}"
)
return layers_per_ipu
def split_encoder_decoder_ipu_config(
ipu_config: IPUConfig, num_encoder_layers: int, num_decoder_layers: int
) -> List[IPUConfig]:
"""
Splits an `IPUConfig` instance for an encoder-decoder model into a configuration for the encoder part and a configuration for the decoder part.
It also splits `layers_per_ipu` into two given the numbers of encoder and decoder layers.
Example:
```
>> ipu_config = IPUConfig(layers_per_ipu=[12, 12], ipus_per_replica=2)
>> encoder_ipu_config, decoder_ipu_config = split_encoder_decoder_ipu_config(ipu_config, 12, 12)
>> encoder_ipu_config
=> IPUConfig(layers_ler_ipu=[12], ipus_per_replica=1)
>> decoder_ipu_config
=> IPUConfig(layers_ler_ipu=[12], ipus_per_replica=1)
```
Args:
ipu_config:
The `IPUConfig` instance for the the whole encoder-decoder model.
num_encoder_layers:
The number of encoder layers in the model.
num_decoder_layers:
The number of decoder layers in the model.
Returns:
The configuration for the encoder part, `encoder_ipu_config`, and the configuration for the decoder part, `decoder_ipu_config`.
"""
layers_per_ipu_mode_str = ipu_config._get_managed_attr_mode_name("layers_per_ipu")
ipus_per_replica_mode_str = ipu_config._get_managed_attr_mode_name("ipus_per_replica")
# Need at least two IPUs to do the split
if ipu_config._ipus_per_replica < 2:
raise IncompatibleIPUConfigError(
f"Need {ipus_per_replica_mode_str} to be at least 2 to split ipu_config into encoder and decoder configs"
)
ipu_configs = {name: copy.deepcopy(ipu_config) for name in ["encoder", "decoder"]}
# Split layers_per_ipu between the given num layers
layers_per_ipu = _expand_layers_per_ipu_wildcard(ipu_config, num_encoder_layers + num_decoder_layers)
cumsum = [sum(layers_per_ipu[: i + 1]) for i in range(len(layers_per_ipu))]
try:
cut = [i + 1 for i, c in enumerate(cumsum) if c == num_encoder_layers]
# Choose the cut index that's the highest power of 2
cut = max([num for num in cut if num & (num - 1) == 0])
except Exception:
raise IncompatibleIPUConfigError(
f"Unable to find a valid split of ipu_config.{layers_per_ipu_mode_str}\n"
"Arguments: \n"
f"\tipu_config.{layers_per_ipu_mode_str}={ipu_config._layers_per_ipu}\n"
f"\tnum_encoder_layers={num_encoder_layers}\n"
f"\tnum_decoder_layers={num_decoder_layers}\n"
"Possible causes: \n"
"Encoder and decoder layers cannot be placed on the same IPUs.\n"
f"The encoder and decoder {layers_per_ipu_mode_str} splits each need a number of devices that's a power of 2."
)
ipu_configs["encoder"]._layers_per_ipu = layers_per_ipu[:cut]
ipu_configs["decoder"]._layers_per_ipu = layers_per_ipu[cut:]
# Split the per ipu configurations for SerializedEmbedding and SplitProjection if they have been provided
# Note that serialized layers across IPUs cannot be present in both the encoder and decoder
if ipu_config._serialized_embedding_splits_per_ipu is not None:
ipu_configs["encoder"]._serialized_embedding_splits_per_ipu = ipu_config._serialized_embedding_splits_per_ipu[
:cut
]
ipu_configs["decoder"]._serialized_embedding_splits_per_ipu = None
# dec_split must contain all zeros, this layer cannot be split across the encoder and decoder
if sum(dec_split := ipu_config._serialized_embedding_splits_per_ipu[cut:]):
serialized_embedding_splits_per_ipu_mode_str = ipu_config._get_managed_attr_mode_name(
"serialized_embedding_splits_per_ipu"
)
raise ValueError(
"The `SerializedEmbedding` must have all splits placed on the encoder, but"
f" {serialized_embedding_splits_per_ipu_mode_str}={ipu_config._serialized_embedding_splits_per_ipu} results in"
f" {dec_split} being placed on the decoder"
)
if ipu_config._serialized_projection_splits_per_ipu is not None:
ipu_configs["encoder"]._serialized_projection_splits_per_ipu = None
ipu_configs[
"decoder"
]._serialized_projection_splits_per_ipu = ipu_config._serialized_projection_splits_per_ipu[cut:]
if sum(enc_split := ipu_config._serialized_projection_splits_per_ipu[:cut]):
serialized_projection_splits_per_ipu_mode_str = ipu_config._get_managed_attr_mode_name(
"serialized_projection_splits_per_ipu"
)
raise ValueError(
"The `SplitProjection` must have all splits placed on the decoder, but"
f" {serialized_projection_splits_per_ipu_mode_str}={ipu_config._serialized_projection_splits_per_ipu} results in"
f" {enc_split} being placed on the encoder"
)
# Modify the ipus_per_replica
ipu_configs["encoder"]._ipus_per_replica = len(ipu_configs["encoder"]._layers_per_ipu)
ipu_configs["decoder"]._ipus_per_replica = len(ipu_configs["decoder"]._layers_per_ipu)
# Split matmul_proportion between the given num layers
matmul_proportion = ipu_config._matmul_proportion
if isinstance(matmul_proportion, list):
ipu_configs["encoder"]._matmul_proportion = matmul_proportion[:cut]
ipu_configs["decoder"]._matmul_proportion = matmul_proportion[cut:]
return ipu_configs.values()
def get_layer_ipu(ipu_config: IPUConfig, target_number_of_layers: Optional[Union[int, List]] = None) -> List[int]:
layers_per_ipu = _expand_layers_per_ipu_wildcard(ipu_config, target_number_of_layers)
# List of the IPU Id for each layer
layer_ipu: List[int] = []
for ipu, n_layers in enumerate(layers_per_ipu):
layer_ipu += [ipu] * n_layers
return layer_ipu
def recomputation_checkpoint(module: nn.Module) -> torch.utils.hooks.RemovableHandle:
"""Annotates the output of a module to be checkpointed instead of
recomputed."""
def recompute_outputs(module, inputs, outputs):
if isinstance(outputs, torch.Tensor):
return poptorch.recomputationCheckpoint(outputs)
elif isinstance(outputs, tuple):
return tuple(poptorch.recomputationCheckpoint(y) for y in outputs)
return module.register_forward_hook(recompute_outputs)
def outline_attribute(module: nn.Module, value: str):
"""Adds an attribute to a module. This attribute will be used
when comparing operation equivalence in outlining.
For example:
```
layer1 = nn.Linear(...)
layer2 = nn.Linear(...)
layer3 = nn.Linear(...)
layer4 = nn.Linear(...)
outline_attribute(layer1, "A")
outline_attribute(layer2, "A")
outline_attribute(layer3, "B")
```
The code for `layer1` can be reused for `layer2`, but
it can't be used for `layer3` or `layer4`.
"""
context = poptorch.Attribute(__outline={"layer": value})
def enable(*args):
context.__enter__()
def disable(*args):
context.__exit__(None, None, None)
handles = []
handles.append(module.register_forward_pre_hook(enable))
handles.append(module.register_forward_hook(disable))
return handles
class SerializedEmbedding(nn.Module):
"""
Wrapper for an `nn.Embedding` layer that performs the embedding look-up into
smaller serialized steps in order to reduce memory in the embedding gradient
calculation.
Args:
embedding:
An `nn.Embedding` instance to wrap.
serialization_factor:
The number of serialized embedding look-ups.
"""
def __init__(self, embedding: nn.Embedding, serialization_factor: int):
super().__init__()
self.serialization_factor = serialization_factor
self.num_embeddings = embedding.num_embeddings
# Num embeddings should be divisible by the serialization factor
assert self.num_embeddings % self.serialization_factor == 0
self.split_size = self.num_embeddings // self.serialization_factor
freeze = not embedding.weight.requires_grad
self.padding_idx = embedding.padding_idx
boundaries = torch.linspace(
self.split_size - 1, self.num_embeddings - 1, self.serialization_factor, dtype=torch.int
)
self.split_embeddings = nn.ModuleList(
[
nn.Embedding.from_pretrained(
embedding.weight[i * self.split_size : (i + 1) * self.split_size, :].detach(),
freeze=freeze,
padding_idx=self.padding_idx - i * self.split_size
if self.padding_idx and i == torch.bucketize(self.padding_idx, boundaries).item()
else None,
)
for i in range(self.serialization_factor)
]
)
@classmethod
def from_model(cls, embedding: nn.Embedding, serialization_factor: int) -> SerializedEmbedding:
return cls(embedding, serialization_factor)
def parallelize(self, splits_per_ipu: List[int]):
for ipu_id, splits in enumerate(splits_per_ipu):
if splits:
from_split = sum(splits_per_ipu[:ipu_id])
to_split = from_split + splits - 1
shard_range = f"{from_split}-{to_split}" if from_split != to_split else f"{from_split}"
logger.info(f"Embedding splits: {shard_range} --> IPU {ipu_id}")
self.split_embeddings[from_split] = poptorch.BeginBlock(
self.split_embeddings[from_split], ipu_id=ipu_id, user_id=f"Embedding-{shard_range}"
)
return self
def to_model(self) -> nn.Embedding:
"""
Deserialize the internal wrapped embedding layer and return it as an
`nn.Embedding` object.
Returns:
An `nn.Embedding` layer.
"""
freeze = not self.split_embeddings[0].weight.requires_grad
return nn.Embedding.from_pretrained(
torch.vstack([l.weight for l in self.split_embeddings]), padding_idx=self.padding_idx, freeze=freeze
)
def forward(self, indices):
# iterate through the splits
x_sum = None
for i in range(self.serialization_factor):
# mask out the indices not in this split
split_indices = indices - i * self.split_size
mask = (split_indices >= 0) * (split_indices < self.split_size)
mask = mask.detach()
split_indices *= mask
# do the embedding lookup
x = self.split_embeddings[i](split_indices)
# multiply the output by mask
x *= mask.unsqueeze(-1)
# add to partial
if x_sum is not None:
x_sum += x
else:
x_sum = x
return x_sum
class SerializedLinear(nn.Linear):
"""
Exactly equivalent to `nn.Linear` layer, but with the matrix multiplication replaced with
a serialized matrix multiplication: `poptorch.serializedMatMul`.
The matrix multiplication is split into separate smaller multiplications, calculated one after the other,
to reduce the memory requirements of the multiplication and its gradient calculation.
Args:
in_features:
Size of each input sample
out_features:
Size of each output sample
factor:
Number of serialized multiplications. Must be a factor of
the dimension to serialize on.
bias: If set to `False`, the layer will not learn an additive bias.
Default: `True`.
mode: The dimension of the matmul to serialize on.
For matrix A (m by n) multiplied by matrix B (n by p).
* InputChannels: Split across the input channels (dimension m).
* ReducingDim: Split across the reducing dimension (n).
* OutputChannels: Split across the output channels (dimension p).
* Disabled: Same as an ordinary matrix multiplication.
"""
def __init__(
self,
in_features,
out_features,
factor,
bias=False,
mode=poptorch.MatMulSerializationMode.OutputChannels,
):
super().__init__(in_features, out_features, bias)
self.mode = mode
self.factor = factor
@classmethod
def from_model(
cls, model: nn.Linear, factor: int, mode=poptorch.MatMulSerializationMode.OutputChannels
) -> SerializedLinear:
clone = copy.deepcopy(model)
clone.__class__ = cls
clone.factor = factor
clone.mode = mode
return clone
def to_model(self) -> nn.Linear:
del self.factor
del self.mode
original = copy.deepcopy(self)
original.__class__ = nn.Linear
return original
def forward(self, x):
output = poptorch.serializedMatMul(x, self.weight.t(), self.mode, self.factor)
if self.bias is not None:
output += self.bias
return output
class SplitProjection(torch.nn.Module):
"""
Exactly equivalent to `nn.Linear` layer, but with the linear layer split into
partial linear layers in order to reduce resident memory. The linear layer
is split along the reducing dimension `nn.Linear.in_features` in equal parts.
The forward call aggregates the partial sums obtained from each linear layer.
Args:
linear: A `nn.Linear` to wrap
serialization_factor: The number of partitions of the linear layer. This must
be a factor of linear.in_features
serialization_mode: The dimension of the matmul to serialize on.
For matrix A (m by n) multiplied by matrix B (n by p):
* ReducingDim: Split across the reducing dimension (n).
* OutputChannels: Split across the output channels (dimension p).
"""
def __init__(
self,
linear: torch.nn.Linear,
serialization_factor: int,
serialization_mode=poptorch.MatMulSerializationMode.OutputChannels,
) -> None:
super().__init__()
self.in_features = linear.in_features
self.out_features = linear.out_features
self.serialization_mode = serialization_mode
self.split_linear_layers = torch.nn.ModuleList()
if serialization_mode is poptorch.MatMulSerializationMode.OutputChannels:
if self.out_features % serialization_factor != 0:
raise ValueError(f"{linear.out_features=} must be divisible by {serialization_factor=}")
self.split_size = self.out_features // serialization_factor
for i in range(0, self.out_features, self.split_size):
split_linear = torch.nn.Linear(
self.in_features, self.split_size, bias=False, dtype=linear.weight.dtype
)
with torch.no_grad():
split_linear.weight.copy_(linear.weight[i : i + self.split_size, :].detach())
self.split_linear_layers.append(split_linear)
elif serialization_mode is poptorch.MatMulSerializationMode.ReducingDim:
if self.in_features % serialization_factor != 0:
raise ValueError(f"{linear.in_features=} must be divisible by {serialization_factor=}")
self.split_size = self.in_features // serialization_factor
for i in range(0, self.in_features, self.split_size):
split_linear = torch.nn.Linear(
self.split_size, self.out_features, bias=False, dtype=linear.weight.dtype
)
with torch.no_grad():
split_linear.weight.copy_(linear.weight[:, i : i + self.split_size].detach())
self.split_linear_layers.append(split_linear)
else:
raise ValueError(
f"`SplitProjection` `serialization_mode` can only be {poptorch.MatMulSerializationMode.OutputChannels} or {poptorch.MatMulSerializationMode.ReducingDim}."
f" You provided: {serialization_mode=}"
)
self.bias = None
if linear.bias is not None:
self.bias = torch.nn.Parameter(torch.zeros_like(linear.bias))
with torch.no_grad():
self.bias.copy_(linear.bias.detach())
def forward(self, x):
if self.serialization_mode is poptorch.MatMulSerializationMode.OutputChannels:
out = []
for i, split_linear_layer in enumerate(self.split_linear_layers):
out.append(split_linear_layer(x))
out = torch.concat(out, -1)
elif self.serialization_mode is poptorch.MatMulSerializationMode.ReducingDim:
out = self.split_linear_layers[0](x[..., 0 : self.split_size])
for i, split_linear_layer in enumerate(self.split_linear_layers[1:]):
out += split_linear_layer(x[..., i * self.split_size : (i + 1) * self.split_size])
if self.bias is not None:
out += self.bias
return out
@classmethod
def from_model(
cls,
linear: torch.nn.Linear,
serialization_factor: int,
serialization_mode=poptorch.MatMulSerializationMode.OutputChannels,
) -> SplitProjection:
return cls(linear, serialization_factor, serialization_mode)
def to_model(self) -> nn.Linear:
"""
Merge the sub linear layers into one
Returns:
`nn.Linear` layer
"""
dtype = self.split_linear_layers[0].weight.dtype
layer = nn.Linear(self.in_features, self.out_features, bias=self.bias is not None)
if dtype == torch.float16:
layer = layer.half()
if self.serialization_mode is poptorch.MatMulSerializationMode.OutputChannels:
with torch.no_grad():
layer.weight.copy_(torch.vstack([l.weight.detach() for l in self.split_linear_layers]))
elif self.serialization_mode is poptorch.MatMulSerializationMode.ReducingDim:
with torch.no_grad():
layer.weight.copy_(torch.hstack([l.weight.detach() for l in self.split_linear_layers]))
if self.bias is not None:
with torch.no_grad():
layer.bias.copy_(self.bias)
return layer
def parallelize(self, splits_per_ipu: List[int]):
for ipu_id, splits in enumerate(splits_per_ipu):
if splits:
from_split = sum(splits_per_ipu[:ipu_id])
to_split = from_split + splits - 1
shard_range = f"{from_split}-{to_split}" if from_split != to_split else f"{from_split}"
logger.info(f"Linear splits: {shard_range} --> IPU {ipu_id}")
self.split_linear_layers[from_split] = poptorch.BeginBlock(
self.split_linear_layers[from_split], ipu_id=ipu_id, user_id=f"Linear-{shard_range}"
)
return self
class SharedEmbedding(nn.Module):
"""
Wrapper around the shared embedding between the encoder and the decoder stacks.
Attributes:
shared:
The shared embedding layer.
"""
def __init__(self, shared: nn.Embedding):
super().__init__()
self.shared = shared
def _combine_inputs(self, input_ids: torch.Tensor, decoder_input_ids: torch.Tensor) -> Tuple[int, torch.Tensor]:
idx = input_ids.size(1)
return idx, torch.cat([input_ids, decoder_input_ids], dim=1)
def _separate_inputs(self, idx: int, embeds: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return embeds[:, :idx, :], embeds[:, idx:, :]
def forward(
self,
input_ids: torch.Tensor,
decoder_input_ids: torch.Tensor,
encoder_embed_scale: Optional[float] = None,
decoder_embed_scale: Optional[float] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO: use this once the TiedGather pattern issue is solved.
# encoder_inputs_embeds, decoder_inputs_embeds = None, None
# if input_ids is not None and encoder_embed_scale is not None:
# encoder_inputs_embeds = self.shared(input_ids) * encoder_embed_scale
# if decoder_input_ids is not None and decoder_embed_scale is not None:
# decoder_inputs_embeds = self.shared(decoder_input_ids) * decoder_embed_scale
# combined, n1, n2 = self._combine_inputs(input_ids, decoder_input_ids)
# encoder_inputs_embeds, decoder_inputs_embeds = self._separate_inputs(self.shared(combined), n1, n2)
encoder_inputs_embeds, decoder_inputs_embeds = None, None
if input_ids is None:
# call on decoder_input_ids only
decoder_inputs_embeds = self.shared(decoder_input_ids)
elif decoder_input_ids is None:
# call on input_ids only
encoder_inputs_embeds = self.shared(input_ids)
else:
# Call on the combined case
# This case is assuming input_ids and decoder_input_ids are not None
idx, combined = self._combine_inputs(input_ids, decoder_input_ids)
encoder_inputs_embeds, decoder_inputs_embeds = self._separate_inputs(idx, self.shared(combined))
if encoder_embed_scale:
encoder_inputs_embeds = encoder_inputs_embeds * encoder_embed_scale
if decoder_embed_scale:
decoder_inputs_embeds = decoder_inputs_embeds * decoder_embed_scale
return encoder_inputs_embeds, decoder_inputs_embeds
class OnehotGather(nn.Module):
"""
Gathers selected indices from a tensor by transforming the list of indices
into a one-hot matrix and then multiplying the tensor by that matrix.
"""
def forward(self, sequence, positions):
"""
Gathers the vectors at the specific positions over a batch.
"""
num_classes = int(sequence.shape[1])
one_hot_positions = F.one_hot(positions, num_classes).to(dtype=sequence.dtype)
return torch.matmul(one_hot_positions.detach(), sequence)
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
# Upstream version:
# shifted_input_ids = input_ids.new_zeros(input_ids.shape)
# shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
# shifted_input_ids[:, 0] = decoder_start_token_id
# Change to fix slice assignment:
shifted_input_ids = torch.cat(
[torch.ones(input_ids.shape[0], 1, dtype=input_ids.dtype) * decoder_start_token_id, input_ids[:, :-1]], 1
)
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = torch.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids