optimum/habana/transformers/models/stablelm/modeling_stablelm.py (372 lines of code) (raw):
import math
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.stablelm.configuration_stablelm import StableLmConfig
from transformers.models.stablelm.modeling_stablelm import (
StableLmAttention,
StableLmForCausalLM,
StableLmMLP,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging
from ...modeling_attn_mask_utils import (
_gaudi_prepare_4d_causal_attention_mask,
)
from ...modeling_rope_utils import GaudiRotaryEmbedding
logger = logging.get_logger(__name__)
class GaudiStableLmAttention(StableLmAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: StableLmConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
self.rotary_emb = GaudiRotaryEmbedding(config=self.config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
token_idx: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Copied from StableLmAttention.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/stablelm/modeling_stablelm.py
The only differences are:
- add new args token_idx
- optimize KV cache
"""
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if self.qk_layernorm:
query_states = self.q_layernorm(query_states)
key_states = self.k_layernorm(key_states)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
if token_idx is not None and past_key_value.get_usable_length(kv_seq_len, self.layer_idx) > 0:
# When token_idx is used, static seq len = (input token len + max output token len)
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
else:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# Partial rotary embedding
query_rot, query_pass = (
query_states[..., : self.rotary_ndims],
query_states[..., self.rotary_ndims :],
)
key_rot, key_pass = (
key_states[..., : self.rotary_ndims],
key_states[..., self.rotary_ndims :],
)
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos[position_ids], sin[position_ids])
# [batch_size, seq_length, num_heads, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1)
key_states = torch.cat((key_rot, key_pass), dim=-1)
if past_key_value is not None:
if token_idx is not None:
if 0 <= self.layer_idx < len(past_key_value.key_cache):
past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states)
past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states)
key_states = past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx]
else:
past_key_value.key_cache.append(key_states)
past_key_value.value_cache.append(value_states)
else:
# Specific to RoPE models with partial rotation
cache_kwargs = {
"sin": sin,
"cos": cos,
"partial_rotation_size": self.rotary_ndims,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# Repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights += causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype)
attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class GaudiStableLmDecoderLayer(torch.nn.Module):
def __init__(self, config: StableLmConfig, layer_idx: int):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.hidden_size = config.hidden_size
self.self_attn = GaudiStableLmAttention(config, layer_idx=layer_idx)
self.mlp = StableLmMLP(config)
self.input_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = None
if not self.use_parallel_residual:
self.post_attention_layernorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = torch.nn.Dropout(config.hidden_dropout)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
token_idx: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Copied from StableLmDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/stablelm/modeling_stablelm.py
The only differences are:
- add new args token_idx
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
self_attn_output, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
token_idx=token_idx,
)
# copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer.forward
if self.use_parallel_residual:
# x = x + attn(ln1(x)) + mlp(ln1(x))
# Fully Connected
mlp_output = self.mlp(hidden_states)
mlp_output = self.dropout(mlp_output)
hidden_states = residual + self_attn_output + mlp_output
else:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
residual = residual + self_attn_output
# Fully Connected
mlp_output = self.mlp(self.post_attention_layernorm(residual))
mlp_output = self.dropout(mlp_output)
hidden_states = residual + mlp_output
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
def gaudi_stablelm_model_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
token_idx: Optional[torch.Tensor] = None,
) -> BaseModelOutputWithPast:
"""
Copied from StableLmModel.forward: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/stablelm/modeling_stablelm.py
The only differences are:
- add new args token_idx
- replace _prepare_4d_causal_attention_mask with _gaudi_prepare_4d_causal_attention_mask
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if token_idx is None:
past_key_values_length = past_key_values.get_usable_length(seq_length)
seq_length_with_past = seq_length_with_past + past_key_values_length
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _gaudi_prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
token_idx=token_idx,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class GaudiStableLmForCausalLM(StableLmForCausalLM):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
token_idx: Optional[torch.Tensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
"""
Inherits from StableLmForCausalLM: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/stablelm/modeling_stablelm.py
The only differences are:
- add new args token_idx
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
token_idx=token_idx,
)
hidden_states = outputs.last_hidden_state
# No upscaling to float was ever done for StableLm
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=None,
**kwargs,
):
"""
Inherits from StableLmForCausalLM: https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/stablelm/modeling_stablelm.py
The only differences are:
- add new args token_idx
- add token_idx into model_inputs
- from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx
- from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx
"""
token_idx = kwargs.get("token_idx", None)
if past_key_values is not None:
if token_idx is None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif (
input_ids.shape[1] != cache_position.shape[0]
): # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
else:
idx = token_idx + kwargs.get("inputs_embeds_offset", 0) - 1
input_ids = torch.index_select(input_ids, 1, idx)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = torch.index_select(position_ids, 1, token_idx - 1)
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {
"input_ids": input_ids.clone(memory_format=torch.contiguous_format)
} # `contiguous()` needed for compilation use cases
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"token_idx": token_idx,
}
)
return model_inputs