megatron_patch/model/chatglm/positional_embeddings.py (60 lines of code) (raw):

# Copyright (c) 2023 Alibaba PAI and Nvidia Megatron-LM Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.nn.functional as F class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, base=10000, precision=torch.half, learnable=False): super().__init__() inv_freq = 1. / (base**(torch.arange(0, dim, 2).float() / dim)) inv_freq = inv_freq.half() self.learnable = learnable if learnable: self.inv_freq = torch.nn.Parameter(inv_freq) self.max_seq_len_cached = None else: self.register_buffer('inv_freq', inv_freq) self.max_seq_len_cached = None self.cos_cached = None self.sin_cached = None self.precision = precision def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): pass def forward(self, x, seq_dim=1, seq_len=None): if seq_len is None: seq_len = x.shape[seq_dim] if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached): self.max_seq_len_cached = None if self.learnable else seq_len t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype) freqs = torch.einsum('i,j->ij', t, self.inv_freq) # Different from paper, # but it uses a different permutation # in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) if self.precision == torch.bfloat16: emb = emb.float() # [sx, 1 (b * np), hn] cos_cached = emb.cos()[:, None, :] sin_cached = emb.sin()[:, None, :] if self.precision == torch.bfloat16: cos_cached = cos_cached.bfloat16() sin_cached = sin_cached.bfloat16() if self.learnable: return cos_cached, sin_cached self.cos_cached, self.sin_cached = cos_cached, sin_cached return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] def _apply(self, fn): if self.cos_cached is not None: self.cos_cached = fn(self.cos_cached) if self.sin_cached is not None: self.sin_cached = fn(self.sin_cached) return super()._apply(fn) def rotate_half(x): x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return torch.cat( (-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions @torch.jit.script def apply_rotary_pos_emb_index(q, k, cos, sin, position_id): # position_id: # [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn] cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \ F.embedding(position_id, sin.squeeze(1)).unsqueeze(2) q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) return q, k