maga_transformer/models/minicpmv_embedding/minicpmv_embedding.py (289 lines of code) (raw):
import json
import os
from typing import Any, Dict, List, Tuple, Union
import torch
import math
from PIL import Image
from transformers import AutoTokenizer
from maga_transformer.config.gpt_init_model_parameters import \
GptInitModelParameters
from maga_transformer.distribute.worker_info import ParallelInfo, g_parallel_info
from maga_transformer.model_factory_register import register_model
from maga_transformer.models.multimodal.multimodal_mixin import MultiModalMixin, BaseVitWeights
from maga_transformer.models.multimodal.multimodal_common import MultiModalEmbeddingInterface, mm_lock
from maga_transformer.utils.multimodal_util import MMUrlType
from transformers import LlamaTokenizer
# from maga_transformer.models.minicpmv.modeling_navit_siglip import SiglipVisionTransformer, SiglipVisionConfig
from maga_transformer.models.minicpmv_embedding.resampler import Resampler
from maga_transformer.models.multimodal.multimodal_mixin import BaseVitWeights, BaseMultiModalWeightInfo
from maga_transformer.utils.multimodal_util import MMUrlType, vit_emb_cache_, get_bytes_io_from_url
from torchvision import transforms
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from typing import Any, Dict, List, Type, Optional
from maga_transformer.models.downstream_modules.custom_module import CustomModule
from maga_transformer.models.downstream_modules.embedding.minicpmv_embedding_module import MiniCPMVModule, slice_image
from maga_transformer.models.llama_weight import LlamaWeightInfo
from maga_transformer.models.llama import Llama
from maga_transformer.models.minicpmv.minicpmv import encode_video
import timm
# for faster batch inference
from concurrent.futures import ThreadPoolExecutor
from maga_transformer.config.task_type import TaskType
class LlamaTokenizerWrapper(LlamaTokenizer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.im_start = "<image>"
self.im_end = "</image>"
self.ref_start = "<ref>"
self.ref_end = "</ref>"
self.box_start = "<box>"
self.box_end = "</box>"
self.quad_start = "<quad>"
self.quad_end = "</quad>"
self.point_start = "<point>"
self.point_end = "</point>"
self.slice_start = "<slice>"
self.slice_end = "</slice>"
@property
def eos_id(self):
return self.sp_model.eos_id()
@property
def bos_id(self):
return self.sp_model.bos_id()
@property
def unk_id(self):
return self.sp_model.unk_id()
@property
def im_start_id(self):
return self._convert_token_to_id(self.im_start)
@property
def im_end_id(self):
return self._convert_token_to_id(self.im_end)
class ImageEmbeddingInterface(MultiModalEmbeddingInterface):
def __init__(self, config: GptInitModelParameters):
self.config = config
config = config.mm_related_params.config
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN,
std=IMAGENET_INCEPTION_STD),
])
self.vision_encoder = config['vision_encoder']
self.drop_vision_last_layer = config['drop_vision_last_layer']
self.vpm = self.init_vision_module()
self.vision_dim = self.vpm.embed_dim
self.embed_dim = config['llm_hidden_size']
self.query_num = config['query_num']
self.max_slice_nums = config['max_slice_nums']
self.scale_resolution = config['scale_resolution']
self.patch_size = config['patch_size']
self.slice_mode = config['slice_mode']
self.resampler = Resampler(grid_size=int(math.sqrt(self.query_num)),
embed_dim=self.embed_dim,
num_heads=self.embed_dim // 128,
kv_dim=self.vision_dim,
adaptive=True)
@property
def _device(self):
return next(self.vpm.parameters()).device
def init_vision_module(self):
model = timm.create_model(self.vision_encoder,
pretrained=False,
num_classes=0,
dynamic_img_size=True,
dynamic_img_pad=True)
if isinstance(model, timm.models.VisionTransformer):
if model.attn_pool is not None:
model.attn_pool = torch.nn.Identity()
if self.drop_vision_last_layer:
model.blocks = model.blocks[:-1]
return model
@torch.inference_mode()
def mm_embedding(self, url: str, mm_type: MMUrlType, **kwargs):
dtype = self._data_type
if g_parallel_info.tp_rank > 0:
return torch.Tensor([])
cached_res = vit_emb_cache_.check_cache(url)
if cached_res is None:
cached_url_res = get_bytes_io_from_url(url)
cached_url_res = self._mm_preprocess(cached_url_res, mm_type)
with mm_lock:
features = self.mm_process(cached_url_res,
mm_type=mm_type,
**kwargs)
if isinstance(features, list):
features = torch.stack(features).to(dtype).contiguous()
vit_emb_cache_.insert_cache(url, features)
return (features, None)
else:
return (cached_res, None)
def _mm_preprocess(self, data, type, **kwargs):
if type == MMUrlType.IMAGE:
return Image.open(data).convert("RGB")
elif type == MMUrlType.VIDEO:
return encode_video(data)
@torch.inference_mode()
def mm_process(self, mm_input, **kwargs):
mm_type = kwargs.get("mm_type")
if mm_type == MMUrlType.DEFAULT:
if isinstance(mm_input, list):
return self.image_embedding(mm_input)
else:
return self.image_embedding([mm_input])
elif mm_type == MMUrlType.IMAGE:
if isinstance(mm_input, list):
raise Exception("expect single image input, but get a list")
return self.image_embedding([mm_input])
elif mm_type == MMUrlType.VIDEO:
if not isinstance(mm_input, list):
raise Exception("expect video input, but get a single image")
return self.image_embedding(mm_input)
else:
raise Exception("unknown mm url type")
def get_vision_embedding(self, pixel_values):
res = []
dtype = self._data_type
# first slice
H, W = pixel_values[0].shape[-2:]
tgt_size = (math.ceil(H / self.vpm.patch_embed.patch_size[0]),
math.ceil(W / self.vpm.patch_embed.patch_size[0]))
vision_embedding = self.vpm.forward_features(
pixel_values[0].unsqueeze(0).type(dtype))
res.append(self.resampler(vision_embedding, tgt_size)[0])
# remaining slices as a batch
if len(pixel_values) > 1:
H, W = pixel_values[1].shape[-2:]
tgt_size = (math.ceil(H / self.vpm.patch_embed.patch_size[0]),
math.ceil(W / self.vpm.patch_embed.patch_size[0]))
vision_embedding = self.vpm.forward_features(
torch.stack(pixel_values[1:], dim=0).type(dtype))
vision_embedding = self.resampler(vision_embedding, tgt_size)
for i in range(len(pixel_values) - 1):
res.append(vision_embedding[i])
return res
@torch.no_grad()
def image_embedding(self, images: List[Any]) -> List[torch.Tensor]:
new_images_list = []
for image in images:
if self.slice_mode:
source_image, patches, best_grid = slice_image(
image,
self.max_slice_nums,
self.scale_resolution,
self.patch_size,
)
slice_images = [source_image]
if len(patches) > 0:
for i in range(len(patches)):
for j in range(len(patches[0])):
slice_images.append(patches[i][j])
new_images_list.append(slice_images)
else:
new_images_list.append([image])
pixel_values_list = []
with ThreadPoolExecutor(max_workers=8) as executor:
for img_batch in new_images_list:
img_inps = list(executor.map(self.transform, img_batch))
for i in range(len(img_inps)):
img_inps[i] = img_inps[i].to(self._device)
pixel_values_list.append(img_inps if img_inps else [])
vision_hidden_states = []
for pixel_values in pixel_values_list:
if len(pixel_values) > 0:
vision_hidden_states.extend(
self.get_vision_embedding(pixel_values))
else:
vision_hidden_states.append([])
return vision_hidden_states
class MiniCPMVVitWeight(BaseVitWeights):
def _set_weight_prefix(self):
self._ckpt_prefix = ""
self._ft_prefix = "self.mm_part."
class MiniCPMVWeightInfo(LlamaWeightInfo, BaseMultiModalWeightInfo):
def __init__(self, config, tp_size, tp_rank):
LlamaWeightInfo.__init__(self, config, tp_size, tp_rank, prefix="llm.")
BaseMultiModalWeightInfo.__init__(self, config)
def _get_weight_info(self):
llama_vl_weight = super()._get_weight_info()
self._get_vit_info(llama_vl_weight)
return llama_vl_weight
class MiniCPMVEmbedding(Llama, MultiModalMixin):
def __init__(self, config: GptInitModelParameters):
Llama.__init__(self, config)
self.im_start = "<image>"
self.im_end = "</image>"
self.slice_start = "<slice>"
self.slice_end = "</slice>"
# self.im_start_id = self.tokenizer._convert_token_to_id(self.im_start)
# self.im_end_id = self.tokenizer._convert_token_to_id(self.im_end)
# self.slice_start_id = self.tokenizer._convert_token_to_id(self.slice_start)
# self.slice_end_id = self.tokenizer._convert_token_to_id(self.slice_end)
self.im_start_id = self.tokenizer.im_start_id
self.im_end_id = self.tokenizer.im_end_id
self.slice_start_id = self.tokenizer._convert_token_to_id(
self.slice_start)
self.slice_end_id = self.tokenizer._convert_token_to_id(self.slice_end)
self.config.mm_sep_tokens = [[self.im_start_id, self.im_end_id]
# [self.slice_start_id, self.slice_end_id]
]
def _init_multimodal(self, config: GptInitModelParameters):
self.mm_part = ImageEmbeddingInterface(config)
config.mm_related_params.vit_weights = MiniCPMVVitWeight({
"vpm":
self.mm_part.vpm,
"resampler":
self.mm_part.resampler
})
@staticmethod
def get_weight_cls():
return MiniCPMVWeightInfo
@classmethod
def get_tokenizer(cls, config: GptInitModelParameters):
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path,
verbose=False,
trust_remote_code=True,
use_fast=True)
return tokenizer
@classmethod
def _create_config(cls, ckpt_path: str):
config = GptInitModelParameters(
head_num=0,
size_per_head=0,
layer_num=0,
max_seq_len=0,
vocab_size=0,
ckpt_path=ckpt_path,
activation_type='SiGLU',
norm_type='rmsnorm',
rotary_embedding_dim=128,
rotary_embedding_style=1,
has_post_decoder_layernorm=True,
)
config_path = os.path.join(ckpt_path, 'config.json')
if os.path.exists(config_path):
with open(config_path) as reader:
content = reader.read()
config_json = json.loads(content)
Llama.from_huggingface(config, config_json)
config.input_embedding_scalar = config_json.get("scale_emb", 1)
config.residual_scalar = config_json.get("scale_depth", 1.4) / math.sqrt(config.layer_num)
# config.activation_type = config_json["hidden_act"]
MiniCPMVEmbedding._init_vit_params(config, config_json)
else:
raise Exception("no config.json found")
return config
@staticmethod
def _init_vit_params(config: GptInitModelParameters,
config_json: Dict[str, Any]):
# config.mm_related_params.config = config_json["vision_config"]
config.mm_related_params.config["llm_hidden_size"] = config_json[
"hidden_size"]
config.mm_related_params.config["query_num"] = config_json["query_num"]
config.mm_related_params.config["ckpt_path"] = config.ckpt_path
config.mm_related_params.config["max_slice_nums"] = config_json[
"max_slice_nums"]
config.mm_related_params.config["scale_resolution"] = config_json[
"scale_resolution"]
config.mm_related_params.config["patch_size"] = config_json[
"patch_size"]
config.mm_related_params.config["slice_mode"] = config_json[
"slice_mode"]
config.mm_related_params.config["vision_encoder"] = config_json[
"vision_encoder"]
config.mm_related_params.config[
"drop_vision_last_layer"] = config_json["drop_vision_last_layer"]
def load_custom_module(self) -> Optional[CustomModule]:
return MiniCPMVModule(self.config, self.tokenizer)
# return super().load_custom_module()
@classmethod
def get_tokenizer(cls, config: GptInitModelParameters):
return LlamaTokenizerWrapper.from_pretrained(config.tokenizer_path)
register_model('minicpmv_embedding', MiniCPMVEmbedding, ["MiniCPMVEmbedding"])