modules/SwissArmyTransformer/sat/model/position_embedding/rotary_embeddings.py (89 lines of code) (raw):

# Extracted from: https://github.com/EleutherAI/gpt-neox import torch import torch.nn.functional as F class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, base=10000, precision=torch.half, learnable=False, device=torch.device('cpu')): super().__init__() inv_freq = 1. / (base ** (torch.arange(0, dim, 2, device=device).float() / dim)) # inv_freq = inv_freq.half() self.learnable = learnable if learnable: self.inv_freq = torch.nn.Parameter(inv_freq) self.max_seq_len_cached = None else: self.register_buffer('inv_freq', inv_freq) self.max_seq_len_cached = None self.cos_cached = None self.sin_cached = None self.precision = precision def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): pass def forward(self, x, seq_dim=1, seq_len=None): if seq_len is None: seq_len = x.shape[seq_dim] if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): self.max_seq_len_cached = None if self.learnable else seq_len t = torch.arange(seq_len, device=x.device, dtype=torch.float32) freqs = torch.einsum('i,j->ij', t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) if self.precision == torch.bfloat16: emb = emb.float() # [sx, 1 (b * np), hn] cos_cached = emb.cos()[:, None, :] sin_cached = emb.sin()[:, None, :] cos_cached = cos_cached.to(x.dtype) sin_cached = sin_cached.to(x.dtype) if self.learnable: return cos_cached, sin_cached self.cos_cached, self.sin_cached = cos_cached, sin_cached return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] class RotaryPositionalEmbeddingFunction(torch.autograd.Function): @staticmethod def forward(ctx, q, cos, sin): import rotary_positional_embedding_cuda q_ = q.contiguous() cos_ = cos.contiguous() sin_ = sin.contiguous() output = rotary_positional_embedding_cuda.forward(*q.shape, q_, cos_, sin_) ctx.save_for_backward(cos_, sin_) return output @staticmethod def backward(ctx, grad_output): import rotary_positional_embedding_cuda cos_, sin_ = ctx.saved_tensors grad_q = rotary_positional_embedding_cuda.backward(*grad_output.shape, grad_output, cos_, sin_) return grad_q, None, None # rotary pos emb helpers: def rotate_half(x): x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions @torch.jit.script def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...] return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16 cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...] return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) def apply_rotary_pos_emb_fused(q, k, cos, sin, offset: int = 0): cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...] q = RotaryPositionalEmbeddingFunction.apply(q, cos, sin) k = RotaryPositionalEmbeddingFunction.apply(k, cos, sin) return q, k @torch.jit.script def apply_rotary_pos_emb_index_single(q, cos, sin, position_id): # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) return (q * cos) + (rotate_half(q) * sin) @torch.jit.script def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) return q, k def apply_rotary_pos_emb_index_torch(q, k, cos, sin, position_id): # jitting fails with bf16 # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) return q, k def apply_rotary_pos_emb_index_fused(q, k, cos, sin, position_id): # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) q = RotaryPositionalEmbeddingFunction.apply(q, cos, sin) k = RotaryPositionalEmbeddingFunction.apply(k, cos, sin) return q, k