optimum/habana/transformers/modeling_rope_utils.py (81 lines of code) (raw):
import torch
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from optimum.utils import logging
logger = logging.get_logger(__name__)
class GaudiRotaryEmbedding(torch.nn.Module):
"""
Referred from FalconRotaryEmbedding: https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/models/falcon/modeling_falcon.py#L167
The only differences are:
- modify forward function to use seq_len instead of position_ids
"""
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config=None,
):
super().__init__()
# TODO (joao): remove the `if` below, only used for BC
self.rope_kwargs = {}
if config is None:
logger.warning_once(
"`GaudiRotaryEmbedding` can now be fully parameterized by passing the model config through the "
"`config` argument. All other arguments will be removed in v4.46"
)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
# BC: "rope_type" was originally "type"
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("_cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(dtype), persistent=False)
def _dynamic_frequency_update(self, seq_len, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
# seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
# This .to() is needed if the model has been moved to a device after being initialized (because
# the buffer is automatically moved, but not the original copy)
self.original_inv_freq = self.original_inv_freq.to(device)
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(seq_len, device=x.device)
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
if self.attention_scaling == 1.0:
return (
self._cos_cached[:seq_len].to(dtype=x.dtype),
self._sin_cached[:seq_len].to(dtype=x.dtype),
)
else:
return (
self._cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling,
self._sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling,
)