optimum/habana/transformers/models/opt/modeling_opt.py (413 lines of code) (raw):

from typing import List, Optional, Tuple, Union import torch from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.opt.configuration_opt import OPTConfig from transformers.models.opt.modeling_opt import ( OPT_ATTENTION_CLASSES, OPTForCausalLM, OPTLearnedPositionalEmbedding, logger, ) from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask class GaudiOPTLearnedPositionalEmbedding(OPTLearnedPositionalEmbedding): """ Inherits from OPTLearnedPositionalEmbedding: https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py The only differences are: - add new args token_idx - compute embedding using token_idx if past_key_values_length not 0 """ def forward( self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, position_ids: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ): attention_mask = attention_mask.long() if past_key_values_length == 0: # first step or kv cache disabled positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 positions = positions[:, past_key_values_length:] return torch.nn.Embedding.forward(self, positions + self.offset) else: # if not 0, kv cache is enabled and from step = 2, past_key_values_length is equal to the final length of outputs return torch.nn.Embedding.forward(self, token_idx + self.offset) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor: return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def gaudi_opt_attention_forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, # isn't needed in normal attention, but needed in flash attention so to keep the signature same position_ids: Optional[torch.Tensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Copied from OPTAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py The only differences are: - add new args token_idx - optimize KV cache """ # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj if is_cross_attention and past_key_value is not None: # reuse k,v, cross_attentions key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: # cross_attentions key_states = _shape(self, self.k_proj(key_value_states), -1, bsz) value_states = _shape(self, self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: # reuse k, v, self_attention key_states = _shape(self, self.k_proj(hidden_states), -1, bsz) value_states = _shape(self, self.v_proj(hidden_states), -1, bsz) if token_idx is not None: past_key_value[0].index_copy_(2, token_idx - 1, key_states) past_key_value[1].index_copy_(2, token_idx - 1, value_states) key_states = past_key_value[0] value_states = past_key_value[1] else: key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention key_states = _shape(self, self.k_proj(hidden_states), -1, bsz) value_states = _shape(self, self.v_proj(hidden_states), -1, bsz) past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = _shape(self, query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.view(*proj_shape) value_states = value_states.view(*proj_shape) src_len = key_states.size(1) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = torch.max( attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if output_attentions: # this operation is a bit awkward, but it's required to # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to be reshaped # twice and have to be reused in the following attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) else: attn_weights_reshaped = None attn_probs = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned aross GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped, past_key_value class GaudiOPTDecoderLayer(torch.nn.Module): def __init__(self, config: OPTConfig, layer_idx: Optional[int] = None): """ Attention implementation is set to "eager" (default in Transformers is "sdpa"). """ super().__init__() self.embed_dim = config.hidden_size self.self_attn = OPT_ATTENTION_CLASSES["eager"](config=config, layer_idx=layer_idx) self.do_layer_norm_before = config.do_layer_norm_before self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.self_attn_layer_norm = torch.nn.LayerNorm( self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine ) self.fc1 = torch.nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) self.fc2 = torch.nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) self.final_layer_norm = torch.nn.LayerNorm( self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, position_ids: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Copied from OPTDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py The only differences are: - add new args token_idx """ residual = hidden_states # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention if self.do_layer_norm_before: hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=past_key_value, position_ids=position_ids, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, token_idx=token_idx, ) hidden_states = torch.nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # 350m applies layer norm AFTER attention if not self.do_layer_norm_before: hidden_states = self.self_attn_layer_norm(hidden_states) # Fully Connected hidden_states_shape = hidden_states.shape hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) residual = hidden_states # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention if self.do_layer_norm_before: hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = self.fc2(hidden_states) hidden_states = torch.nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = (residual + hidden_states).view(hidden_states_shape) # 350m applies layer norm AFTER attention if not self.do_layer_norm_before: hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs def gaudi_opt_decoder_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from OPTDecoder.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py The only differences are: - add new args token_idx - update calculation of mask_seq_length """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = input_shape if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] mask_seq_length = past_key_values_length else: past_key_values_length = 0 mask_seq_length = seq_length # embed positions # 4d mask is passed through the layers if attention_mask is None: attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) elif attention_mask.shape[1] != mask_seq_length: raise ValueError( f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " f"{mask_seq_length} (sum of the lengths of current and past inputs)" ) causal_attention_mask = _gaudi_prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) pos_embeds = self.embed_positions(attention_mask, past_key_values_length, position_ids, token_idx) if self.project_in is not None: inputs_embeds = self.project_in(inputs_embeds) hidden_states = inputs_embeds + pos_embeds if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = () if use_cache else None # check if head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask], ["head_mask"]): if attn_mask is not None: if attn_mask.size()[0] != (len(self.layers)): raise ValueError( f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) if output_hidden_states: all_hidden_states += (hidden_states,) if self.training: dropout_probability = torch.rand([]) if dropout_probability < self.layerdrop: continue past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_attention_mask, head_mask[idx] if head_mask is not None else None, None, output_attentions, use_cache, position_ids, None, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_attention_mask, position_ids=position_ids, layer_head_mask=(head_mask[idx] if head_mask is not None else None), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, token_idx=token_idx, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) if self.final_layer_norm is not None: hidden_states = self.final_layer_norm(hidden_states) if self.project_out is not None: hidden_states = self.project_out(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) def gaudi_opt_model_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from OPTModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py The only differences are: - add new args token_idx """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, ) if not return_dict: return decoder_outputs return BaseModelOutputWithPast( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, hidden_states=decoder_outputs.hidden_states, attentions=decoder_outputs.attentions, ) class GaudiOPTForCausalLM(OPTForCausalLM): """ Inherits from OPTForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py The only differences are: - add new args token_idx - add token_idx into model_inputs - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx """ def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, position_ids: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model.decoder( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, ) logits = self.lm_head(outputs[0]).contiguous() loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(logits.device) loss = self.loss_function( logits, labels, vocab_size=self.config.vocab_size, **kwargs, ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, token_idx=None, inputs_embeds=None, **kwargs ): if past_key_values is not None: if token_idx is not None: idx = token_idx + kwargs.get("inputs_embeds_offset", 0) - 1 input_ids = torch.index_select(input_ids, 1, idx) else: past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: remove_prefix_length = past_length else: # Default to old behavior: keep only final ID remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "token_idx": token_idx, } ) return model_inputs