optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py (1,773 lines of code) (raw):

# coding=utf-8 # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch DeepSeekV2 model. Adapted from https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/resolve/main/modeling_deepseek.py""" import math import os import warnings from typing import List, Optional, Tuple, Union import habana_frameworks.torch.core as htcore import torch import torch.distributed as dist # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. import torch.fx import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers import PretrainedConfig from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.generation import GenerationMixin from transformers.integrations.deepspeed import is_deepspeed_available from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, ) from transformers.modeling_outputs import ( BaseModelOutputWithPast, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, SequenceClassifierOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel from transformers.pytorch_utils import ( ALL_LAYERNORM_LAYERS, ) from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, logging, ) from ....distributed.tensorparallel import _all_reduce from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask from .configuration_deepseek_v2 import DeepseekV2Config _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DeepseekV2Config" # default expert number per slice for dynamic MoE SLICE_MAX_EXPERT = 80 try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE print("Using HPU fused kernel for apply_rotary_pos_emb") except ImportError: print("Not using HPU fused kernel for apply_rotary_pos_emb") FusedRoPE = None try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm print("Using HPU fused kernel for RMSNorm") except ImportError: print("Not using HPU fused kernel for RMSNorm") FusedRMSNorm = None try: from habana_frameworks.torch.hpex.kernels import FusedSDPA except ImportError: print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DeepseekV2Config" def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return ( indices, cu_seqlens, max_seqlen_in_batch, ) # Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func def load_balancing_loss_func( gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], num_experts: Optional[int] = None, top_k=2, attention_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, int]: r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between experts is too unbalanced. Args: gate_logits: Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of shape [batch_size X sequence_length, num_experts]. num_experts: Number of experts top_k: The number of experts to route per-token, can be also interpreted as the `top-k` routing parameter. attention_mask (`torch.Tensor`, *optional*): The attention_mask used in forward function shape [batch_size X sequence_length] if not None. Returns: The auxiliary loss. """ if gate_logits is None or not isinstance(gate_logits, tuple): return 0 if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) if attention_mask is None: # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.mean(expert_mask.float(), dim=0) # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) else: batch_size, sequence_length = attention_mask.shape num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( attention_mask[None, :, :, None, None] .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) .reshape(-1, top_k, num_experts) .to(compute_device) ) # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( expert_attention_mask, dim=0 ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) .reshape(-1, num_experts) .to(compute_device) ) # Compute the average probability of routing to these experts router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( router_per_expert_attention_mask, dim=0 ) overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts class DeepseekV2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ DeepseekV2RMSNorm is equivalent to T5LayerNorm """ super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): if hidden_states.device.type == "hpu" and FusedRMSNorm: # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype if hidden_states.dtype != self.weight.dtype: orig_dtype = hidden_states.dtype hidden_states = FusedRMSNorm.apply( hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon ) return hidden_states.to(orig_dtype) else: hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) return hidden_states else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm) class DeepseekV2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self.max_seq_len_cached = max_position_embeddings self._set_cos_sin_cache( seq_len=self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.get_default_dtype(), ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq.to(t.device)) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len is not None and seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( self.cos_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype), ) # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2 class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): """DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" def __init__( self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, ): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) t = t / self.scaling_factor freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2 class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): """DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" def __init__( self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, ): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len if seq_len > self.max_position_embeddings: base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) # Inverse dim formula to find dim based on number of rotations def yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048): return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) # Find dim range bounds based on rotations def yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048): low = math.floor(yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) high = math.ceil(yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) return max(low, 0), min(high, dim - 1) # Clamp values just in case def yarn_get_mscale(scale=1, mscale=1): if scale <= 1: return 1.0 return 0.1 * mscale * math.log(scale) + 1.0 def yarn_linear_ramp_mask(min, max, dim): if min == max: max += 0.001 # Prevent singularity linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding): def __init__( self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0, original_max_position_embeddings=4096, beta_fast=32, beta_slow=1, mscale=1, mscale_all_dim=0, ): self.scaling_factor = scaling_factor self.original_max_position_embeddings = original_max_position_embeddings self.beta_fast = beta_fast self.beta_slow = beta_slow self.mscale = mscale self.mscale_all_dim = mscale_all_dim super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len dim = self.dim freq_extra = 1.0 / (self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) freq_inter = 1.0 / ( self.scaling_factor * self.base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) ) low, high = yarn_find_correction_range( self.beta_fast, self.beta_slow, dim, self.base, self.original_max_position_embeddings, ) inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(device=device, dtype=torch.float32) inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(seq_len, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) _mscale = float( yarn_get_mscale(self.scaling_factor, self.mscale) / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) ) emb = torch.cat((freqs, freqs), dim=-1) emb_cos = (emb.cos() * _mscale).to(dtype) emb_sin = (emb.sin() * _mscale).to(dtype) self.register_buffer("cos_cached", emb_cos, persistent=False) self.register_buffer("sin_cached", emb_sin, persistent=False) def apply_customized_rope(q, k, cos, sin, position_ids): if q.device.type == "hpu" and FusedRoPE: return FusedRoPE.apply( q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids ), FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) else: return apply_rotary_pos_emb(q, k, cos, sin, position_ids) # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q: torch.Tensor, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`): The position indices of the tokens corresponding to the query and key tensors. For example, this can be used to pass offsetted position ids when working with a KV-cache. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ b, h, s, d = q.shape q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) if q.device.type == "hpu" and FusedRoPE: return FusedRoPE.apply( q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids ) else: cos = cos[position_ids].unsqueeze(unsqueeze_dim) sin = sin[position_ids].unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) return q_embed class DeepseekV2MLP(nn.Module): def __init__(self, config, hidden_size=None, intermediate_size=None): super().__init__() self.config = config self.hidden_size = config.hidden_size if hidden_size is None else hidden_size self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj class MoEGate(nn.Module): def __init__(self, config): super().__init__() self.config = config self.top_k = config.num_experts_per_tok self.n_routed_experts = config.n_routed_experts self.routed_scaling_factor = config.routed_scaling_factor self.scoring_func = config.scoring_func self.alpha = config.aux_loss_alpha self.seq_aux = config.seq_aux self.topk_method = config.topk_method self.n_group = config.n_group self.topk_group = config.topk_group # topk selection algorithm self.norm_topk_prob = config.norm_topk_prob self.gating_dim = config.hidden_size self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) self.reset_parameters() def reset_parameters(self) -> None: import torch.nn.init as init init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape ### compute gating score hidden_states = hidden_states.view(-1, h) logits = F.linear(hidden_states.type(torch.bfloat16), self.weight.type(torch.bfloat16), None).to( dtype=torch.float32 ) if self.scoring_func == "softmax": scores = logits.softmax(dim=-1, dtype=torch.float32) else: raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") ### select top-k experts if self.topk_method == "greedy": topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=True) elif self.topk_method == "group_limited_greedy": group_scores = scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values # [n, n_group] group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True)[1] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] score_mask = ( group_mask.unsqueeze(-1) .expand(bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group) .reshape(bsz * seq_len, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=True) ### norm gate to sum 1 if self.top_k > 1 and self.norm_topk_prob: denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 topk_weight = topk_weight / denominator else: topk_weight = topk_weight * self.routed_scaling_factor ### expert-level computation auxiliary loss if self.training and self.alpha > 0.0: scores_for_aux = scores aux_topk = self.top_k # always compute aux loss based on the naive greedy topk method topk_idx_for_aux_loss = topk_idx.view(bsz, -1) if self.seq_aux: scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) ce.scatter_add_( 1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), ).div_(seq_len * aux_topk / self.n_routed_experts) aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha else: mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) ce = mask_ce.float().mean(0) Pi = scores_for_aux.mean(0) fi = ce * self.n_routed_experts aux_loss = (Pi * fi).sum() * self.alpha else: aux_loss = None return topk_idx, topk_weight, aux_loss class AddAuxiliaryLoss(torch.autograd.Function): """ The trick function of adding auxiliary (aux) loss, which includes the gradient of the aux loss during backpropagation. """ @staticmethod def forward(ctx, x, loss): assert loss.numel() == 1 ctx.dtype = loss.dtype ctx.required_aux_loss = loss.requires_grad return x @staticmethod def backward(ctx, grad_output): grad_loss = None if ctx.required_aux_loss: grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) return grad_output, grad_loss class DeepseekV2MoE(nn.Module): """ A mixed expert module containing shared experts. """ def __init__(self, config): super().__init__() self.config = config self.num_experts_per_tok = config.num_experts_per_tok self.experts_per_rank = config.n_routed_experts if hasattr(config, "ep_size") and config.ep_size > 1: assert config.ep_size == dist.get_world_size() self.ep_size = config.ep_size self.experts_per_rank = config.n_routed_experts // config.ep_size self.ep_rank = dist.get_rank() self.experts = nn.ModuleList( [ ( DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size) if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank else None ) for i in range(config.n_routed_experts) ] ) else: self.ep_size = 1 self.experts_per_rank = config.n_routed_experts self.ep_rank = 0 self.experts = nn.ModuleList( [ DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size) for i in range(config.n_routed_experts) ] ) self.gate = MoEGate(config) if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size) self.expert_slice = math.ceil(self.experts_per_rank / SLICE_MAX_EXPERT) self.expert_chunk = math.ceil(self.experts_per_rank / self.expert_slice) def forward(self, hidden_states): identity = hidden_states orig_shape = hidden_states.shape topk_idx, topk_weight, aux_loss = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) # we cast back to the input dtype topk_weight = topk_weight.to(hidden_states.dtype) batch = orig_shape[0] sequence_length = orig_shape[1] hidden_dim = orig_shape[2] if self.training: padded_weights = torch.zeros( (batch * sequence_length, self.config.n_routed_experts), dtype=topk_weight.dtype, device=topk_weight.device, ) padded_weights.scatter_(-1, topk_idx, topk_weight) padded_weights = padded_weights.reshape(-1, sequence_length, self.config.n_routed_experts) padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) final_hidden_states = torch.zeros( (batch, sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) for i, expert in enumerate(self.experts): current_hidden_state = expert(hidden_states) current_padded_weight = padded_weights[i] final_hidden_states = ( final_hidden_states + current_hidden_state.reshape(-1, sequence_length, hidden_dim) * current_padded_weight ) final_hidden_states = final_hidden_states.type(hidden_states.dtype) final_hidden_states = final_hidden_states.view(*orig_shape) final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, aux_loss) else: final_hidden_states = torch.zeros( (batch * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) for idx in range(self.expert_slice): experts_min = (self.ep_rank * self.experts_per_rank) + (self.expert_chunk * idx) experts_max = min((experts_min + self.expert_chunk), (self.ep_rank + 1) * self.experts_per_rank) experts_range = range(experts_min, experts_max) gate_proj_list = [self.experts[i].gate_proj.weight.squeeze() for i in experts_range] down_proj_list = [self.experts[i].down_proj.weight.squeeze() for i in experts_range] up_proj_list = [self.experts[i].up_proj.weight.squeeze() for i in experts_range] hidden_states_slice = torch.ops.hpu.mixture_of_experts( hidden_states=hidden_states, expert_routing_table=topk_idx, router_weights=topk_weight, w1=gate_proj_list, w2=up_proj_list, w3=down_proj_list, permuted_weights=True, activation="silu", experts_min=experts_min, experts_max=experts_max - 1, ) final_hidden_states = final_hidden_states + hidden_states_slice htcore.mark_step() if self.ep_size > 1: final_hidden_states = _all_reduce(final_hidden_states) elif is_deepspeed_available(): from deepspeed import comm as dist if dist.is_initialized(): dist.all_reduce(final_hidden_states, op=dist.ReduceOp.SUM) final_hidden_states = final_hidden_states.type(hidden_states.dtype) final_hidden_states = final_hidden_states.reshape(-1, sequence_length, hidden_dim) if self.config.n_shared_experts is not None: final_hidden_states = final_hidden_states + self.shared_experts(identity) return final_hidden_states class Matmul(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return torch.matmul(x, y) def gaudi_deepseekv2_repeat_kv( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask: torch.Tensor, n_rep: int, ): """ Copied from repeat_kv: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py The only differences are: - Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them. - Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion. The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) """ batch, num_key_value_heads, kv_len, head_dim = key_states.shape if n_rep == 1 or num_key_value_heads == 1: return query_states, key_states, value_states, attention_mask new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) key_states = key_states.reshape(new_kv_shape) value_states = value_states.reshape(new_kv_shape) batch, q_heads, q_len, head_dim = query_states.shape new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) query_states = query_states.reshape(new_q_shape) if attention_mask is not None: # Add groups dim and set to 1 attention_mask = attention_mask.unsqueeze(1) return query_states, key_states, value_states, attention_mask class KVCache(torch.nn.Module): def __init__(self): super(KVCache, self).__init__() self.cache = None self.inp_seq_len = -1 def allocate(self, inp_seq_len, dtype, device, shape): if self.cache is None or self.cache.shape != shape: self.inp_seq_len = inp_seq_len self.cache = torch.zeros(shape, dtype=dtype, device=device) else: assert self.inp_seq_len == inp_seq_len, ( f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" ) self.cache.fill_(0) def update(self, prev, cur, dim, idx, inp_seq_len): orig_cur = cur if prev.shape == cur.shape: prev.copy_(cur) return orig_cur if cur.shape[1] > 1 and cur.shape[1] <= prev.shape[1]: # Initialize prev[:, :inp_seq_len, :].copy_(cur) return orig_cur assert cur.shape[1] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" if idx is not None: prev.index_copy_(dim, idx - 1, cur) return prev else: return torch.cat((prev, cur), dim=dim) def get_shape(self): if self.cache is None: return None return self.cache.shape def forward(self, cur, dim, idx): return self.update(self.cache, cur, dim, idx, self.inp_seq_len) class ModuleFusedSDPA(torch.nn.Module): def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8): super().__init__() self._hpu_kernel_fsdpa = fusedSDPA self.scale = scale self.attention_dropout = attention_dropout self.enable_recompute = enable_recompute self.flash_attention_fp8 = flash_attention_fp8 def forward( self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side="left", ): return self._hpu_kernel_fsdpa.apply( query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side, ) # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 class DeepseekV2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.q_lora_rank = config.q_lora_rank self.qk_rope_head_dim = config.qk_rope_head_dim self.kv_lora_rank = config.kv_lora_rank self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim self.is_causal = True if self.q_lora_rank is None: self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False) else: self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias=config.attention_bias) self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank) self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False) self.kv_a_proj_with_mqa = nn.Linear( self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias=config.attention_bias, ) self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank) self.kv_b_proj = nn.Linear( config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias=False, ) self.o_proj = nn.Linear( self.num_heads * self.v_head_dim, self.hidden_size, bias=config.attention_bias, ) self._init_rope() self.num_key_value_groups = self.num_heads // config.num_key_value_heads self.matmul_qk = Matmul() self.matmul_av = Matmul() self.k_cache = KVCache() self.v_cache = KVCache() self.inp_seq_len = -1 self.softmax_scale = self.q_head_dim ** (-0.5) if self.config.rope_scaling is not None: mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) scaling_factor = self.config.rope_scaling["factor"] if mscale_all_dim: mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) self.softmax_scale = self.softmax_scale * mscale * mscale self.norm_factor = self.softmax_scale self.fused_scaled_dot_product_attention = ( ModuleFusedSDPA( FusedSDPA, scale=self.norm_factor, attention_dropout=self.attention_dropout, enable_recompute=False, flash_attention_fp8=getattr(config, "flash_attention_fp8", False), ) if FusedSDPA else None ) def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = DeepseekV2RotaryEmbedding( self.qk_rope_head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( self.qk_rope_head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, ) elif scaling_type == "dynamic": self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding( self.qk_rope_head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, ) elif scaling_type == "yarn": kwargs = { key: self.config.rope_scaling[key] for key in [ "original_max_position_embeddings", "beta_fast", "beta_slow", "mscale", "mscale_all_dim", ] if key in self.config.rope_scaling } self.rotary_emb = DeepseekV2YarnRotaryEmbedding( self.qk_rope_head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor, base=self.rope_theta, **kwargs, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): compressed_kv_cache_shape = (batch_size, max_seq_len, self.kv_lora_rank) k_pe_cache_shape = (batch_size, max_seq_len, self.qk_rope_head_dim) device = self.kv_a_proj_with_mqa.weight.device dtype = self.config.torch_dtype self.k_cache.allocate(inp_seq_len, dtype, device, compressed_kv_cache_shape) self.v_cache.allocate(inp_seq_len, dtype, device, k_pe_cache_shape) def update_sincos_cache(self, seq_len): # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings # This helps in avoiding creation of these caches during actual model forward pass and # reduce memory consumption and improve performance. if seq_len > self.max_position_embeddings: self.max_position_embeddings = seq_len _, _ = self.rotary_emb(self.k_b_proj.weight, seq_len=seq_len) def reorder(self, tensor, beam_idx, dim_a, dim_b): updated = tensor.index_select(0, beam_idx) tensor.copy_(updated) def reorder_kv_cache(self, beam_idx: torch.LongTensor): if self.k_cache.cache is None: return (None, None) head_dim = self.k_cache.cache.size(-1) seq_length = self.k_cache.cache.size(-2) self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim) self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) return (self.k_cache.cache.shape, self.v_cache.cache.shape) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous() def split_kv_b_proj(self): kv_b_proj_weight = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) self.q_absorb = kv_b_proj_weight[:, : self.qk_nope_head_dim, :].unsqueeze(0).transpose(0, 1) self.out_absorb = kv_b_proj_weight[:, self.qk_nope_head_dim :, :].unsqueeze(0) def compress_kv( self, hidden_states_kv: torch.Tensor, kv_position_ids: torch.LongTensor, ) -> torch.Tensor: # return the RoPE'ed & compressed kv bsz, kv_seq_len, _ = hidden_states_kv.size() compressed_kv = self.kv_a_proj_with_mqa(hidden_states_kv) compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_layernorm(compressed_kv) k_pe = k_pe.view(bsz, kv_seq_len, 1, self.qk_rope_head_dim).transpose(1, 2) cos, sin = self.rotary_emb.cos_cached, self.rotary_emb.sin_cached k_pe = apply_rotary_pos_emb(k_pe, cos, sin, kv_position_ids).view(bsz, kv_seq_len, self.qk_rope_head_dim) return compressed_kv, k_pe def train_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, token_idx: Optional[torch.Tensor] = None, cache_position: Optional[torch.LongTensor] = None, attn_softmax_bf16: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, valid_sequence_lengths: Optional[torch.Tensor] = None, num_virtual_tokens: int = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .transpose(1, 2) ) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) kv_seq_len = value_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) if token_idx is None: if hasattr(past_key_value, "get_usable_length"): kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else: kv_seq_len += past_key_value[0].shape[-2] else: if num_virtual_tokens is not None and num_virtual_tokens == past_key_value[0].shape[-2]: kv_seq_len = past_key_value[0].shape[-2] + kv_seq_len else: kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) q_pe, k_pe = apply_customized_rope(q_pe, k_pe, cos, sin, position_ids) query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # optimization if use_flash_attention and FusedSDPA is not None: softmax_mode = "fast" if flash_attention_fast_softmax else "None" if flash_attention_causal_mask: attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, None, 0.0, True, None, softmax_mode, flash_attention_recompute, valid_sequence_lengths, "left", ) else: attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode, flash_attention_recompute, None, "None", ) else: query_states, key_states, value_states, attention_mask = gaudi_deepseekv2_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups ) attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.softmax_scale htcore.mark_step() if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask if cache_position is not None: causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask.float() if attn_softmax_bf16: attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) else: # upcast attention to fp32 attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( query_states.dtype ) attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = self.matmul_av(attn_weights, value_states) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def prefill_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, use_cache: bool = False, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_proj_with_mqa(hidden_states) compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) compressed_kv = self.kv_a_layernorm(compressed_kv) k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) cos, sin = self.rotary_emb.cos_cached, self.rotary_emb.sin_cached k_pe_cache = apply_rotary_pos_emb(k_pe, cos, sin, position_ids).view(bsz, q_len, self.qk_rope_head_dim) kv = ( self.kv_b_proj(compressed_kv) # self.kv_a_layernorm(compressed_kv)) .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .transpose(1, 2) ) k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) kv_seq_len = value_states.shape[-2] assert kv_seq_len == q_len if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) if token_idx is None: if hasattr(past_key_value, "get_usable_length"): kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else: kv_seq_len += past_key_value[0].shape[-2] else: if reuse_cache: kv_seq_len = past_key_value[0][-2] else: kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) q_pe, k_pe = apply_customized_rope(q_pe, k_pe, cos, sin, position_ids) # update & get all compressed_kv, k_pe if use_cache: k_pe_cache = k_pe_cache.view(bsz, q_len, self.qk_rope_head_dim) # k_pe.squeeze(1) if reuse_cache: compressed_kv = self.k_cache(compressed_kv, 1, token_idx) k_pe_cache = self.v_cache(k_pe_cache, 1, token_idx) past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: if past_key_value is None: dtype_1 = hidden_states.dtype device_1 = hidden_states.device past_key = torch.zeros(compressed_kv.shape, dtype=dtype_1, device=device_1) past_value = torch.zeros(k_pe_cache.shape, dtype=dtype_1, device=device_1) past_key_value = (past_key, past_value) compressed_kv = self.k_cache.update(past_key_value[0], compressed_kv, 1, token_idx, self.inp_seq_len) k_pe_cache = self.v_cache.update(past_key_value[1], k_pe_cache, 1, token_idx, self.inp_seq_len) if token_idx is None: past_key_value = (compressed_kv, k_pe_cache) else: past_key_value = None query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) assert attention_mask is not None if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) return attn_output, attn_weights, past_key_value def decode_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, use_cache: bool = False, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: hidden_states_q = hidden_states hidden_states_kv = hidden_states self.split_kv_b_proj() q_position_ids = position_ids kv_position_ids = position_ids bsz, q_len, _ = hidden_states_q.size() if self.q_lora_rank is None: q = self.q_proj(hidden_states_q) else: q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states_q))) q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) kv_seq_len = q_pe.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) if token_idx is None: if hasattr(past_key_value, "get_usable_length"): kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else: kv_seq_len += past_key_value[0].shape[-2] else: if reuse_cache: kv_seq_len = past_key_value[0][-2] else: kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len) q_pe = apply_rotary_pos_emb(q_pe, cos, sin, q_position_ids) q_nope = torch.matmul(q_nope.transpose(0, 1), self.q_absorb).transpose(0, 1) compressed_kv, k_pe = self.compress_kv(hidden_states_kv, kv_position_ids) # update & get all compressed_kv, k_pe if use_cache: if reuse_cache: if past_key_value is not None and isinstance(past_key_value[0], torch.Tensor): # prefix tuning case. attach past_key_value to generate first token. compressed_kv = torch.cat((past_key_value[0], compressed_kv), -2) k_pe = torch.cat((past_key_value[1], k_pe), -2) compressed_kv = self.k_cache(compressed_kv, 1, token_idx) k_pe = self.v_cache(k_pe, 1, token_idx) past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) else: if past_key_value is None: dtype_1 = hidden_states.dtype device_1 = hidden_states.device past_key = torch.zeros(compressed_kv.shape, dtype=dtype_1, device=device_1) past_value = torch.zeros(k_pe.shape, dtype=dtype_1, device=device_1) past_key_value = (past_key, past_value) compressed_kv = self.k_cache.update(past_key_value[0], compressed_kv, 1, token_idx, self.inp_seq_len) k_pe = self.v_cache.update(past_key_value[1], k_pe, 1, token_idx, self.inp_seq_len) if token_idx is None: past_key_value = (compressed_kv, k_pe) if cache_idx is not None and q_len == 1: compressed_kv = compressed_kv[:, :cache_idx, :] k_pe = k_pe[:, :cache_idx, :] if attention_mask is not None: attention_mask = attention_mask[:, :, :, :cache_idx] kv_seq_len = compressed_kv.shape[-2] else: past_key_value = None kv_seq_len = compressed_kv.size(1) k_pe = k_pe.view(bsz, 1, kv_seq_len, self.qk_rope_head_dim) attn_weights = ( torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT) ) * self.softmax_scale assert attention_mask is not None if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" f" {attn_weights.size()}" ) attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_nope.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.einsum("bhql,blc->bhqc", attn_weights, compressed_kv) attn_output = torch.matmul(attn_output.permute(2, 1, 0, 3), self.out_absorb.mT).permute(2, 1, 0, 3) return attn_output, attn_weights, past_key_value def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, cache_position: Optional[torch.LongTensor] = None, attn_softmax_bf16: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, valid_sequence_lengths: Optional[torch.Tensor] = None, num_virtual_tokens: int = None, first_token: Optional[bool] = True, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Attention masks and past cache are removed. Input: - hidden_states: [bsz, q_len, hidden_size] - position_ids: [bsz, q_len] """ if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) bsz, q_len, _ = hidden_states.size() if self.training: attn_output, attn_weights, past_key_value = self.train_forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, token_idx=token_idx, cache_position=cache_position, attn_softmax_bf16=attn_softmax_bf16, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, valid_sequence_lengths=valid_sequence_lengths, num_virtual_tokens=num_virtual_tokens, **kwargs, ) elif first_token: # prefill MHA attn_output, attn_weights, past_key_value = self.prefill_forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache, token_idx=token_idx, reuse_cache=reuse_cache, ) else: # decode MLA attn_output, attn_weights, past_key_value = self.decode_forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, use_cache=use_cache, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, ) if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class DeepseekV2DecoderLayer(nn.Module): def __init__(self, config: DeepseekV2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = DeepseekV2Attention(config=config, layer_idx=layer_idx) self.mlp = ( DeepseekV2MoE(config) if ( config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0 ) else DeepseekV2MLP(config) ) self.input_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.self_attn.reorder_kv_cache(beam_idx) def update_sincos_cache(self, seq_len): self.self_attn.update_sincos_cache(seq_len) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, cache_position: Optional[torch.LongTensor] = None, attn_softmax_bf16: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, valid_sequence_lengths: Optional[torch.Tensor] = None, num_virtual_tokens: int = None, first_token: Optional[bool] = True, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, cache_position=cache_position, attn_softmax_bf16=attn_softmax_bf16, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, valid_sequence_lengths=valid_sequence_lengths, num_virtual_tokens=num_virtual_tokens, first_token=first_token, **kwargs, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) if isinstance(self.mlp, DeepseekV2MoE): hidden_states = self.mlp(hidden_states) else: hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs DeepseekV2_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config ([`DeepseekV2Config`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ @add_start_docstrings( "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", DeepseekV2_START_DOCSTRING, ) class DeepseekV2PreTrainedModel(PreTrainedModel): config_class = DeepseekV2Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["DeepseekV2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = False _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() DeepseekV2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see `past_key_values`). If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. Two formats are allowed: - a [`~cache_utils.Cache`] instance; - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy cache format. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the legacy cache format will be returned. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @add_start_docstrings( "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", DeepseekV2_START_DOCSTRING, ) class DeepseekV2Model(DeepseekV2PreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] Args: config: DeepseekV2Config """ def __init__(self, config: DeepseekV2Config): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList( [DeepseekV2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self._attn_implementation = "eager" self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): for layer in self.layers: layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) def reorder_kv_cache(self, beam_idx: torch.LongTensor): return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) def update_sincos_cache(self, seq_len): for layer in self.layers: layer.update_sincos_cache(seq_len) def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, cache_idx: int = None, lazy_mode: Optional[bool] = True, valid_sequence_lengths: Optional[torch.Tensor] = None, num_virtual_tokens: int = None, first_token: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." ) use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) ignore_cache_position = True # Ignoring cache position for HPU use_new_cache = False # Ignoring new Cache path for HPU past_seen_tokens = 0 if past_key_values is not None and use_cache: # kept for BC (cache positions) if reuse_cache: if isinstance(past_key_values[0][0], torch.Tensor): past_seen_tokens = past_key_values[0][0].shape[2] else: past_seen_tokens = past_key_values[0][0][2] else: if use_new_cache: if not isinstance(past_key_values, StaticCache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() else: if past_key_values[0] is not None: ##added for (None, None) past_seen_tokens = past_key_values[0][0].shape[2] if ignore_cache_position is False: if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None and cache_position: position_ids = cache_position.unsqueeze(0) else: if position_ids is None: position_ids = torch.arange( past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device ) position_ids = position_ids.unsqueeze(0) cache_position = None causal_mask = _gaudi_prepare_4d_causal_attention_mask( attention_mask, input_ids.shape if input_ids is not None else (batch_size, seq_length), inputs_embeds, past_seen_tokens, ) # embed positions hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_router_logits = () if output_router_logits else None next_decoder_cache = () if not use_new_cache else None if lazy_mode: htcore.mark_step() for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, token_idx, reuse_cache, cache_idx, cache_position, attn_softmax_bf16, use_flash_attention, flash_attention_recompute, flash_attention_causal_mask, flash_attention_fast_softmax, valid_sequence_lengths, num_virtual_tokens, ) else: if ( lazy_mode and not self.training and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1) ): htcore.mark_step() layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=None if past_key_values is None else past_key_values[layer_idx], output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, first_token=first_token, ) if ( lazy_mode and not self.training and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1) ): htcore.mark_step() hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) if output_router_logits: all_router_logits += (layer_outputs[-1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = ( next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache ) if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] if v is not None ) return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, router_logits=all_router_logits, ) class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = DeepseekV2Model(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) self.kv_cache_len = max_seq_len def reorder_kv_cache(self, beam_idx: torch.LongTensor): return self.model.reorder_kv_cache(beam_idx) def update_sincos_cache(self, seq_len): self.model.update_sincos_cache(seq_len) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, cache_dir: Optional[Union[str, os.PathLike]] = None, ignore_mismatched_sizes: bool = False, force_download: bool = False, local_files_only: bool = False, token: Optional[Union[str, bool]] = None, revision: str = "main", use_safetensors: bool = None, **kwargs, ): # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=False, proxies=None, local_files_only=local_files_only, token=token, revision=revision, subfolder="", _from_auto=False, _from_pipeline=None, **kwargs, ) return super(DeepseekV2ForCausalLM, cls).from_pretrained( pretrained_model_name_or_path, *model_args, config=config, cache_dir=cache_dir, ignore_mismatched_sizes=ignore_mismatched_sizes, force_download=force_download, local_files_only=local_files_only, token=token, revision=revision, use_safetensors=use_safetensors, **kwargs, ) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = False, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = None, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, cache_position: Optional[torch.LongTensor] = None, trim_logits: Optional[bool] = False, attn_softmax_bf16: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, valid_sequence_lengths: torch.Tensor = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, first_token: Optional[bool] = True, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, cache_position=cache_position, token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, cache_idx=cache_idx, lazy_mode=lazy_mode, valid_sequence_lengths=valid_sequence_lengths, num_virtual_tokens=num_virtual_tokens, first_token=first_token, ) hidden_states = outputs[0] _, seq_len, _ = hidden_states.shape if seq_len > 1 and trim_logits and not self.training: if token_idx is not None: hidden_states = hidden_states.index_select(1, token_idx - 1) else: hidden_states = hidden_states[:, -1, :] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( outputs.router_logits if return_dict else outputs[-1], self.num_experts, self.num_experts_per_tok, attention_mask, ) if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device if not return_dict: output = (logits,) + outputs[1:] if output_router_logits: output = (aux_loss,) + output return (loss,) + output if loss is not None else output return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, router_logits=outputs.router_logits, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, num_logits_to_keep=None, token_idx=None, **kwargs, ): reuse_cache = kwargs.get("reuse_cache") bucket_internal = kwargs.get("bucket_internal") if past_key_values is not None: if token_idx is not None: idx = token_idx + kwargs.get("inputs_embeds_offset", 0) - 1 input_ids = torch.index_select(input_ids, 1, idx) else: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] elif ( input_ids.shape[1] != cache_position.shape[0] ): # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] elif (reuse_cache or bucket_internal) and token_idx is not None: # KV cache is pre allocated with reuse cache or will be padded with bucket internal # hence for the 1st token we can slice the inputs till token idx for the fwd pass. input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx] if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: if token_idx is not None: position_ids = torch.index_select(position_ids, 1, token_idx - 1) else: position_ids = position_ids[:, -input_ids.shape[1] :] # keep cache_position implementation as None for HPU cache_position = None # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} if num_logits_to_keep is not None: model_inputs["num_logits_to_keep"] = num_logits_to_keep model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, "token_idx": token_idx, "trim_logits": kwargs.get("trim_logits"), "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), "reuse_cache": reuse_cache, "use_flash_attention": kwargs.get("use_flash_attention"), "flash_attention_recompute": kwargs.get("flash_attention_recompute"), "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), "flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"), "valid_sequence_lengths": kwargs.get("valid_sequence_lengths"), "cache_idx": kwargs.get("cache_idx"), "lazy_mode": kwargs.get("lazy_mode"), "num_virtual_tokens": kwargs.get("num_virtual_tokens"), "first_token": kwargs.get("first_token", True), } ) return model_inputs @add_start_docstrings( """ The DeepseekV2 Model transformer with a sequence classification head on top (linear layer). [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. Since it does classification on the last token, it requires to know the position of the last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """, DeepseekV2_START_DOCSTRING, ) class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = DeepseekV2Model(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( logits.device ) else: sequence_lengths = -1 pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, )