optimum/habana/transformers/models/bloom/modeling_bloom.py (432 lines of code) (raw):

# coding=utf-8 # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ############################################################################### # Copyright (C) 2022-2023 Habana Labs, Ltd. an Intel Company ############################################################################### import math import os import warnings from typing import Optional, Tuple, Union import torch from torch.nn import functional as F from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from transformers.models.bloom.modeling_bloom import BloomForCausalLM, BloomMLP, dropout_add from transformers.utils import logging from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask logger = logging.get_logger(__name__) def gaudi_bloom_build_alibi_tensor( attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype, training: bool ) -> torch.Tensor: """ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value `softmax(l+a) = softmax(l)`. Based on https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. Args: Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) attention_mask (`torch.Tensor`): Token-wise attention mask, this should be of shape (batch_size, max_seq_len). num_heads (`int`): Number of heads. dtype (`torch.dtype`): Dtype of the output tensor. training (`bool`): Whether the model is being trained or not. """ batch_size, seq_length = attention_mask.shape closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) base = torch.tensor( 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 ) powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != num_heads: extra_base = torch.tensor( 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 ) num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) # Note: alibi will added to the attention bias that will be applied to the query, key product of attention # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) # => the query_length dimension will then be broadcasted correctly # This is more or less identical to T5's relative position bias: # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 if training: arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] alibi = slopes[..., None] * arange_tensor return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) else: # code taken from Megatron transformer.py alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(seq_length, device=attention_mask.device).unsqueeze( 0 ).unsqueeze(0).expand(num_heads, -1, -1) # Select the part of the tensor that corresponds to our tensor parallel index. # if inference_tp_size is set use it instead of world size world = int(os.environ.get("WORLD_SIZE", 1)) tp_world_size = GaudiBloomForCausalLM.inference_tp_size if GaudiBloomForCausalLM.inference_tp_size else world tp_index = 0 # if world size == 1 ignore rank and use 0 (for cases where WORLD_SIZE is not equal to tp size) if tp_world_size > 1: tp_index = int(os.environ.get("RANK", 0)) alibi = alibi.reshape((tp_world_size, -1, *alibi.shape[1:]))[tp_index] alibi = alibi.repeat(batch_size, 1, 1) return alibi.to(dtype) def update(prev, cur, dim, idx): if idx is not None: if os.environ.get("WA_INDEX_COPY", "1") == "1": past_selector, value_selector = idx if dim == 1: sel = torch.cat([past_selector, value_selector.unsqueeze(2)], dim=2) val = torch.cat([prev, cur], dim=1) return torch.bmm(sel, val) else: sel = torch.cat([past_selector, value_selector.unsqueeze(1)], dim=1) val = torch.cat([prev, cur], dim=2) return torch.bmm(val, sel) else: return prev.index_copy_(dim, idx - 1, cur) else: return torch.cat((prev, cur), dim=dim) def gaudi_bloom_attention_forward( self, hidden_states: torch.Tensor, residual: torch.Tensor, alibi: torch.Tensor, attention_mask: torch.Tensor, layer_past: Optional[Cache] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ): batch_size, q_length, _ = hidden_states.shape fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] # 3 x [batch_size, num_heads, seq_length, head_dim] query_layer, key_layer, value_layer = self._reshape(fused_qkv) query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(1, 2) value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) # Collapse views to improve performance on HPU query_layer = query_layer.contiguous() key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() if layer_past is not None: past_key, past_value = layer_past # concatenate along seq_length dimension: # - key: [batch_size * self.num_heads, head_dim, kv_length] # - value: [batch_size * self.num_heads, kv_length, head_dim] key_layer = update(past_key, key_layer, 2, token_idx) value_layer = update(past_value, value_layer, 1, token_idx) _, _, kv_length = key_layer.shape if use_cache is True: present = (key_layer, value_layer) else: present = None # [batch_size * num_heads, q_length, kv_length] attention_scores = alibi.baddbmm( batch1=query_layer, batch2=key_layer, beta=self.beta, alpha=self.inv_norm_factor, ) # change view to [batch_size, num_heads, q_length, kv_length] attention_scores = attention_scores.view(batch_size, self.num_heads, q_length, -1) # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] input_dtype = attention_scores.dtype attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) # [batch_size, num_heads, q_length, kv_length] attention_probs = self.attention_dropout(attention_probs) if head_mask is not None: attention_probs = attention_probs * head_mask # change view [batch_size x num_heads, q_length, kv_length] attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1) # matmul: [batch_size * num_heads, q_length, head_dim] context_layer = torch.bmm(attention_probs_reshaped, value_layer) # change view [batch_size, q_length, num_heads * head_dim] context_layer = self._merge_heads(context_layer) # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 if self.pretraining_tp > 1 and self.slow_but_exact: slices = self.hidden_size / self.pretraining_tp output_tensor = torch.zeros_like(context_layer) for i in range(self.pretraining_tp): output_tensor = output_tensor + F.linear( context_layer[:, :, int(i * slices) : int((i + 1) * slices)], self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], ) else: output_tensor = self.dense(context_layer) output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) outputs = (output_tensor, present) if output_attentions: outputs += (attention_probs,) return outputs class GaudiBloomMLP(BloomMLP): def __init__(self, config): super().__init__(config) self.gelu_impl = torch.nn.GELU(approximate="tanh") def gaudi_bloom_block_forward( self, hidden_states: torch.Tensor, alibi: torch.Tensor, attention_mask: torch.Tensor, layer_past: Optional[Cache] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, ): # hidden_states: [batch_size, seq_length, hidden_size] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Layer norm post the self attention. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states # Self attention. attn_outputs = self.self_attention( layernorm_output, residual, layer_past=layer_past, attention_mask=attention_mask, alibi=alibi, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, token_idx=token_idx, ) attention_output = attn_outputs[0] outputs = attn_outputs[1:] layernorm_output = self.post_attention_layernorm(attention_output) # Get residual if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = attention_output # MLP. output = self.mlp(layernorm_output, residual) if use_cache: outputs = (output,) + outputs else: outputs = (output,) + outputs[1:] return outputs # hidden_states, present, attentions def gaudi_bloom_convert_to_standard_cache( self, past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int, training: bool ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: """ Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...])) """ batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape if training: num_heads = batch_size_times_num_heads // batch_size else: world = int(os.environ.get("WORLD_SIZE", 1)) tp_world_size = GaudiBloomForCausalLM.inference_tp_size if GaudiBloomForCausalLM.inference_tp_size else world num_heads = self.config.n_head // tp_world_size batch_size = batch_size_times_num_heads // num_heads # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] return tuple( ( layer_past[0].view(batch_size, num_heads, head_dim, seq_length), layer_past[1].view(batch_size, num_heads, seq_length, head_dim), ) for layer_past in past_key_value ) def gaudi_bloom_convert_to_bloom_cache( self, past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: """ Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) """ batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape batch_size_times_num_heads = batch_size * num_heads # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] return tuple( ( layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), ) for layer_past in past_key_value ) def gaudi_bloom_model_forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop("position_ids", False) is not False: # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` warnings.warn( "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.", FutureWarning, ) if len(deprecated_arguments) > 0: raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You must specify exactly one of input_ids or inputs_embeds") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") if past_key_values is None: past_key_values = tuple([None] * len(self.h)) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) hidden_states = self.word_embeddings_layernorm(inputs_embeds) presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False # Compute alibi tensor: check gaudi_bloom_build_alibi_tensor seq_length_with_past = seq_length past_key_values_length = 0 if past_key_values[0] is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) else: attention_mask = attention_mask.to(hidden_states.device) alibi = gaudi_bloom_build_alibi_tensor(attention_mask, self.num_heads, hidden_states.dtype, self.training) causal_mask = _gaudi_prepare_4d_causal_attention_mask( attention_mask, input_shape=(batch_size, seq_length), inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) causal_mask = causal_mask.bool() if token_idx is not None and past_key_values[0] is not None and os.environ.get("WA_INDEX_COPY", "1") == "1": pkv = past_key_values[0][0] cur = torch.nn.functional.one_hot(torch.tile(token_idx - 1, (pkv.shape[0],)), pkv.shape[-1]).to(pkv.dtype) past = torch.diag_embed(1 - cur) token_idx = (past, cur) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, alibi, causal_mask, layer_past, head_mask[i], use_cache, output_attentions, cache_position, None, ) else: outputs = block( hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, cache_position=cache_position, token_idx=token_idx, ) hidden_states = outputs[0] if use_cache is True: presents = presents + (outputs[1],) if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) # Add last hidden state hidden_states = self.ln_f(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class GaudiBloomForCausalLM(BloomForCausalLM): inference_tp_size = None def set_tp_for_inference(tp_for_inference: int): world = int(os.environ.get("WORLD_SIZE", 1)) assert tp_for_inference == 1 or tp_for_inference == world, "only setting 1 (no tp) or world size is supported" GaudiBloomForCausalLM.inference_tp_size = tp_for_inference def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, use_cache=True, token_idx: Optional[torch.Tensor] = None, **kwargs, ) -> dict: # only last tokens for input_ids if past is not None if past_key_values is not None: if token_idx is None: input_ids = input_ids[:, -1].unsqueeze(-1) else: idx = token_idx + kwargs.get("inputs_embeds_offset", 0) - 1 input_ids = torch.index_select(input_ids, 1, idx) # the cache may be in the standard format (e.g. in contrastive search), convert to bloom's format if needed if past_key_values[0][0].shape[0] == input_ids.shape[0]: past_key_values = self._convert_to_bloom_cache(past_key_values) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the # input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in # the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} model_inputs.update( { "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, "token_idx": token_idx, } ) return model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ # Bloom has deprecated kwargs, so we need to pop num_items_in_batch explicitly num_items_in_batch = deprecated_arguments.pop("num_items_in_batch", None) if deprecated_arguments.pop("position_ids", False) is not False: # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` warnings.warn( "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" " passing `position_ids`.", FutureWarning, ) if len(deprecated_arguments) > 0: raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, token_idx=token_idx, ) hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) loss = None if labels is not None: # move labels to correct device to enable model parallelism labels = labels.to(lm_logits.device) # Flatten the tokens loss = self.loss_function( lm_logits, labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithCrossAttentions( loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) def _reorder_cache( self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. Output shares the same memory storage as `past`. """ standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx), training=self.training) # Get a copy of `beam_idx` on all the devices where we need those indices. device_to_beam_idx = { past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past } reordered_past = tuple( ( layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), ) for layer_past in standardized_past ) return self._convert_to_bloom_cache(reordered_past)