optimum/habana/transformers/models/xglm/modeling_xglm.py (375 lines of code) (raw):

from typing import List, Optional, Tuple, Union import torch from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from transformers.models.xglm.modeling_xglm import XGLMForCausalLM from transformers.utils import logging from ...modeling_attn_mask_utils import ( _gaudi_prepare_4d_causal_attention_mask, ) logger = logging.get_logger(__name__) def gaudi_xglm_attention_forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, token_idx: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Copied from XGLMAttention.forward: https://github.com/huggingface/transformers/blob/v4.44.1/src/transformers/models/xglm/modeling_xglm.py The only differences are: - add new args token_idx - optimize KV cache """ # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if token_idx is not None: past_key_value[0].index_copy_(2, token_idx - 1, key_states) past_key_value[1].index_copy_(2, token_idx - 1, value_states) key_states = past_key_value[0] value_states = past_key_value[1] else: key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = torch.max( attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 if attn_weights.dtype == torch.float16: attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) else: attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if output_attentions: # this operation is a bit awkward, but it's required to # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to be reshaped # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned aross GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped, past_key_value def gaudi_xglm_decoder_layer_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, token_idx: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Copied from XGLMDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.44.1/src/transformers/models/xglm/modeling_xglm.py The only differences are: - add new args token_idx """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, token_idx=token_idx, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # add cross-attn to positions 3,4 of present_key_value tuple present_key_value = present_key_value + cross_attn_present_key_value # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) if use_cache: outputs += (present_key_value,) return outputs def gaudi_xglm_model_forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: """ Copied from XGLMModel.forward: https://github.com/huggingface/transformers/blob/v4.44.1/src/transformers/models/xglm/modeling_xglm.py The only differences are: - add new args token_idx - replace _prepare_4d_causal_attention_mask with _gaudi_prepare_4d_causal_attention_mask """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if position_ids is None: position_ids = torch.arange( past_key_values_length, input_shape[-1] + past_key_values_length, dtype=torch.long, device=input_ids.device if input_ids is not None else inputs_embeds.device, ) position_ids = position_ids.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) attention_mask = _gaudi_prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _gaudi_prepare_4d_causal_attention_mask( encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ) hidden_states = inputs_embeds + self.embed_positions(position_ids, past_key_values_length).to(inputs_embeds.device) hidden_states = nn.functional.dropout(hidden_states, p=float(self.dropout), training=self.training) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache = True` is incompatible with gradient checkpointing`. Setting `use_cache = False`..." ) use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None next_decoder_cache = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): if attn_mask is not None: if attn_mask.size()[0] != len(self.layers): raise ValueError( f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: all_hidden_states += (hidden_states,) if self.training: dropout_probability = torch.rand([]) if dropout_probability < self.layerdrop: continue past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, None, output_attentions, use_cache, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, token_idx=token_idx, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) if encoder_hidden_states is not None: all_cross_attentions += (layer_outputs[2],) hidden_states = self.layer_norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple( v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, ) class GaudiXGLMForCausalLM(XGLMForCausalLM): def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: """ Inherits from XGLMForCausalLM: https://github.com/huggingface/transformers/blob/v4.44.1/src/transformers/models/xglm/modeling_xglm.py The only differences are: - add new args token_idx """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, ) logits = self.lm_head(outputs[0]) loss = None if labels is not None: loss = self.loss_function( logits, labels, vocab_size=self.config.vocab_size, pad_token_id=self.config.pad_token_id, **kwargs, ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithCrossAttentions( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs ): """ Inherits from XGLMForCausalLM: https://github.com/huggingface/transformers/blob/v4.44.1/src/transformers/models/xglm/modeling_xglm.py The only differences are: - add new args token_idx - add token_idx into model_inputs - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx """ token_idx = kwargs.get("token_idx", None) if past_key_values is not None: if token_idx is None: past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: remove_prefix_length = past_length else: # Default to old behavior: keep only final ID remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] else: input_ids = torch.index_select(input_ids, 1, token_idx - 1) position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: if token_idx is not None: position_ids = torch.index_select(position_ids, 1, token_idx - 1) else: position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: attention_mask = input_ids.new_ones(input_ids.shape) # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed "attention_mask": attention_mask, "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": use_cache, "token_idx": token_idx, }