optimum/habana/transformers/models/phi/modeling_phi.py (458 lines of code) (raw):
# coding=utf-8
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Phi model."""
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.phi.configuration_phi import PhiConfig
from transformers.models.phi.modeling_phi import (
KwargsForCausalLM,
PhiAttention,
PhiForCausalLM,
PhiMLP,
PhiModel,
apply_rotary_pos_emb,
)
from transformers.processing_utils import Unpack
from transformers.utils import logging
from ...modeling_attn_mask_utils import (
_gaudi_prepare_4d_causal_attention_mask,
)
from ...modeling_rope_utils import GaudiRotaryEmbedding
from ..modeling_all_models import KVCache, Matmul
logger = logging.get_logger(__name__)
def gaudi_phi_repeat_kv(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
n_rep: int,
):
"""
Copied from repeat_kv: https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/phi/modeling_phi.py
The only differences are:
- Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them.
- Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion.
The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim)
The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim)
"""
batch, num_key_value_heads, kv_len, head_dim = key_states.shape
if n_rep == 1 or num_key_value_heads == 1:
return query_states, key_states, value_states, attention_mask
new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim)
key_states = key_states.reshape(new_kv_shape)
value_states = value_states.reshape(new_kv_shape)
batch, _, q_len, head_dim = query_states.shape
new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim)
query_states = query_states.reshape(new_q_shape)
if attention_mask is not None:
# Add groups dim and set to 1
attention_mask = attention_mask.unsqueeze(1)
return query_states, key_states, value_states, attention_mask
def gaudi_eager_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
bsz, q_len = kwargs["input_shape"]
query_states, key_states, value_states, attention_mask = gaudi_phi_repeat_kv(
query, key, value, attention_mask, module.num_key_value_groups
)
attn_weights = module.matmul_qk(query_states, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = module.matmul_av(attn_weights, value_states)
attn_output = attn_output.reshape(bsz, -1, q_len, module.head_dim)
return attn_output, attn_weights
class GaudiPhiAttention(PhiAttention):
def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.matmul_qk = Matmul()
self.matmul_av = Matmul()
self.k_cache = KVCache()
self.v_cache = KVCache()
self.inp_seq_len = -1
self.rotary_emb = GaudiRotaryEmbedding(config=self.config)
self.num_key_value_heads = config.num_key_value_heads
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim)
device = self.k_proj.weight.device
dtype = self.config.torch_dtype
self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape)
self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
token_idx: Optional[torch.Tensor] = None,
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Copied from PhiAttention.forward: https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py
The only differences are:
- add new args token_idx
- optimize KV cache
- add new args reuse_cache
- add new args cache_idx
"""
input_shape = hidden_states.shape[:-1]
q_len = input_shape[1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
if self.qk_layernorm:
query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_shape = (
(past_key_value[0][-2] if reuse_cache else past_key_value[0].shape[-2])
if isinstance(past_key_value, tuple)
else past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
)
if token_idx is not None:
kv_seq_len = kv_shape
else:
kv_seq_len += kv_shape
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# Partial rotary embedding
query_rot, query_pass = (
query_states[..., : self.rotary_ndims],
query_states[..., self.rotary_ndims :],
)
key_rot, key_pass = (
key_states[..., : self.rotary_ndims],
key_states[..., self.rotary_ndims :],
)
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot, key_rot = apply_rotary_pos_emb(
query_rot, key_rot, cos[kwargs["position_ids"]], sin[kwargs["position_ids"]]
)
# [batch_size, seq_length, num_heads, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1)
key_states = torch.cat((key_rot, key_pass), dim=-1)
if use_cache:
# reuse k, v, self_attention
if reuse_cache:
key_states = self.k_cache(key_states, 2, token_idx)
value_states = self.v_cache(value_states, 2, token_idx)
past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape())
else:
if past_key_value is None:
past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device)
past_value = torch.zeros(
key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device
)
past_key_value = (past_key, past_value)
key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len)
value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len)
if token_idx is None:
past_key_value = (key_states, value_states)
if cache_idx is not None and q_len == 1:
key_states = key_states[:, :, :cache_idx, :]
value_states = value_states[:, :, :cache_idx, :]
if attention_mask is not None:
attention_mask = attention_mask[:, :, :, :cache_idx]
kv_seq_len = key_states.shape[-2]
else:
past_key_value = None
attn_output, attn_weights = gaudi_eager_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
input_shape=input_shape,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.dense(attn_output)
return attn_output, attn_weights, past_key_value
class GaudiPhiDecoderLayer(torch.nn.Module):
def __init__(self, config: PhiConfig, layer_idx: int):
super().__init__()
self.self_attn = GaudiPhiAttention(config, layer_idx=layer_idx)
self.mlp = PhiMLP(config)
self.input_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.resid_dropout = torch.nn.Dropout(config.resid_pdrop)
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
token_idx: Optional[torch.Tensor] = None,
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Copied from PhiDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py
The only differences are:
- add new args token_idx
- add new args reuse_cache
- add new args cache_idx
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
attn_outputs, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
token_idx=token_idx,
reuse_cache=reuse_cache,
cache_idx=cache_idx,
)
attn_outputs = self.resid_dropout(attn_outputs)
feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
hidden_states = attn_outputs + feed_forward_hidden_states + residual
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class GaudiPhiModel(PhiModel):
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
for layer in self.layers:
layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
token_idx: Optional[torch.Tensor] = None,
reuse_cache: Optional[bool] = False,
cache_idx: Optional[int] = None,
**kwargs,
) -> BaseModelOutputWithPast:
"""
Copied from PhiModel.forward: https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py
The only differences are:
- add new args token_idx
- add new args reuse_cache
- add new args cache_idx
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
use_legacy_cache = True
use_new_cache = False
past_seen_tokens = 0
if past_key_values is not None and use_cache:
if reuse_cache:
past_seen_tokens = past_key_values[0][0][2]
else:
if use_new_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
else:
past_seen_tokens = past_key_values[0][0].shape[2]
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None:
past_seen_tokens = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_seen_tokens = past_key_values.get_seq_length()
else:
past_seen_tokens = past_key_values[0][0].shape[2]
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
# 4d mask is passed through the layers
attention_mask = _gaudi_prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_seen_tokens
)
inputs_embeds = self.embed_dropout(inputs_embeds) # diff with Llama
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if not use_new_cache else None
for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
partial(decoder_layer.__call__, **kwargs),
hidden_states,
attention_mask,
position_ids,
None if past_key_values is None else past_key_values[layer_idx],
output_attentions,
use_cache,
cache_position,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=None if past_key_values is None else past_key_values[layer_idx],
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
token_idx=token_idx,
reuse_cache=reuse_cache,
cache_idx=cache_idx,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.final_layernorm(hidden_states) # diff with Llama
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class GaudiPhiForCausalLM(PhiForCausalLM):
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
token_idx: Optional[torch.Tensor] = None,
reuse_cache: Optional[bool] = False,
trim_logits: Optional[bool] = False,
cache_idx: Optional[int] = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> CausalLMOutputWithPast:
"""
Inherits from PhiForCausalLM: https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py
The only differences are:
- add new args token_idx
- add new args reuse_cache
- add new args cache_idx
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
token_idx=token_idx,
reuse_cache=reuse_cache,
cache_idx=cache_idx,
)
hidden_states = outputs.last_hidden_state
_, seq_len, _ = hidden_states.shape
if seq_len > 1 and trim_logits and not self.training:
if token_idx is not None:
hidden_states = hidden_states.index_select(1, token_idx - 1)
else:
hidden_states = hidden_states[:, -1, :]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=None,
token_idx=None,
**kwargs,
):
"""
Inherits from PhiForCausalLM: https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py
The only differences are:
- add new args token_idx
- add token_idx into model_inputs
- from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx
- from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx
"""
reuse_cache = kwargs.get("reuse_cache")
# Omit tokens covered by past_key_values
if past_key_values is not None:
if token_idx is not None:
idx = token_idx + kwargs.get("inputs_embeds_offset", 0) - 1
input_ids = torch.index_select(input_ids, 1, idx)
else:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif (
input_ids.shape[1] != cache_position.shape[0]
): # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
elif reuse_cache and token_idx is not None:
# With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass
input_ids = input_ids[:, :token_idx]
attention_mask = attention_mask[:, :token_idx]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = torch.index_select(position_ids, 1, token_idx - 1)
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {
"input_ids": input_ids.clone(memory_format=torch.contiguous_format)
} # `contiguous()` needed for compilation use cases
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"token_idx": token_idx,
"reuse_cache": kwargs.get("reuse_cache"),
"trim_logits": kwargs.get("trim_logits"),
"cache_idx": kwargs.get("cache_idx"),
}
)
return model_inputs