optimum/habana/transformers/models/video_llava/modeling_video_llava.py (336 lines of code) (raw):

# coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch VideoLlava model.""" from typing import List, Optional, Tuple, Union import torch from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPooling from transformers.models.video_llava.modeling_video_llava import ( VideoLlavaCausalLMOutputWithPast, VideoLlavaConfig, VideoLlavaForConditionalGeneration, ) from transformers.utils import logging logger = logging.get_logger(__name__) class GaudiVideoLlavaForConditionalGeneration(VideoLlavaForConditionalGeneration): def __init__(self, config: VideoLlavaConfig): super().__init__(config) self.feature_offset = 0 def _merge_input_ids_with_visual_features( self, visual_features, inputs_embeds, input_ids, attention_mask, labels, token_idx, num_frames=1 ): r""" Copied from VideoLlavaForConditionalGeneration._merge_input_ids_with_visual_features: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/video_llava/modeling_video_llava.py The only differences are: - add new args token_idx - add self.feature_offset param """ num_images, num_image_patches, embed_dim = visual_features.shape batch_size, sequence_length = input_ids.shape last_token_idx = token_idx + self.feature_offset left_padding = not torch.sum(input_ids[:, last_token_idx - 1] == torch.tensor(self.pad_token_id)) special_vision_token = self.config.video_token_index if num_frames > 1 else self.config.image_token_index # 1. Create a mask to know where special image tokens are special_image_token_mask = input_ids == special_vision_token num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) # Compute the maximum embed dimension max_seq_len = (num_special_image_tokens.max() * (num_image_patches * num_frames - 1)) + sequence_length self.feature_offset = self.feature_offset + max_seq_len - sequence_length batch_indices, non_image_indices = torch.where(input_ids != special_vision_token) # 2. Compute the positions where text should be written # Calculate new positions for text tokens in merged image-text sequence. # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. # `torch.cumsum` computes how each image token shifts subsequent text token positions. # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. new_token_positions = ( torch.cumsum((special_image_token_mask * (num_image_patches * num_frames - 1) + 1), dim=-1) - 1 ) nb_image_pad = max_seq_len - 1 - new_token_positions[:, -1] if left_padding: new_token_positions += nb_image_pad[:, None] # offset for left padding text_to_overwrite = new_token_positions[batch_indices, non_image_indices] # 3. Create the full embedding, already padded to the maximum position # expand input ids so that the second "merge" with videos does not fail final_embedding = torch.zeros( batch_size, max_seq_len, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device ) final_attention_mask = torch.zeros( batch_size, max_seq_len, dtype=attention_mask.dtype, device=inputs_embeds.device ) final_input_ids = torch.full( (batch_size, max_seq_len), self.pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device ) # In case the Vision model or the Language model has been offloaded to CPU, we need to manually # set the corresponding tensors into their correct target device. target_device = inputs_embeds.device batch_indices, non_image_indices, text_to_overwrite = ( batch_indices.to(target_device), non_image_indices.to(target_device), text_to_overwrite.to(target_device), ) attention_mask = attention_mask.to(target_device) # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"] # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices] if labels is not None: final_labels = torch.full( (batch_size, max_seq_len), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device ) final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] else: final_labels = None # 5. Fill the embeddings corresponding to the images. Anything that is still zeros needs filling image_to_overwrite = torch.full((batch_size, max_seq_len), True, dtype=torch.bool, device=inputs_embeds.device) image_to_overwrite[batch_indices, text_to_overwrite] = False image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) if image_to_overwrite.sum() != visual_features.shape[:-1].numel(): visual_type = "videos" if num_frames == 8 else "images" num_images //= num_frames raise ValueError( f"The input provided to the model are wrong. The number of {visual_type} tokens is {torch.sum(special_image_token_mask)} while" f" the number of {visual_type} given to the model is {num_images}. This prevents correct indexing and breaks batch generation." ) final_embedding[image_to_overwrite] = visual_features.contiguous().reshape(-1, embed_dim).to(target_device) final_attention_mask |= image_to_overwrite position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) return final_embedding, final_attention_mask, final_labels, position_ids, final_input_ids def _get_vision_features( self, pixel_values_images: Optional[torch.FloatTensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[int] = None, vision_feature_select_strategy: Optional[str] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: if pixel_values_images is None and pixel_values_videos is None: raise ValueError("You have to specify `pixel_values_images` or `pixel_values_videos`") # videos do not need to select features and it's always "full" (as it is done in the orig implementation) if pixel_values_videos is not None: batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width) video_outputs = self.video_tower(pixel_values, output_hidden_states=True) video_outputs = video_outputs.hidden_states[vision_feature_layer].squeeze(1) else: video_outputs = None num_frames = 0 if pixel_values_images is not None: image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True) image_outputs = image_outputs.hidden_states[vision_feature_layer].squeeze(1) if vision_feature_select_strategy == "default": image_outputs = image_outputs[:, 1:] elif vision_feature_select_strategy == "full": image_outputs = image_outputs else: raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") else: image_outputs = None return image_outputs, video_outputs, num_frames def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values_images: Optional[torch.FloatTensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, List[int]]] = None, vision_feature_select_strategy: Optional[str] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, token_idx: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple, VideoLlavaCausalLMOutputWithPast]: r""" Copied from VideoLlavaForConditionalGeneration.forward: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/video_llava/modeling_video_llava.py The only differences are: - add new args token_idx - add new args attn_softmax_bf16 - add new args reuse_cache - add new args use_flash_attention - add new args flash_attention_recompute - add new args flash_attention_causal_mask - add new args flash_attention_fast_softmax """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, logits_to_keep=0, token_idx=token_idx, **kwargs, ) logits = outputs[0] if logits.shape[1] > 1: logits = logits[:, self.feature_offset :, :] loss = None if labels is not None: # Shift so that tokens < n predict n if attention_mask is not None: # we use the input attention mask to shift the logits and labels, because it is 2D. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() else: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return VideoLlavaCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=kwargs.get("image_features", None) if pixel_values_images is not None else None, video_hidden_states=kwargs.get("video_features", None) if pixel_values_videos is not None else None, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values_images=None, pixel_values_videos=None, attention_mask=None, cache_position=None, logits_to_keep=None, **kwargs, ): token_idx = kwargs.get("token_idx", None) if token_idx is None: return super().prepare_inputs_for_generation( input_ids=input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, pixel_values_images=pixel_values_images, pixel_values_videos=pixel_values_videos, attention_mask=attention_mask, cache_position=cache_position, logits_to_keep=logits_to_keep, **kwargs, ) # Else, we need to update token_idx when merging features from videos/images with input embeddings labels = kwargs.get("labels", None) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if (pixel_values_images is not None or pixel_values_videos is not None) and inputs_embeds is not None: raise ValueError( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) legacy_processing = False inputs_not_expanded = False if input_ids is not None: img_token_not_enough = (input_ids == self.config.image_token_index).sum( 1 ).max() < self.config.image_seq_length video_token_not_enough = (input_ids == self.config.video_token_index).sum( 1 ).max() < self.config.video_seq_length # if the number of image/video tokens is more than image embeddings seq length, then prob we expanded it in processing # not very reliable, but we don't expect one to actually pass 500+ images for one prompt inputs_not_expanded = (img_token_not_enough and pixel_values_images is not None) or ( video_token_not_enough and pixel_values_videos is not None ) model_inputs = self.language_model.prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, cache_position=cache_position, logits_to_keep=logits_to_keep, **kwargs, ) position_ids = model_inputs["position_ids"] cache_position = model_inputs["cache_position"] attention_mask = model_inputs["attention_mask"] inputs_embeds = model_inputs.get("inputs_embeds", None) input_ids = model_inputs.get("input_ids", None) if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) pixels_present = input_ids.shape[-1] == 1 and ( pixel_values_images is not None or pixel_values_videos is not None ) legacy_processing = inputs_not_expanded or pixels_present vision_feature_layer = kwargs.get("vision_feature_layer", None) vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) vision_feature_select_strategy = kwargs.get("vision_feature_select_strategy", None) vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) if pixel_values_images is not None or pixel_values_videos is not None: image_outputs, video_outputs, num_frames = self._get_vision_features( pixel_values_images=pixel_values_images, pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) image_features = video_features = None if image_outputs is not None: image_features = self.multi_modal_projector(image_outputs) if video_outputs is not None: video_features = self.multi_modal_projector(video_outputs) if legacy_processing: logger.warning_once( "Expanding inputs for image tokens in Video-LLaVa should be done in processing. " "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." ) if input_ids.shape[1] != 1: self.feature_offset = 0 for features, frames in ((image_features, 1), (video_features, num_frames)): if features is not None: ( inputs_embeds, attention_mask, labels, position_ids, input_ids, ) = self._merge_input_ids_with_visual_features( features, inputs_embeds, input_ids, attention_mask, labels, token_idx, num_frames=frames, ) cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) target_length = input_ids.shape[1] past_length = first_layer_past_key_value.shape[-1] extended_attention_mask = torch.ones( (attention_mask.shape[0], past_length), dtype=attention_mask.dtype, device=attention_mask.device, ) # Filter out only the tokens that can be un-attended, this can happen # if one uses Llava + Fused modules where the cache on the # first iteration is already big enough, or if one passes custom cache valid_indices = non_attended_tokens < extended_attention_mask.size(-1) new_batch_index = batch_index[valid_indices] new_non_attended_tokens = non_attended_tokens[valid_indices] # Zero-out the places where we don't need to attend extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 new_token_idx = token_idx + self.feature_offset extended_attention_mask[:, new_token_idx - 1 + target_length :] = 0 attention_mask = extended_attention_mask.clone() position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 cache_position = new_token_idx # TODO: @raushan retain only the new behavior after v4.47 else: if image_outputs is not None: special_image_mask = ( (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) if video_outputs is not None: special_image_mask = ( (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) ) video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "attention_mask": attention_mask, "token_idx": token_idx + self.feature_offset, "inputs_embeds": inputs_embeds, } ) if legacy_processing or (cache_position is not None and cache_position[0]) == 0: # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore # Otherwise we need pixel values to be passed to model model_inputs["pixel_values_images"] = pixel_values_images model_inputs["pixel_values_videos"] = pixel_values_videos model_inputs["image_features"] = image_features model_inputs["video_features"] = video_features return model_inputs