optimum/habana/transformers/models/mpt/modeling_mpt.py (344 lines of code) (raw):
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# Copyright (C) 2022-2023 Habana Labs, Ltd. an Intel Company
###############################################################################
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.models.mpt.modeling_mpt import (
MptAttention,
MptBlock,
MptConfig,
MptForCausalLM,
MptModel,
)
from transformers.utils import logging
from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
print("Not using HPU fused scaled dot-product attention kernel.")
FusedSDPA = None
logger = logging.get_logger(__name__)
class GaudiMptAttention(MptAttention):
def __init__(self, config: MptConfig):
super().__init__(config)
def forward(
self,
hidden_states: torch.Tensor,
position_bias: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: Optional[torch.Tensor] = None,
):
"""
Copied from MptAttention.forward: https://github.com/huggingface/transformers/blob/v4.44.1/src/transformers/models/mpt/modeling_mpt.py
The only differences are:
- add new args token_idx
- optimize KV cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new args cache_idx
"""
batch_size, seq_length = hidden_states.shape[:2]
mixed_qkv = self.Wqkv(hidden_states)
if self.clip_qkv:
mixed_qkv = mixed_qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
query_states, key_states, value_states = mixed_qkv.chunk(3, dim=2)
query_states = query_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2)
if past_key_value is not None:
if len(past_key_value) != 0:
if token_idx is not None:
# copy kv to kv_cache
past_key_value[0].index_copy_(2, token_idx - 1, key_states)
past_key_value[1].index_copy_(2, token_idx - 1, value_states)
if cache_idx is not None:
key_states = past_key_value[0][:, :, :cache_idx, :]
value_states = past_key_value[1][:, :, :cache_idx, :]
else:
key_states = past_key_value[0]
value_states = past_key_value[1]
else:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = [key_states, value_states]
else:
if cache_idx is not None:
cache_shape = (batch_size, self.n_heads, self.kv_cache_len, self.head_dim)
past_key_value = [
torch.empty(cache_shape, dtype=key_states.dtype, device=key_states.device),
torch.empty(cache_shape, dtype=key_states.dtype, device=key_states.device),
]
past_key_value[0][:, :, : key_states.shape[2], :] = key_states
past_key_value[1][:, :, : key_states.shape[2], :] = value_states
else:
past_key_value = [key_states.clone(), value_states.clone()]
query_length = seq_length if past_key_value is None else seq_length + past_key_value[0].shape[2]
if position_bias is not None:
if len(position_bias.shape) != 3:
raise ValueError(f"Expecting position_bias shape to be 3 dimensions, got {len(position_bias.shape)}")
key_length = key_states.shape[-2]
position_bias_query_index = max(0, position_bias.size(1) - query_length)
position_bias_key_index = max(0, position_bias.size(2) - key_length)
position_bias = position_bias[:, position_bias_query_index:, position_bias_key_index:]
if use_flash_attention and FusedSDPA:
import habana_frameworks.torch.hpu as ht
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = FusedSDPA.apply(
query_states,
key_states,
value_states,
attention_mask * torch.finfo(query_states.dtype).min + position_bias.to(query_states.dtype),
0.0,
False,
None,
)
attn_weights = None
else:
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale
if position_bias is not None:
attention_scores = attention_scores + position_bias
if attention_mask is not None:
attention_scores = attention_scores.masked_fill(attention_mask, torch.finfo(query_states.dtype).min)
# (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).to(value_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attn_dropout_p, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights, past_key_value
class GaudiMptBlock(MptBlock):
def __init__(self, config: MptConfig):
super().__init__(config)
self.attn = GaudiMptAttention(config)
def forward(
self,
hidden_states: torch.Tensor,
position_bias: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_attentions: bool = False,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: Optional[torch.Tensor] = None,
):
"""
Copied from MptBlock.forward: https://github.com/huggingface/transformers/blob/v4.32.0/src/transformers/models/mpt/modeling_mpt.py
The only differences are:
- add new args token_idx
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new args cache_idx
"""
# hidden_states: [batch_size, seq_length, hidden_size]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.norm_1(hidden_states)
residual = hidden_states
# Self attention.
attn_outputs, attn_weights, past_key_value = self.attn(
layernorm_output,
position_bias=position_bias,
attention_mask=attention_mask,
past_key_value=layer_past,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
)
hidden_states = self.resid_attn_dropout(attn_outputs) + residual
layernorm_output = self.norm_2(hidden_states)
# Get residual
residual = hidden_states
# MLP.
output = self.ffn(layernorm_output, residual)
outputs = (output,)
if use_cache:
outputs += (past_key_value,)
if output_attentions:
outputs += (attn_weights,)
return outputs # hidden_states, present, attentions
class GaudiMptModel(MptModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: Optional[torch.Tensor] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
"""
Copied from MptModel.forward: https://github.com/huggingface/transformers/blob/v4.32.0/src/transformers/models/mpt/modeling_mpt.py
The only differences are:
- add new args token_idx
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new args cache_idx
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past_key_values is None:
past_key_values = tuple([None] * len(self.blocks))
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
hidden_states = inputs_embeds
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None and token_idx is None: # because RW-cache, not standard format
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
alibi = self.build_mpt_alibi_tensor(self.num_heads, self.config.max_seq_len, device=hidden_states.device)
causal_mask = _gaudi_prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
causal_mask = causal_mask.bool()
for block, layer_past in zip(self.blocks, past_key_values):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,
causal_mask,
layer_past,
use_cache,
output_attentions,
None,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
use_cache=use_cache,
output_attentions=output_attentions,
position_bias=alibi,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# Add last hidden state
hidden_states = self.norm_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class GaudiMptForCausalLM(MptForCausalLM):
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
**kwargs,
) -> dict:
"""
Inherits from MptForCausalLM: https://github.com/huggingface/transformers/blob/v4.32.0/src/transformers/models/mpt/modeling_mpt.py
The only differences are:
- add new args token_idx
- add token_idx into model_inputs
- from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx
- support for internal bucketing
"""
bucket_internal = kwargs.get("bucket_internal")
cache_idx = kwargs.get("cache_idx") # kv_cache index for slice operations
use_flash_attention = kwargs.get("use_flash_attention", False)
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
# only last tokens for input_ids if past is not None
if past_key_values is not None:
if token_idx is None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
else:
idx = token_idx + kwargs.get("inputs_embeds_offset", 0) - 1
input_ids = torch.index_select(input_ids, 1, idx)
if bucket_internal and token_idx is not None:
attention_mask = attention_mask[:, :cache_idx]
elif bucket_internal and token_idx is not None:
# for the 1st token we can slice the inputs till token idx for the fwd pass.
input_ids = input_ids[:, :token_idx]
attention_mask = attention_mask[:, :token_idx]
# for the 1st token the cache_idx is None
cache_idx = token_idx
for block in self.transformer.blocks:
block.attn.kv_cache_len = kwargs["kv_cache_len"]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values, # NITS should it be layer_past?
"use_cache": use_cache,
"attention_mask": attention_mask,
"token_idx": token_idx,
"use_flash_attention": use_flash_attention,
"flash_attention_recompute": flash_attention_recompute,
"cache_idx": cache_idx,
}
)
return model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
"""
Inherits from MptForCausalLM: https://github.com/huggingface/transformers/blob/v4.32.0/src/transformers/models/mpt/modeling_mpt.py
The only differences are:
- add new args token_idx
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new args cache_idx
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Flatten the tokens
loss = self.loss_function(
lm_logits,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)