server/text_generation_server/models/custom_modeling/flash_gemma3_modeling.py (724 lines of code) (raw):

# coding=utf-8 # Copyright 2024 HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import torch.distributed from torch import nn from typing import Optional, List, Tuple import copy from text_generation_server.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, get_linear, # SpeculativeHead, TensorParallelMultiAdapterLinear, TensorParallelAdapterRowLinear, ) import torch import torch.nn.functional as F from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, ) from text_generation_server.layers.attention.kv_cache import get_kv_scales from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( FastRMSNorm, ) from text_generation_server.models.globals import ATTENTION from text_generation_server.utils.weights import UnquantizedWeight from transformers.activations import ACT2FN from text_generation_server.layers.attention import ( paged_attention, attention, Seqlen, ) ATTENTION_TYPE_GLOBAL = "global" ATTENTION_TYPE_LOCAL = "local_sliding" class Gemma3FastRMSNorm(FastRMSNorm): @classmethod def load(cls, prefix: str, weights, eps=1e-6): dtype = weights.dtype weights.dtype = torch.float32 weight = weights.get_tensor(f"{prefix}.weight") + 1 weights.dtype = dtype new = cls(weight, eps) new.dtype = dtype return new # perform the multiplication in full precision and downcast after def forward(self, hidden_states, residual=None): if residual is not None: hidden_states += residual residual = hidden_states hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * self.weight return hidden_states.to(self.dtype), residual def load_attention(config, prefix: str, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) else: return TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, weights=weights, bias=False, ) def _load_gqa(config, prefix: str, weights): assert config.num_attention_heads % weights.process_group.size() == 0 weight = weights.get_multi_weights_col( prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], dim=0, ) if isinstance(weight, UnquantizedWeight): weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device) head_size = config.head_dim num_heads = config.num_attention_heads // weights.process_group.size() num_key_value_heads = config.num_key_value_heads // weights.process_group.size() assert list(weight.weight.shape) == [ (num_heads + 2 * num_key_value_heads) * head_size, config.hidden_size, ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}" return TensorParallelColumnLinear(get_linear(weight, bias=None)) class FlashGemma3Attention(torch.nn.Module): def __init__( self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool ): super().__init__() self.num_heads = config.num_attention_heads self.head_size = config.head_dim self.causal = causal if is_sliding: self.window_size = config.sliding_window # TODO: remove this hack to support local sliding window config = copy.deepcopy(config) config.rope_scaling = dict(rope_type="default") self.rotary_emb = PositionRotaryEmbedding.static( config=config, dim=config.head_dim, base=config.rope_local_base_freq, device=weights.device, ) else: self.window_size = -1 self.rotary_emb = PositionRotaryEmbedding.static( config=config, dim=config.head_dim, base=config.rope_theta, device=weights.device, ) self.softmax_scale = ( config.query_pre_attn_scalar**-0.5 if config.query_pre_attn_scalar is not None else None ) if self.num_heads % weights.process_group.size() != 0: raise ValueError( f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " f"and `num_shards`: {weights.process_group.size()}" ) self.num_heads = self.num_heads // weights.process_group.size() self.num_key_value_heads = ( config.num_key_value_heads // weights.process_group.size() ) self.softcap = None # config.attn_logit_softcapping query_key_value = load_attention(config, prefix, weights) self.query_key_value = TensorParallelMultiAdapterLinear.load( query_key_value, layer_id, ["q_proj", "k_proj", "v_proj"], sizes=[ self.head_size * config.num_attention_heads, self.head_size * config.num_key_value_heads, self.head_size * config.num_key_value_heads, ], process_group=weights.process_group, ) self.kv_scales = get_kv_scales(weights, f"{prefix}") o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) self.o_proj = TensorParallelAdapterRowLinear.load( o_proj, layer_id, "o_proj", process_group=weights.process_group, ) self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device ).repeat_interleave(self.num_groups) self.q_norm = Gemma3FastRMSNorm.load( prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps ) self.k_norm = Gemma3FastRMSNorm.load( prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps ) self.enable_gqa = self.num_heads != self.num_key_value_heads def forward( self, hidden_states, cos, sin, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, adapter_data, attention_mask, ): qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( [ self.head_size * self.num_heads, 2 * self.head_size * self.num_key_value_heads, ], dim=1, ) kv = kv.view(-1, 2, self.num_key_value_heads * self.head_size) key = kv[:, 0] value = kv[:, 1] query = query.reshape(-1, self.head_size) key = key.reshape(-1, self.head_size) query, _ = self.q_norm(query.contiguous()) key, _ = self.k_norm(key.contiguous()) query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_key_value_heads, self.head_size) value = value.view(-1, self.num_key_value_heads, self.head_size) self.rotary_emb(query, key, cos, sin) kv_cache.store( key=key, value=value, slots=slots, kv_scales=self.kv_scales, ) # Prefill if cu_seqlen_prefill is not None: if attention_mask is None or ATTENTION == "flashinfer": # flash attention attn_output = attention( query=query, key=key, value=value, kv_cache=kv_cache, kv_scales=self.kv_scales, seqlen=seqlen, block_tables=block_tables, softmax_scale=self.softmax_scale, window_size_left=self.window_size, softcap=self.softcap, ) else: lengths = cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1] # Split tensors using vectorized split query_list = torch.split(query, lengths.tolist(), dim=0) key_list = torch.split(key, lengths.tolist(), dim=0) value_list = torch.split(value, lengths.tolist(), dim=0) padded_query = torch.nn.utils.rnn.pad_sequence( query_list, batch_first=True ) padded_key = torch.nn.utils.rnn.pad_sequence(key_list, batch_first=True) padded_value = torch.nn.utils.rnn.pad_sequence( value_list, batch_first=True ) padded_query = padded_query.transpose(1, 2).contiguous() padded_key = padded_key.transpose(1, 2).contiguous() padded_value = padded_value.transpose(1, 2).contiguous() # Compute attention attn_output = F.scaled_dot_product_attention( padded_query, padded_key, padded_value, attn_mask=attention_mask, scale=self.softmax_scale, enable_gqa=self.enable_gqa, ) attn_output = attn_output.transpose( 1, 2 ) # [batch_size, seq_len, num_heads, head_dim] max_seq_len = padded_query.size(2) seq_range = torch.arange( max_seq_len, device=padded_query.device ).unsqueeze(0) lengths_tensor = torch.tensor( lengths, device=padded_query.device ).unsqueeze(1) mask = seq_range < lengths_tensor # [batch, max_seq_len] attn_output = attn_output[mask] # [total_seq_len, num_heads, head_dim] # Decode else: attn_output = paged_attention( query, kv_cache, self.kv_head_mapping, self.softmax_scale, block_tables, seqlen, max_s, softcap=self.softcap, kv_scales=self.kv_scales, window_size_left=self.window_size, ) return self.o_proj( attn_output.view(-1, self.num_heads * self.head_size), adapter_data ) class Gemma3MLP(nn.Module): def __init__(self, prefix, config, weights, layer_id): super().__init__() act = config.hidden_activation self.act = ( ACT2FN[act] if "gelu" not in act else lambda x: torch.nn.functional.gelu( x, approximate=( "tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" ), ) ) # Fuse gate and up proj gate_up_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, bias=False, ) self.gate_up_proj = TensorParallelMultiAdapterLinear.load( gate_up_proj, layer_id, ["gate_proj", "up_proj"], sizes=[ config.intermediate_size, config.intermediate_size, ], process_group=weights.process_group, ) down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=False, ) self.down_proj = TensorParallelAdapterRowLinear.load( down_proj, layer_id, "down_proj", process_group=weights.process_group, ) self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) def forward(self, hidden_states, adapter_data): gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) return self.down_proj( self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data ) class FlashGemma3Layer(nn.Module): def __init__( self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool ): super().__init__() self.self_attn = FlashGemma3Attention( prefix=f"{prefix}.self_attn", config=config, weights=weights, layer_id=layer_id, causal=causal, is_sliding=is_sliding, ) self.mlp = Gemma3MLP( prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id ) self.input_layernorm = Gemma3FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps ) self.post_attention_layernorm = Gemma3FastRMSNorm.load( prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps, ) self.pre_feedforward_layernorm = Gemma3FastRMSNorm.load( prefix=f"{prefix}.pre_feedforward_layernorm", weights=weights, eps=config.rms_norm_eps, ) self.post_feedforward_layernorm = Gemma3FastRMSNorm.load( prefix=f"{prefix}.post_feedforward_layernorm", weights=weights, eps=config.rms_norm_eps, ) def forward( self, hidden_states, residual, cos, sin, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, adapter_data, attention_mask, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) # Self Attention attn_output = self.self_attn( normed_hidden_states, cos, sin, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, adapter_data, attention_mask, ) # faster post attention rms norm normed_attn_res_output, _ = self.post_attention_layernorm(attn_output) normed_attn_res_output = normed_attn_res_output + res res = normed_attn_res_output pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output) mlp_output = self.mlp(pre_normed, adapter_data) post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output) return post_hidden_states, normed_attn_res_output class FlashGemma3Model(torch.nn.Module): def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.layers = nn.ModuleList( [ FlashGemma3Layer( prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights, layer_id=layer_id, causal=causal, is_sliding=bool((layer_id + 1) % config.sliding_window_pattern), ) for layer_id in range(config.num_hidden_layers) ] ) self.norm = Gemma3FastRMSNorm.load( prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps ) self.head_size = self.layers[0].self_attn.head_size self.num_heads = self.layers[0].self_attn.num_heads self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads def forward( self, inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, max_s: int, adapter_data: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, attention_mask_local: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = inputs_embeds # Get rotary cos and sin for this forward # Avoid to index in each layer residual = None for i, layer in enumerate(self.layers): cos, sin = self.layers[i].self_attn.rotary_emb.get_cos_sin( position_ids, max_s, hidden_states.dtype ) hidden_states, residual = layer( hidden_states, residual, cos, sin, cu_seqlen_prefill, kv_cache[i], block_tables, slots, seqlen, max_s, adapter_data, ( attention_mask if self.layers[i].self_attn.window_size == -1 else attention_mask_local ), ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states class FlashGemma3ForCausalLM(torch.nn.Module): def __init__(self, prefix: str, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 if not prefix: prefix = "model" else: prefix = f"{prefix}.model" self.embed_tokens = TensorParallelEmbedding( prefix=f"{prefix}.embed_tokens", weights=weights ) self.embed_tokens.weight *= embed_norm self.model = FlashGemma3Model( prefix=prefix, config=config, weights=weights, causal=causal ) self.lm_head = SpeculativeHead.load( prefix=( f"{prefix}.embed_tokens" if config.tie_word_embeddings else f"{prefix}.lm_head" ), config=config, weights=weights, ) # self.softcap = config.attn_logit_softcapping # assert isinstance(self.softcap, float) self.softcap = None def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: input_embeds = self.embed_tokens(input_ids) hidden_states = self.model( input_embeds, position_ids, cu_seqlen_prefill, kv_cache, block_tables, slots, seqlen, max_s, adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.lm_head(hidden_states) return logits, speculative_logits class Gemma3MultimodalInputProjection(torch.nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.mm_input_projection_weight = weights.get_tensor( "multi_modal_projector.mm_input_projection_weight" ) self.mm_soft_emb_norm = Gemma3FastRMSNorm.load( prefix=f"{prefix}.mm_soft_emb_norm", weights=weights, eps=config.vision_config.layer_norm_eps, ) self.patches_per_image = int( config.vision_config.image_size // config.vision_config.patch_size ) self.tokens_per_side = int(config.mm_tokens_per_image**0.5) self.kernel_size = self.patches_per_image // self.tokens_per_side self.avg_pool = nn.AvgPool2d( kernel_size=self.kernel_size, stride=self.kernel_size ) def forward(self, vision_outputs: torch.Tensor): batch_size, _, seq_length = vision_outputs.shape reshaped_vision_outputs = vision_outputs.transpose(1, 2) reshaped_vision_outputs = reshaped_vision_outputs.reshape( batch_size, seq_length, self.patches_per_image, self.patches_per_image ) reshaped_vision_outputs = reshaped_vision_outputs.contiguous() pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) pooled_vision_outputs = pooled_vision_outputs.flatten(2) pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) normed_vision_outputs, _ = self.mm_soft_emb_norm(pooled_vision_outputs) projected_vision_outputs = torch.matmul( normed_vision_outputs, self.mm_input_projection_weight ) return projected_vision_outputs.type_as(vision_outputs) class Gemma3ForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.config = config if config.vision_config is not None: config.vision_config.quantize = config.quantize self.post_vision_model_layernorm = nn.LayerNorm.load( prefix="vision_tower.vision_model.post_layernorm", weights=weights, eps=config.vision_config.layer_norm_eps, ) self.multimodal_projector = Gemma3MultimodalInputProjection( prefix="multi_modal_projector", config=config, weights=weights, ) text_config = config.text_config text_config.speculator = config.speculator text_config.quantize = config.quantize self.vision_model = load_vision_model( prefix="vision_tower" if not prefix else f"{prefix}.vision_tower", config=config.vision_config, weights=weights, ) self.text_model = load_text_model( prefix="language_model" if not prefix else f"{prefix}.language_model", config=config.text_config, weights=weights, ) else: config.text_config.quantize = config.quantize config.text_config.speculator = config.speculator self.text_model = load_text_model( prefix=prefix, config=config.text_config, weights=weights, ) self.pad_token_id = ( config.pad_token_id if config.pad_token_id is not None else -1 ) self.dtype = weights.dtype def get_attention_mask( self, input_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], dtype: torch.dtype, bool_mask: bool = False, ): image_token_mask = (input_ids == self.config.image_token_index).to( input_ids.device ) device = input_ids.device min_dtype = torch.finfo(dtype).min lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist() batch_size = len(lengths) sequence_length = max(lengths) target_length = sequence_length # Create the padding mask from the computed lengths. # pad_mask: [batch, sequence_length] where True indicates valid tokens. seq_range = torch.arange(sequence_length, device=device).unsqueeze(0) lengths_tensor = torch.tensor(lengths, device=device).unsqueeze(1) pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length] # Build the base causal mask (for non-image tokens): causal_mask = torch.tril( torch.ones( (sequence_length, sequence_length), dtype=torch.bool, device=device ) ) base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze( 1 ) # [batch, sequence_length, sequence_length] base_mask = base_mask & causal_mask.unsqueeze(0) # apply causal constraint image_token_mask = torch.nn.utils.rnn.pad_sequence( torch.split(image_token_mask, lengths), batch_first=True, padding_value=0 ) bidirectional_mask = image_token_mask.unsqueeze(2) & image_token_mask.unsqueeze( 1 ) # Combine the causal base mask and the bidirectional mask. combined_mask = torch.logical_or( base_mask.unsqueeze(1), bidirectional_mask.unsqueeze(1) ).to(device) # combined_mask now has shape [batch, 1, sequence_length, sequence_length] full_attention_mask = torch.zeros( (batch_size, 1, sequence_length, target_length), device=device, dtype=torch.bool, ) full_attention_mask[:, :, :, :sequence_length] = combined_mask if bool_mask: return full_attention_mask else: return torch.where(full_attention_mask, 0, min_dtype).to(device) def get_vision_embeds( self, pixel_values: torch.FloatTensor, pixel_attention_mask: Optional[torch.FloatTensor] = None, image_sizes: Optional[torch.Tensor] = None, image_grid_thw: Optional[torch.LongTensor] = None, ): pixel_values = pixel_values.to(dtype=self.dtype) image_outputs = self.vision_model(pixel_values) vision_outputs = self.post_vision_model_layernorm( image_outputs.last_hidden_state ) image_features = self.multimodal_projector(vision_outputs) image_features = image_features.view(-1, image_features.shape[-1]) return image_features def get_inputs_embeds( self, input_ids: torch.Tensor, vision_embeds: torch.Tensor = None, ): inputs_embeds = self.text_model.embed_tokens(input_ids) if vision_embeds is not None: # Replace the image token embeddings with the vision features image_token_mask = (input_ids == self.config.image_token_index).to( input_ids.device ) inputs_embeds[image_token_mask] = vision_embeds.view( -1, vision_embeds.shape[-1] ) return inputs_embeds def forward( self, inputs_embeds: torch.Tensor, position_ids: torch.Tensor, cu_seqlen_prefill: Optional[torch.Tensor], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, pixel_values: torch.FloatTensor = None, # Unused here attention_mask: Optional[torch.BoolTensor] = None, adapter_data: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if cu_seqlen_prefill is not None: max_s += 1 position_ids += 1 # Use flash attention for text-only input # else: # if cu_seqlen_prefill is not None: # min_dtype = torch.finfo(inputs_embeds.dtype).min # lengths = (cu_seqlen_prefill[1:] - cu_seqlen_prefill[:-1]).tolist() # # Determine the maximum sequence length (after padding) from query. # sequence_length = max(lengths) # target_length = sequence_length # # Create the padding mask from the computed lengths. # # pad_mask: [batch, sequence_length] where True indicates valid tokens. # seq_range = torch.arange( # sequence_length, device=input_ids.device # ).unsqueeze(0) # lengths_tensor = torch.tensor( # lengths, device=input_ids.device # ).unsqueeze(1) # pad_mask = seq_range < lengths_tensor # shape: [batch, sequence_length] # # Build the base causal mask (for non-image tokens): # causal_mask = torch.tril( # torch.ones( # (sequence_length, sequence_length), # dtype=torch.bool, # device=input_ids.device, # ) # ) # base_mask = pad_mask.unsqueeze(2) & pad_mask.unsqueeze( # 1 # ) # [batch, sequence_length, sequence_length] # base_mask = base_mask & causal_mask.unsqueeze(0) # attention_mask = base_mask.unsqueeze( # 1 # ) # [batch, 1, sequence_length, sequence_length] # full_attention_mask = torch.zeros( # (len(lengths), 1, sequence_length, target_length), # device=input_ids.device, # dtype=torch.bool, # ) # full_attention_mask[:, :, :, :sequence_length] = attention_mask # attention_mask = torch.where(full_attention_mask, 0, min_dtype).to( # input_ids.device # ) if attention_mask is not None: min_dtype = torch.finfo(inputs_embeds.dtype).min # prefill may be larger than sliding window effective_seq_len = max( position_ids.shape[0], self.config.text_config.sliding_window ) sliding_window_mask = torch.tril( torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.config.text_config.sliding_window, ) attention_mask_local = torch.where( sliding_window_mask, min_dtype, attention_mask ) offset = max(0, position_ids.shape[0] - effective_seq_len) attention_mask_local = attention_mask_local[ :, :, :, offset : offset + effective_seq_len ] else: attention_mask_local = None hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, kv_cache=kv_cache, block_tables=block_tables, slots=slots, seqlen=seqlen, max_s=max_s, attention_mask=attention_mask, attention_mask_local=attention_mask_local, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] logits, speculative_logits = self.text_model.lm_head(hidden_states) # pad logit with 1 zero logit for the image token if pixel_values is not None: logits = torch.cat( [logits, torch.zeros(logits.size(0), 1, device=logits.device)], dim=1 ) if speculative_logits is not None: speculative_logits = torch.cat( [ speculative_logits, torch.zeros( speculative_logits.size(0), 1, device=speculative_logits.device, ), ], dim=1, ) return logits, speculative_logits