optimum/habana/transformers/models/blip/modeling_blip_text.py (413 lines of code) (raw):

import math from typing import List, Optional, Tuple, Union import torch from torch import nn from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, ) from transformers.modeling_utils import apply_chunking_to_forward from transformers.utils import logging logger = logging.get_logger(__name__) def gaudi_BlipTextSelfAttention_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: """ Copied from BlipTextSelfAttention.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/blip/modeling_blip_text.py#L143 The only differences are: - add token_idx """ mixed_query_layer = self.query(hidden_states) # If this is instantiated as a cross-attention module, the keys # and values come from an encoder; the attention mask needs to be # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None if is_cross_attention: key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) if token_idx is not None: past_key_value[0].index_copy_(2, token_idx - 1, key_layer) past_key_value[1].index_copy_(2, token_idx - 1, value_layer) key_layer = past_key_value[0] value_layer = past_key_value[1] else: key_layer = torch.cat([past_key_value[0], key_layer], dim=2) value_layer = torch.cat([past_key_value[1], value_layer], dim=2) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) past_key_value = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": seq_length = hidden_states.size()[1] position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1) distance = position_ids_l - position_ids_r positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility if self.position_embedding_type == "relative_key": relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) attention_scores = attention_scores + relative_position_scores elif self.position_embedding_type == "relative_key_query": relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BlipTextModel forward() function) attention_scores = attention_scores + attention_mask.to(attention_scores.device) # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs_dropped = self.dropout(attention_probs) # Mask heads if we want to if head_mask is not None: attention_probs_dropped = attention_probs_dropped * head_mask context_layer = torch.matmul(attention_probs_dropped, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = outputs + (past_key_value,) return outputs def gaudi_BlipTextAttention_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: """ Copied from BlipTextAttention.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/blip/modeling_blip_text.py#L265 The only differences are: - add token_idx """ self_outputs = self.self( hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, token_idx=token_idx, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs def gaudi_BlipTextLayer_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor]: """ Copied from BlipTextLayer.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/blip/modeling_blip_text.py#L333 The only differences are: - add token_idx """ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, past_key_value=self_attn_past_key_value, token_idx=token_idx, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:-1] present_key_value = self_attention_outputs[-1] if encoder_hidden_states is not None: cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, output_attentions=output_attentions, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs outputs = outputs + (present_key_value,) return outputs def gaudi_BlipTextEncoder_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: """ Copied from BlipTextEncoder.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/blip/modeling_blip_text.py#L391 The only differences are: - add token_idx """ if self.gradient_checkpointing and self.training: if use_cache: logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.is_decoder else None next_decoder_cache = () if use_cache else None for i in range(self.config.num_hidden_layers): layer_module = self.layer[i] if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, ) else: layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, token_idx=token_idx, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [ hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, ] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) def gaudi_BlipTextModel_forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, is_decoder: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: """ Copied from BlipTextModel.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/blip/modeling_blip_text.py#L666 The only differences are: - add token_idx """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if is_decoder: use_cache = use_cache if use_cache is not None else self.config.use_cache else: use_cache = False if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() batch_size, seq_length = input_shape device = input_ids.device elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] batch_size, seq_length = input_shape device = inputs_embeds.device elif encoder_embeds is not None: input_shape = encoder_embeds.size()[:-1] batch_size, seq_length = input_shape device = encoder_embeds.device else: raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds") # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length))).to(device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( attention_mask, input_shape, device, is_decoder ) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if encoder_hidden_states is not None: if isinstance(encoder_hidden_states, list): encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size() else: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if isinstance(encoder_attention_mask, list): encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask] elif encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) if encoder_embeds is None: embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) else: embedding_output = encoder_embeds encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, ) def gaudi_BlipTextLMHead_forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.Tensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, return_logits: Optional[bool] = False, is_decoder: Optional[bool] = True, reduction: Optional[str] = "mean", token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: """ Copied from BlipTextLMHeadModel.forward: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/blip/modeling_blip_text.py#L820 The only differences are: - add token_idx """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: use_cache = False outputs = self.bert( input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, is_decoder=is_decoder, token_idx=token_idx, ) sequence_output = outputs[0] prediction_scores = self.cls(sequence_output) if return_logits: return prediction_scores[:, :-1, :].contiguous() lm_loss = None if labels is not None: # we are doing next-token prediction; shift prediction scores and input ids by one shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() labels = labels[:, 1:].contiguous().to(shifted_prediction_scores.device) loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=self.label_smoothing) lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) if reduction == "none": lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) if not return_dict: output = (prediction_scores,) + outputs[2:] return ((lm_loss,) + output) if lm_loss is not None else output return CausalLMOutputWithCrossAttentions( loss=lm_loss, logits=prediction_scores, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) def gaudi_BlipTextLMHead_prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, token_idx=None, **model_kwargs ): """ Copied from BlipTextLMHeadModel.prepare_inputs_for_generation: https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/blip/modeling_blip_text.py#L910 The only differences are: - add token_idx support, add position_ids """ input_shape = input_ids.shape # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) # cut decoder_input_ids if past_key_values is used if past_key_values is not None: if token_idx is not None: idx = token_idx + model_kwargs.get("inputs_embeds_offset", 0) - 1 input_ids = torch.index_select(input_ids, 1, idx) else: past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: remove_prefix_length = past_length else: # Default to old behavior: keep only final ID remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] position_ids = None if token_idx is not None and attention_mask is not None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = torch.index_select(position_ids, 1, token_idx - 1) return { "input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), "is_decoder": True, "token_idx": token_idx, "position_ids": position_ids, }