maga_transformer/models/llava_vit.py (750 lines of code) (raw):
from typing import Optional, Tuple, Union, Dict, List, Any
import os
import re
import math
from dataclasses import dataclass
from functools import partial, reduce
import numpy as np
import torch
import torch.utils.checkpoint
import torch.nn as nn
import logging
import copy
from PIL import Image
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
from maga_transformer.models.multimodal.multimodal_common import MultiModalEmbeddingInterface
from maga_transformer.utils.multimodal_util import MMUrlType
from maga_transformer.models.llava_utils import expand2square, process_anyres_image, unpad_image, get_anyres_image_grid_shape
from maga_transformer.distribute.worker_info import g_parallel_info
from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters
try:
import av
from decord import VideoReader, cpu
except ImportError:
print("Please install pyav to use video processing functions.")
from transformers.image_processing_utils import BatchFeature, get_size_dict
from transformers.image_transforms import (
convert_to_rgb,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import (
ChannelDimension,
PILImageResampling,
to_numpy_array,
)
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from transformers import PretrainedConfig
from transformers.utils import ModelOutput
class LlavaImageEmbedding(MultiModalEmbeddingInterface):
def __init__(self, config: GptInitModelParameters):
self.config = config
if config.mm_related_params.config.get("vision_config", None) != None:
raise Exception("llava-hf style config is not implemented yet")
else:
self.vision_tower = self.build_vision_tower(config.mm_related_params.config)
self.mm_projector = self.build_vision_projector(config.mm_related_params.config)
if "unpad" in self.config.mm_related_params.config.get("mm_patch_merge_type", "flat"):
self.image_newline = nn.Parameter(
torch.empty(self.config.mm_related_params.config["hidden_size"])
)
@torch.inference_mode()
def mm_process(self, mm_input, **kwargs):
mm_type = kwargs.get("mm_type")
if mm_type == MMUrlType.DEFAULT:
if isinstance(mm_input, list):
return torch.cat(self.image_embedding(mm_input))
else:
return self.image_embedding([mm_input])[0]
elif mm_type == MMUrlType.IMAGE:
if isinstance(mm_input, list):
raise Exception("expect single image input, but get a list")
return self.image_embedding([mm_input])[0]
elif mm_type == MMUrlType.VIDEO:
if not isinstance(mm_input, list):
raise Exception("expect video input, but get a single image")
return torch.cat(self.image_embedding(mm_input, MMUrlType.VIDEO))
else:
raise Exception("unknown mm url type")
def _mm_preprocess(self, data, **kwargs):
mm_type = kwargs.get("mm_type")
if mm_type == MMUrlType.DEFAULT:
origin_data = copy.copy(data)
try:
return self.load_image(data, **kwargs)
except Exception as e:
try:
return self.load_video(origin_data, **kwargs)
except Exception as e:
raise Exception(str(e))
elif mm_type == MMUrlType.IMAGE:
return self.load_image(data, **kwargs)
elif mm_type == MMUrlType.VIDEO:
return self.load_video(data, **kwargs)
else:
raise Exception("unknown mm url type")
def load_image(self, data, **kwargs):
return Image.open(data).convert("RGB")
def load_video(self, data, configs, **kwargs):
fps = 1 if configs.fps == -1 else configs.fps
vr = VideoReader(data, ctx=cpu(0), num_threads=1)
total_frame_num = len(vr)
video_time = total_frame_num / vr.get_avg_fps()
frame_num = round(video_time * fps)
# set frame num between 1 and 100
max_frame_num = configs.max_frames if configs.max_frames != -1 else 100
min_frame_num = configs.min_frames if configs.min_frames != -1 else 1
frame_num = max(min_frame_num, min(max_frame_num, frame_num))
frame_idx = np.linspace(0, total_frame_num - 1, frame_num).tolist()
frame_idx = [int(idx) for idx in frame_idx]
video = vr.get_batch(frame_idx).asnumpy()
num_frames_to_sample = num_frames = len(frame_idx)
vr.seek(0)
return [Image.fromarray(frame) for frame in video]
@property
def _device(self):
return self.vision_tower.device
@torch.no_grad()
def image_embedding(self, images: List[Image.Image], mm_type = MMUrlType.IMAGE):
config = self.config.mm_related_params.config
image_aspect_ratio = config["image_aspect_ratio"]
mm_patch_merge_type = config.get("mm_patch_merge_type", "flat")
mm_newline_position = config.get("mm_newline_position", "one_token")
processed_images = process_images(images,
image_aspect_ratio,
self.vision_tower.image_processor,
self._device,
self._data_type,
mm_type,
image_grid_pinpoints = config.get("image_grid_pinpoints", []))
processed_images = [image.unsqueeze(0) if image.ndim == 3 else image for image in processed_images]
split_sizes = [processed_image.shape[0] for processed_image in processed_images]
processed_images = torch.cat(processed_images)
image_features = self.encode_images(processed_images)
image_features = list(torch.split(image_features, split_sizes, dim=0))
if mm_type == MMUrlType.VIDEO:
image_features = [self.get_2dPool(feature) for feature in image_features]
if mm_patch_merge_type == "flat":
image_features = [x.flatten(0, 1) for x in image_features]
elif mm_patch_merge_type.startswith("spatial"):
image_sizes = [image.size for image in images]
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if mm_type == MMUrlType.VIDEO: # video operations
if mm_newline_position == "grid":
image_feature = self.add_token_per_grid(image_feature)
if self.config.mm_related_params.config["add_faster_video"]:
raise Exception("add_faster_video is not implemented")
# faster_video_feature = self.add_token_per_grid(all_faster_video_features[image_idx])
# concat_slow_fater_token = []
# for _ in range(image_feature.shape[0]):
# if _ % self.config.faster_token_stride == 0:
# concat_slow_fater_token.append(torch.cat((image_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0))
# else:
# concat_slow_fater_token.append(torch.cat((faster_video_feature[_], self.model.faster_token[None].to(image_feature.device)), dim=0))
# image_feature = torch.cat(concat_slow_fater_token)
new_image_features.append(image_feature)
elif mm_newline_position == "frame":
image_feature = self.add_token_per_frame(image_feature)
new_image_features.append(image_feature.flatten(0, 1))
elif mm_newline_position == "one_token":
# one-token
image_feature = image_feature.flatten(0, 1)
if 'unpad' in mm_patch_merge_type:
image_feature = torch.cat((
image_feature,
self.image_newline[None].to(image_feature.device)
), dim=0)
new_image_features.append(image_feature)
elif mm_newline_position == "no_token":
new_image_features.append(image_feature.flatten(0, 1))
else:
raise ValueError(f"Unexpected mm_newline_position: {mm_newline_position}")
elif image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.vision_tower.num_patches_per_side
assert height * width == base_image_feature.shape[0]
if "anyres_max" in image_aspect_ratio:
matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio)
if matched_anyres_max_num_patches:
max_num_patches = int(matched_anyres_max_num_patches.group(1))
if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
try:
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], config["image_grid_pinpoints"], self.vision_tower.config.image_size)
except Exception as e:
logging.error(f"exception {str(e)}, set num_path_width and num_patch_height to 2")
num_patch_width, num_patch_height = 2, 2
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
else:
image_feature = image_feature.view(2, 2, height, width, -1)
if "maxpool2x2" in mm_patch_merge_type:
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = nn.functional.max_pool2d(image_feature, 2)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
unit = image_feature.shape[2]
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
c, h, w = image_feature.shape
times = math.sqrt(h * w / (max_num_patches * unit**2))
if times > 1.1:
image_feature = image_feature[None]
image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0]
image_feature = torch.cat((image_feature, self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
elif 'unpad' in mm_patch_merge_type:
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat((
image_feature,
self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
), dim=-1)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
image_feature = image_feature.flatten(0, 3)
if "nobase" in mm_patch_merge_type:
pass
else:
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if 'unpad' in mm_patch_merge_type:
image_feature = torch.cat((
image_feature,
self.image_newline[None].to(image_feature.device)
), dim=0)
new_image_features.append(image_feature)
image_features = new_image_features
return image_features
def encode_images(self, images):
if images.shape[0] == 0:
return images
image_features = self.vision_tower(images)
image_features = self.mm_projector(image_features)
return image_features
def add_token_per_grid(self, image_feature):
resize_h = int(math.sqrt(image_feature.shape[1]))
num_frames = image_feature.shape[0]
feature_dim = image_feature.shape[-1]
image_feature = image_feature.view(num_frames, 1, resize_h, resize_h, -1)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = torch.cat((image_feature, self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
if self.config.mm_related_params.config["add_faster_video"]:
image_feature = image_feature.view(feature_dim, num_frames,resize_h, -1)
image_feature = image_feature.permute(1, 2, 3, 0).contiguous()
image_feature = image_feature.flatten(1, 2)
return image_feature
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
return image_feature
def get_2dPool(self, image_feature, stride=2):
height = width = self.vision_tower.num_patches_per_side
num_frames, num_tokens, num_dim = image_feature.shape
image_feature = image_feature.view(num_frames, height, width, -1)
image_feature = image_feature.permute(0, 3, 1, 2).contiguous()
# image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
mm_spatial_pool_mode = self.config.mm_related_params.config["mm_spatial_pool_mode"]
if mm_spatial_pool_mode == "average":
image_feature = nn.functional.avg_pool2d(image_feature, stride)
elif mm_spatial_pool_mode == "max":
image_feature = nn.functional.max_pool2d(image_feature, stride)
elif mm_spatial_pool_mode == "bilinear":
height, width = image_feature.shape[2:]
scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear')
else:
raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}")
image_feature = image_feature.permute(0, 2, 3, 1)
image_feature = image_feature.view(num_frames, -1, num_dim)
return image_feature
def build_vision_tower(self, vision_tower_cfg: Dict[str, Any], **kwargs: Any):
vision_tower_name = os.environ.get('EXTRA_DATA_PATH', '')
vision_tower = os.environ.get('LOCAL_EXTRA_DATA_PATH', None)
if vision_tower is None:
vision_tower_name = vision_tower_cfg['vit_tower_path']
vision_tower = vision_tower_cfg['vit_tower_path']
if "siglip" in vision_tower_name:
return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
else:
return CLIPVisionTower(vision_tower,
select_layer=vision_tower_cfg.get("mm_vision_select_layer", -2),
select_feature=vision_tower_cfg.get("mm_vision_select_feature", "patch"),
**kwargs)
raise ValueError(f'Unknown vision tower: {vision_tower}')
def add_token_per_frame(self, image_feature):
image_feature = image_feature.permute(2, 0, 1).contiguous()
image_feature = torch.cat((image_feature, self.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
image_feature = image_feature.permute(1, 2, 0).contiguous()
return image_feature
def build_vision_projector(self, config, delay_load=False, **kwargs):
projector_type = config.get('mm_projector_type', 'linear')
if projector_type == 'linear':
return torch.nn.Linear(config['mm_hidden_size'], config['hidden_size'])
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [torch.nn.Linear(config['mm_hidden_size'], config['hidden_size'])]
for _ in range(1, mlp_depth):
modules.append(torch.nn.GELU())
modules.append(torch.nn.Linear(config['hidden_size'], config['hidden_size']))
return torch.nn.Sequential(*modules)
if projector_type == 'identity':
return IdentityMap()
raise ValueError(f'Unknown projector type: {projector_type}')
# ViT
class CLIPVisionTower(nn.Module):
def __init__(self, vision_tower, select_layer=-2, select_feature="patch", delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = select_layer
self.select_feature = select_feature
if not delay_load:
self.load_model()
else:
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
def load_model(self):
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == 'patch':
image_features = image_features[:, 1:]
elif self.select_feature == 'cls_patch':
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
# Projector
class IdentityMap(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": 'identity'}
def process_images(images, image_aspect_ratio, image_processor, device, data_type, mm_type = MMUrlType.IMAGE, **kwargs):
if mm_type == MMUrlType.VIDEO:
return image_processor.preprocess(images, return_tensors='pt')['pixel_values'].to(device, dtype=data_type)
new_images = []
if image_aspect_ratio == "pad":
for image in images:
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
new_images.append(image)
elif "anyres" in image_aspect_ratio:
for image in images:
image = process_anyres_image(image, image_processor, kwargs.get('image_grid_pinpoints', []))
new_images.append(image)
else:
return image_processor.preprocess(images, return_tensors='pt')['pixel_values'].to(device, dtype=data_type)
if type(new_images) is list:
new_images = [image.to(device, dtype=data_type) for image in new_images]
else:
new_images = new_images.to(device, dtype=data_type)
return new_images
class SigLipImageProcessor:
def __init__(self, image_mean=(0.5, 0.5, 0.5), image_std=(0.5, 0.5, 0.5), size=(384, 384), crop_size: Dict[str, int] = None, resample=PILImageResampling.BICUBIC, rescale_factor=1 / 255, data_format=ChannelDimension.FIRST):
crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
self.image_mean = image_mean
self.image_std = image_std
self.size = size
self.resample = resample
self.rescale_factor = rescale_factor
self.data_format = data_format
self.crop_size = crop_size
def preprocess(self, images, return_tensors):
if isinstance(images, Image.Image):
images = [images]
else:
# to adapt video data
images = [to_numpy_array(image) for image in images]
assert isinstance(images, list)
transforms = [
convert_to_rgb,
to_numpy_array,
partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
]
images = reduce(lambda x, f: [*map(f, x)], transforms, images)
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
class SigLipVisionConfig(PretrainedConfig):
model_type = "siglip_vision_model"
def __init__(
self,
hidden_size=1152,
image_mean=(0.5, 0.5, 0.5),
intermediate_size=4304,
num_hidden_layers=27,
num_attention_heads=16,
num_channels=3,
image_size=384,
patch_size=14,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.image_mean = image_mean
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
cls._set_token_in_kwargs(kwargs)
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from SigLipConfig
if config_dict.get("model_type") == "siglip":
config_dict = config_dict["vision_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
print(f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors.")
return cls.from_dict(config_dict, **kwargs)
@dataclass
# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip
class SigLipVisionModelOutput(ModelOutput):
"""
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
Args:
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class SigLipVisionEmbeddings(nn.Module):
def __init__(self, config: SigLipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
embeddings = patch_embeds.flatten(2).transpose(1, 2)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
class SigLipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
batch_size, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" f" {attn_weights.size()}")
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}")
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
raise ValueError(f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip
class SigLipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->SigLip
class SigLipEncoderLayer(nn.Module):
def __init__(self, config: SigLipVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SigLipAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SigLipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class SigLipPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = SigLipVisionConfig
base_model_prefix = "siglip"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
pass
# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip
class SigLipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`SigLipEncoderLayer`].
Args:
config: SigLipVisionConfig
"""
def __init__(self, config: SigLipVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
# Ignore copy
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions)
class SigLipVisionTransformer(nn.Module):
def __init__(self, config: SigLipVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SigLipVisionEmbeddings(config)
self.encoder = SigLipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.head = SigLipMultiheadAttentionPoolingHead(config)
def forward(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = self.head(last_hidden_state)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class SigLipMultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling."""
def __init__(self, config: SigLipVisionConfig):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = SigLipMLP(config)
def forward(self, hidden_state):
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
class SigLipVisionModel(SigLipPreTrainedModel):
config_class = SigLipVisionConfig
main_input_name = "pixel_values"
_no_split_modules = ["SigLipEncoderLayer"]
def __init__(self, config: SigLipVisionConfig):
super().__init__(config)
self.vision_model = SigLipVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
def forward(
self,
pixel_values,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, SigLipVisionModel
>>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224")
>>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled features
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class SigLipVisionTower(nn.Module):
def __init__(self, vision_tower, vision_tower_cfg, delay_load=False):
super().__init__()
self.is_loaded = False
self.config = SigLipVisionConfig()
self.vision_tower_name = vision_tower
self.image_processor = SigLipImageProcessor()
if not delay_load:
self.load_model()
elif getattr(vision_tower_cfg, "unfreeze_mm_vision_tower", False):
self.load_model()
elif hasattr(vision_tower_cfg, "mm_tunable_parts") and "mm_vision_tower" in vision_tower_cfg.mm_tunable_parts:
self.load_model()
else:
self.cfg_only = self.config
def load_model(self, device_map=None):
if self.is_loaded:
return
self.vision_tower = SigLipVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
del self.vision_tower.vision_model.encoder.layers[-1:]
self.vision_tower.vision_model.head = nn.Identity()
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
assert image_features.shape[-2] == 729
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
assert image_features.shape[-2] == 729
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
for p in self.vision_tower.parameters():
return p.dtype
@property
def device(self):
for p in self.vision_tower.parameters():
return p.device
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size
@property
def image_size(self):
return self.config.image_size