optimum/habana/transformers/models/idefics2/modeling_idefics2.py (383 lines of code) (raw):

# coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Idefics2 model.""" from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.models.idefics2.modeling_idefics2 import ( Idefics2BaseModelOutputWithPast, Idefics2CausalLMOutputWithPast, Idefics2ForConditionalGeneration, Idefics2Model, Idefics2VisionEmbeddings, ) from transformers.utils import logging logger = logging.get_logger(__name__) class GaudiIdefics2VisionEmbeddings(Idefics2VisionEmbeddings): def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: """ Inherits from Idefics2VisionEmbeddings::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L159 The only differences are: - add int() in nb_patches_h. nb_patches_w to avoid overflow in torch.arange. sometimes return shape is nb_patches_h/nb_patch_w + 1 - delete to("cpu") of p_attn_mask """ batch_size, _, max_im_h, max_im_w = pixel_values.shape patch_embeds = self.patch_embedding(pixel_values) embeddings = patch_embeds.flatten(2).transpose(1, 2) max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) position_ids = torch.full( size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0, device=self.position_embedding.weight.device, ) for batch_idx, p_attn_mask in enumerate(patch_attention_mask): nb_patches_h = int(p_attn_mask[:, 0].sum()) nb_patches_w = int(p_attn_mask[0].sum()) fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids.to(self.position_embedding.weight.device) embeddings = embeddings + self.position_embedding(position_ids) return embeddings class GaudiIdefics2Model(Idefics2Model): def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, image_hidden_states: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, Idefics2BaseModelOutputWithPast]: """ Inherits from Idefics2Model::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1303 The only differences are: - ignoring new Cache path for HPU - unfold is not supported in HPU, replace with conv2d """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.training and self.text_model.gradient_checkpointing and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False # retrieve input_ids and inputs_embeds if input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") past_seen_tokens = 0 return_legacy_cache = True use_new_cache = False # Ignoring new Cache path for HPU if use_cache and use_new_cache: if not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() else: if past_key_values is not None: past_seen_tokens = past_key_values[0][0].shape[2] if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") if inputs_embeds is None: inputs_embeds = self.text_model.get_input_embeddings()(input_ids) # START VISUAL INPUTS INTEGRATION if pixel_values is not None and image_hidden_states is not None: raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") elif pixel_values is not None: batch_size, num_images, num_channels, height, width = pixel_values.shape pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) # Remove padding images - padding images are full 0. nb_values_per_image = pixel_values.shape[1:].numel() real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image pixel_values = pixel_values[real_images_inds].contiguous() # Handle the vision attention mask if pixel_attention_mask is None: pixel_attention_mask = torch.ones( size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)), dtype=torch.bool, device=pixel_values.device, ) else: # Remove padding images from the mask/pP p pixel_attention_mask = pixel_attention_mask.view( batch_size * num_images, *pixel_attention_mask.shape[2:] ) pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() patch_size = self.config.vision_config.patch_size conv_kernel = torch.ones( [1, 1, patch_size, patch_size], dtype=pixel_values.dtype, device=pixel_values.device ) patches_subgrid = torch.nn.functional.conv2d( pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), conv_kernel, stride=patch_size ).squeeze(1) patch_attention_mask = torch.eq(patches_subgrid, (patch_size * patch_size)) # Get sequence from the vision encoder image_hidden_states = self.vision_model( pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, ).last_hidden_state # Modality projection & resampling image_hidden_states = self.connector( image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1) ) elif image_hidden_states is not None: image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None: # When we generate, we don't want to replace the potential image_token_id that we generated by images # that simply don't exist inputs_embeds = self.inputs_merger( input_ids=input_ids, inputs_embeds=inputs_embeds, image_hidden_states=image_hidden_states, ) outputs = self.text_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, return_dict=return_dict, ) if return_legacy_cache and use_cache: if return_dict: outputs.past_key_values = ( outputs.past_key_values.to_legacy_cache() if isinstance(outputs.past_key_values, Cache) else outputs.past_key_values ) else: outputs[1] = outputs[1].to_legacy_cache() if isinstance(outputs[1], Cache) else outputs[1] if not return_dict: return tuple(v for v in [*outputs, image_hidden_states] if v is not None) return Idefics2BaseModelOutputWithPast( last_hidden_state=outputs.last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_hidden_states, ) def inputs_merger( self, input_ids: torch.LongTensor, inputs_embeds: Optional[torch.Tensor], image_hidden_states: Optional[torch.Tensor], ): """ Inherits from Idefics2Model::inputs_merger https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1268 The only differences are: - replace `==` with torch.where to fix the issue in hpu graph """ num_images, _, vision_hidden_size = image_hidden_states.shape special_image_token_mask = torch.where(input_ids == self.image_token_id) new_inputs_embeds = inputs_embeds.clone() reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size) new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states.to(new_inputs_embeds.device) return new_inputs_embeds class GaudiIdefics2ForConditionalGeneration(Idefics2ForConditionalGeneration): def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None, image_hidden_states: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, Idefics2CausalLMOutputWithPast]: """ Inherits from Idefics2ForConditionalGeneration::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1505 The only differences are: - add new args token_idx """ if token_idx is not None: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) if input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") past_seen_tokens = 0 return_legacy_cache = True use_new_cache = False # Ignoring new Cache path for HPU if use_cache and use_new_cache: if not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() else: if past_key_values is not None: past_seen_tokens = past_key_values[0][0].shape[2] if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: raise ValueError( "When first calling the model, if input_embeds are passed, input_ids should not be None." ) if inputs_embeds is None: inputs_embeds = self.model.text_model.get_input_embeddings()(input_ids) # START VISUAL INPUTS INTEGRATION if pixel_values is not None and image_hidden_states is not None: raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") elif image_hidden_states is not None: image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None: # When we generate, we don't want to replace the potential image_token_id that we generated by images # that simply don't exist inputs_embeds = self.model.inputs_merger( input_ids=input_ids, inputs_embeds=inputs_embeds, image_hidden_states=image_hidden_states, ) outputs = self.model.text_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, ) if return_legacy_cache and use_cache: if return_dict: outputs.past_key_values = ( outputs.past_key_values.to_legacy_cache() if isinstance(outputs.past_key_values, Cache) else outputs.past_key_values ) else: outputs[1] = outputs[1].to_legacy_cache() if isinstance(outputs[1], Cache) else outputs[1] hidden_states = outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: labels = labels.to(logits.device) # Shift so that tokens < n predict n if attention_mask is not None: shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device) shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous() shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous() else: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return Idefics2CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_hidden_states, ) else: return super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask, image_hidden_states=image_hidden_states, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, return_dict=return_dict, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): """ Inherits from Idefics2ForConditionalGeneration::forward https://github.com/huggingface/transformers/blob/v4.43.4/src/transformers/models/idefics2/modeling_idefics2.py#L1622 The only differences are: - add new args token_idx - add None "Cache" past_key_values support - move vision_model to prepare_input_for_generation """ past_length = 0 token_idx = kwargs.get("token_idx", None) # Omit tokens covered by past_key_values if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore if isinstance(past_key_values, Cache): past_length = past_key_values.get_seq_length() max_cache_length = past_key_values.get_max_length() else: past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) elif attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( max_cache_length is not None and attention_mask is not None and past_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: if token_idx is not None: position_ids = torch.index_select(position_ids, 1, token_idx - 1) else: position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} image_hidden_states = kwargs.get("image_hidden_states", None) if image_hidden_states is not None: pixel_values = None pixel_attention_mask = None else: pixel_values = kwargs.get("pixel_values", None) pixel_attention_mask = kwargs.get("pixel_attention_mask", None) if token_idx is not None and pixel_values is not None: batch_size, num_images, num_channels, height, width = pixel_values.shape pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) # Remove padding images - padding images are full 0. nb_values_per_image = pixel_values.shape[1:].numel() real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image pixel_values = pixel_values[real_images_inds].contiguous() # Handle the vision attention mask if pixel_attention_mask is None: pixel_attention_mask = torch.ones( size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)), dtype=torch.bool, device=pixel_values.device, ) else: # Remove padding images from the mask/pP p pixel_attention_mask = pixel_attention_mask.view( batch_size * num_images, *pixel_attention_mask.shape[2:] ) pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() patch_size = self.config.vision_config.patch_size conv_kernel = torch.ones( [1, 1, patch_size, patch_size], dtype=pixel_values.dtype, device=pixel_values.device ) patches_subgrid = torch.nn.functional.conv2d( pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype), conv_kernel, stride=patch_size ).squeeze(1) patch_attention_mask = torch.eq(patches_subgrid, (patch_size * patch_size)) # Get sequence from the vision encoder image_hidden_states = self.model.vision_model( pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, ).last_hidden_state # Modality projection & resampling image_hidden_states = self.model.connector( image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1) ) pixel_values = None pixel_attention_mask = None model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "pixel_values": pixel_values, "pixel_attention_mask": pixel_attention_mask, "image_hidden_states": image_hidden_states, "token_idx": token_idx, } ) return model_inputs