optimum/habana/transformers/modeling_rope_utils.py (81 lines of code) (raw):

import torch from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from optimum.utils import logging logger = logging.get_logger(__name__) class GaudiRotaryEmbedding(torch.nn.Module): """ Referred from FalconRotaryEmbedding: https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/models/falcon/modeling_falcon.py#L167 The only differences are: - modify forward function to use seq_len instead of position_ids """ def __init__( self, dim=None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, rope_type="default", config=None, ): super().__init__() # TODO (joao): remove the `if` below, only used for BC self.rope_kwargs = {} if config is None: logger.warning_once( "`GaudiRotaryEmbedding` can now be fully parameterized by passing the model config through the " "`config` argument. All other arguments will be removed in v4.46" ) self.rope_kwargs = { "rope_type": rope_type, "factor": scaling_factor, "dim": dim, "base": base, "max_position_embeddings": max_position_embeddings, } self.rope_type = rope_type self.max_seq_len_cached = max_position_embeddings self.original_max_seq_len = max_position_embeddings else: # BC: "rope_type" was originally "type" if config.rope_scaling is not None: self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) else: self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( seq_len=self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.get_default_dtype() ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False) def _dynamic_frequency_update(self, seq_len, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - growing beyond the cached sequence length (allow scaling) 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) """ # seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth inv_freq, self.attention_scaling = self.rope_init_fn( self.config, device, seq_len=seq_len, **self.rope_kwargs ) self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset # This .to() is needed if the model has been moved to a device after being initialized (because # the buffer is automatically moved, but not the original copy) self.original_inv_freq = self.original_inv_freq.to(device) self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if "dynamic" in self.rope_type: self._dynamic_frequency_update(seq_len, device=x.device) if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) if self.attention_scaling == 1.0: return ( self._cos_cached[:seq_len].to(dtype=x.dtype), self._sin_cached[:seq_len].to(dtype=x.dtype), ) else: return ( self._cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, self._sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, )