optimum/habana/transformers/models/paligemma/modeling_paligemma.py (121 lines of code) (raw):

# coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Paligemma model.""" from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.paligemma.modeling_paligemma import ( PaliGemmaCausalLMOutputWithPast, PaliGemmaForConditionalGeneration, ) from transformers.utils import is_torchdynamo_compiling, logging logger = logging.get_logger(__name__) class GaudiPaliGemmaForConditionalGeneration(PaliGemmaForConditionalGeneration): def forward( self, input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None, token_type_ids: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, token_idx: Optional[torch.Tensor] = None, **lm_kwargs, ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]: """ Inherits from PaliGemmaForConditionalGeneration::forward https://github.com/huggingface/transformers/blob/v4.45.1/src/transformers/models/paligemma/modeling_paligemma.py#L402 The only differences are: - add new args token_idx """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict is_training = token_type_ids is not None and labels is not None # Replace image id woth PAD if the image token if OOV, to avoid index-errors if input_ids is not None and self.config.image_token_index >= self.vocab_size: special_image_mask = input_ids == self.config.image_token_index llm_input_ids = input_ids.clone() llm_input_ids[special_image_mask] = 0 else: llm_input_ids = input_ids if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(llm_input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed # Merge text and images if pixel_values is not None: image_features = self.get_image_features(pixel_values) if input_ids is None: special_image_mask = inputs_embeds == self.get_input_embeddings()( torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) ) else: special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] raise ValueError( f"Number of images does not match number of special image tokens in the input text. " f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " "tokens from image embeddings." ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # mask out pad-token-ids in labels for BC if labels is not None and self.pad_token_id in labels: logger.warning_once( "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", ) labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) causal_mask = self._update_causal_mask( attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training ) outputs: CausalLMOutputWithPast = self.language_model( attention_mask=causal_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, cache_position=cache_position, # TODO: from Transformers v4.45, `generate` sets `num_logits_to_keep` to 1 if not given, which we don't want here # logits_to_keep=logits_to_keep, token_idx=token_idx, **lm_kwargs, ) logits = outputs[0] loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] if attention_mask is not None: # we use the input attention mask to shift the logits and labels, because it is 2D. # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device) shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous() shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous() else: shift_logits = shift_logits.contiguous() shift_labels = shift_labels.contiguous() # Flatten the tokens loss_fct = nn.CrossEntropyLoss() flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) flat_labels = shift_labels.view(-1).to(shift_logits.device) loss = loss_fct(flat_logits, flat_labels) output = PaliGemmaCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) return output if return_dict else output.to_tuple()