optimum/graphcore/generation/attention_mixin.py (300 lines of code) (raw):

# Copyright (c) 2023 Graphcore Ltd. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy from typing import Callable, Optional, Tuple import poptorch import torch from .utils import assert_poptorch_supports_cond FLOAT16_LIMIT = 1e4 class IPUAttentionMixin: """ The aim of this class is to provide common, model-agnostic functionality such as KV caching and attention serialization to transformer attention layers. The intended usage is best demonstrated with an existing example, Whisper. There are roughly two steps: 1. subclass the parent attention layer to inject this mixin, for example, `class IPUWhisperAttention(WhisperAttention, IPUAttentionMixin)` and use the `add_to_kv_cache` and `update_attention_mask` methods to add the KV values at the current time step to the cache, or `serialized_attention` to serialize attention across the batch or sequence dimensions. 2. replace the existing attention layers with above via the provided class method `from_model`, e.g. `decoder_layer.self_attn = IPUWhisperAttention.from_model(decoder_layer.self_attn, use_cache=True, **kwargs)`. """ _kv_cache_initialized: bool = False _cross_kv_cache_initialized: bool = False _num_beams: int = 1 _batch_serialization_factor: int = 1 _sequence_serialization_factor: int = 1 @property def kv_cache_initialized(self) -> bool: return self._kv_cache_initialized @property def cross_kv_cache_initialized(self) -> bool: return self._cross_kv_cache_initialized def _create_kv_cache(self, cache_shape: Tuple[int], dtype: torch.dtype, num_beams=1): self.register_buffer("_generation_step", torch.tensor([0], dtype=torch.int32), persistent=False) self.register_buffer("_k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False) self.register_buffer("_v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False) if num_beams > 1: self.register_buffer("_beam_idx", torch.arange(cache_shape[0], dtype=torch.int32), persistent=False) self._num_beams = num_beams self._kv_cache_initialized = True def _delete_kv_cache(self): if not self._kv_cache_initialized: return del self._generation_step del self._k_cache del self._v_cache if hasattr(self, "_beam_idx"): del self._beam_idx del self._num_beams del self._kv_cache_initialized def _create_cross_kv_cache(self, cache_shape: Tuple[int], dtype: torch.dtype, num_beams=1): if not hasattr(self, "_generation_step"): self.register_buffer("_generation_step", torch.tensor([0], dtype=torch.int32), persistent=False) self.register_buffer("_cross_k_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False) self.register_buffer("_cross_v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False) if num_beams > 1 and not hasattr(self, "_beam_idx"): self.register_buffer("_beam_idx", torch.arange(cache_shape[0], dtype=torch.int32), persistent=False) self._cross_kv_cache_initialized = True def _delete_cross_kv_cache(self): if not self._cross_kv_cache_initialized: return if hasattr(self, "_generation_step"): del self._generation_step del self._cross_k_cache del self._cross_v_cache if hasattr(self, "_beam_idx"): del self._beam_idx del self._cross_kv_cache_initialized @classmethod def from_model( cls, attention_layer: torch.nn.Module, use_cache: bool = False, batch_size: int = 1, max_length: int = 128, num_beams: int = 1, num_heads: Optional[int] = None, head_dim: Optional[int] = None, dtype: torch.dtype = torch.float16, batch_serialization_factor: int = 1, sequence_serialization_factor: int = 1, use_cross_cache: bool = False, encoder_max_length: int = 128, ): """ Returns an instance of the provided `attention_layer` with functionality provided by `IPUAttentionMixin`. If `use_cache=True`, instantiates the self-attention KV caches, each of shape `(batch_size * num_beams, num_heads, max_length, head_dim)`. If `batch_serialization_factor > 1` or `sequence_serialization_factor > 1`, attention will be serialized along the batch or sequence dimension respectively. """ clone = copy.deepcopy(attention_layer) clone.__class__ = cls def infer_attribute_from_layer(attr: str): err_msg = ( f"Attempting to replace attention class `{attention_layer.__class__.__name__}` with `{cls.__name__}`." f" However unable to infer `{{0}}` from `{attention_layer.__class__.__name__}`." " Provide the `{0}` argument to `IPUAttentionMixin.from_model`." ) try: value = getattr(clone, attr) return value except AttributeError as e: raise AttributeError(err_msg.format(attr)) from e if use_cache or use_cross_cache: num_heads = infer_attribute_from_layer("num_heads") if num_heads is None else num_heads head_dim = infer_attribute_from_layer("head_dim") if head_dim is None else head_dim if use_cache: clone._create_kv_cache( (batch_size * num_beams, num_heads, max_length, head_dim), dtype=dtype, num_beams=num_beams, ) if use_cross_cache: assert_poptorch_supports_cond( context="Cross-attention KV caching has been enabled with `use_cross_cache=True`." ) clone._create_cross_kv_cache( (batch_size * num_beams, num_heads, encoder_max_length, head_dim), dtype=dtype, num_beams=num_beams, ) if batch_serialization_factor < 1 or sequence_serialization_factor < 1: raise ValueError( "`batch_serialization_factor` and `sequence_serialization_factor` must be > 0 if provided." ) elif batch_serialization_factor > 1 and sequence_serialization_factor > 1: raise ValueError( "If serializing attention, only one of `batch_serialization_factor` " "and `sequence_serialization_factor` should be greater than 1, not both." ) elif batch_serialization_factor > 1 or sequence_serialization_factor > 1: if use_cache: raise ValueError("Attention serialization is redundant when KV caching is enabled.") clone._batch_serialization_factor = batch_serialization_factor clone._sequence_serialization_factor = sequence_serialization_factor return clone def to_model(self, cls) -> torch.nn.Module: """ Returns an instance of the `attention_layer` provided to `from_model` with functionality provided by `IPUAttentionMixin` removed. """ self._delete_kv_cache() self._delete_cross_kv_cache() self._delete_serialization_factors() original = copy.deepcopy(self) original.__class__ = cls return original def add_to_kv_cache(self, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Copies the key-value pair into their corresponding key-value caches. Args: key (`torch.FloatTensor`): key tensor of shape `(batch_size * num_beams, num_heads, 1, head_dim)`. value (`torch.FloatTensor`): value tensor of shape `(batch_size * num_beams, num_heads, 1, head_dim)`. """ if not self.kv_cache_initialized: raise ValueError( f"{self.__class__.__name__} assumes that self-attention has KV caching enabled. " f"Please instantiate using `{self.__class__.__name__}.from_model()` so the KV " "cache can be created." ) if self.training: raise RuntimeError("KV caching is currently only supported for inference.") expected_key_shape, expected_value_shape = list(self._k_cache.shape), list(self._v_cache.shape) expected_key_shape[-2] = 1 expected_value_shape[-2] = 1 if list(key.shape) != expected_key_shape: raise ValueError(f"Expected key shape {expected_key_shape}, received {list(key.shape)}.") if list(value.shape) != expected_value_shape: raise ValueError(f"Expected value shape {expected_value_shape}, received {list(value.shape)}.") # For now assume that generation will always start from step 0. reset_kv_cache = self._generation_step == 0 self._k_cache *= 1 - reset_kv_cache.to(self._k_cache.dtype) self._v_cache *= 1 - reset_kv_cache.to(self._v_cache.dtype) if hasattr(self, "_beam_idx"): # For beam search, permute the cache since inputs are permuted on host. _k_cache = torch.index_select(self._k_cache, 0, self._beam_idx) _v_cache = torch.index_select(self._v_cache, 0, self._beam_idx) self._k_cache.copy_(_k_cache) self._v_cache.copy_(_v_cache) # Dynamic update leads to uneven tile placement, and scatter leads to large re-arrangements, # so use a brute force matmul approach which empirically seems best for now. bsz, heads, src_len, head_dim = self._k_cache.shape mm_mask = (torch.arange(src_len) == self._generation_step).view(src_len, 1) _key = torch.matmul(mm_mask.to(key.dtype), key.view(bsz * heads, 1, head_dim)) _value = torch.matmul(mm_mask.to(value.dtype), value.view(bsz * heads, 1, head_dim)) self._k_cache += _key.view(self._k_cache.shape) self._v_cache += _value.view(self._v_cache.shape) return self._k_cache, self._v_cache def update_attention_mask(self, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Creates a default attention mask intended for use with KV caches. It masks up to and including the current generation step, marking the point up to which the caches have been populated. """ bsz, _, src_len, _ = self._k_cache.shape mask = torch.full((1, src_len), -FLOAT16_LIMIT) mask_cond = torch.arange(src_len).view(1, src_len) mask.masked_fill_(mask_cond < self._generation_step + 1, 0) mask = mask.to(self._k_cache.dtype) mask = mask.expand(bsz, 1, 1, src_len) if attention_mask is not None: if attention_mask.size() != mask.size(): raise ValueError( f"Attention mask does not match expected KV cache mask dimensions. " f"Received: {attention_mask.size()}, expected {mask.size()}." ) mask = mask + attention_mask return mask def add_to_cross_kv_cache( self, cross_input: torch.Tensor, key_fn: Callable[[torch.Tensor], torch.Tensor], value_fn: Callable[[torch.Tensor], torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: if not self.cross_kv_cache_initialized: raise ValueError( f"{self.__class__.__name__} assumes that cross-attention has cross KV caching enabled. " f"Please instantiate using `{self.__class__.__name__}.from_model()` so the cross KV " "cache can be created." ) if self.training: raise RuntimeError("Cross KV caching is currently only supported for inference.") assert_poptorch_supports_cond( context="Cross-attention KV caching has been enabled with `use_cross_cache=True`." ) # For now assume that generation will always start from step 0. reset_kv_cache = self._generation_step == 0 self._cross_k_cache *= 1 - reset_kv_cache.to(self._cross_k_cache.dtype) self._cross_v_cache *= 1 - reset_kv_cache.to(self._cross_v_cache.dtype) if hasattr(self, "_beam_idx"): # For beam search, permute the cache since inputs are permuted on host. _cross_k_cache = torch.index_select(self._cross_k_cache, 0, self._beam_idx) _cross_v_cache = torch.index_select(self._cross_v_cache, 0, self._beam_idx) self._cross_k_cache.copy_(_cross_k_cache) self._cross_v_cache.copy_(_cross_v_cache) def then_k_body(x): return key_fn(x) def else_k_body(_): return self._cross_k_cache def then_v_body(x): return value_fn(x) def else_v_body(_): return self._cross_v_cache self._cross_k_cache.copy_( poptorch.cond(reset_kv_cache, then_k_body, [cross_input], else_k_body, [cross_input])[0] ) self._cross_v_cache.copy_( poptorch.cond(reset_kv_cache, then_v_body, [cross_input], else_v_body, [cross_input])[0] ) return self._cross_k_cache, self._cross_v_cache @property def is_attention_serialized(self) -> bool: return self._batch_serialization_factor > 1 or self._sequence_serialization_factor > 1 def _delete_serialization_factors(self): if not self.is_attention_serialized: return del self._batch_serialization_factor del self._sequence_serialization_factor def serialized_attention( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, scale: Optional[float] = 1.0, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Serializes the attention operation either across the batch (if `batch_serialization_factor > 1`) or the sequence (if `sequence_serialization_factor > 1`) dimensions to reduce peak memory usage. NB: if serializing across the batch, this will include `num_heads` wherein we expect the leading dimension to be of size `batch_size * num_heads`. """ if query.ndim != 3 or key.ndim != 3 or value.ndim != 3: raise ValueError( "Expected query, key, value all to be 3D, which we will interpret " "as (batch_size * num_heads, sequence_length, head_dim). Received " f"{query.shape}, {key.shape}, {value.shape}." ) if self._batch_serialization_factor > 1: return self._batch_serialized_attention( query, key, value, scale, attention_mask, self._batch_serialization_factor ) elif self._sequence_serialization_factor > 1: return self._sequence_serialized_attention( query, key, value, scale, attention_mask, self._sequence_serialization_factor ) else: raise ValueError( "Attempting to serialize attention but neither serialization factor is >1. " "To serialize attention, please provide either a `batch_serialization_factor` or " "`sequence_serialization_factor` kwarg to `IPUWhisperAttention.from_model` with " "values greater than 1." ) def _batch_serialized_attention( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, scale: Optional[float] = 1.0, attention_mask: Optional[torch.Tensor] = None, serialization_factor: Optional[int] = 1, ) -> torch.Tensor: if query.shape[0] % serialization_factor != 0: raise ValueError( f"Cannot evenly divide query batch dim: {query.shape[0]} by `serialization_factor`: {serialization_factor}." ) slice_size = query.shape[0] // serialization_factor hidden_states = [] key = key.transpose(1, 2) for i in range(serialization_factor): start_idx = i * slice_size end_idx = (i + 1) * slice_size attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx]) * scale if attention_mask is not None: attn_slice = attn_slice + attention_mask[start_idx:end_idx] attn_slice = attn_slice.softmax(dim=-1) attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) hidden_states.append(attn_slice) hidden_states = torch.cat(hidden_states, dim=0) return hidden_states def _sequence_serialized_attention( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, scale: Optional[float] = 1.0, attention_mask: Optional[torch.Tensor] = None, serialization_factor: Optional[int] = 1, ) -> torch.Tensor: if query.shape[1] % serialization_factor != 0: raise ValueError( f"Cannot evenly divide query sequence dim: {query.shape[1]} by `serialization_factor`: {serialization_factor}." ) slice_size = query.shape[1] // serialization_factor hidden_states = [] key = key.transpose(1, 2) for i in range(serialization_factor): start_idx = i * slice_size end_idx = (i + 1) * slice_size attn_slice = torch.matmul(query[:, start_idx:end_idx], key) * scale if attention_mask is not None: attn_slice = attn_slice + attention_mask[:, start_idx:end_idx] attn_slice = attn_slice.softmax(dim=-1) attn_slice = torch.matmul(attn_slice, value) hidden_states.append(attn_slice) hidden_states = torch.cat(hidden_states, dim=1) return hidden_states