optimum/habana/transformers/models/llava/modeling_llava.py (276 lines of code) (raw):
# coding=utf-8
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Llava model."""
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.cache_utils import Cache
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration
from transformers.utils import logging
logger = logging.get_logger(__name__)
def _pad_inputs(
input_ids, attention_mask, image_token_index, num_patches, pad_token_id, vision_feature_select_strategy=None
):
"""
pad inputs for static shape
"""
if vision_feature_select_strategy == "default":
num_patches = num_patches
elif vision_feature_select_strategy == "full":
num_patches = num_patches + 1
else:
raise ValueError(f"Unexpected select feature strategy: {vision_feature_select_strategy}")
image_offset = 0
new_input_ids = []
new_attention_mask = []
tokens_pos = []
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask):
image_token_indices = torch.where(cur_input_ids == image_token_index)[0].tolist() + [cur_input_ids.shape[0]]
cur_input_ids_extend = []
cur_attention_mask_extend = []
cur_token_pos = []
start = 0
for i, image_token_indice in enumerate(image_token_indices):
token_pos = len(cur_input_ids_extend)
cur_input_ids_extend.extend(cur_input_ids[start:image_token_indice].cpu().tolist())
cur_attention_mask_extend.extend(cur_attention_mask[start:image_token_indice].cpu().tolist())
cur_token_pos.extend(list(range(token_pos, len(cur_input_ids_extend))))
if i != len(image_token_indices) - 1:
cur_input_ids_extend.extend([image_token_index] * num_patches)
cur_attention_mask_extend.extend([1] * num_patches)
cur_token_pos.append(image_token_indice)
start = image_token_indice + 1
new_input_ids.append(cur_input_ids_extend)
new_attention_mask.append(cur_attention_mask_extend)
tokens_pos.append(cur_token_pos)
max_len = max(len(x) for x in new_input_ids)
image_offset += max_len - input_ids.shape[1]
# padding
new_input_ids_padded = []
new_attention_mask_padded = []
tokens_pos_padded = []
# left padding for no image in example, so we don't need change token_idx
for cur_new_ids, cur_attention_mask, cur_token_pos in zip(new_input_ids, new_attention_mask, tokens_pos):
pad_len = max_len - len(cur_new_ids)
new_input_ids_padded.append([pad_token_id] * pad_len + cur_new_ids)
new_attention_mask_padded.append([0] * pad_len + cur_attention_mask)
tokens_pos_padded.append([x + pad_len for x in cur_token_pos])
input_ids = torch.tensor(new_input_ids_padded).to(input_ids.device)
attention_mask = torch.tensor(new_attention_mask_padded).to(input_ids.device)
tokens_pos = torch.tensor(tokens_pos_padded).to(input_ids.device)
return input_ids, attention_mask, image_offset, tokens_pos
def _merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids, image_token_index):
"""
Merge text and images
"""
batch_size, sequence_length, embed_dim = inputs_embeds.shape
batch_indices, image_indices = torch.where(input_ids == image_token_index)
inputs_embeds[batch_indices, image_indices] = image_features.contiguous().reshape(-1, embed_dim)
return inputs_embeds.contiguous()
class GaudiLlavaForConditionalGeneration(LlavaForConditionalGeneration):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
vision_feature_layer: Optional[Union[int, List[int]]] = None,
vision_feature_select_strategy: Optional[str] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: Optional[torch.Tensor] = None,
token_idx: Optional[torch.Tensor] = None,
image_offset: Optional[int] = None,
tokens_pos: Optional[torch.LongTensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
**lm_kwargs,
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
"""
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/llava/modeling_llava.py#L362
The only differences are:
- add new args token_idx
- add new args image_offset
- add new args tokens_pos
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
else self.config.vision_feature_select_strategy
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)
image_features = None
# 2. Merge text and images
if pixel_values is not None and input_ids.shape[1] != 1:
image_outputs = self.vision_tower(
pixel_values,
output_hidden_states=True,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
)
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
if vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
image_features = self.multi_modal_projector(selected_image_feature)
inputs_embeds = _merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, self.config.image_token_index
)
if token_idx is not None:
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
token_idx=token_idx + image_offset,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
**lm_kwargs,
)
if input_ids.shape[1] != 1 and pixel_values is not None and tokens_pos is not None:
batch_size, seq_len = tokens_pos.shape
batch_indices = torch.arange(batch_size).repeat_interleave(seq_len)
logits = outputs[0][batch_indices, tokens_pos.reshape(-1), :].reshape(batch_size, seq_len, -1)
else:
logits = outputs[0]
loss = None
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return LlavaCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
else:
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
logits_to_keep=logits_to_keep,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
**lm_kwargs,
)
logits = outputs[0]
loss = None
if labels is not None:
# Shift so that tokens < n predict n
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:]
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return LlavaCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=image_features if pixel_values is not None else None,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
attention_mask=None,
cache_position=None,
logits_to_keep=None,
**kwargs,
):
"""
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llava/modeling_llava.py
The only differences are:
- add new args token_idx
- add new args image_offset
- add new args tokens_pos
- from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx
- from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_id
"""
token_idx = kwargs.get("token_idx", None)
image_offset = 0
tokens_pos = None
legacy_processing = (
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
) or ((input_ids.shape[-1] == 1 if token_idx is None else token_idx == 1) and pixel_values is not None)
if token_idx is not None and pixel_values is not None and legacy_processing:
input_ids, attention_mask, image_offset, tokens_pos = _pad_inputs(
input_ids,
attention_mask,
self.config.image_token_index,
self.vision_tower.vision_model.embeddings.num_patches,
self.pad_token_id,
vision_feature_select_strategy=self.config.vision_feature_select_strategy,
)
past_length = 0
if past_key_values is not None:
if token_idx is None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
else:
cache_length = past_length = past_key_values[0][0].shape[2]
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
elif self.config.image_token_index in input_ids:
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
# older attention values, as their corresponding values are not part of the input.
if cache_length < past_length and attention_mask is not None:
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
else:
# past_length += token_idx
input_ids = torch.index_select(input_ids, 1, token_idx + image_offset - 1)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
if token_idx is not None:
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
else:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
use_flash_attention = kwargs.get("use_flash_attention", False)
flash_attention_recompute = kwargs.get("flash_attention_recompute", False)
if logits_to_keep is not None:
model_inputs["logits_to_keep"] = logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"token_idx": token_idx,
"image_offset": image_offset,
"tokens_pos": tokens_pos,
"use_flash_attention": use_flash_attention,
"flash_attention_recompute": flash_attention_recompute,
}
)
return model_inputs