server/text_generation_server/models/custom_modeling/qwen2_5_vl.py (748 lines of code) (raw):
# coding=utf-8
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Qwen2.5 VL model."""
from typing import Optional, Tuple, List
import torch
import torch.utils.checkpoint
from torch import nn
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
else:
import flash_attn_2_cuda
import numpy as np
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
import torch.nn.functional as F
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
TensorParallelEmbedding,
SpeculativeHead,
)
from text_generation_server.layers.attention import (
Seqlen,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model,
)
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
from typing import Union
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput, VideoInput
from transformers.processing_utils import (
ProcessingKwargs,
ProcessorMixin,
Unpack,
VideosKwargs,
)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):
fps: Union[List[float], float]
class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
videos_kwargs: Qwen2_5_VLVideosProcessorKwargs
_defaults = {
"text_kwargs": {
"padding": False,
},
"videos_kwargs": {"fps": 2.0},
}
class Qwen2_5_VLProcessor(ProcessorMixin):
r"""
Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
[`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
[`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
Args:
image_processor ([`Qwen2VLImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(
self, image_processor=None, tokenizer=None, chat_template=None, **kwargs
):
self.image_token = (
"<|image_pad|>"
if not hasattr(tokenizer, "image_token")
else tokenizer.image_token
)
self.video_token = (
"<|video_pad|>"
if not hasattr(tokenizer, "video_token")
else tokenizer.video_token
)
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
self,
images: ImageInput = None,
text: Union[
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
] = None,
videos: VideoInput = None,
**kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
"""
output_kwargs = self._merge_kwargs(
Qwen2_5_VLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(
images=images, videos=None, **output_kwargs["images_kwargs"]
)
image_grid_thw = image_inputs["image_grid_thw"]
else:
image_inputs = {}
image_grid_thw = None
if videos is not None:
videos_inputs = self.image_processor(
images=None, videos=videos, **output_kwargs["images_kwargs"]
)
video_grid_thw = videos_inputs["video_grid_thw"]
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
if isinstance(fps, (int, float)):
second_per_grid_ts = [
self.image_processor.temporal_patch_size / fps
] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
second_per_grid_ts = [
self.image_processor.temporal_patch_size / tmp for tmp in fps
]
else:
raise ValueError(
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
)
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
else:
videos_inputs = {}
video_grid_thw = None
if not isinstance(text, list):
text = [text]
if image_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(text)):
while self.image_token in text[i]:
text[i] = text[i].replace(
self.image_token,
"<|placeholder|>"
* (image_grid_thw[index].prod() // merge_length),
1,
)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.image_token)
if video_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(text)):
while self.video_token in text[i]:
text[i] = text[i].replace(
self.video_token,
"<|placeholder|>"
* (video_grid_thw[index].prod() // merge_length),
1,
)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.video_token)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
names_from_processor = list(
dict.fromkeys(tokenizer_input_names + image_processor_input_names)
)
return names_from_processor + ["second_per_grid_ts"]
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py
class Qwen2_5_VLVisionConfig(PretrainedConfig):
model_type = "qwen2_5_vl"
base_config_key = "vision_config"
def __init__(
self,
depth=32,
hidden_size=3584,
hidden_act="silu",
intermediate_size=3420,
num_heads=16,
in_channels=3,
patch_size=14,
spatial_merge_size=2,
spatial_patch_size=14,
temporal_patch_size=2,
tokens_per_second=4,
window_size=112,
out_hidden_size=3584,
fullatt_block_indexes=[7, 15, 23, 31],
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_patch_size = spatial_patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.tokens_per_second = tokens_per_second
self.window_size = window_size
self.fullatt_block_indexes = fullatt_block_indexes
self.out_hidden_size = out_hidden_size
class Qwen2_5_VLConfig(PretrainedConfig):
def __init__(
self,
vocab_size=152064,
hidden_size=8192,
intermediate_size=29568,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
tie_word_embeddings=False,
rope_theta=1000000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=80,
attention_dropout=0.0,
vision_config=None,
rope_scaling=None,
**kwargs,
):
if vision_config is not None:
self.vision_config = Qwen2_5_VLVisionConfig(**vision_config)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
# TODO: @raushan update config in the hub
if self.rope_scaling is not None and "type" in self.rope_scaling:
if self.rope_scaling["type"] == "mrope":
self.rope_scaling["type"] = "default"
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
class Qwen2_5VLAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size // weights.process_group.size()
self.head_dim = config.hidden_size // config.num_heads
self.num_heads = config.num_heads // weights.process_group.size()
self.qkv = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv",
weights=weights,
bias=False,
num_heads=self.num_heads,
num_key_value_heads=self.num_heads,
)
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
self.proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.proj",
weights=weights,
bias=True,
)
self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
def forward(
self,
hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int,
) -> torch.Tensor:
# apply the qkv linear layer to the hidden state
qkv = self.qkv(hidden_state)
query, key, value = qkv.split(
[self.embed_dim, self.embed_dim, self.embed_dim], dim=1
)
# reshape the query, key, and value tensors
_shape = (
hidden_state.shape[0],
self.num_heads,
self.embed_dim // self.num_heads,
)
query = query.view(*_shape)
key = key.view(*_shape)
value = value.view(*_shape)
# apply rotary positional embeddings
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
0
)
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
# calc maximum sequence length for any batch
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
causal = False
# execute flash attention
if SYSTEM == "ipex":
attn_output = torch.empty_like(query)
ipex.llm.functional.varlen_attention(
(query.contiguous() if query.device.type == "xpu" else query),
(key.contiguous() if key.device.type == "xpu" else key),
(value.contiguous() if value.device.type == "xpu" else value),
attn_output,
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
0.0,
self.softmax_scale,
False,
causal,
False,
None,
)
else:
attn_output = flash_attn_2_cuda.varlen_fwd(
query,
key,
value,
None, # tmp buffer (auto-allocated)
cu_seqlens, # cu_seqlens_q
cu_seqlens, # cu_seqlens_k
None, # max_seqlen_q (auto-computed)
None, # max_seqlen_k (auto-computed)
None, # block_tables
None, # broadcast_mask
max_seqlen, # max_seqlen
max_seqlen, # max_seqlen
0.0, # dropout_p
self.softmax_scale,
False, # zero_tensors
causal, # causal attention within each sequence
-1, # window_size_left
-1, # window_size_right
0.0, # softmax_cap
False, # deterministic
None, # rng_state
)[0]
# reshape output to original dimensions
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
attn_output = self.proj(attn_output)
return attn_output
class Qwen2_5VLVisionMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.activation_fn = ACT2FN[config.hidden_act]
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
self.up = TensorParallelColumnLinear.load(
prefix=f"{prefix}.up_proj", weights=weights, config=config, bias=True
)
self.gate = TensorParallelColumnLinear.load(
prefix=f"{prefix}.gate_proj", weights=weights, config=config, bias=True
)
self.down = TensorParallelRowLinear.load(
prefix=f"{prefix}.down_proj", weights=weights, config=config, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_states = self.gate(hidden_states)
up_states = self.up(hidden_states)
activated_states = self.activation_fn(gate_states) * up_states
down_states = self.down(activated_states)
return down_states
class Qwen2_5VLVisionBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.attn = Qwen2_5VLAttention(
prefix=f"{prefix}.attn",
config=config,
weights=weights,
)
self.norm1 = FastRMSNorm.load(
prefix=f"{prefix}.norm1",
weights=weights,
eps=1e-6,
)
self.norm2 = FastRMSNorm.load(
prefix=f"{prefix}.norm2",
weights=weights,
eps=1e-6,
)
self.mlp = Qwen2_5VLVisionMLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
)
def forward(
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
) -> torch.Tensor:
norm1_out, _ = self.norm1(hidden_states)
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
hidden_states = hidden_states + attn_out
norm2_out, _ = self.norm2(hidden_states)
mlp_out = self.mlp(norm2_out)
hidden_states = hidden_states + mlp_out
return hidden_states
class Qwen2_5VLPatchMerger(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
self.patch_merger_ln_q = FastRMSNorm.load(
prefix=f"{prefix}.ln_q",
weights=weights,
eps=1e-6,
)
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
hidden_states = hidden_states.view(-1, self.hidden_size)
hidden_states = self.fc1(hidden_states)
hidden_states = F.gelu(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Qwen2_5VisionModel(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.spatial_merge_size = config.spatial_merge_size
kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
self.patch_embedding = nn.Conv3d(
in_channels=config.in_channels,
out_channels=config.hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
)
head_dim = config.hidden_size // config.num_heads
theta = 10000.0
dim = head_dim // 2
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.blocks = nn.ModuleList(
[
Qwen2_5VLVisionBlock(
prefix=f"{prefix}.blocks.{i}",
config=config,
weights=weights,
)
for i in range(config.depth)
]
)
self.merger = Qwen2_5VLPatchMerger(
prefix=f"{prefix}.merger",
config=config,
weights=weights,
)
self.temporal_patch_size = config.temporal_patch_size
self.spatial_patch_size = config.spatial_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.hidden_size
self.window_size = config.window_size
self.patch_size = config.patch_size
self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size
self.fullatt_block_indexes = config.fullatt_block_indexes
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def get_window_index(self, grid_thw):
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = (
self.window_size // self.spatial_merge_size // self.patch_size
)
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (
grid_h // self.spatial_merge_size,
grid_w // self.spatial_merge_size,
)
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
grid_t, llm_grid_h, llm_grid_w
)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
index_padded = index_padded.reshape(
grid_t,
num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size,
)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t,
num_windows_h * num_windows_w,
vit_merger_window_size,
vit_merger_window_size,
)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = (
seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
)
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
def forward(
self,
pixel_values: torch.Tensor,
grid_thw: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
# reshape the input tensor for processing
shape = (
-1,
self.in_channels,
self.temporal_patch_size,
self.spatial_patch_size,
self.spatial_patch_size,
)
pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
# TODO: revisit to see if we can avoid some of these reshapes
# find the position ids for the input tensor based on the grid_thw
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
# apply the positional embeddings to the position ids
seq = torch.arange(
max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
seq_len = hidden_states.shape[0]
patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
og_shape = (seq_len, -1)
hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view(
og_shape
)
rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view(
og_shape
)
rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=hidden_states.device,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
# create a cu_seqlens tensor to be used in the attention mask
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
# iterately apply the blocks to the hidden states
for layer_num, block in enumerate(self.blocks):
# NOTE: qwen2_5_vl.py has a concept of full attention blocks
# that are applied at specific layers.
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
hidden_states = block(
hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen
)
# apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states
class Qwen2_5VLForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
# set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
# returns rope_scaling.type == "default" for Qwen2_5-VL model at the moment
if (
hasattr(config, "rope_scaling")
and config.rope_scaling is not None
and config.rope_scaling.get("type", None) == "default"
):
config.rope_scaling.update({"rope_type": "mrope"})
self.hidden_size = config.hidden_size
self.vision_start_token_id = config.vision_start_token_id
self.vision_end_token_id = config.vision_end_token_id
self.image_token_id = config.image_token_id
self.video_token_id = config.video_token_id
self.spatial_merge_size = config.vision_config.spatial_merge_size
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.visual = Qwen2_5VisionModel(
prefix="visual", config=config.vision_config, weights=weights
)
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
)
self.device = weights.device
# based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
# modified to first find segments then initialize position ids for each segment
# Steps:
# locate all vision and text segments
# calculate `vision_segment_lengths` for each vision segment to be use as offset
# calculate `text_segment_lengths` for each text segment to be used as offset
# create position ids for each vision segment based on the image grid
# create position ids for each text segment
# combine all the position ids
# the final segment is the difference between the last vision segment and the end of the input
# combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
def get_position_ids(
self,
input_ids: torch.Tensor,
image_grid_thw: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if image_grid_thw is None:
return (
torch.arange(input_ids.shape[0], device=input_ids.device)
.unsqueeze(1)
.repeat(1, 3)
)
spatial_merge_size = self.spatial_merge_size
vision_start_token_id = self.vision_start_token_id
vision_end_token_id = self.vision_end_token_id
device = input_ids.device
dtype = input_ids.dtype
input_ids_len = input_ids.shape[0]
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
prev_vision_end = torch.cat(
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
)
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
vision_widths_max = torch.cat(
[
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
image_grid_thw[:-1, 2] // spatial_merge_size,
]
)
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
# create position ids for each vision segment based on the image grid
llm_pos_ids_list = []
for i, _ in enumerate(vision_segments):
t, h, w = (
image_grid_thw[i][0],
image_grid_thw[i][1] // spatial_merge_size,
image_grid_thw[i][2] // spatial_merge_size,
)
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
w_indices = torch.arange(w, device=device).repeat(t * h)
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
# offset by the position of the last vision segment
im = image_position_ids + vision_segment_lengths[i]
llm_pos_ids_list.append(im)
# create position ids for each text segment
text_ranges = [
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
+ text_segment_lengths[i]
for i, seq_len in enumerate(text_lengths_between_vision)
]
full_llm_pos_ids_list = [
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
]
# import ipdb
# ipdb.set_trace()
max_s = full_llm_pos_ids_list[-1].max() + 1
final_text_len = input_ids_len - vision_ends[-1]
if final_text_len > 0:
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
full_llm_pos_ids_list.append(m + max_s)
position_ids = (
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
)
return position_ids
def get_vision_embeds(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).squeeze(0)
return image_embeds
def get_inputs_embeds(
self,
input_ids: torch.Tensor,
vision_embeds: torch.Tensor = None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if vision_embeds is not None:
inputs_embeds[input_ids == self.image_token_id] = vision_embeds
return inputs_embeds
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor],
# Unused in this model
attention_mask: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_indices=None,
):
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits