optimum/habana/transformers/models/qwen2_vl/modeling_qwen2_vl.py (548 lines of code) (raw):

# coding=utf-8 # Copyright 2024 the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Gaudi Qwen2-VL model.""" from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, StaticCache from transformers.modeling_outputs import ( BaseModelOutputWithPast, ) from transformers.models.qwen2_vl.modeling_qwen2_vl import ( Qwen2VisionTransformerPretrainedModel, Qwen2VLCausalLMOutputWithPast, Qwen2VLConfig, Qwen2VLDecoderLayer, Qwen2VLForConditionalGeneration, Qwen2VLModel, Qwen2VLSdpaAttention, Qwen2VLVisionBlock, VisionSdpaAttention, apply_multimodal_rotary_pos_emb, apply_rotary_pos_emb_vision, repeat_kv, ) from transformers.utils import is_torchdynamo_compiling, logging try: from habana_frameworks.torch.hpex.kernels import FusedSDPA except ImportError: print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None logger = logging.get_logger(__name__) class ModuleFusedSDPA(torch.nn.Module): def __init__(self, fusedSDPA): super().__init__() self._hpu_kernel_fsdpa = fusedSDPA def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) # from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L383 class GaudiVisionSdpaAttention(VisionSdpaAttention): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__(dim, num_heads) self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_flash_attention: Optional[bool] = False, ) -> torch.Tensor: """ Copied from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L390 The only differences are: - add new args use_flash_attention - add FusedSDPA """ seq_length = hidden_states.shape[0] q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " "removed and `position_embeddings` will be mandatory." ) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) cos = emb.cos() sin = emb.sin() else: cos, sin = position_embeddings q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool) for i in range(1, len(cu_seqlens)): attention_mask[:, cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True q = q.transpose(0, 1) k = k.transpose(0, 1) v = v.transpose(0, 1) if FusedSDPA is not None and use_flash_attention: attn_output = self.fused_scaled_dot_product_attention( q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask, 0.0, False, None, "None" ) else: attn_output = F.scaled_dot_product_attention( q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attention_mask, dropout_p=0.0 ) attn_output = attn_output.squeeze(0).transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) attn_output = self.proj(attn_output) del attention_mask return attn_output # from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L418 class GaudiQwen2VLVisionBlock(Qwen2VLVisionBlock): def __init__(self, config, attn_implementation: str = "sdpa") -> None: super().__init__(config, attn_implementation) self.attn = GaudiVisionSdpaAttention(config.embed_dim, num_heads=config.num_heads) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_flash_attention: Optional[bool] = False, ) -> torch.Tensor: """ Copied from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L430 The only differences are: - add new args use_flash_attention """ hidden_states = hidden_states + self.attn( self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb, position_embeddings=position_embeddings, use_flash_attention=use_flash_attention, ) hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) return hidden_states # from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L821 class GaudiQwen2VLSdpaAttention(Qwen2VLSdpaAttention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to SDPA API. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None # Adapted from Qwen2Attention.forward def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC use_flash_attention: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Copied from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L829 The only differences are: - add new args use_flash_attention - add FusedSDPA """ if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( "Qwen2VLModel is using Qwen2VLSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_multimodal_rotary_pos_emb( query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] ) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal = True if causal_mask is None and q_len > 1 else False if FusedSDPA is not None and use_flash_attention: attn_output = self.fused_scaled_dot_product_attention( query_states, key_states, value_states, causal_mask, self.attention_dropout if self.training else 0.0, is_causal, None, # scale "None", #'fast' ) else: attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value # from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L930 class GaudiQwen2VLDecoderLayer(Qwen2VLDecoderLayer): def __init__(self, config: Qwen2VLConfig, layer_idx: int): super().__init__(config, layer_idx) self.self_attn = GaudiQwen2VLSdpaAttention(config, layer_idx) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Copied from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L946 The only differences are: - add new kwargs use_flash_attention """ """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, with `head_dim` being the embedding dimension of each attention head. kwargs (`dict`, *optional*): Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code into the model """ use_flash_attention = kwargs.get("use_flash_attention", None) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, use_flash_attention=use_flash_attention, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights,) if use_cache: outputs += (present_key_value,) return outputs # from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1058 class GaudiQwen2VisionTransformerPretrainedModel(Qwen2VisionTransformerPretrainedModel): def forward( self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, use_flash_attention: Optional[bool] = False, ) -> torch.Tensor: """ Copied from https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1118 The only differences are: - add new args use_flash_attention """ hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) position_embeddings = (emb.cos(), emb.sin()) cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( dim=0, dtype=torch.int32 ) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) for blk in self.blocks: if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( blk.__call__, hidden_states, cu_seqlens, None, position_embeddings, use_flash_attention ) else: hidden_states = blk( hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, use_flash_attention=use_flash_attention, ) return self.merger(hidden_states) # from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1137 class GaudiQwen2VLModel(Qwen2VLModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, use_flash_attention: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPast]: """ Copied from Qwen2VLModel https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1161 The only differences are: - add new arg use_flash_attention - fixes graph recompilation due to torch.arange """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 # causes graph recompilations # cache_position = torch.arange( # past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device # ) cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens # the hard coded `3` is for temporal, height and width. if position_ids is None: position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1) elif position_ids.dim() == 2: position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, position_embeddings, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, use_flash_attention=use_flash_attention, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) # from: https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1420 class GaudiQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration): # todo: change when the following gets fixed https://github.com/huggingface/transformers/blame/66f29aaaf55c8fe0c3dbcd24beede2ca4effac56/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L390C5-L390C27 _supports_static_cache = True def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_values: Optional[torch.Tensor] = None, pixel_values_videos: Optional[torch.FloatTensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, use_flash_attention: Optional[bool] = False, ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: """ Copied from Qwen2VLForConditionalGeneration https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1623 The only differences are: - add new arg token_idx - add new arg use_flash_attention - add Gaudi Example """ r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor >>> from optimum.habana.transformers.models import GaudiQwen2VLForConditionalGeneration >>> from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi >>> from habana_frameworks.torch.hpu import wrap_in_hpu_graph >>> adapt_transformers_to_gaudi() >>> model = GaudiQwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") >>> model = model.to("hpu") >>> wrap_in_hpu_graph(model) >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") >>> messages = [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": "What is shown in this image?"}, ], }, ] >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) >>> inputs = processor(text=[text], images=[image], return_tensors="pt") >>> inputs = inputs.to("hpu") >>> generate_kwargs = { "lazy_mode": True, "hpu_graphs": True, "static_shapes": True, "use_cache": True, "cache_implementation": "static", "use_flash_attention": True } >>> # Generate >>> generate_ids = model.generate(**inputs, max_new_tokens=30, **generate_kwargs) >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "The image shows a street scene in what appears to be a Chinatown area. The focal point is a red stop sign on the left side of the..." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is None: inputs_embeds = self.model.embed_tokens(input_ids) if pixel_values is not None: pixel_values = pixel_values.type(self.visual.get_dtype()) image_embeds = self.visual( pixel_values, grid_thw=image_grid_thw, use_flash_attention=use_flash_attention ) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) # HPU WA (masked_scatter has perf issue, flatten for hpu graphs) # original code: https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1690-L1694 image_mask = input_ids == self.config.image_token_id mbatch, mtokens = image_mask.size() image_mask = image_mask.flatten(0, -1) inputs_embeds = inputs_embeds.flatten(0, -2) if self.training: inputs_embeds = inputs_embeds.clone() inputs_embeds[image_mask] = image_embeds inputs_embeds = inputs_embeds.unflatten(0, [mbatch, mtokens]) if pixel_values_videos is not None: pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) n_video_tokens = (input_ids == self.config.video_token_id).sum().item() n_video_features = video_embeds.shape[0] if n_video_tokens != n_video_features: raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) video_mask = ( (input_ids == self.config.video_token_id) .unsqueeze(-1) .expand_as(inputs_embeds) .to(inputs_embeds.device) ) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) if attention_mask is not None: attention_mask = attention_mask.to(inputs_embeds.device) # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): # calculate RoPE index once per generation in the pre-fill stage only if ( (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None or (past_key_values is None or past_key_values.get_seq_length() == 0) ): position_ids, rope_deltas = self.get_rope_index( input_ids, image_grid_thw, video_grid_thw, attention_mask ) self.rope_deltas = rope_deltas # then use the prev pre-calculated rope-deltas to get the correct position ids else: batch_size, seq_length, _ = inputs_embeds.shape delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 position_ids = torch.arange(seq_length, device=inputs_embeds.device) position_ids = position_ids.view(1, -1).expand(batch_size, -1) if cache_position is not None: # otherwise `deltas` is an int `0` delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) delta = delta.to(position_ids.device) position_ids = position_ids.add(delta) position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) outputs = self.model( input_ids=None, position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, use_flash_attention=use_flash_attention, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) loss = None if labels is not None: # Upcast to float if we need to compute the loss to avoid potential precision issues logits = logits.float() # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return Qwen2VLCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, rope_deltas=self.rope_deltas, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, position_ids=None, use_cache=True, pixel_values=None, pixel_values_videos=None, image_grid_thw=None, video_grid_thw=None, **kwargs, ): """ Copied from https://github.com/huggingface/transformers/blob/53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1748 The only differences are: - handle new args token_idx - handle new args use_flash_attention """ token_idx = kwargs.get("token_idx", None) use_flash_attention = kwargs.get("use_flash_attention", False) if token_idx is not None: if isinstance(past_key_values, StaticCache): if cache_position.shape[0] > 1: input_ids = input_ids[:, :token_idx] attention_mask = attention_mask[:, :token_idx] cache_position = cache_position[:token_idx] else: # over-write with token idx cache_position[0] = token_idx - 1 # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case. # (we can't check exception 3 while compiling) # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and # generate the first token for each sequence. Later use the generated Input ids for continuation. if past_key_values is not None: if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] elif ( inputs_embeds is not None # Exception 1 or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3 ): input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] if cache_position[0] != 0: pixel_values = None pixel_values_videos = None # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: model_inputs = {"input_ids": input_ids, "inputs_embeds": None} if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = inputs_embeds.shape device = inputs_embeds.device else: batch_size, sequence_length = input_ids.shape device = input_ids.device attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, batch_size=batch_size, config=self.config, past_key_values=past_key_values, ) model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, "pixel_values": pixel_values, "pixel_values_videos": pixel_values_videos, "image_grid_thw": image_grid_thw, "video_grid_thw": video_grid_thw, "cache_position": cache_position, "token_idx": token_idx, "use_flash_attention": use_flash_attention, } ) return model_inputs