optimum/graphcore/modeling_utils.py (479 lines of code) (raw):

# Copyright 2021 The HuggingFace Team. All rights reserved. # Copyright (c) 2022 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import copy from typing import List, Optional, Tuple, Union import poptorch import torch import torch.nn.functional as F from peft import PeftModel, PeftType, get_peft_model from torch import nn from transformers import PreTrainedModel from optimum.utils import logging from .ipu_configuration import IncompatibleIPUConfigError, IPUConfig logger = logging.get_logger(__name__) _PRETRAINED_TO_PIPELINED_REGISTRY = {} def register(transformers_cls=None): def wrapper(cls): orig_cls = transformers_cls if orig_cls is None: found = False for base_cls in cls.__bases__: if base_cls != PipelineMixin: orig_cls = base_cls found = True break if not found: raise ValueError(f"Was not able to find original transformers class for {cls}") _PRETRAINED_TO_PIPELINED_REGISTRY[orig_cls] = cls return cls return wrapper def to_pipelined(model: nn.Module, ipu_config: IPUConfig, force: bool = False): model_cls = model.get_base_model().__class__ if isinstance(model, PeftModel) else model.__class__ pipelined_cls = _PRETRAINED_TO_PIPELINED_REGISTRY.get(model_cls, None) if pipelined_cls is not None and isinstance(model, PeftModel): return pipelined_cls.from_peft(model, ipu_config) elif pipelined_cls is not None: return pipelined_cls.from_transformers(model, ipu_config) # If the user defined his/her own model and already subclassed from PipelineMixin. I.e., the model is already pipelined. elif isinstance(model, PipelineMixin): clone = copy.deepcopy(model) clone.ipu_config = copy.deepcopy(ipu_config) return clone else: if force: logger.warning( f"No pipelined version exists for {model_cls.__name__}, creating it dynamically so it might not work as expected." ) pipelined_cls = type(f"Pipelined{model_cls.__name__}", (model_cls, PipelineMixin), {}) return pipelined_cls.from_model(model) else: raise KeyError(f"{model_cls.__name__} pipelined version not found in registry.") class PipelineMixin: @classmethod def from_transformers(cls, model: PreTrainedModel, ipu_config: IPUConfig): """ Creates a pipelined version of model from a [`~transformers.PreTrainedModel`] instance. Args: model ([`~transformers.PreTrainedModel`]): The model to convert to a pipelined model. ipu_config ([`IPUConfig`]): The `IPUConfig` instance of the pipelined model. Returns: The pipelined version of the model. """ config = copy.deepcopy(model.config) generation_config = copy.deepcopy(model.generation_config) pipelined_model = cls(config) pipelined_model.load_state_dict(model.state_dict()) pipelined_model.ipu_config = copy.deepcopy(ipu_config) pipelined_model.generation_config = generation_config pipelined_model.training = model.training return pipelined_model @classmethod def from_pretrained_transformers(cls, model_name_or_path: str, ipu_config: IPUConfig, *model_args, **kwargs): """ Creates a pipelined version of a model by using the `from_pretrained` function. Args: model_name_or_path (`str`): The model name or path. ipu_config ([`IPUConfig`]): The `IPUConfig` of the pipelined model. model_args (`Tuple[Any]`): The positional arguments to use when instantiating the model. kwargs (`Dict[str, Any]`): The keyword arguments to use when instantiating the model. Returns: The pipelined model. """ pipelined_model = cls.from_pretrained(model_name_or_path, *model_args, **kwargs) pipelined_model.ipu_config = copy.deepcopy(ipu_config) return pipelined_model @classmethod def from_peft(cls, model: PeftModel, ipu_config: IPUConfig): """ Creates a pipelined version of model from a [`~peft.PeftModel`] instance. Currently, only `peft.PeftType.LORA` is supported. Args: model ([`~peft.PeftModel`]): The model to convert to a pipelined model. ipu_config ([`IPUConfig`]): The `IPUConfig` instance of the pipelined model. Returns: An instance of `peft.PeftModel` wrapping a pipelined version of the base model. """ # Technically speaking, instead of returning an instance of a `PipelineMixin`, such as Pipelined<Model>For<Task>, # we return an instance of a `peft.PeftModel` which wraps such a pipelined model and defers attribute access. if len(model.peft_config) > 1 or model.active_adapter != "default": raise ValueError("Currently only `PeftModel` instances with the `'default'` adapter are supported.") if model.peft_type != PeftType.LORA: raise ValueError(f"Currently only LoRA is supported, received {model.peft_type}.") pretrained = model.get_base_model() config = copy.deepcopy(pretrained.config) generation_config = copy.deepcopy(pretrained.generation_config) peft_config = model.active_peft_config pipelined_model = cls(config) pipelined_model.ipu_config = copy.deepcopy(ipu_config) pipelined_model.generation_config = generation_config peft_pipelined_model = get_peft_model(pipelined_model, peft_config) peft_pipelined_model.load_state_dict(model.state_dict()) peft_pipelined_model.training = model.training return peft_pipelined_model @classmethod def from_model(cls, model: nn.Module): clone = copy.deepcopy(model) clone.__class__ = cls # Just needed so that .parallelize() does not throw an error clone.ipu_config = IPUConfig() return clone def _has_ipu_config_check(self): _ipu_config = getattr(self, "_ipu_config", None) if _ipu_config is None: raise AttributeError("No IPUConfig was found. Please set the ipu_config attribute") @property def ipu_config(self): """Checks that the model has an [`IPUConfig`] attached, and returns it.""" self._has_ipu_config_check() return self._ipu_config @ipu_config.setter def ipu_config(self, value: IPUConfig): if not isinstance(value, IPUConfig): raise TypeError(f"ipu_config must be an instance of IPUConfig, but {type(value)} was provided") self._ipu_config = value def parallelize(self): """Transforms the model to run in an IPU pipeline.""" self._hooks = [] self._has_ipu_config_check() return self def deparallelize(self): """ Undoes the changes to the model done by `parallelize`. You should call this function before calling `save_pretrained` so that the `model.state_dict` dictionary is fully compatible with the original model. """ # Remove hooks if hasattr(self, "_hooks"): for h in self._hooks: h.remove() # Remove poptorch Blocks for m in self.modules(): if m is not self: poptorch.removeBlocks(m) return self def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: """ Gets the number of (optionally, trainable or non-embeddings) parameters in the module. Args: only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`): If `True`, only returns the number of trainable parameters. exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`): If `True`, only returns the number of non-embeddings parameters. Returns: :obj:`int`: The number of parameters. """ # TODO: actually overwrite this to handle SerializedEmbedding. if exclude_embeddings: embedding_param_names = [ f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding) ] non_embedding_parameters = [ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names ] return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) else: return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) def _expand_layers_per_ipu_wildcard( ipu_config: IPUConfig, target_number_of_layers: Optional[Union[int, List]] = None ) -> List[int]: """ Expands any wildcard values in `layers_per_ipu` of the IPU configuration. For example, if we have: ``` layers_per_ipu = [-1, -1] target_number_of_layers = 9 ``` this function will expand the wildcard values to `layers_per_ipu = [4, 5]` Args: ipu_config ([`IPUConfig`]): The `IPUConfig` instance of the model. target_number_of_layers (:obj:`int` or `List[int]`, `optional`): The total number of target layers. Returns: :obj:`List[int]`: The `layers_per_ipu` with wildcards replaced by the number of layers per IPU. """ layers_per_ipu = copy.deepcopy(ipu_config._layers_per_ipu) layers_per_ipu_mode_str = ipu_config._get_managed_attr_mode_name("layers_per_ipu") ipus_per_replica = ipu_config._ipus_per_replica ipus_per_replica_mode_str = ipu_config._get_managed_attr_mode_name("ipus_per_replica") # Check inputs are valid if not all(isinstance(n, int) and n >= -1 for n in layers_per_ipu): raise IncompatibleIPUConfigError( f"Invalid values in {layers_per_ipu_mode_str}. {layers_per_ipu_mode_str}={layers_per_ipu}" ) if ipus_per_replica < 1: raise IncompatibleIPUConfigError( f"Invalid value for {ipus_per_replica_mode_str}. {ipus_per_replica_mode_str}={ipus_per_replica}" ) if target_number_of_layers is not None: if not isinstance(target_number_of_layers, int): target_number_of_layers = len(target_number_of_layers) # if ipus_per_replica is 1, then put everything on IPU0, ignoring layers_per_ipu if ipus_per_replica == 1: return [target_number_of_layers] elif ipus_per_replica > 1: # default/wildcards - split layers evenly over all ipus if layers_per_ipu in ([-1], [-1] * ipus_per_replica): quotient, remainder = divmod(target_number_of_layers, ipus_per_replica) layers_per_ipu = [quotient] * ipus_per_replica if remainder > 0: # add any remainder layers to last wildcard IPU layers_per_ipu[-1] += remainder # combination of wildcards and integers elif -1 in layers_per_ipu and len(layers_per_ipu) == ipus_per_replica: wildcard_idxs = [idx for idx, v in enumerate(layers_per_ipu) if v == -1] num_wildcard_ipus = len(wildcard_idxs) # wildcard_layers = target_num_layers - num_non_wildcard_layers num_wildcard_layers = target_number_of_layers - sum([l for l in layers_per_ipu if l != -1]) quotient, remainder = divmod(num_wildcard_layers, num_wildcard_ipus) for idx in wildcard_idxs: layers_per_ipu[idx] = quotient if remainder > 0: # add any remainder layers to last wildcard IPU layers_per_ipu[wildcard_idxs[-1]] += remainder elif len(layers_per_ipu) != ipus_per_replica: raise IncompatibleIPUConfigError( f"{layers_per_ipu_mode_str} has a non-default value set, but its length does not match {ipus_per_replica_mode_str}" f"{layers_per_ipu_mode_str}={layers_per_ipu}, {ipus_per_replica_mode_str}={ipus_per_replica}. " ) # no wildcards used elif sum(layers_per_ipu) != target_number_of_layers: raise IncompatibleIPUConfigError( f"{layers_per_ipu_mode_str} does not define the correct number of layers for the current model." " The current IPU Config specifies IPU assignments for " f"{sum(layers_per_ipu)} layers but there are {target_number_of_layers} layers " f"in the model. {layers_per_ipu_mode_str}={layers_per_ipu}" ) return layers_per_ipu def split_encoder_decoder_ipu_config( ipu_config: IPUConfig, num_encoder_layers: int, num_decoder_layers: int ) -> List[IPUConfig]: """ Splits an `IPUConfig` instance for an encoder-decoder model into a configuration for the encoder part and a configuration for the decoder part. It also splits `layers_per_ipu` into two given the numbers of encoder and decoder layers. Example: ``` >> ipu_config = IPUConfig(layers_per_ipu=[12, 12], ipus_per_replica=2) >> encoder_ipu_config, decoder_ipu_config = split_encoder_decoder_ipu_config(ipu_config, 12, 12) >> encoder_ipu_config => IPUConfig(layers_ler_ipu=[12], ipus_per_replica=1) >> decoder_ipu_config => IPUConfig(layers_ler_ipu=[12], ipus_per_replica=1) ``` Args: ipu_config: The `IPUConfig` instance for the the whole encoder-decoder model. num_encoder_layers: The number of encoder layers in the model. num_decoder_layers: The number of decoder layers in the model. Returns: The configuration for the encoder part, `encoder_ipu_config`, and the configuration for the decoder part, `decoder_ipu_config`. """ layers_per_ipu_mode_str = ipu_config._get_managed_attr_mode_name("layers_per_ipu") ipus_per_replica_mode_str = ipu_config._get_managed_attr_mode_name("ipus_per_replica") # Need at least two IPUs to do the split if ipu_config._ipus_per_replica < 2: raise IncompatibleIPUConfigError( f"Need {ipus_per_replica_mode_str} to be at least 2 to split ipu_config into encoder and decoder configs" ) ipu_configs = {name: copy.deepcopy(ipu_config) for name in ["encoder", "decoder"]} # Split layers_per_ipu between the given num layers layers_per_ipu = _expand_layers_per_ipu_wildcard(ipu_config, num_encoder_layers + num_decoder_layers) cumsum = [sum(layers_per_ipu[: i + 1]) for i in range(len(layers_per_ipu))] try: cut = [i + 1 for i, c in enumerate(cumsum) if c == num_encoder_layers] # Choose the cut index that's the highest power of 2 cut = max([num for num in cut if num & (num - 1) == 0]) except Exception: raise IncompatibleIPUConfigError( f"Unable to find a valid split of ipu_config.{layers_per_ipu_mode_str}\n" "Arguments: \n" f"\tipu_config.{layers_per_ipu_mode_str}={ipu_config._layers_per_ipu}\n" f"\tnum_encoder_layers={num_encoder_layers}\n" f"\tnum_decoder_layers={num_decoder_layers}\n" "Possible causes: \n" "Encoder and decoder layers cannot be placed on the same IPUs.\n" f"The encoder and decoder {layers_per_ipu_mode_str} splits each need a number of devices that's a power of 2." ) ipu_configs["encoder"]._layers_per_ipu = layers_per_ipu[:cut] ipu_configs["decoder"]._layers_per_ipu = layers_per_ipu[cut:] # Split the per ipu configurations for SerializedEmbedding and SplitProjection if they have been provided # Note that serialized layers across IPUs cannot be present in both the encoder and decoder if ipu_config._serialized_embedding_splits_per_ipu is not None: ipu_configs["encoder"]._serialized_embedding_splits_per_ipu = ipu_config._serialized_embedding_splits_per_ipu[ :cut ] ipu_configs["decoder"]._serialized_embedding_splits_per_ipu = None # dec_split must contain all zeros, this layer cannot be split across the encoder and decoder if sum(dec_split := ipu_config._serialized_embedding_splits_per_ipu[cut:]): serialized_embedding_splits_per_ipu_mode_str = ipu_config._get_managed_attr_mode_name( "serialized_embedding_splits_per_ipu" ) raise ValueError( "The `SerializedEmbedding` must have all splits placed on the encoder, but" f" {serialized_embedding_splits_per_ipu_mode_str}={ipu_config._serialized_embedding_splits_per_ipu} results in" f" {dec_split} being placed on the decoder" ) if ipu_config._serialized_projection_splits_per_ipu is not None: ipu_configs["encoder"]._serialized_projection_splits_per_ipu = None ipu_configs[ "decoder" ]._serialized_projection_splits_per_ipu = ipu_config._serialized_projection_splits_per_ipu[cut:] if sum(enc_split := ipu_config._serialized_projection_splits_per_ipu[:cut]): serialized_projection_splits_per_ipu_mode_str = ipu_config._get_managed_attr_mode_name( "serialized_projection_splits_per_ipu" ) raise ValueError( "The `SplitProjection` must have all splits placed on the decoder, but" f" {serialized_projection_splits_per_ipu_mode_str}={ipu_config._serialized_projection_splits_per_ipu} results in" f" {enc_split} being placed on the encoder" ) # Modify the ipus_per_replica ipu_configs["encoder"]._ipus_per_replica = len(ipu_configs["encoder"]._layers_per_ipu) ipu_configs["decoder"]._ipus_per_replica = len(ipu_configs["decoder"]._layers_per_ipu) # Split matmul_proportion between the given num layers matmul_proportion = ipu_config._matmul_proportion if isinstance(matmul_proportion, list): ipu_configs["encoder"]._matmul_proportion = matmul_proportion[:cut] ipu_configs["decoder"]._matmul_proportion = matmul_proportion[cut:] return ipu_configs.values() def get_layer_ipu(ipu_config: IPUConfig, target_number_of_layers: Optional[Union[int, List]] = None) -> List[int]: layers_per_ipu = _expand_layers_per_ipu_wildcard(ipu_config, target_number_of_layers) # List of the IPU Id for each layer layer_ipu: List[int] = [] for ipu, n_layers in enumerate(layers_per_ipu): layer_ipu += [ipu] * n_layers return layer_ipu def recomputation_checkpoint(module: nn.Module) -> torch.utils.hooks.RemovableHandle: """Annotates the output of a module to be checkpointed instead of recomputed.""" def recompute_outputs(module, inputs, outputs): if isinstance(outputs, torch.Tensor): return poptorch.recomputationCheckpoint(outputs) elif isinstance(outputs, tuple): return tuple(poptorch.recomputationCheckpoint(y) for y in outputs) return module.register_forward_hook(recompute_outputs) def outline_attribute(module: nn.Module, value: str): """Adds an attribute to a module. This attribute will be used when comparing operation equivalence in outlining. For example: ``` layer1 = nn.Linear(...) layer2 = nn.Linear(...) layer3 = nn.Linear(...) layer4 = nn.Linear(...) outline_attribute(layer1, "A") outline_attribute(layer2, "A") outline_attribute(layer3, "B") ``` The code for `layer1` can be reused for `layer2`, but it can't be used for `layer3` or `layer4`. """ context = poptorch.Attribute(__outline={"layer": value}) def enable(*args): context.__enter__() def disable(*args): context.__exit__(None, None, None) handles = [] handles.append(module.register_forward_pre_hook(enable)) handles.append(module.register_forward_hook(disable)) return handles class SerializedEmbedding(nn.Module): """ Wrapper for an `nn.Embedding` layer that performs the embedding look-up into smaller serialized steps in order to reduce memory in the embedding gradient calculation. Args: embedding: An `nn.Embedding` instance to wrap. serialization_factor: The number of serialized embedding look-ups. """ def __init__(self, embedding: nn.Embedding, serialization_factor: int): super().__init__() self.serialization_factor = serialization_factor self.num_embeddings = embedding.num_embeddings # Num embeddings should be divisible by the serialization factor assert self.num_embeddings % self.serialization_factor == 0 self.split_size = self.num_embeddings // self.serialization_factor freeze = not embedding.weight.requires_grad self.padding_idx = embedding.padding_idx boundaries = torch.linspace( self.split_size - 1, self.num_embeddings - 1, self.serialization_factor, dtype=torch.int ) self.split_embeddings = nn.ModuleList( [ nn.Embedding.from_pretrained( embedding.weight[i * self.split_size : (i + 1) * self.split_size, :].detach(), freeze=freeze, padding_idx=self.padding_idx - i * self.split_size if self.padding_idx and i == torch.bucketize(self.padding_idx, boundaries).item() else None, ) for i in range(self.serialization_factor) ] ) @classmethod def from_model(cls, embedding: nn.Embedding, serialization_factor: int) -> SerializedEmbedding: return cls(embedding, serialization_factor) def parallelize(self, splits_per_ipu: List[int]): for ipu_id, splits in enumerate(splits_per_ipu): if splits: from_split = sum(splits_per_ipu[:ipu_id]) to_split = from_split + splits - 1 shard_range = f"{from_split}-{to_split}" if from_split != to_split else f"{from_split}" logger.info(f"Embedding splits: {shard_range} --> IPU {ipu_id}") self.split_embeddings[from_split] = poptorch.BeginBlock( self.split_embeddings[from_split], ipu_id=ipu_id, user_id=f"Embedding-{shard_range}" ) return self def to_model(self) -> nn.Embedding: """ Deserialize the internal wrapped embedding layer and return it as an `nn.Embedding` object. Returns: An `nn.Embedding` layer. """ freeze = not self.split_embeddings[0].weight.requires_grad return nn.Embedding.from_pretrained( torch.vstack([l.weight for l in self.split_embeddings]), padding_idx=self.padding_idx, freeze=freeze ) def forward(self, indices): # iterate through the splits x_sum = None for i in range(self.serialization_factor): # mask out the indices not in this split split_indices = indices - i * self.split_size mask = (split_indices >= 0) * (split_indices < self.split_size) mask = mask.detach() split_indices *= mask # do the embedding lookup x = self.split_embeddings[i](split_indices) # multiply the output by mask x *= mask.unsqueeze(-1) # add to partial if x_sum is not None: x_sum += x else: x_sum = x return x_sum class SerializedLinear(nn.Linear): """ Exactly equivalent to `nn.Linear` layer, but with the matrix multiplication replaced with a serialized matrix multiplication: `poptorch.serializedMatMul`. The matrix multiplication is split into separate smaller multiplications, calculated one after the other, to reduce the memory requirements of the multiplication and its gradient calculation. Args: in_features: Size of each input sample out_features: Size of each output sample factor: Number of serialized multiplications. Must be a factor of the dimension to serialize on. bias: If set to `False`, the layer will not learn an additive bias. Default: `True`. mode: The dimension of the matmul to serialize on. For matrix A (m by n) multiplied by matrix B (n by p). * InputChannels: Split across the input channels (dimension m). * ReducingDim: Split across the reducing dimension (n). * OutputChannels: Split across the output channels (dimension p). * Disabled: Same as an ordinary matrix multiplication. """ def __init__( self, in_features, out_features, factor, bias=False, mode=poptorch.MatMulSerializationMode.OutputChannels, ): super().__init__(in_features, out_features, bias) self.mode = mode self.factor = factor @classmethod def from_model( cls, model: nn.Linear, factor: int, mode=poptorch.MatMulSerializationMode.OutputChannels ) -> SerializedLinear: clone = copy.deepcopy(model) clone.__class__ = cls clone.factor = factor clone.mode = mode return clone def to_model(self) -> nn.Linear: del self.factor del self.mode original = copy.deepcopy(self) original.__class__ = nn.Linear return original def forward(self, x): output = poptorch.serializedMatMul(x, self.weight.t(), self.mode, self.factor) if self.bias is not None: output += self.bias return output class SplitProjection(torch.nn.Module): """ Exactly equivalent to `nn.Linear` layer, but with the linear layer split into partial linear layers in order to reduce resident memory. The linear layer is split along the reducing dimension `nn.Linear.in_features` in equal parts. The forward call aggregates the partial sums obtained from each linear layer. Args: linear: A `nn.Linear` to wrap serialization_factor: The number of partitions of the linear layer. This must be a factor of linear.in_features serialization_mode: The dimension of the matmul to serialize on. For matrix A (m by n) multiplied by matrix B (n by p): * ReducingDim: Split across the reducing dimension (n). * OutputChannels: Split across the output channels (dimension p). """ def __init__( self, linear: torch.nn.Linear, serialization_factor: int, serialization_mode=poptorch.MatMulSerializationMode.OutputChannels, ) -> None: super().__init__() self.in_features = linear.in_features self.out_features = linear.out_features self.serialization_mode = serialization_mode self.split_linear_layers = torch.nn.ModuleList() if serialization_mode is poptorch.MatMulSerializationMode.OutputChannels: if self.out_features % serialization_factor != 0: raise ValueError(f"{linear.out_features=} must be divisible by {serialization_factor=}") self.split_size = self.out_features // serialization_factor for i in range(0, self.out_features, self.split_size): split_linear = torch.nn.Linear( self.in_features, self.split_size, bias=False, dtype=linear.weight.dtype ) with torch.no_grad(): split_linear.weight.copy_(linear.weight[i : i + self.split_size, :].detach()) self.split_linear_layers.append(split_linear) elif serialization_mode is poptorch.MatMulSerializationMode.ReducingDim: if self.in_features % serialization_factor != 0: raise ValueError(f"{linear.in_features=} must be divisible by {serialization_factor=}") self.split_size = self.in_features // serialization_factor for i in range(0, self.in_features, self.split_size): split_linear = torch.nn.Linear( self.split_size, self.out_features, bias=False, dtype=linear.weight.dtype ) with torch.no_grad(): split_linear.weight.copy_(linear.weight[:, i : i + self.split_size].detach()) self.split_linear_layers.append(split_linear) else: raise ValueError( f"`SplitProjection` `serialization_mode` can only be {poptorch.MatMulSerializationMode.OutputChannels} or {poptorch.MatMulSerializationMode.ReducingDim}." f" You provided: {serialization_mode=}" ) self.bias = None if linear.bias is not None: self.bias = torch.nn.Parameter(torch.zeros_like(linear.bias)) with torch.no_grad(): self.bias.copy_(linear.bias.detach()) def forward(self, x): if self.serialization_mode is poptorch.MatMulSerializationMode.OutputChannels: out = [] for i, split_linear_layer in enumerate(self.split_linear_layers): out.append(split_linear_layer(x)) out = torch.concat(out, -1) elif self.serialization_mode is poptorch.MatMulSerializationMode.ReducingDim: out = self.split_linear_layers[0](x[..., 0 : self.split_size]) for i, split_linear_layer in enumerate(self.split_linear_layers[1:]): out += split_linear_layer(x[..., i * self.split_size : (i + 1) * self.split_size]) if self.bias is not None: out += self.bias return out @classmethod def from_model( cls, linear: torch.nn.Linear, serialization_factor: int, serialization_mode=poptorch.MatMulSerializationMode.OutputChannels, ) -> SplitProjection: return cls(linear, serialization_factor, serialization_mode) def to_model(self) -> nn.Linear: """ Merge the sub linear layers into one Returns: `nn.Linear` layer """ dtype = self.split_linear_layers[0].weight.dtype layer = nn.Linear(self.in_features, self.out_features, bias=self.bias is not None) if dtype == torch.float16: layer = layer.half() if self.serialization_mode is poptorch.MatMulSerializationMode.OutputChannels: with torch.no_grad(): layer.weight.copy_(torch.vstack([l.weight.detach() for l in self.split_linear_layers])) elif self.serialization_mode is poptorch.MatMulSerializationMode.ReducingDim: with torch.no_grad(): layer.weight.copy_(torch.hstack([l.weight.detach() for l in self.split_linear_layers])) if self.bias is not None: with torch.no_grad(): layer.bias.copy_(self.bias) return layer def parallelize(self, splits_per_ipu: List[int]): for ipu_id, splits in enumerate(splits_per_ipu): if splits: from_split = sum(splits_per_ipu[:ipu_id]) to_split = from_split + splits - 1 shard_range = f"{from_split}-{to_split}" if from_split != to_split else f"{from_split}" logger.info(f"Linear splits: {shard_range} --> IPU {ipu_id}") self.split_linear_layers[from_split] = poptorch.BeginBlock( self.split_linear_layers[from_split], ipu_id=ipu_id, user_id=f"Linear-{shard_range}" ) return self class SharedEmbedding(nn.Module): """ Wrapper around the shared embedding between the encoder and the decoder stacks. Attributes: shared: The shared embedding layer. """ def __init__(self, shared: nn.Embedding): super().__init__() self.shared = shared def _combine_inputs(self, input_ids: torch.Tensor, decoder_input_ids: torch.Tensor) -> Tuple[int, torch.Tensor]: idx = input_ids.size(1) return idx, torch.cat([input_ids, decoder_input_ids], dim=1) def _separate_inputs(self, idx: int, embeds: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return embeds[:, :idx, :], embeds[:, idx:, :] def forward( self, input_ids: torch.Tensor, decoder_input_ids: torch.Tensor, encoder_embed_scale: Optional[float] = None, decoder_embed_scale: Optional[float] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: # TODO: use this once the TiedGather pattern issue is solved. # encoder_inputs_embeds, decoder_inputs_embeds = None, None # if input_ids is not None and encoder_embed_scale is not None: # encoder_inputs_embeds = self.shared(input_ids) * encoder_embed_scale # if decoder_input_ids is not None and decoder_embed_scale is not None: # decoder_inputs_embeds = self.shared(decoder_input_ids) * decoder_embed_scale # combined, n1, n2 = self._combine_inputs(input_ids, decoder_input_ids) # encoder_inputs_embeds, decoder_inputs_embeds = self._separate_inputs(self.shared(combined), n1, n2) encoder_inputs_embeds, decoder_inputs_embeds = None, None if input_ids is None: # call on decoder_input_ids only decoder_inputs_embeds = self.shared(decoder_input_ids) elif decoder_input_ids is None: # call on input_ids only encoder_inputs_embeds = self.shared(input_ids) else: # Call on the combined case # This case is assuming input_ids and decoder_input_ids are not None idx, combined = self._combine_inputs(input_ids, decoder_input_ids) encoder_inputs_embeds, decoder_inputs_embeds = self._separate_inputs(idx, self.shared(combined)) if encoder_embed_scale: encoder_inputs_embeds = encoder_inputs_embeds * encoder_embed_scale if decoder_embed_scale: decoder_inputs_embeds = decoder_inputs_embeds * decoder_embed_scale return encoder_inputs_embeds, decoder_inputs_embeds class OnehotGather(nn.Module): """ Gathers selected indices from a tensor by transforming the list of indices into a one-hot matrix and then multiplying the tensor by that matrix. """ def forward(self, sequence, positions): """ Gathers the vectors at the specific positions over a batch. """ num_classes = int(sequence.shape[1]) one_hot_positions = F.one_hot(positions, num_classes).to(dtype=sequence.dtype) return torch.matmul(one_hot_positions.detach(), sequence) def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. """ # Upstream version: # shifted_input_ids = input_ids.new_zeros(input_ids.shape) # shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() # shifted_input_ids[:, 0] = decoder_start_token_id # Change to fix slice assignment: shifted_input_ids = torch.cat( [torch.ones(input_ids.shape[0], 1, dtype=input_ids.dtype) * decoder_start_token_id, input_ids[:, :-1]], 1 ) if pad_token_id is None: raise ValueError("self.model.config.pad_token_id has to be defined.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids = torch.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) return shifted_input_ids