optimum/intel/openvino/modeling_visual_language.py (3,548 lines of code) (raw):
import copy
import enum
import inspect
import logging
import math
import os
import warnings
from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np
import openvino as ov
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from openvino._offline_transformations import apply_moc_transformations, compress_model_transformation
from transformers import (
AutoConfig,
AutoImageProcessor,
AutoModelForCausalLM,
AutoModelForVision2Seq,
GenerationConfig,
GenerationMixin,
PretrainedConfig,
PreTrainedTokenizer,
)
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.utils import ModelOutput
from ...exporters.openvino import main_export
from ...exporters.openvino.stateful import ensure_stateful_is_available, model_has_input_output_name
from ...exporters.openvino.utils import save_config
from ..utils.import_utils import is_transformers_version
from .configuration import OVConfig, OVWeightQuantizationConfig
from .modeling_base import OVBaseModel, OVModelPart
from .modeling_decoder import CausalLMOutputWithPast, OVModelForCausalLM
from .utils import (
OV_LANGUAGE_MODEL_NAME,
OV_TEXT_EMBEDDINGS_MODEL_NAME,
OV_VISION_EMBEDDINGS_MODEL_NAME,
TemporaryDirectory,
)
try:
from transformers import LlavaForConditionalGeneration
except ImportError:
LlavaForConditionalGeneration = None
try:
from transformers import LlavaNextForConditionalGeneration
except ImportError:
LlavaNextForConditionalGeneration = None
if TYPE_CHECKING:
from PIL.Image import Image
if is_transformers_version(">=", "4.42.0"):
from transformers.image_utils import VideoInput
else:
VideoInput = List[Image]
logger = logging.getLogger(__name__)
core = ov.Core()
class InputMode(enum.Enum):
LANGUAGE = 0
VISION = 1
SPEECH = 2
VISION_SPEECH = 3
class OVModelWithEmbedForCausalLM(OVModelForCausalLM):
def __init__(
self,
model: ov.Model,
text_embeds_model: ov.Model,
config: PretrainedConfig = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None,
**kwargs,
):
self.model = model
self.text_emb_model = text_embeds_model
self.request = None
self.text_emb_request = None
compile_only = kwargs.get("compile_only", False)
if compile_only:
self.text_emb_request = self.text_emb_model
self.request = self.model.create_infer_request()
super().__init__(
model, config, device, dynamic_shapes, ov_config, model_save_dir, quantization_config, **kwargs
)
def compile(self):
if self.request is None:
logger.info(f"Compiling the Language model to {self._device} ...")
super().compile()
self._compile_text_emb()
def _compile_text_emb(self):
if self.text_emb_request is None:
logger.info(f"Compiling the Text embeddings model to {self._device} ...")
if self._compile_only:
self.text_emb_request = self.text_emb_model
else:
logger.info(f"Compiling the Text embeddings model to {self._device} ...")
self.text_emb_request = self._compile_model(
self.text_emb_model, self._device, self.ov_config, self.model_save_dir
)
def clear_requests(self):
if self._compile_only:
raise ValueError(
"`clear_requests()` is not supported with `compile_only` mode, please initialize model without this option"
)
self.request = None
self.text_emb_request = None
def embed_tokens(self, input_ids: torch.LongTensor):
self._compile_text_emb()
res = self.text_emb_request(input_ids, share_inputs=True)
return res[0]
def prepare_inputs(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
**kwargs,
):
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
inputs = {}
# past_key_values are not used explicitly, instead they are handled inside the model
if past_key_values is None:
# This is the first iteration in a sequence, reset all states
if self.request is not None:
self.request.reset_state()
# Set initial value for the next beam_idx input that will be used at the current iteration
# and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
self.next_beam_idx = np.arange(batch_size, dtype=int)
self._past_length = 0
past_len = self._get_past_length(past_key_values)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids if past_key_values is None else input_ids[:, -1:])
if hasattr(self.config, "scale_emb"):
inputs_embeds = inputs_embeds * self.config.scale_emb
inputs["inputs_embeds"] = inputs_embeds
# Add the attention_mask inputs when needed
if "attention_mask" in self.input_names or "position_ids" in self.input_names:
if attention_mask is not None:
attention_mask = attention_mask.cpu().numpy()
else:
attention_mask = np.ones((inputs_embeds.shape[0], inputs_embeds.shape[1] + past_len), dtype=int)
if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask
if "position_ids" in self.input_names:
if position_ids is not None:
position_ids = position_ids.cpu().numpy()
else:
position_ids = np.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1
if past_len:
position_ids = position_ids[:, -inputs_embeds.shape[1] :]
if self.config.model_type == "qwen2_vl" and position_ids.ndim != 3:
position_ids = np.repeat(np.expand_dims(position_ids, 0), 3, axis=0)
inputs["position_ids"] = position_ids
if "token_type_ids" in self.input_names:
if token_type_ids is None:
token_type_ids = np.zeros(inputs_embeds.shape[:2], dtype=int)
inputs["token_type_ids"] = token_type_ids
if "beam_idx" in self.input_names:
inputs["beam_idx"] = (
self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int)
)
return inputs
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
**kwargs,
):
self.compile()
inputs = self.prepare_inputs(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
# Run inference
self.request.start_async(inputs, share_inputs=True)
self.request.wait()
logits = self.request.get_tensor("logits").data
logits = torch.from_numpy(logits).clone().to(self.device)
past_key_values = ((),)
self._past_length += inputs["inputs_embeds"].shape[1]
return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
class OVVisionEmbedding(OVModelPart):
_model_name = "vision_embeddings"
def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None:
super().__init__(model, parent_model, model_name=self._model_name)
self.output_dtypes = {key.get_any_name(): key.get_element_type().get_type_name() for key in self.model.outputs}
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
self.hidden_states_output_names = []
if len(self.model.outputs) > 2:
self.hidden_states_output_names = [
key.get_any_name() for key in self.model.outputs[2:] if "hidden_states" in key.get_any_name()
]
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
if model_has_input_output_name(self.model, "images"):
self._main_input = "images"
elif model_has_input_output_name(self.model, "hidden_states"):
self._main_input = "hidden_states"
else:
self._main_input = "pixel_values"
def forward(self, pixel_values, **kwargs):
self._compile()
inputs = {self._main_input: pixel_values}
if len(self.input_names) > 1:
for name in self.input_names:
if name in kwargs:
inputs[name] = kwargs[name]
result = self.request(inputs)
last_hidden_state = result[0]
hidden_states = None
pooler_out = None
if len(result) > 1:
pooler_out = result[1]
if self.hidden_states_output_names:
hidden_states = []
for out in self.hidden_states_output_names:
hidden_states.append(result[out])
return BaseModelOutputWithPooling(
pooler_output=pooler_out, last_hidden_state=last_hidden_state, hidden_states=hidden_states
)
class OVResampler(OVModelPart):
_model_name = "resampler"
def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None:
super().__init__(model, parent_model, model_name=self._model_name)
self.output_dtypes = {key.get_any_name(): key.get_element_type().get_type_name() for key in self.model.outputs}
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
def forward(self, image_feature, pos_embed, key_padding_mask):
self._compile()
result = self.request(
{"image_feature": image_feature, "pos_embed": pos_embed, "key_padding_mask": key_padding_mask}
)[0]
return result
class OVVisionProjection(OVModelPart):
_model_name = "vision_projection"
def forward(self, img_features):
self._compile()
return self.request(img_features)[0]
class OVVisionResampler(OVVisionProjection):
_model_name = "vision_resampler"
class OVMultiModalProjector(OVVisionProjection):
_model_name = "multi_modal_projector"
class OVAudioEmbeddings(OVModelPart):
_model_name = "audio_embeddings"
def forward(self, audio_signal):
self._compile()
return self.request(audio_signal)[0]
class OVAudioEncoder(OVModelPart):
_model_name = "audio_encoder"
def forward(self, audio_feature, audio_mask):
self._compile()
return self.request({"audio_feature": audio_feature, "audio_mask": audio_mask})[0]
MODEL_PARTS_CLS_MAPPING = {
"resampler": OVResampler,
"language_model": OVModelWithEmbedForCausalLM,
"vision_embeddings": OVVisionEmbedding,
"vision_projection": OVVisionProjection,
"vision_resampler": OVVisionResampler,
"multi_modal_projector": OVMultiModalProjector,
"vision_embeddings_merger": OVVisionEmbedding,
"audio_embeddings": OVAudioEmbeddings,
"audio_forward_embeddings": OVAudioEmbeddings,
"audio_encoder": OVAudioEncoder,
"audio_vision_projection": OVAudioEmbeddings,
"audio_speech_projection": OVAudioEmbeddings,
}
class OVModelForVisualCausalLM(OVBaseModel, GenerationMixin):
export_feature = "image-text-to-text"
additional_parts = []
auto_model_class = AutoModelForCausalLM
def __init__(
self,
language_model: ov.Model,
text_embeddings: ov.Model,
vision_embeddings: ov.Model,
config: PretrainedConfig = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
self.config = config
self.use_cache = kwargs.get("use_cache", True)
self._model_save_dir = model_save_dir
self._device = device.upper()
self.is_dynamic = dynamic_shapes
self.ov_config = {} if ov_config is None else {**ov_config}
self.preprocessors = kwargs.get("preprocessors", [])
self.lm_model = language_model
self.text_embeddings_model = text_embeddings
self.vision_embeddings_model = vision_embeddings
self._supports_cache_class = False
self.main_input_name = "input_ids"
self._compile_only = kwargs.get("compile_only", False)
for part in self.additional_parts:
setattr(self, f"{part}_model", kwargs.get(part))
enable_compilation = kwargs.get("compile", True)
self.generation_config = kwargs.get("generation_config", GenerationConfig.from_model_config(config))
self._openvino_config = None
if quantization_config:
self._openvino_config = OVConfig(quantization_config=quantization_config)
self._set_ov_config_parameters()
self.language_model = OVModelWithEmbedForCausalLM(
self.lm_model,
self.text_embeddings_model,
config=config,
device=device,
ov_config=ov_config,
model_save_dir=model_save_dir,
quantization_config=quantization_config,
compile=self._compile_only or enable_compilation,
compile_only=self._compile_only,
)
self.vision_embeddings = OVVisionEmbedding(self.vision_embeddings_model, self)
for part in self.additional_parts:
model_part = getattr(self, f"{part}_model", None)
if model_part is not None:
model_part = MODEL_PARTS_CLS_MAPPING[part](model_part, self)
setattr(self, part, model_part)
if enable_compilation and not self._compile_only:
self.compile()
# Avoid warnings when creating a transformers pipeline
AutoConfig.register(self.base_model_prefix, AutoConfig)
try:
self.auto_model_class.register(AutoConfig, self.__class__)
except AttributeError:
pass
def clear_requests(self):
if self._compile_only:
raise ValueError(
"`clear_requests()` is not supported with `compile_only` mode, please initialize model without this option"
)
for _, component in self.components.items():
component.clear_requests()
def compile(self):
for _, component in self.components.items():
if isinstance(component, OVModelPart):
component._compile()
else:
component.compile()
def _save_config(self, save_directory):
"""
Saves a model configuration into a directory, so that it can be re-loaded using the
[`from_pretrained`] class method.
"""
save_config(self.config, save_directory)
def _save_pretrained(self, save_directory: Union[str, Path]):
"""
Saves the model to the OpenVINO IR format so that it can be re-loaded using the
[`~optimum.intel.openvino.modeling.OVModel.from_pretrained`] class method.
Arguments:
save_directory (`str` or `Path`):
The directory where to save the model files.
"""
src_models = self.ov_submodels
dst_file_names = {
"lm_model": OV_LANGUAGE_MODEL_NAME,
"text_embeddings_model": OV_TEXT_EMBEDDINGS_MODEL_NAME,
"vision_embeddings_model": OV_VISION_EMBEDDINGS_MODEL_NAME,
}
for name in self._ov_submodel_names:
if name not in dst_file_names:
dst_file_names[name] = f"openvino_{name}.xml"
for name in self._ov_submodel_names:
model = src_models[name]
dst_file_name = dst_file_names[name]
dst_path = os.path.join(save_directory, dst_file_name)
ov.save_model(model, dst_path, compress_to_fp16=False)
self._save_openvino_config(save_directory)
if self.generation_config is not None:
try:
self.generation_config.save_pretrained(save_directory)
except Exception as exception:
logger.warning(
f"The generation config will not be saved, saving failed with following error:\n{exception}"
)
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
local_files_only: bool = False,
load_in_8bit: bool = False,
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
"""
Loads a model and its configuration file from a directory or the HF Hub.
Arguments:
model_id (`str` or `Path`):
The directory from which to load the model.
Can be either:
- The model id of a pretrained model hosted inside a model repo on huggingface.co.
- The path to a directory containing the model weights.
use_auth_token (Optional[Union[bool, str]], defaults to `None`):
Deprecated. Please use `token` instead.
token (Optional[Union[bool, str]], defaults to `None`):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`):
The specific model version to use. It can be a branch name, a tag name, or a commit id.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
cache_dir (`Union[str, Path]`, *optional*):
The path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
encoder_file_name(`str`, *optional*):
The encoder model file name. Overwrites the default file name openvino_encoder_model.xml and allows one to
load the encoder model with a different name.
decoder_file_name(`str`, *optional*):
The decoder model file name. Overwrites the default file name openvino_decoder_model.xml and allows one to
load the decoder model with a different name.
decoder_with_past_file_name(`str`, *optional*):
The decoder with past key values model file name overwriting the default file name
openvino_decoder_with_past_model.xml, allowing to load the decoder model with a different name.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
"""
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed soon. Please use the `token` argument instead.",
FutureWarning,
)
if token is not None:
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
token = use_auth_token
model_file_names = {
"language_model": OV_LANGUAGE_MODEL_NAME,
"language_model_bin": OV_LANGUAGE_MODEL_NAME.replace(".xml", ".bin"),
"text_embeddings": OV_TEXT_EMBEDDINGS_MODEL_NAME,
"text_embeddings_bin": OV_TEXT_EMBEDDINGS_MODEL_NAME.replace(".xml", ".bin"),
"vision_embeddings": OV_VISION_EMBEDDINGS_MODEL_NAME,
"vision_embeddings_bin": OV_VISION_EMBEDDINGS_MODEL_NAME.replace(".xml", ".bin"),
}
model_cls = MODEL_TYPE_TO_CLS_MAPPING[config.model_type]
for part in model_cls.additional_parts:
model_file_names[part] = f"openvino_{part}_model.xml"
model_file_names[part + "_bin"] = f"openvino_{part}_model.bin"
compile_only = kwargs.get("compile_only", False)
if os.path.isdir(model_id):
# Load model from a local directory
model_save_dir = Path(model_id)
file_names = {k: os.path.join(model_id, model_file_names[k]) for k in model_file_names}
else:
file_names = {}
for name, file_name in model_file_names.items():
model_cache_path = hf_hub_download(
repo_id=model_id,
filename=file_name,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
file_names[name] = model_cache_path
model_save_dir = Path(model_cache_path).parent
if not compile_only:
language_model = model_cls.load_model(file_names["language_model"])
text_embeddings = model_cls.load_model(file_names["text_embeddings"])
vision_embeddings = model_cls.load_model(file_names["vision_embeddings"])
for part in model_cls.additional_parts:
kwargs[part] = model_cls.load_model(file_names[part])
else:
language_model = model_cls._compile_model(
file_names["language_model"],
kwargs.get("device", "CPU"),
kwargs.get("ov_config"),
model_save_dir,
)
text_embeddings = model_cls._compile_model(
file_names["text_embeddings"],
kwargs.get("device", "CPU"),
kwargs.get("ov_config"),
model_save_dir,
)
vision_embeddings = model_cls._compile_model(
file_names["vision_embeddings"],
kwargs.get("device", "CPU"),
kwargs.get("ov_config"),
model_save_dir,
)
for part in model_cls.additional_parts:
kwargs[part] = model_cls._compile_model(
file_names[part],
kwargs.get("device", "CPU"),
kwargs.get("ov_config"),
model_save_dir,
)
try:
generation_config = GenerationConfig.from_pretrained(
model_id,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
)
kwargs["generation_config"] = generation_config
except Exception:
pass
quantization_config = model_cls._prepare_quantization_config(quantization_config, load_in_8bit)
to_quantize = not compile_only and quantization_config is not None
if to_quantize:
kwargs["compile"] = False
model = model_cls(
language_model=language_model,
text_embeddings=text_embeddings,
vision_embeddings=vision_embeddings,
config=config,
model_save_dir=model_save_dir,
quantization_config=quantization_config,
**kwargs,
)
if to_quantize:
from optimum.intel.openvino.quantization import OVQuantizer
quantization_config_copy = copy.deepcopy(quantization_config)
quantization_config_copy.tokenizer = quantization_config.tokenizer or model_id
potential_processor_id = config.mm_vision_tower if isinstance(model, _OVNanoLlavaForCausalLM) else model_id
quantization_config_copy.processor = quantization_config.processor or potential_processor_id
OVQuantizer(model).quantize(ov_config=OVConfig(quantization_config=quantization_config_copy))
return model
@classmethod
def _export(
cls,
model_id: str,
config: PretrainedConfig,
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
subfolder: str = "",
local_files_only: bool = False,
task: Optional[str] = None,
use_cache: bool = True,
trust_remote_code: bool = False,
load_in_8bit: Optional[bool] = None,
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
compile_only = kwargs.pop("compile_only", False)
if compile_only:
logger.warning(
"`compile_only` mode will be disabled because it does not support model export."
"Please provide openvino model obtained using optimum-cli or saved on disk using `save_pretrained`"
)
compile_only = False
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
# This attribute is needed to keep one reference on the temporary directory, since garbage collecting
# would end-up removing the directory containing the underlying OpenVINO model
cls._model_save_dir_tempdirectory_instance = save_dir
if task is None:
task = cls.export_feature
# If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
if load_in_8bit is None and not quantization_config:
ov_config = None
else:
# Export in fp32 if compression won't be applied later
ov_config = OVConfig(dtype="fp32" if load_in_8bit is False else "auto")
stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)
variant = kwargs.pop("variant", None)
main_export(
model_name_or_path=model_id,
output=save_dir_path,
task=task,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
token=token,
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
ov_config=ov_config,
stateful=stateful,
variant=variant,
)
config = AutoConfig.from_pretrained(save_dir_path, trust_remote_code=trust_remote_code)
return cls._from_pretrained(
model_id=save_dir_path,
config=config,
use_cache=use_cache,
load_in_8bit=load_in_8bit,
quantization_config=quantization_config,
**kwargs,
)
@property
def _component_names(self):
base_components = ["language_model", "vision_embeddings"]
additional_components = [part for part in self.additional_parts if getattr(self, part, None) is not None]
return base_components + additional_components
@property
def components(self):
return {component_name: getattr(self, component_name) for component_name in self._component_names}
@property
def _ov_submodel_names(self):
model_names = ["lm_model", "text_embeddings_model", "vision_embeddings_model"]
for part in self.additional_parts:
if getattr(self, part, None) is not None:
model_names.append(part + "_model")
return model_names
def reshape(self, batch_size: int, sequence_length: int):
logger.warning("Static shapes are not supported for causal language model.")
return self
def half(self):
"""
Converts all the model weights to FP16 for more efficient inference on GPU.
"""
for submodel in self.ov_submodels.values():
apply_moc_transformations(submodel, cf=False)
compress_model_transformation(submodel)
return self
def to(self, device):
self.language_model.to(device)
super().to(device)
return self
def forward(
self,
input_ids,
pixel_values=None,
past_key_values=None,
inputs_embeds=None,
image_sizes=None,
attention_mask=None,
position_ids=None,
image_bound=None,
tgt_sizes=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
rope_deltas=None,
images=None,
second_per_grid_ts=None,
token_type_ids=None,
pixel_attention_mask=None,
input_image_embeds: Optional[torch.FloatTensor] = None,
image_pixel_values: Optional[torch.FloatTensor] = None,
image_attention_mask=None,
audio_input_features: Optional[torch.FloatTensor] = None,
input_audio_embeds: Optional[torch.FloatTensor] = None,
audio_embed_sizes=None,
audio_attention_mask=None,
input_mode=None,
**kwargs,
):
if pixel_values is None:
pixel_values = images if images is not None else image_pixel_values
inputs_embeds, attention_mask, position_ids = self.get_multimodal_embeddings(
input_ids,
pixel_values,
image_sizes=image_sizes,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
image_bound=image_bound,
tgt_sizes=tgt_sizes,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
rope_deltas=rope_deltas,
second_per_grid_ts=second_per_grid_ts,
pixel_attention_mask=pixel_attention_mask,
input_image_embeds=input_image_embeds,
image_attention_mask=image_attention_mask,
input_audio_embeds=input_audio_embeds if input_audio_embeds is not None else audio_input_features,
audio_embed_sizes=audio_embed_sizes,
audio_attention_mask=audio_attention_mask,
input_mode=input_mode,
**kwargs,
)
return self.language_model.forward(
input_ids=None,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
token_type_ids=token_type_ids,
past_key_values=past_key_values,
**kwargs,
)
def _reorder_cache(self, past_key_values, beam_idx):
return self.language_model._reorder_cache(past_key_values, beam_idx)
def get_vision_embeddings(self, pixel_values, **kwargs):
raise NotImplementedError
def get_text_embeddings(self, input_ids, **kwargs):
return self.language_model.embed_tokens(input_ids)
def merge_vision_text_embeddings(
self, vision_embeds, inputs_embeds, input_ids=None, attention_mask=None, position_ids=None, **kwargs
):
raise NotImplementedError
def get_multimodal_embeddings(
self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, **kwargs
):
inputs_embeds = self.get_text_embeddings(input_ids, **kwargs)
if pixel_values is not None:
vision_embeds = self.get_vision_embeddings(pixel_values, input_ids=input_ids, **kwargs)
if vision_embeds is not None:
inputs_embeds, attention_mask, position_ids = self.merge_vision_text_embeddings(
vision_embeds,
inputs_embeds,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**kwargs,
)
return inputs_embeds, attention_mask, position_ids
# Adopted from https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llava/modeling_llava.py#L521
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
inputs_embeds=None,
pixel_values=None,
image_sizes=None,
attention_mask=None,
**kwargs,
):
if past_key_values is not None:
past_length = self.language_model._get_past_length(past_key_values)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and past_length + 1 > input_ids.shape[1]:
input_discount = max(attention_mask.shape[1] - past_length, 1)
input_ids = input_ids[:, -input_discount:]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.llava
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
elif getattr(self.config, "image_token_index", -1) in input_ids:
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
# position_ids in Gemma3 are 1-indexed
if self.config.model_type == "gemma3":
position_ids += 1
if past_key_values is not None:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
if pixel_values is None:
pixel_values = kwargs.get("input_image_embeds", kwargs.get("images", kwargs.get("image_pixel_values")))
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"image_sizes": image_sizes,
"image_bound": kwargs.get("image_bound"),
"tgt_sizes": kwargs.get("tgt_sizes"),
"pixel_values_videos": kwargs.get("pixel_values_videos"),
"image_grid_thw": kwargs.get("image_grid_thw"),
"video_grid_thw": kwargs.get("video_grid_thw"),
"token_type_ids": kwargs.get("token_type_ids"),
"pixel_attetion_mask": kwargs.get("pixle_attetion_mask"),
"image_attention_mask": kwargs.get("image_attention_mask"),
"input_audio_embeds": kwargs.get("input_audio_embeds", kwargs.get("audio_input_features")),
"audio_embed_sizes": kwargs.get("audio_embed_sizes"),
"input_mode": kwargs.get("input_mode"),
}
)
return model_inputs
def can_generate(self):
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
return True
@staticmethod
@abstractmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
"""
Preprocess input instruction and an image.
"""
class _OVLlavaForCausalLM(OVModelForVisualCausalLM):
auto_model_class = LlavaForConditionalGeneration
def __init__(
self,
language_model: ov.Model,
text_embeddings: ov.Model,
vision_embeddings: ov.Model,
config: PretrainedConfig = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
super().__init__(
language_model=language_model,
text_embeddings=text_embeddings,
vision_embeddings=vision_embeddings,
config=config,
device=device,
dynamic_shapes=dynamic_shapes,
ov_config=ov_config,
model_save_dir=model_save_dir,
quantization_config=quantization_config,
**kwargs,
)
self._support_new_processing = hasattr(self.config, "image_seq_length")
def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
return None
if not isinstance(pixel_values, list):
image_features = self.vision_embeddings(pixel_values).last_hidden_state
else:
image_features = []
for patch in pixel_values:
if isinstance(patch, list):
patch_feats = []
for patch_value in patch:
patch_feats.append(self.vision_embeddings(np.expand_dims(patch_value, 0)).last_hidden_state)
patch_feats = np.concatenate(patch_feats, axis=1)
else:
patch_feats = self.vision_embeddings(patch).last_hidden_state
image_features.append(patch_feats)
image_features = np.concatenate(image_features, 0)
return image_features
# Adopted from https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/llava/modeling_llava.py#L297C9-L297C45
def merge_vision_text_embeddings(
self,
vision_embeds,
inputs_embeds,
input_ids,
attention_mask,
position_ids=None,
legacy_processing=False,
**kwargs,
):
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
if legacy_processing:
pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
num_images, num_image_patches, embed_dim = image_features.shape
batch_size, sequence_length = input_ids.shape
left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(pad_token_id))
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == self.config.image_token_index
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# Compute the maximum embed dimension
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
if left_padding:
new_token_positions += nb_image_pad[:, None] # offset for left padding
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None]
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
raise ValueError(
f"The input provided to the model a/pre-releasesre wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
)
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
batch_indices, pad_indices = torch.where(input_ids == pad_token_id)
indices_to_mask = new_token_positions[batch_indices, pad_indices]
final_embedding[batch_indices, indices_to_mask] = 0
else:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
image_features = image_features.to(inputs_embeds.dtype)
final_embedding = inputs_embeds.masked_scatter(special_image_mask, image_features)
final_attention_mask = attention_mask
return final_embedding, final_attention_mask, position_ids
def get_multimodal_embeddings(
self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, past_key_values=None, **kwargs
):
if pixel_values is not None and self._support_new_processing and past_key_values is None:
legacy_processing = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
else:
legacy_processing = True
inputs_embeds, attention_mask, position_ids = super().get_multimodal_embeddings(
input_ids, pixel_values, attention_mask, position_ids, legacy_processing=legacy_processing, **kwargs
)
if legacy_processing and pixel_values is not None and past_key_values is not None:
attention_mask, position_ids = self._filter_unattended_tokens(input_ids, attention_mask, past_key_values)
return inputs_embeds, attention_mask, position_ids
def _filter_unattended_tokens(self, input_ids, attention_mask, past_key_values):
# Get the target length
target_length = input_ids.shape[1]
past_length = self.language_model._get_past_length(past_key_values)
extended_attention_mask = torch.ones(
(attention_mask.shape[0], past_length),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
position_ids = torch.cumsum(attention_mask, axis=1) - 1
position_ids[attention_mask == 0] = 1
return attention_mask, position_ids
@staticmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
if processor is None:
raise ValueError("Processor is required.")
if video is not None:
raise ValueError("Video input is not supported")
if audio is not None:
raise ValueError("Audio input is not supported")
if getattr(processor, "chat_template", None) is not None:
chat_prompt = [{"role": "user", "content": [{"type": "text", "text": text}]}]
if image is not None:
chat_prompt[0]["content"].append({"type": "image"})
prompt = processor.apply_chat_template(chat_prompt, add_generation_prompt=True, tokenize=False)
else:
if image is not None and "<image>" not in text:
prompt = "<image>\n" + text
else:
prompt = text
if is_transformers_version(">", "4.47.99") and getattr(processor, "patch_size", None) is None:
if (
getattr(config, "vision_config", None) is not None
and getattr(config.vision_config, "patch_size", None) is not None
):
processor.patch_size = config.vision_config.patch_size
else:
raise ValueError(
"Processor does not have `patch_size` attribute. Please fix the processor or provide `patch_size` in the config."
)
inputs = processor(images=image, text=prompt, return_tensors="pt")
return inputs
class _OVLlavaNextForCausalLM(_OVLlavaForCausalLM):
auto_model_class = LlavaNextForConditionalGeneration
# Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L655
def pack_image_features(self, image_features, image_sizes, image_newline=None):
from transformers.models.llava_next.modeling_llava_next import get_anyres_image_grid_shape, unpad_image
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Args:
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
List of image feature tensor, each contains all the visual feature of all patches.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
token length of each image in image_features
"""
new_image_features = []
feature_lens = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
if height * width != base_image_feature.shape[0]:
raise ValueError("The number of patches is not consistent with the image size.")
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
if image_newline is not None:
image_feature = torch.cat(
(
image_feature,
image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if image_newline is not None:
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
new_image_features.append(image_feature)
feature_lens.append(image_feature.size(0))
image_features = torch.cat(new_image_features, dim=0)
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
return image_features, feature_lens
def add_image_features(
self,
input_ids,
inputs_embeds,
pixel_values,
attention_mask,
position_ids,
image_sizes,
legacy_processing,
**kwargs,
):
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches
# ! infer image_num_patches from image_sizes
image_num_patches = [
image_size_to_num_patches(
image_size=imsize,
grid_pinpoints=self.config.image_grid_pinpoints,
patch_size=self.config.vision_config.image_size,
)
for imsize in image_sizes
]
# figure out if pixel_values is concatenated or stacked
if pixel_values.dim() == 5:
# stacking when input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
vision_embeds = self.get_vision_embeddings(pixel_values, input_ids=input_ids, **kwargs)
if vision_embeds is not None:
image_newline = torch.tensor(self.config.image_newline)
image_features = torch.split(torch.from_numpy(vision_embeds), image_num_patches, dim=0)
image_features, feature_lens = self.pack_image_features(
image_features,
image_sizes,
image_newline=image_newline,
)
inputs_embeds, attention_mask, position_ids = self.merge_vision_text_embeddings(
image_features,
inputs_embeds,
feature_lens=feature_lens,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
legacy_processing=legacy_processing,
**kwargs,
)
return inputs_embeds, attention_mask, position_ids
# Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L416
def get_multimodal_embeddings(
self,
input_ids,
pixel_values=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
image_sizes=None,
**kwargs,
):
inputs_embeds = self.get_text_embeddings(input_ids, **kwargs)
if pixel_values is not None and self._support_new_processing and past_key_values is None:
legacy_processing = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
else:
legacy_processing = True
if pixel_values is not None and pixel_values.size(0) > 0:
inputs_embeds, attention_mask, position_ids = self.add_image_features(
input_ids,
inputs_embeds,
pixel_values,
attention_mask,
position_ids,
image_sizes,
legacy_processing,
**kwargs,
)
if legacy_processing and pixel_values is not None and past_key_values is not None and input_ids.shape[1] == 1:
attention_mask, position_ids = self._filter_unattended_tokens(input_ids, attention_mask, past_key_values)
return inputs_embeds, attention_mask, position_ids
def merge_vision_text_embeddings(
self,
vision_embeds,
inputs_embeds,
feature_lens,
input_ids,
attention_mask,
position_ids=None,
legacy_processing=False,
image_token_index=None,
**kwargs,
):
image_token_index = self.config.image_token_index if image_token_index is None else image_token_index
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
if legacy_processing:
with torch.no_grad():
# ! in llava 1.6, number of patches is variable
num_images = feature_lens.size(0)
num_image_features, embed_dim = image_features.shape
if feature_lens.sum() != num_image_features:
raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}")
batch_size = input_ids.shape[0]
_left_padding = torch.any(attention_mask[:, 0] == 0)
_right_padding = torch.any(attention_mask[:, -1] == 0)
left_padding = True
if batch_size > 1:
if _left_padding and not _right_padding:
left_padding = True
elif not _left_padding and _right_padding:
left_padding = False
elif not _left_padding and not _right_padding:
left_padding = True
else:
# invalid attention_mask
raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
# Whether to turn off right padding
# 1. Create a mask to know where special image tokens are
special_image_token_mask = input_ids == image_token_index
# special_image_token_mask: [bsz, seqlen]
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
# num_special_image_tokens: [bsz]
# Reserve for padding of num_images
total_num_special_image_tokens = torch.sum(special_image_token_mask)
if total_num_special_image_tokens != num_images:
raise ValueError(
f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
)
# Compute the maximum embed dimension
# max_image_feature_lens is max_feature_lens per batch
feature_lens = feature_lens.to(input_ids.device)
feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device)
embed_sequence_lengths = (
(attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
)
max_embed_dim = embed_sequence_lengths.max()
batch_indices, non_image_indices = torch.where(
(input_ids != image_token_index) & (attention_mask == 1)
)
# 2. Compute the positions where text should be written
# Calculate new positions for text tokens in merged image-text sequence.
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
# ! instead of special_image_token_mask * (num_image_patches - 1)
# special_image_token_mask * (num_feature_len - 1)
special_image_token_mask = special_image_token_mask.long()
special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
if left_padding:
# shift right token positions so that they are ending at the same number
# the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
# 3. Create the full embedding, already padded to the maximum position
final_embedding = torch.zeros(
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
final_attention_mask = torch.zeros(
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
)
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
# set the corresponding tensors into their correct target device.
target_device = inputs_embeds.device
batch_indices, non_image_indices, text_to_overwrite = (
batch_indices.to(target_device),
non_image_indices.to(target_device),
text_to_overwrite.to(target_device),
)
attention_mask = attention_mask.to(target_device)
input_ids = input_ids.to(target_device)
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
with torch.no_grad():
image_to_overwrite = torch.full(
(batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
)
image_to_overwrite[batch_indices, text_to_overwrite] = False
embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
embed_indices = embed_indices.expand(batch_size, max_embed_dim)
embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
if left_padding:
# exclude padding on the left
max_embed_dim = max_embed_dim.to(target_device)
val = (max_embed_dim - embed_indices) <= embed_seq_lens
else:
# exclude padding on the right
val = embed_indices < embed_seq_lens
image_to_overwrite &= val
if image_to_overwrite.sum() != num_image_features:
raise ValueError(
f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
f" the number of image given to the model is {num_images}. "
f"This prevents correct indexing and breaks batch generation."
)
final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
final_attention_mask |= image_to_overwrite
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
else:
special_image_mask = (input_ids == image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
image_features = image_features.to(inputs_embeds.dtype)
final_embedding = inputs_embeds.masked_scatter(special_image_mask, image_features)
final_attention_mask = attention_mask
return final_embedding, final_attention_mask, position_ids
def get_text_embeddings(self, input_ids, **kwargs):
for_inputs_embeds_ids = input_ids.clone()
for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
return super().get_text_embeddings(for_inputs_embeds_ids, **kwargs)
class _OVLlavaNextVideoForCausalLM(_OVLlavaNextForCausalLM):
additional_parts = ["vision_resampler", "multi_modal_projector"]
auto_model_class = AutoModelForVision2Seq
def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
return None
image_features = self.vision_embeddings(pixel_values).last_hidden_state
image_features = self.multi_modal_projector(image_features)
return image_features
def pack_image_features(self, image_features, image_sizes, image_newline=None):
"""
Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors.
Args:
image_features (`List[torch.Tensor]` of length num_images, each of shape `(num_patches, image_length, embed_dim)`)
List of image feature tensor, each contains all the visual feature of all patches.
image_sizes (`torch.Tensor` of shape `(num_images, 2)`)
Actual image size of each images (H, W).
vision_feature_select_strategy (`str`)
The feature selection strategy used to select the vision feature from the vision backbone.
image_newline (`torch.Tensor` of shape `(embed_dim)`)
New line embedding vector.
Returns:
image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`)
feature_lens (`List[int]`)
token length of each image in image_features
"""
from transformers.models.llava_next_video.modeling_llava_next_video import (
get_anyres_image_grid_shape,
unpad_image,
)
new_image_features = []
feature_lens = []
vision_feature_select_strategy = self.config.vision_feature_select_strategy
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.config.vision_config.image_size // self.config.vision_config.patch_size
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
if (
np.prod(image_feature.shape) % (num_patch_height * num_patch_width * height * width) != 0
and vision_feature_select_strategy == "default"
):
logger.warning_once(
"Image feature shape does not line up with the provided patch size. "
"You may be using the `default` vision_feature_select_strategy with a"
" visual encoder that does not have CLS."
)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
if image_newline is not None:
image_feature = torch.cat(
(
image_feature,
image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device, image_feature.dtype),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if image_newline is not None:
image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0)
new_image_features.append(image_feature)
feature_lens.append(image_feature.size(0))
image_features = torch.cat(new_image_features, dim=0)
feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
return image_features, feature_lens
@staticmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
if processor is None:
raise ValueError("Processor is required.")
if audio is not None:
raise ValueError("Audio input is not supported")
if getattr(processor, "chat_template", None) is not None:
chat_prompt = [{"role": "user", "content": [{"type": "text", "text": text}]}]
if image is not None:
chat_prompt[0]["content"].append({"type": "image"})
if video is not None:
chat_prompt[0]["content"].append({"type": "video"})
prompt = processor.apply_chat_template(chat_prompt, add_generation_prompt=True, tokenize=False)
else:
prompt = text
if image is not None and "<image>" not in prompt:
prompt = "<image>\n" + prompt
if video is not None and "<video>" not in prompt:
prompt = "<video>\n" + prompt
if is_transformers_version(">", "4.47.99") and getattr(processor, "patch_size", None) is None:
if (
getattr(config, "vision_config", None) is not None
and getattr(config.vision_config, "patch_size", None) is not None
):
processor.patch_size = config.vision_config.patch_size
else:
raise ValueError(
"Processor does not have `patch_size` attribute. Please fix the processor or provide `patch_size` in the config."
)
inputs = processor(images=image, text=prompt, videos=video, return_tensors="pt")
return inputs
def get_multimodal_embeddings(
self,
input_ids,
pixel_values=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
image_sizes=None,
pixel_values_videos=None,
**kwargs,
):
inputs_embeds = self.get_text_embeddings(input_ids, **kwargs)
if (
pixel_values is not None
and pixel_values.size(0) > 0
and self._support_new_processing
and past_key_values is None
):
legacy_processing = (
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
).item()
elif (
pixel_values_videos is not None
and pixel_values_videos.size(0) > 0
and self._support_new_processing
and past_key_values is None
):
legacy_processing = (
(input_ids == self.config.video_token_index).sum(1).max() < self.config.video_seq_length
).item()
else:
legacy_processing = True
legacy_processing = (
legacy_processing.item() if isinstance(legacy_processing, torch.Tensor) else legacy_processing
)
if pixel_values is not None and pixel_values.size(0) > 0:
inputs_embeds, attention_mask, position_ids = self.add_image_features(
input_ids,
inputs_embeds,
pixel_values,
attention_mask,
position_ids,
image_sizes,
legacy_processing,
**kwargs,
)
if pixel_values_videos is not None and pixel_values_videos.size(0) > 0:
inputs_embeds, attention_mask, position_ids = self.add_video_features(
input_ids,
inputs_embeds,
pixel_values_videos,
attention_mask,
position_ids,
legacy_processing=legacy_processing,
**kwargs,
)
if legacy_processing and pixel_values is not None and past_key_values is not None and input_ids.shape[1] == 1:
attention_mask, position_ids = self._filter_unattended_tokens(input_ids, attention_mask, past_key_values)
return inputs_embeds, attention_mask, position_ids
def add_video_features(
self,
input_ids,
inputs_embeds,
pixel_values_videos,
attention_mask,
position_ids,
legacy_processing,
**kwargs,
):
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/llava_next_video/modeling_llava_next_video.py#L732-L751
video_features = self.get_video_features(pixel_values_videos, input_ids)
if video_features is not None and len(video_features) != 0:
video_features = [feature.flatten(0, 1) for feature in video_features]
video_feature_lens = [feature.size(0) for feature in video_features]
video_features = torch.cat(video_features, dim=0)
video_feature_lens = torch.tensor(video_feature_lens, dtype=torch.long, device=video_features.device)
if legacy_processing:
inputs_embeds, attention_mask, position_ids = self.merge_vision_text_embeddings(
video_features,
inputs_embeds,
video_feature_lens,
input_ids,
attention_mask,
position_ids,
legacy_processing,
self.config.video_token_index,
)
else:
inputs_embeds = (
torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
)
special_image_mask = (input_ids == self.config.video_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds)
if inputs_embeds[special_image_mask].numel() != video_features.numel():
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
return inputs_embeds, attention_mask, position_ids
def get_video_features(self, pixel_values, input_ids=None, **kwargs):
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/llava_next_video/modeling_llava_next_video.py#L835
if input_ids is not None and input_ids.shape[1] == 1:
return None
batch_size, frames, channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width)
selected_video_features = self.vision_embeddings(pixel_values).last_hidden_state
video_features = self.vision_resampler(selected_video_features)
video_features = self.multi_modal_projector(video_features)
video_features = torch.split(torch.from_numpy(video_features), frames, dim=0)
return video_features
class _OVInternVLForCausalLM(OVModelForVisualCausalLM):
def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
return None
image_features = self.vision_embeddings(pixel_values, **kwargs).last_hidden_state
return image_features
def merge_vision_text_embeddings(
self, vision_embeds, input_embeds, input_ids, attention_mask, position_ids=None, **kwargs
):
input_embeds = torch.from_numpy(input_embeds) if isinstance(input_embeds, np.ndarray) else input_embeds
vision_embeds = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
input_ids = input_ids.reshape(B * N)
selected = input_ids == self.config.img_context_token_id
assert selected.sum() != 0
input_embeds[selected] = vision_embeds.reshape(-1, C)
input_embeds = input_embeds.reshape(B, N, C)
return input_embeds, attention_mask, position_ids
@staticmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
if tokenizer is None:
raise ValueError("Tokenizer is required.")
if video is not None:
raise ValueError("Video input is not supported")
if audio is not None:
raise ValueError("Audio input is not supported")
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
IMG_START_TOKEN = "<img>"
IMG_END_TOKEN = "</img>"
IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD),
]
)
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=12, image_size=28, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = {
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
}
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size,
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image, input_size=448, max_num=12):
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
if image is not None and "<image>" not in text:
text = "<image>\n" + text
if tokenizer.chat_template is not None:
text = tokenizer.apply_chat_template(
[{"role": "user", "content": text}], add_generation_prompt=True, tokenize=False
)
inputs = {}
if image is not None:
if config is None:
raise ValueError("Config is required.")
pixel_values = load_image(image, input_size=config.vision_config.image_size)
num_patches = pixel_values.shape[0]
num_image_token = int(
(config.vision_config.image_size // config.vision_config.patch_size) ** 2
* (config.downsample_ratio**2)
)
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * num_image_token * num_patches + IMG_END_TOKEN
text = text.replace("<image>", image_tokens, 1)
inputs.update({"pixel_values": pixel_values})
inputs.update(tokenizer(text, return_tensors="pt"))
return inputs
# internvl has issue with check _get_non_default_parameters, as wrkaraund overide _prepare_generation_config
def _prepare_generation_config(
self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Dict
) -> Tuple[GenerationConfig, Dict]:
using_model_generation_config = False
if generation_config is None:
if (
self.generation_config._from_model_config # 1)
and self.generation_config._original_object_hash == hash(self.generation_config) # 2)
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config: # 4)
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed in v5."
" Please use and modify the model generation configuration (see"
" https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration )",
UserWarning,
)
self.generation_config = new_generation_config
generation_config = self.generation_config
using_model_generation_config = True
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
# If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
if not using_model_generation_config:
if generation_config.bos_token_id is None:
generation_config.bos_token_id = self.generation_config.bos_token_id
if generation_config.eos_token_id is None:
generation_config.eos_token_id = self.generation_config.eos_token_id
if generation_config.pad_token_id is None:
generation_config.pad_token_id = self.generation_config.pad_token_id
if generation_config.decoder_start_token_id is None:
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
return generation_config, model_kwargs
class _OVMiniCPMVForCausalLM(OVModelForVisualCausalLM):
additional_parts = ["resampler"]
def __init__(
self,
language_model: ov.Model,
text_embeddings: ov.Model,
vision_embeddings: ov.Model,
config: PretrainedConfig = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
super().__init__(
language_model,
text_embeddings,
vision_embeddings,
config,
device,
dynamic_shapes,
ov_config,
model_save_dir,
quantization_config,
**kwargs,
)
self.embed_dim = self.language_model.config.hidden_size
max_size = self.config.vision_config.image_size // self.config.vision_config.patch_size
self._pos_embeds = torch.from_numpy(self._get_2d_sincos_pos_embed(self.embed_dim, max_size)).float()
self.max_size = (max_size, max_size)
def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
return None
tgt_sizes = kwargs["tgt_sizes"]
pixel_values_list = pixel_values
vision_hidden_states = []
all_pixel_values = []
img_cnt = []
for pixel_value in pixel_values_list:
img_cnt.append(len(pixel_value))
all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_value])
vision_embedding = None
# exist image
if all_pixel_values:
tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
all_pixel_values = torch.nn.utils.rnn.pad_sequence(all_pixel_values, batch_first=True, padding_value=0.0)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool)
for i in range(B):
patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
position_ids = self._prepare_vis_position_ids(
all_pixel_values,
patch_attn_mask,
tgt_sizes,
self.config.vision_config.patch_size,
self.config.vision_config.image_size // self.config.patch_size,
)
vision_embedding = torch.from_numpy(
self.vision_embeddings(
pixel_values=all_pixel_values, patch_attention_mask=patch_attn_mask, position_ids=position_ids
)[0]
)
vision_embedding = self.resampling(vision_embedding, tgt_sizes)
start = 0
for pixel_value in pixel_values_list:
img_cnt = len(pixel_value)
if img_cnt > 0:
vision_hidden_states.append(vision_embedding[start : start + img_cnt])
start += img_cnt
else:
vision_hidden_states.append([])
else: # no image
dummy_feature = []
for _ in range(len(pixel_values_list)):
vision_hidden_states.append(dummy_feature)
return vision_hidden_states
def resampling(self, x, tgt_sizes):
bs = x.shape[0]
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
self._adjust_pos_cache(tgt_sizes)
max_patch_len = torch.max(patch_len)
key_padding_mask = torch.zeros((bs, max_patch_len), dtype=torch.bool)
pos_embed = []
for i in range(bs):
tgt_h, tgt_w = tgt_sizes[i]
pos_embed.append(self._pos_embeds[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1))) # patches * D
key_padding_mask[i, patch_len[i] :] = True
pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, batch_first=True, padding_value=0.0).permute(
1, 0, 2
) # BLD => L * B * D
res = torch.from_numpy(self.resampler(image_feature=x, pos_embed=pos_embed, key_padding_mask=key_padding_mask))
return res
def _set_2d_pos_cache(self, max_size):
pos_embed = torch.from_numpy(self._get_2d_sincos_pos_embed(self.embed_dim, max_size)).float()
self._pos_embeds = pos_embed
def _adjust_pos_cache(self, tgt_sizes):
max_h = torch.max(tgt_sizes[:, 0])
max_w = torch.max(tgt_sizes[:, 1])
if max_h > self.max_size[0] or max_w > self.max_size[1]:
self.max_size = [max(max_h, self.max_size[0]), max(max_w, self.max_size[1])]
self._set_2d_pos_cache(self.max_size)
def _get_2d_sincos_pos_embed(self, embed_dim, image_size):
"""
image_size: image_size or (image_height, image_width)
return:
pos_embed: [image_height, image_width, embed_dim]
"""
if isinstance(image_size, int):
grid_h_size, grid_w_size = image_size, image_size
else:
grid_h_size, grid_w_size = image_size[0], image_size[1]
grid_h = np.arange(grid_h_size, dtype=np.float32)
grid_w = np.arange(grid_w_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
pos_embed = self._get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
return pos_embed
def _get_2d_sincos_pos_embed_from_grid(self, embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = self._get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[0]) # (H, W, D/2)
emb_w = self._get_1d_sincos_pos_embed_from_grid_new(embed_dim // 2, grid[1]) # (H, W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D)
return emb
def _get_1d_sincos_pos_embed_from_grid_new(self, embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (H, W)
out: (H, W, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product
emb_sin = np.sin(out) # (H, W, D/2)
emb_cos = np.cos(out) # (H, W, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D)
return emb
def _prepare_vis_position_ids(
self, pixel_values, patch_attention_mask, tgt_sizes, patch_size, num_patches_per_side
):
batch_size = pixel_values.size(0)
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
max_nb_patches_h, max_nb_patches_w = max_im_h // patch_size, max_im_w // patch_size
boundaries = torch.arange(1 / num_patches_per_side, 1.0, 1 / num_patches_per_side)
position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
if tgt_sizes is not None:
nb_patches_h = tgt_sizes[batch_idx][0]
nb_patches_w = tgt_sizes[batch_idx][1]
else:
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
pos_ids = (bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
return position_ids
def merge_vision_text_embeddings(
self, vision_embeds, input_embeds, input_ids, attention_mask, position_ids=None, **kwargs
):
bs = input_ids.shape[0]
image_bound = kwargs["image_bound"]
vllm_embedding = torch.from_numpy(input_embeds)
for i in range(bs):
cur_vs_hs = vision_embeds[i]
if len(cur_vs_hs) > 0:
cur_vllm_emb = vllm_embedding[i]
cur_image_bound = image_bound[i]
if len(cur_image_bound) > 0:
image_indices = torch.stack([torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound])
cur_vllm_emb.scatter_(
0,
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
)
return vllm_embedding, attention_mask, position_ids
@staticmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
if processor is None:
raise ValueError("Processor is required.")
if video is not None:
raise ValueError("Video input is not supported")
if audio is not None:
raise ValueError("Audio input is not supported")
if getattr(processor, "chat_template", None) is not None:
messages = [{"role": "user", "content": text if image is None else "(<image>./</image>)\n" + text}]
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
prompt = (
f"<|im_start|>user\n(<image>./</image>)\n{text}<|im_end|>\n<|im_start|>assistant\n"
if image is not None
else text
)
inputs = processor([prompt], [image], return_tensors="pt")
inputs.pop("image_sizes", None)
return inputs
class _OVNanoLlavaForCausalLM(OVModelForVisualCausalLM):
def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
return None
if isinstance(pixel_values, list) or pixel_values.ndim == 5:
concat_images = torch.cat(pixel_values, dim=0) if isinstance(pixel_values, list) else pixel_values
image_features = torch.from_numpy(self.vision_embeddings(concat_images).last_hidden_state)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
else:
image_features = self.vision_embeddings(pixel_values).last_hidden_state
return image_features
def get_multimodal_embeddings(
self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, **kwargs
):
vision_embeds = None
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
if pixel_values is None and "images" in kwargs:
pixel_values = kwargs["images"]
if pixel_values is not None:
vision_embeds = self.get_vision_embeddings(pixel_values, input_ids=input_ids, **kwargs)
if vision_embeds is None:
inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids))
past_len = self.language_model._get_past_length(kwargs.get("past_key_values"))
if attention_mask is not None and attention_mask.shape[1] < past_len + input_ids.shape[1]:
attention_mask = torch.cat(
[
attention_mask,
torch.ones(attention_mask.shape[0], past_len + input_ids.shape[1] - attention_mask.shape[1]),
],
dim=1,
)
position_ids = None
return inputs_embeds, attention_mask, position_ids
vision_embeds = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
labels = torch.full_like(input_ids, IGNORE_INDEX)
# remove the padding using attention_mask -- TODO: double check
input_ids = [
cur_input_ids[cur_attention_mask]
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask.bool())
]
labels = [
cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask.bool())
]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
if num_images == 0:
cur_image_features = vision_embeds[cur_image_idx]
cur_input_embeds_1 = torch.from_numpy(self.get_text_embeddings(cur_input_ids.unsqueeze(0))[0])
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = (
[-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
)
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = torch.from_numpy(
self.get_text_embeddings(torch.cat(cur_input_ids_noim).unsqueeze(0))[0]
)
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
cur_image_features = vision_embeds[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],),
IGNORE_INDEX,
device=cur_labels.device,
dtype=cur_labels.dtype,
)
)
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
if tokenizer_model_max_length is not None:
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full(
(batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device
)
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
new_input_embeds_padded.append(
torch.cat(
(
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
cur_new_embed,
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
else:
new_input_embeds_padded.append(
torch.cat(
(
cur_new_embed,
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
return new_input_embeds, attention_mask, position_ids
@staticmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
if tokenizer is None:
raise ValueError("Tokenizer is required.")
if video is not None:
raise ValueError("Video input is not supported")
if audio is not None:
raise ValueError("Audio input is not supported")
if image is not None and processor is None:
raise ValueError("Processor is required.")
text = f"<image>\n{text}" if image is not None else text
messages = [{"role": "user", "content": text}]
if tokenizer.chat_template is not None:
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
if image is not None:
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split("<image>")]
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
else:
input_ids = tokenizer(text, return_tensors="pt").input_ids
attention_mask = torch.ones_like(input_ids, dtype=torch.int64)
result = {"input_ids": input_ids, "attention_mask": attention_mask}
if image is not None:
result["images"] = processor(images=[image], return_tensors="pt")["pixel_values"]
return result
class _OVPhi3VisionForCausalLM(OVModelForVisualCausalLM):
additional_parts = ["vision_projection"]
def __init__(
self,
language_model: ov.Model,
text_embeddings: ov.Model,
vision_embeddings: ov.Model,
config: PretrainedConfig = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
super().__init__(
language_model,
text_embeddings,
vision_embeddings,
config,
device,
dynamic_shapes,
ov_config,
model_save_dir,
quantization_config,
**kwargs,
)
self.sub_GN = torch.tensor(self.config.sub_GN)
self.glb_GN = torch.tensor(self.config.glb_GN)
self.image_dim_out = self.config.img_processor["image_dim_out"]
def get_vision_embeddings(self, pixel_values, image_sizes, **kwargs):
num_images, num_crops, c, h, w = pixel_values.shape
img_features = self.vision_embeddings(pixel_values.flatten(0, 1)).last_hidden_state.reshape(
num_images, num_crops, -1, self.image_dim_out
)
image_features_proj = self.hd_feature_transform(img_features, image_sizes)
return image_features_proj
def hd_feature_transform(self, image_features, image_sizes):
"""
image_features: (num_images, num_crops+1, 24*24, 1024)
"""
image_features = torch.from_numpy(image_features)
global_image_features = image_features[:, 0] # (num_images, 24*24, 1024)
# global feature can be viewed as a special HD case with num_crops 1x1
global_image_features_hd = self.reshape_hd_patches_2x2merge(global_image_features, 1, 1)
global_image_features_hd_newline = self.add_image_newline(global_image_features_hd)
all_image_embeddings = []
# need a for loop to process each image because of different image sizes
# (patch arrangement is different for each image)
for i, img_size in enumerate(image_sizes):
h, w = img_size
h_crop = h // 336
w_crop = w // 336
num_crops = h_crop * w_crop
# NOTE: real num_crops is padded
# (num_crops, 24*24, 1024)
sub_image_features = image_features[i, 1 : 1 + num_crops]
sub_image_features_hd = self.reshape_hd_patches_2x2merge(sub_image_features, h_crop, w_crop)
sub_image_features_hd_newline = self.add_image_newline(sub_image_features_hd)
# [sub features, separator, global features]
all_image_embeddings.extend(
[
sub_image_features_hd_newline.squeeze(0), # (h_crop*12*(w_crop*12+1), 4096)
self.glb_GN.squeeze(0),
global_image_features_hd_newline[i],
]
)
image_features_proj = self.vision_projection(torch.cat(all_image_embeddings, dim=0).unsqueeze(0))[0]
return image_features_proj
def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
"""
image_features: (num_images*num_crops, 24*24, 1024)
output: (num_images, h_crop*12, w_crop*12, 4096), h_crop*w_crop == num_crops
"""
N, L, C = image_features.shape
assert L == 24 * 24 and C == 1024 and N % (h_crop * w_crop) == 0
num_images = N // (h_crop * w_crop)
H = int(L**0.5)
image_features_hd = (
image_features.reshape(N, H, H, C) # N, 24, 24, 1024
.reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
.permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
.reshape(N, -1, 4 * C) # N, 144, 4096
.reshape(num_images, h_crop, w_crop, H // 2, H // 2, -1) # n_img, h_crop, w_crop, 12, 12, 4096
.permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
.reshape(num_images, h_crop * H // 2, w_crop * H // 2, 4 * C) # n_img, h_crop*12, w_crop*12, 4096
)
return image_features_hd
def add_image_newline(self, image_features_hd):
"""
image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
"""
num_images, h, w, hid_dim = image_features_hd.shape
# add the newline token to the HD image feature patches
newline_embeddings = self.sub_GN.expand(num_images, h, -1, -1) # (n_img, h, 1, hid_dim)
image_features_hd_newline = torch.cat([image_features_hd, newline_embeddings], dim=2).reshape(
num_images, -1, hid_dim
)
return image_features_hd_newline
def get_multimodal_embeddings(
self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, image_sizes=None, **kwargs
):
MAX_INPUT_ID = int(1e9)
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
# positions for image tokens
positions = torch.nonzero((input_ids < 0) & (input_ids > -MAX_INPUT_ID), as_tuple=True)
has_image = len(positions[0].tolist()) > 0
input_ids = input_ids.clamp_min(0).clamp_max(self.config.vocab_size)
inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids, **kwargs))
if has_image:
vision_embeds = self.get_vision_embeddings(
pixel_values, input_ids=input_ids, image_sizes=image_sizes, **kwargs
)
image_features_proj = torch.from_numpy(vision_embeds)
inputs_embeds = inputs_embeds.index_put(positions, image_features_proj, accumulate=False)
return inputs_embeds, attention_mask, position_ids
@staticmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
if processor is None:
raise ValueError("Processor is required.")
if video is not None:
raise ValueError("Video input is not supported")
if audio is not None:
raise ValueError("Audio input is not supported")
if image is not None and "<|image_1|>" not in text:
text = "<|image_1|>\n" + text
if getattr(processor.tokenizer, "chat_template", None) is not None:
chat_prompt = [{"role": "user", "content": text}]
text = processor.tokenizer.apply_chat_template(chat_prompt, add_generation_prompt=True, tokenize=False)
inputs = processor(images=image, text=text, return_tensors="pt")
return inputs
@dataclass
class QWen2VLModelOutputWithPast(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
rope_deltas: Optional[torch.FloatTensor] = None
second_per_grid_ts: Optional[torch.FloatTensor] = None
class _OVQwen2VLForCausalLM(OVModelForVisualCausalLM):
additional_parts = ["vision_embeddings_merger"]
def __init__(
self,
language_model: ov.Model,
text_embeddings: ov.Model,
vision_embeddings: ov.Model,
config: PretrainedConfig = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
super().__init__(
language_model=language_model,
text_embeddings=text_embeddings,
vision_embeddings=vision_embeddings,
config=config,
device=device,
dynamic_shapes=dynamic_shapes,
ov_config=ov_config,
model_save_dir=model_save_dir,
quantization_config=quantization_config,
**kwargs,
)
self.rope_deltas = None # cache rope_deltas here
if is_transformers_version(">=", "4.45.0"):
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
VisionRotaryEmbedding,
)
self._rotary_pos_emb = VisionRotaryEmbedding(
self.config.vision_config.embed_dim // self.config.vision_config.num_heads // 2
)
else:
raise ValueError(
f"Initialization model for {self.config.model_type} required at least transformers >= 4.45"
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
if past_key_values is not None:
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif inputs_embeds is not None:
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if cache_position[0] != 0:
pixel_values = None
pixel_values_videos = None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"pixel_values_videos": pixel_values_videos,
"image_grid_thw": image_grid_thw,
"video_grid_thw": video_grid_thw,
"cache_position": cache_position,
}
)
return model_inputs
# Copied from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1602
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
) -> Dict[str, Any]:
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
num_new_tokens=num_new_tokens,
)
if getattr(outputs, "rope_deltas", None) is not None:
model_kwargs["rope_deltas"] = outputs.rope_deltas
return model_kwargs
# Copied from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1423
def get_rope_index(
self,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
Explanation:
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
Examples:
input_ids: [T T T T T], here T is for text.
temporal position_ids: [0, 1, 2, 3, 4]
height position_ids: [0, 1, 2, 3, 4]
width position_ids: [0, 1, 2, 3, 4]
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
and 1D rotary position embedding for text part.
Examples:
Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
text temporal position_ids: [3, 4, 5, 6, 7]
text height position_ids: [3, 4, 5, 6, 7]
text width position_ids: [3, 4, 5, 6, 7]
Here we calculate the text start position_ids as the max vision position_ids plus 1.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
Returns:
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
"""
spatial_merge_size = self.config.vision_config.spatial_merge_size
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
mrope_position_deltas = []
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
)
image_index, video_index = 0, 0
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i].to(input_ids.device) == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas
def get_vision_embeddings(self, pixel_values, grid_thw, **kwargs):
hidden_states = self.vision_embeddings(pixel_values)[0]
rotary_pos_emb = self.rot_pos_emb(grid_thw)
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0, dtype=torch.int32
)
cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)
attention_mask = torch.zeros((1, hidden_states.shape[0], hidden_states.shape[0]), dtype=torch.bool)
causal_mask = torch.zeros_like(attention_mask, dtype=torch.float32)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
causal_mask.masked_fill_(torch.logical_not(attention_mask), float("-inf"))
res = self.vision_embeddings_merger(
pixel_values=hidden_states, attention_mask=causal_mask, rotary_pos_emb=rotary_pos_emb
)[0]
return res
# Adopted from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1089
# Use config values instead of model attributes, replace self.rotary_pos_emb -> self._rotary_pos_emb
def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.config.vision_config.spatial_merge_size,
self.config.vision_config.spatial_merge_size,
w // self.config.vision_config.spatial_merge_size,
self.config.vision_config.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.config.vision_config.spatial_merge_size,
self.config.vision_config.spatial_merge_size,
w // self.config.vision_config.spatial_merge_size,
self.config.vision_config.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self._rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def get_multimodal_embeddings(
self,
input_ids,
pixel_values=None,
attention_mask=None,
position_ids=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
cache_position=None,
**kwargs,
):
inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids))
if pixel_values is not None and input_ids.shape[1] != 1:
image_embeds = torch.from_numpy(self.get_vision_embeddings(pixel_values, image_grid_thw))
image_mask = input_ids == self.config.image_token_id
inputs_embeds[image_mask] = image_embeds
if pixel_values_videos is not None and input_ids.shape[1] != 1:
pixel_values_videos = pixel_values_videos
video_embeds = torch.from_numpy(self.get_vision_embeddings(pixel_values_videos, video_grid_thw))
video_mask = input_ids == self.config.video_token_id
inputs_embeds[video_mask] = video_embeds
# if we get 4D attention mask we cannot calculate rope deltas anymore.
if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
return inputs_embeds, attention_mask, position_ids
def forward(
self,
input_ids,
pixel_values=None,
past_key_values=None,
inputs_embeds=None,
image_sizes=None,
attention_mask=None,
position_ids=None,
image_bound=None,
tgt_sizes=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
rope_deltas=None,
**kwargs,
):
result = super().forward(
input_ids,
pixel_values,
past_key_values,
inputs_embeds,
image_sizes,
attention_mask,
position_ids,
image_bound,
tgt_sizes,
pixel_values_videos,
image_grid_thw,
video_grid_thw,
rope_deltas,
**kwargs,
)
final_result = QWen2VLModelOutputWithPast(
logits=result.logits, past_key_values=result.past_key_values, rope_deltas=rope_deltas
)
return final_result
@staticmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
if processor is None:
raise ValueError("Processor is required.")
if audio is not None:
raise ValueError("Audio input is not supported")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": text},
],
}
]
if image is not None:
conversation[0]["content"].insert(0, {"type": "image"})
if video is not None:
conversation[0]["content"].insert(0, {"type": "video"})
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(images=image, text=text_prompt, videos=video, return_tensors="pt")
return inputs
class _OVQwen2_5_VLForCausalLM(OVModelForVisualCausalLM):
additional_parts = ["vision_embeddings_merger"]
def __init__(
self,
language_model: ov.Model,
text_embeddings: ov.Model,
vision_embeddings: ov.Model,
config: PretrainedConfig = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
super().__init__(
language_model=language_model,
text_embeddings=text_embeddings,
vision_embeddings=vision_embeddings,
config=config,
device=device,
dynamic_shapes=dynamic_shapes,
ov_config=ov_config,
model_save_dir=model_save_dir,
quantization_config=quantization_config,
**kwargs,
)
self.rope_deltas = None # cache rope_deltas here
class Qwen2_5_VisionRotaryEmbedding(torch.nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
return freqs
head_dim = config.vision_config.hidden_size // config.vision_config.num_heads
self._rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
def get_rope_index(
self,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# modified from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546
"""
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
"""
spatial_merge_size = self.config.vision_config.spatial_merge_size
image_token_id = self.config.image_token_id
video_token_id = self.config.video_token_id
vision_start_token_id = self.config.vision_start_token_id
mrope_position_deltas = []
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_index, video_index = 0, 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
second_per_grid_t = 0
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
if second_per_grid_ts is not None:
second_per_grid_t = second_per_grid_ts[video_index]
else:
second_per_grid_t = 1.0
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
second_per_grid_ts=None,
**kwargs,
):
if past_key_values is not None:
if inputs_embeds is not None and input_ids.shape[1] == 0:
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif inputs_embeds is not None:
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if cache_position[0] != 0:
pixel_values = None
pixel_values_videos = None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"pixel_values_videos": pixel_values_videos,
"image_grid_thw": image_grid_thw,
"video_grid_thw": video_grid_thw,
"cache_position": cache_position,
"second_per_grid_ts": second_per_grid_ts,
}
)
return model_inputs
def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.config.vision_config.spatial_merge_size,
self.config.vision_config.spatial_merge_size,
w // self.config.vision_config.spatial_merge_size,
self.config.vision_config.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.config.vision_config.spatial_merge_size,
self.config.vision_config.spatial_merge_size,
w // self.config.vision_config.spatial_merge_size,
self.config.vision_config.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self._rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def get_multimodal_embeddings(
self,
input_ids,
pixel_values=None,
attention_mask=None,
position_ids=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
cache_position=None,
second_per_grid_ts: Optional[torch.Tensor] = None,
**kwargs,
):
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1791-L1861
inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids))
if pixel_values is not None and input_ids.shape[1] != 1:
image_embeds = torch.from_numpy(self.get_vision_embeddings(pixel_values, image_grid_thw))
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
mask = input_ids == self.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None and input_ids.shape[1] != 1:
video_embeds = torch.from_numpy(self.get_vision_embeddings(pixel_values_videos, video_grid_thw))
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
mask = input_ids == self.config.video_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
# if we get 4D attention mask we cannot calculate rope deltas anymore.
if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, second_per_grid_ts, attention_mask
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
return inputs_embeds, attention_mask, position_ids
def get_vision_embeddings(self, pixel_values, grid_thw, **kwargs):
hidden_states = self.vision_embeddings(pixel_values)[0]
rotary_pos_emb = self.rot_pos_emb(grid_thw)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
dtype=torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0, dtype=torch.int32
)
cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)
attention_mask = torch.zeros((1, hidden_states.shape[0], hidden_states.shape[0]), dtype=torch.bool)
causal_mask = torch.zeros_like(attention_mask, dtype=torch.float32)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
causal_mask.masked_fill_(torch.logical_not(attention_mask), float("-inf"))
window_attention_mask = torch.zeros((1, hidden_states.shape[0], hidden_states.shape[0]), dtype=torch.bool)
window_causal_mask = torch.zeros_like(attention_mask, dtype=torch.float32)
for i in range(1, len(cu_window_seqlens)):
window_attention_mask[
..., cu_window_seqlens[i - 1] : cu_window_seqlens[i], cu_window_seqlens[i - 1] : cu_window_seqlens[i]
] = True
window_causal_mask.masked_fill_(torch.logical_not(window_attention_mask), float("-inf"))
res = self.vision_embeddings_merger(
pixel_values=hidden_states,
attention_mask=causal_mask,
window_attention_mask=window_causal_mask,
window_index=window_index,
rotary_pos_emb=rotary_pos_emb,
)[0]
return res
def get_window_index(self, grid_thw):
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = (
self.config.vision_config.window_size
// self.config.vision_config.spatial_merge_size
// self.config.vision_config.patch_size
)
spatial_merge_unit = (
self.config.vision_config.spatial_merge_size * self.config.vision_config.spatial_merge_size
)
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (
grid_h // self.config.vision_config.spatial_merge_size,
grid_w // self.config.vision_config.spatial_merge_size,
)
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = torch.nn.functional.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
index_padded = index_padded.reshape(
grid_t,
num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size,
)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t,
num_windows_h * num_windows_w,
vit_merger_window_size,
vit_merger_window_size,
)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = seqlens.cumsum(0) * spatial_merge_unit + cu_window_seqlens[-1]
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
@staticmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
if processor is None:
raise ValueError("Processor is required.")
if audio is not None:
raise ValueError("Audio input is not supported")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": text},
],
}
]
if image is not None:
conversation[0]["content"].insert(0, {"type": "image"})
if video is not None:
conversation[0]["content"].insert(0, {"type": "video"})
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(images=image, text=text_prompt, videos=video, return_tensors="pt")
return inputs
# Copied from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1602
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
) -> Dict[str, Any]:
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
num_new_tokens=num_new_tokens,
)
if getattr(outputs, "rope_deltas", None) is not None:
model_kwargs["rope_deltas"] = outputs.rope_deltas
return model_kwargs
class _OVMaira2ForCausalLM(_OVLlavaForCausalLM):
@staticmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
if processor is None:
raise ValueError("processor is required")
if video is not None:
raise ValueError("Video input is not supported")
if audio is not None:
raise ValueError("Audio input is not supported")
if image is None:
return processor(text=text, return_tensors="pt")
processed_inputs = processor.format_and_preprocess_phrase_grounding_input(
frontal_image=image,
phrase=text,
return_tensors="pt",
)
return processed_inputs
class _OVGemma3ForCausalLM(OVModelForVisualCausalLM):
def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
return None
return self.vision_embeddings(pixel_values).last_hidden_state
def merge_vision_text_embeddings(
self, vision_embeds, inputs_embeds, input_ids=None, attention_mask=None, position_ids=None, **kwargs
):
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-Gemma-3/src/transformers/models/gemma3/modeling_gemma3.py#L1323-L1339
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
if input_ids is None:
special_image_mask = inputs_embeds == torch.from_numpy(
self.get_text_embeddings(torch.tensor([[self.config.image_token_index]], dtype=torch.long))[0]
)
else:
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds)
image_features = image_features.to(inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
return inputs_embeds, attention_mask, position_ids
@staticmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
if processor is None:
raise ValueError("Processor is required.")
if video is not None:
raise ValueError("Video input is not supported")
if audio is not None:
raise ValueError("Audio input is not supported")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": text},
],
}
]
if image is not None:
conversation[0]["content"].insert(0, {"type": "image"})
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
inputs = processor(images=image, text=text_prompt, videos=video, return_tensors="pt")
return inputs
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
) -> Dict[str, Any]:
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
num_new_tokens=num_new_tokens,
)
# Token type ids used only for first inference mask generation
model_kwargs.pop("token_type_ids", None)
return model_kwargs
class _OVGotOCR2ForCausalLM(OVModelForVisualCausalLM):
def get_vision_embeddings(self, pixel_values, input_ids, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1 and kwargs.get("past_key_values") is not None:
return None
return self.vision_embeddings(pixel_values).last_hidden_state
def merge_vision_text_embeddings(
self, vision_embeds, inputs_embeds, input_ids=None, attention_mask=None, position_ids=None, **kwargs
):
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/got_ocr2/modeling_got_ocr2.py#L836-L845
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
n_image_tokens = (input_ids == self.config.image_token_index).sum()
n_image_features = image_features.shape[0] * image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
return inputs_embeds, attention_mask, position_ids
@staticmethod
def preprocess_inputs(
text: Optional[str] = None,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
if processor is None:
raise ValueError("processor is required")
if video is not None:
raise ValueError("Video input is not supported")
if audio is not None:
raise ValueError("Audio input is not supported")
if image is None:
raise ValueError("Image is required")
processed_inputs = processor(image, return_tensors="pt")
return processed_inputs
class _OVIdefics3ForCausalLM(OVModelForVisualCausalLM):
def get_vision_embeddings(self, pixel_values, input_ids, **kwargs):
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-SmolVLM-2/src/transformers/models/smolvlm/modeling_smolvlm.py#L899-L942
if input_ids is not None and input_ids.shape[1] == 1 and kwargs.get("past_key_values") is not None:
return None
batch_size, num_images, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values
pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
if not any(real_images_inds):
# no images, leave one empty image.
real_images_inds[0] = True
pixel_values = pixel_values[real_images_inds].contiguous()
pixel_attention_mask = kwargs.get("pixel_attention_mask")
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=[pixel_values.shape[i] for i in (0, 2, 3)],
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask
pixel_attention_mask = pixel_attention_mask.view(batch_size * num_images, *pixel_attention_mask.shape[2:])
pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
patch_size = self.config.vision_config.patch_size
num_patches_per_side = self.config.vision_config.image_size // patch_size
patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
batch_size_, _, max_im_h, max_im_w = pixel_values.shape
max_nb_patches_h, max_nb_patches_w = max_im_h // patch_size, max_im_w // patch_size
boundaries = torch.arange(1 / num_patches_per_side, 1.0, 1 / num_patches_per_side)
position_ids = torch.full(size=(batch_size_, max_nb_patches_h * max_nb_patches_w), fill_value=0)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
pos_ids = (bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
return self.vision_embeddings(
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, patch_position_ids=position_ids
).last_hidden_state
def merge_vision_text_embeddings(
self, vision_embeds, inputs_embeds, input_ids=None, attention_mask=None, position_ids=None, **kwargs
):
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-SmolVLM-2/src/transformers/models/idefics3/modeling_idefics3.py#L881
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
vision_hidden_size = image_features.shape[2]
special_image_token_mask = input_ids == self.config.image_token_id
# Fixes RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
new_inputs_embeds = inputs_embeds.clone()
reshaped_image_hidden_states = image_features.view(-1, vision_hidden_size)
# cast to the dtype of the input_embeds to support quantized models
reshaped_image_hidden_states = reshaped_image_hidden_states.to(inputs_embeds.device, inputs_embeds.dtype)
new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states
inputs_embeds = new_inputs_embeds
return inputs_embeds, attention_mask, position_ids
@staticmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
if processor is None:
raise ValueError("Processor is required.")
if video is not None:
raise ValueError("video input is not supported")
if audio is not None:
raise ValueError("Audio input is not supported")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": text},
],
}
]
if image is not None:
conversation[0]["content"].insert(0, {"type": "image"})
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(images=image, text=text_prompt, return_tensors="pt")
return inputs
class _OVSmolVLForCasualLM(_OVIdefics3ForCausalLM):
def merge_vision_text_embeddings(
self, vision_embeds, inputs_embeds, input_ids=None, attention_mask=None, position_ids=None, **kwargs
):
# Adopted from https://github.com/huggingface/transformers/blob/v4.49.0-SmolVLM-2/src/transformers/models/smolvlm/modeling_smolvlm.py#L803
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
_, patch_size, _ = image_features.shape
image_mask = input_ids == self.config.image_token_id
num_image_tokens = image_mask.sum(dim=1)
if not torch.all(num_image_tokens % patch_size == 0):
raise ValueError("At least one sample has <image> tokens not divisible by patch_size.")
blocks_per_sample = num_image_tokens // patch_size
offsets = torch.nn.functional.pad(blocks_per_sample.cumsum(dim=0), (1, 0), value=0)
block_offset = offsets[:-1]
row_cum = image_mask.cumsum(dim=-1)
chunk_idx = (row_cum - 1) // patch_size
local_idx = (row_cum - 1) % patch_size
block_idx = block_offset.unsqueeze(1) + chunk_idx
image_embeds = torch.zeros_like(inputs_embeds)
image_embeds[image_mask] = image_features[block_idx[image_mask], local_idx[image_mask], :]
inputs_embeds = torch.where(image_mask.unsqueeze(-1), image_embeds, inputs_embeds)
return inputs_embeds, attention_mask, position_ids
class _OVPhi4MMForCausalLM(OVModelForVisualCausalLM):
additional_parts = [
"vision_projection",
"audio_embeddings",
"audio_forward_embeddings",
"audio_encoder",
"audio_vision_projection",
"audio_speech_projection",
]
def __init__(
self,
language_model: ov.Model,
text_embeddings: ov.Model,
vision_embeddings: ov.Model,
config: PretrainedConfig = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
super().__init__(
language_model,
text_embeddings,
vision_embeddings,
config,
device,
dynamic_shapes,
ov_config,
model_save_dir,
quantization_config,
**kwargs,
)
self.sub_GN = torch.tensor(self.config.sub_GN)
self.glb_GN = torch.tensor(self.config.glb_GN)
self.audio_config = (
config.audio_processor["config"] if hasattr(config, "audio_processor") else config.audio_config.to_dict()
)
self.chunk_size = self.audio_config.get("chunk_size", -1)
self.left_chunk = self.audio_config.get("left_chunk", 18)
self.time_reduction = self.audio_config.get("time_reduction", 8)
self.image_config = (
config.img_processor if hasattr(config, "img_processor") else config.vision_config.to_dict()
)
self.image_size = self.image_config.get("crop_size", 448)
self.patch_size = self.image_config.get("patch_size", 14)
self.num_patches_per_side = self.image_size // self.patch_size
self._IMAGE_SPECIAL_TOKEN_ID = (
200010 if "image_token_id" not in self.image_config else self.image_config["image_token_id"]
)
self._AUDIO_SPECIAL_TOKEN_ID = (
200011 if "audio_token_id" not in self.audio_config else self.audio_config["audio_token_id"]
)
self._COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE = [-9999, -1] # For backward compatibility
self._COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE = [float("-inf"), -10000] # For backward compatibility
self.image_dim_out = self.image_config.get("image_dim_out", self.image_config["hidden_size"])
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L669
def image_embed(
self,
input_ids: torch.LongTensor,
image_pixel_values: torch.FloatTensor,
image_attention_mask,
inputs_embeds,
image_sizes=None,
**kwargs,
):
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
positions_tuple = torch.nonzero(input_ids == self._IMAGE_SPECIAL_TOKEN_ID, as_tuple=True)
if len(positions_tuple[-1]) == 0:
return None
batch_size = image_pixel_values.shape[0]
img_features = self.get_img_features(
image_pixel_values.flatten(0, 1),
image_attention_mask=image_attention_mask.flatten(0, 1).to(dtype=bool),
)
base_feat_size = int(np.sqrt(img_features.shape[1]))
img_features = img_features.view(batch_size, -1, base_feat_size**2, self.image_dim_out)
image_sizes = image_sizes.view(-1, 2)
output_imgs = []
for idx in range(batch_size):
height, width = image_sizes[idx]
height_ratio = height // self.image_size
width_ratio = width // self.image_size
area_ratio = height_ratio * width_ratio
global_img = img_features[idx, :1]
global_img = global_img.reshape(1, base_feat_size, base_feat_size, self.image_dim_out).contiguous()
temporary_extensor = self.sub_GN.repeat(1, base_feat_size, 1, 1)
global_img = torch.cat([global_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out)
sub_img = img_features[idx, 1:]
sub_img = sub_img[:area_ratio]
sub_img = (
sub_img.reshape(height_ratio, width_ratio, base_feat_size, base_feat_size, self.image_dim_out)
.transpose(1, 2)
.reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size, self.image_dim_out)
.contiguous()
)
if image_attention_mask is not None:
reshaped_image_attention_mask = (
image_attention_mask[idx, 1 : area_ratio + 1, 0::2, 0::2]
.reshape(height_ratio, width_ratio, base_feat_size, base_feat_size)
.transpose(1, 2)
.reshape(1, height_ratio * base_feat_size, width_ratio * base_feat_size)
)
useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item())
useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item())
sub_img = sub_img[:, :useful_height, :useful_width]
temporary_extensor = self.sub_GN.repeat(1, useful_height, 1, 1)
else:
temporary_extensor = self.sub_GN.repeat(1, height_ratio * base_feat_size, 1, 1)
sub_img = torch.cat([sub_img, temporary_extensor], dim=2).reshape(1, -1, self.image_dim_out)
# Merge global and sub
output_imgs.append(torch.cat([sub_img, self.glb_GN, global_img], dim=1))
img_set_tensor = []
for output_img in output_imgs:
img_feature_proj = torch.from_numpy(self.vision_projection(output_img))
img_set_tensor.append(img_feature_proj)
merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0)
image_embeds = inputs_embeds.index_put(indices=positions_tuple, values=merged_img_set_tensor, accumulate=False)
return image_embeds
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L1241
def audio_embed(
self,
input_ids: torch.LongTensor,
audio_input_embeds: torch.FloatTensor,
inputs_embeds,
audio_embed_sizes=None,
audio_projection_mode="speech",
**kwargs,
):
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
positions_tuple = torch.nonzero(input_ids == self._AUDIO_SPECIAL_TOKEN_ID, as_tuple=True)
if len(positions_tuple[-1]) == 0:
return None
audio_embeds = self.get_audio_features(audio_input_embeds, audio_projection_mode)
merged_audio_embeds = torch.cat(
[audio_embeds[i, : audio_embed_sizes[i], :] for i in range(len(audio_embed_sizes))], dim=0
)
inputs_embeds = inputs_embeds.index_put(indices=positions_tuple, values=merged_audio_embeds, accumulate=False)
return inputs_embeds
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L1165
def get_audio_features(
self,
input_embeds: torch.FloatTensor,
audio_projection_mode: str = "speech",
):
xs_pad = self.audio_embeddings(input_embeds)
input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(xs_pad)
unfolded = False
ori_bz, seq_len, D = input_tensor.shape
max_seq_len = 500 # maximum position for absolute positional encoding
masks_unfold = None
if seq_len > max_seq_len:
# audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len
unfolded = True
# the unfold op will drop residual frames, pad it to the multiple of max_seq_len
if seq_len % max_seq_len > 0:
chunk_pad_size = max_seq_len - (seq_len % max_seq_len)
else:
chunk_pad_size = 0
if chunk_pad_size > 0:
input_tensor_pad = torch.nn.functional.pad(
torch.from_numpy(input_tensor), (0, 0, 0, chunk_pad_size), "constant", 0
)
input_tensor = input_tensor_pad
input_tensor = self.unfold_tensor(input_tensor, max_seq_len)
if masks is not None:
# revise hs_mask here because the previous calculated hs_mask did not consider extra pad
subsampled_pad_mask = masks.squeeze(1) # [bz, subsampled_unmask_seq_len]
extra_padded_subsamlped_pad_mask = torch.nn.functional.pad(
subsampled_pad_mask, (0, chunk_pad_size), "constant", False
) # extra padding to the pad mask
extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()
masks_unfold = self.unfold_tensor(
extra_padded_subsamlped_pad_mask, max_seq_len
) # unfold the pad mask like we did to the input tensor
masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor
else:
masks_unfold = None
hs_mask = self.calculate_hs_mask(input_tensor, masks_unfold)
audio_features = self.audio_encoder(input_tensor, hs_mask)
if unfolded:
embed_dim = audio_features.shape[-1]
audio_features = np.reshape(audio_features, (ori_bz, -1, embed_dim))
# if we ever padded before unfolding, we need to remove the padding
if chunk_pad_size > 0:
audio_features = audio_features[:, :-chunk_pad_size, :]
audio_encoder = (
self.audio_vision_projection if audio_projection_mode == "vision" else self.audio_speech_projection
)
audio_set_tensor = audio_encoder(audio_features)
return torch.from_numpy(audio_set_tensor)
def _chunk_size_selection(self, chunk_size=None, left_chunk=None):
"""If chunk size is a list, we will randomly select a chunk size."""
if isinstance(chunk_size, list):
# Variable chunk size during training
chunk_size_index = int(torch.randint(low=0, high=len(chunk_size), size=(1,)))
chunk_size_train_eff = chunk_size[chunk_size_index]
if not isinstance(left_chunk, list):
raise ValueError("Since chunk_size is a list, left_chunk must be a list")
if len(left_chunk) != len(chunk_size):
raise ValueError("The length of left_chunk must be the same as length of chunk_size.")
left_chunk_train_eff = left_chunk[chunk_size_index]
else:
chunk_size_train_eff = chunk_size
left_chunk_train_eff = left_chunk
return chunk_size_train_eff, left_chunk_train_eff
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L1121
def forward_embeddings(self, xs_pad, masks=None, chunk_size_nc=None, left_chunk_nc=None):
"""Forwarding the inputs through the top embedding layers
Args:
xs_pad: torch.Tensor
input tensor
masks: torch.Tensor
input mask
chunk_size_nc: (optional, default is None) chunk size for non-causal layers
left_chunk_nc: (optional, default is None) # of left chunks for non-causal layers
"""
seq_len = int(self.compute_lens_change(xs_pad.shape[1]))
if seq_len <= 0:
raise ValueError(
f"""The sequence length after time reduction is invalid: {seq_len}.
Your input feature is too short. Consider filtering out the very
short sentence from data loader""",
)
batch_size = xs_pad.shape[0]
enc_streaming_mask = self._streaming_mask(seq_len, batch_size, self.chunk_size, self.left_chunk)
input_tensor = xs_pad
input_tensor = self.audio_forward_embeddings(input_tensor)
streaming_mask = enc_streaming_mask
if streaming_mask is not None and masks is not None:
hs_mask = masks & streaming_mask
else:
hs_mask = streaming_mask
if chunk_size_nc is not None:
enc_streaming_mask_nc = self._streaming_mask(seq_len, batch_size, chunk_size_nc, left_chunk_nc)
if masks is not None:
hs_mask_nc = masks & enc_streaming_mask_nc
else:
hs_mask_nc = enc_streaming_mask_nc
else:
hs_mask_nc = None
if chunk_size_nc is None:
return input_tensor, None, None, hs_mask, None
return input_tensor, None, None, hs_mask, None, hs_mask_nc
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L1101
def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk):
chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection(chunk_size, left_chunk)
# Create mask matrix for streaming
# S stores start index. if chunksize is 18, s is [0,18,36,....]
chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
# avoid randomness when run evaluation or decoding
enc_streaming_mask = (
self.adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk_train_eff)
.unsqueeze(0)
.expand([batch_size, -1, -1])
)
return enc_streaming_mask
def compute_lens_change(self, feature_lens):
"""feature_lens: int
return updated feature lens.
This used to return a different lambda function for each case that computed
the right thing. That does not work within Torchscript. If you really
need this to be faster, create nn.Module()-s for all the cases and return
one of them. Torchscript does support that.
"""
nemo_conv_settings = self.audio_config.get("nemo_conv_settings")
if nemo_conv_settings is None:
nemo_conv_settings = {"conv_channels": self.audio_config["nemo_conv_channels"]}
# Handle the special causal case
subsampling_causal_cond = nemo_conv_settings.get("subsampling", "dw_striding") in [
"dw_striding",
"striding",
"striding_conv1d",
]
is_causal = nemo_conv_settings.get("is_causal", False)
if is_causal and subsampling_causal_cond:
lens_change = (
torch.ceil(feature_lens / self.time_reduction).long()
if isinstance(feature_lens, torch.Tensor)
else math.ceil(feature_lens / self.time_reduction)
)
feature_lens_remainder = feature_lens % self.time_reduction
if isinstance(feature_lens, torch.Tensor):
lens_change[feature_lens_remainder != 1] += 1
elif feature_lens_remainder != 1:
lens_change += 1
return lens_change
ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil
return ceil_func(feature_lens / self.time_reduction)
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L1146
def calculate_hs_mask(self, xs_pad, mask):
max_audio_length = xs_pad.shape[1]
batch_size = xs_pad.shape[0]
enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size, self.chunk_size, self.left_chunk)
if mask is None:
return enc_streaming_mask
feature_lens = mask.sum(1)
padding_length = feature_lens
pad_mask = torch.arange(0, max_audio_length).expand(padding_length.size(0), -1) < padding_length.unsqueeze(1)
pad_mask = pad_mask.unsqueeze(1)
pad_mask = pad_mask & enc_streaming_mask
return pad_mask
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L1034
@staticmethod
def unfold_tensor(xs_pad, max_seq_len):
"""
For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len,
this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len.
Args:
xs_pad: N, T, D
"""
_, _, D = xs_pad.shape
xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T
# N x D x 1 x T => N x (D x max_seq_len) x T'
xs_pad = torch.nn.functional.unfold(
xs_pad[..., None, :],
kernel_size=(1, max_seq_len),
stride=(1, max_seq_len),
)
new_bsz, _, slen = xs_pad.shape
# N x D x max_seq_len x T'
xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen)
# N x T' x max_seq_len x D
xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous()
# NT' x max_seq_len x D
xs_pad = xs_pad.view(-1, max_seq_len, D)
return xs_pad
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L1053
@staticmethod
def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0):
"""
The function is very important for Transformer Transducer Streaming mode
Args:
xs_len (int): sequence length
chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45]
left_window (int): how many left chunks can be seen
right_window (int): how many right chunks can be seen. It is used for chunk overlap model.
Returns:
mask (torch.Tensor): a mask tensor for streaming model
Torch 1.0.1
tensor([[1., 1., 0., 0.],
[0., 1., 1., 0.],
[0., 0., 1., 1.]])
Torch 1.4.1
tensor([[True., True., False., False.],
[False., True., True., False.],
[False., False., True., True.]])
"""
chunk_start_idx = torch.Tensor(chunk_start_idx).long() # first idx of each chunk, such as [0,18,36,48].
start_pad = torch.nn.functional.pad(
chunk_start_idx, (1, 0)
) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48]
end_pad = torch.nn.functional.pad(
chunk_start_idx, (0, 1), value=x_len
) # append x_len to the end, so it becomes [0,18,36,48, x_len]
seq_range = torch.arange(0, x_len).unsqueeze(-1) # seq_range size: [x_len, 1]
idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] # idx size: [x_len]
end_pad[idx] # boundary size: [x_len]
seq_range_expand = (
torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
) # seq_range_expand size [x_len, x_len]
idx_left = idx - left_window
idx_left[idx_left < 0] = 0
boundary_left = start_pad[idx_left]
mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
idx_right = idx + right_window
idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
boundary_right = end_pad[idx_right]
mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
return mask_left & mask_right
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L494-L512
@staticmethod
def get_vision_position_ids(pixel_values, patch_attention_mask, patch_size=14, num_patches_per_side=32):
batch_size = pixel_values.shape[0]
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
max_nb_patches_h, max_nb_patches_w = max_im_h // patch_size, max_im_w // patch_size
boundaries = torch.arange(1 / num_patches_per_side, 1.0, 1 / num_patches_per_side)
position_ids = torch.full(
size=(
batch_size,
max_nb_patches_h * max_nb_patches_w,
),
fill_value=0,
)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
pos_ids = (bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w).flatten()
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
return position_ids
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.3/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py#L1561
def embed_tokens_extend(
self,
input_ids: torch.LongTensor,
input_image_embeds: torch.FloatTensor = None,
input_audio_embeds: torch.FloatTensor = None,
image_sizes=None,
image_attention_mask=None,
audio_embed_sizes=None,
audio_projection_mode="speech",
past_key_values=None,
):
if past_key_values is not None:
return self.language_model.embed_tokens(input_ids)
new_input_ids = input_ids.clone()
new_input_ids[
(input_ids >= self._COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE[0])
& (input_ids <= self._COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE[1])
] = self._IMAGE_SPECIAL_TOKEN_ID
new_input_ids[
(input_ids >= self._COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE[0])
& (input_ids <= self._COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE[1])
] = self._AUDIO_SPECIAL_TOKEN_ID
input_ids = new_input_ids
image_position_mask = (input_ids == self._IMAGE_SPECIAL_TOKEN_ID).unsqueeze(-1)
non_image_position_mask = ~image_position_mask
hidden_states = torch.from_numpy(self.language_model.embed_tokens(input_ids))
vision_hidden_states = self.image_embed(
input_ids=input_ids,
inputs_embeds=hidden_states,
image_pixel_values=input_image_embeds,
image_sizes=image_sizes,
image_attention_mask=image_attention_mask,
)
audio_hidden_states = self.audio_embed(
input_ids=input_ids,
inputs_embeds=hidden_states,
audio_input_embeds=input_audio_embeds,
audio_embed_sizes=audio_embed_sizes,
audio_projection_mode=audio_projection_mode,
)
if vision_hidden_states is not None and audio_hidden_states is not None:
hidden_states = vision_hidden_states * image_position_mask + audio_hidden_states * non_image_position_mask
elif vision_hidden_states is not None:
hidden_states = vision_hidden_states
elif audio_hidden_states is not None:
hidden_states = audio_hidden_states
return hidden_states
def get_multimodal_embeddings(
self,
input_ids,
pixel_values=None,
attention_mask=None,
position_ids=None,
input_image_embeds: Optional[torch.FloatTensor] = None,
image_sizes: Optional[torch.LongTensor] = None,
image_attention_mask=None,
input_audio_embeds: Optional[torch.FloatTensor] = None,
audio_embed_sizes=None,
input_mode=None,
**kwargs,
):
if pixel_values is not None and input_image_embeds is None:
input_image_embeds = pixel_values
audio_projection_mode = None
if input_audio_embeds is not None:
if isinstance(input_mode, torch.Tensor):
assert len(input_mode) == 1
input_mode = input_mode[0].item()
if input_mode is None:
input_mode = 1 if input_image_embeds is not None else 2
input_mode = InputMode(input_mode)
if input_mode in [InputMode.VISION_SPEECH, InputMode.VISION]:
audio_projection_mode = "vision"
elif input_mode == InputMode.SPEECH:
audio_projection_mode = "speech"
elif input_mode == InputMode.LANGUAGE:
audio_projection_mode = "speech"
else:
raise ValueError(f"Invalid input_mode: {input_mode}")
inputs_embeds = self.embed_tokens_extend(
input_ids=input_ids,
input_image_embeds=input_image_embeds,
input_audio_embeds=input_audio_embeds,
image_sizes=image_sizes,
image_attention_mask=image_attention_mask,
audio_embed_sizes=audio_embed_sizes,
audio_projection_mode=audio_projection_mode,
past_key_values=kwargs.get("past_key_values"),
)
return inputs_embeds, attention_mask, position_ids
@staticmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
if processor is None:
raise ValueError("Processor is required.")
if video is not None:
raise ValueError("Video input is not supported")
user_prompt = "<|user|>"
assistant_prompt = "<|assistant|>"
prompt_suffix = "<|end|>"
image_token = getattr(processor.tokenizer, "image_token", "<|image_1|>")
audio_token = getattr(processor.tokenizer, "audio_token", "<|audio_1|>")
if audio is not None and audio_token not in text:
text = audio_token + text
if image is not None and image_token not in text:
text = image_token + text
if processor.tokenizer.chat_template is None:
if not text.startswith(user_prompt):
text = user_prompt + text + prompt_suffix + assistant_prompt
else:
text = processor.tokenizer.apply_chat_template(
[{"role": "user", "content": text}], tokenize=False, add_generation_prompt=True
)
audio_input = {}
if "audio" in inspect.signature(processor.__call__).parameters:
sample_rate = None
if isinstance(audio, tuple):
audio, sample_rate = audio
if isinstance(audio, list) and len(audio) == 1 and isinstance(audio[0], tuple):
audio, sample_rate = audio[0]
audio_input["audio"] = audio
if sample_rate is not None:
audio_input["sampling_rate"] = sample_rate
else:
audio_input["audios"] = audio
inputs = processor(text=text, images=image, **audio_input, return_tensors="pt")
return inputs
def get_img_features(self, pixel_values, image_attention_mask):
patch_position_ids = self.get_vision_position_ids(
pixel_values, image_attention_mask, self.patch_size, self.num_patches_per_side
)
return torch.from_numpy(
self.vision_embeddings(
pixel_values=pixel_values,
patch_attention_mask=image_attention_mask,
patch_position_ids=patch_position_ids,
)[0]
)
class _OVLlama4ForCausalLM(OVModelForVisualCausalLM):
def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
return None
# Llama4 preprocessor creates bf16 tensor for pixel values, it can not be represented as numpy array
if pixel_values.dtype != torch.float32:
pixel_values.to(torch.float32)
return self.vision_embeddings(pixel_values.to(torch.float32)).last_hidden_state
# Adopted from https://github.com/huggingface/transformers/blob/v4.51.0/src/transformers/models/llama4/modeling_llama4.py#L1743-L1759
def merge_vision_text_embeddings(
self, vision_embeds, inputs_embeds, input_ids=None, attention_mask=None, position_ids=None, **kwargs
):
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
original_inputs_embeds_shape = inputs_embeds.shape
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
final_mask = special_image_mask.to(inputs_embeds.device)
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1))
final_mask_1d = final_mask[..., 0].reshape(-1)
num_tokens_to_fill = final_mask_1d.sum()
if num_tokens_to_fill != image_features.size(0):
raise ValueError(
f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, "
f"but multi_modal_projector returned {image_features.size(0)}"
)
expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1))
inputs_embeds.masked_scatter_(expanded_mask, image_features)
inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape)
return inputs_embeds, attention_mask, position_ids
@staticmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
):
if processor is None:
raise ValueError("Processor is required.")
if video is not None:
raise ValueError("video input is not supported")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": text},
],
}
]
if image is not None:
conversation[0]["content"].insert(0, {"type": "image"})
text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(images=image, text=text_prompt, return_tensors="pt")
return inputs
MODEL_TYPE_TO_CLS_MAPPING = {
"llava": _OVLlavaForCausalLM,
"llava_next": _OVLlavaNextForCausalLM,
"llava_next_video": _OVLlavaNextVideoForCausalLM,
"minicpmv": _OVMiniCPMVForCausalLM,
"llava-qwen2": _OVNanoLlavaForCausalLM,
"maira2": _OVMaira2ForCausalLM,
"phi3_v": _OVPhi3VisionForCausalLM,
"internvl_chat": _OVInternVLForCausalLM,
"qwen2_vl": _OVQwen2VLForCausalLM,
"qwen2_5_vl": _OVQwen2_5_VLForCausalLM,
"got_ocr2": _OVGotOCR2ForCausalLM,
"gemma3": _OVGemma3ForCausalLM,
"idefics3": _OVIdefics3ForCausalLM,
"smolvlm": _OVSmolVLForCasualLM,
"phi4mm": _OVPhi4MMForCausalLM,
"phi4_multimodal": _OVPhi4MMForCausalLM,
"llama4": _OVLlama4ForCausalLM,
}