maga_transformer/models/minicpmv/minicpmv.py (234 lines of code) (raw):
import json
import os
from typing import Any, Dict, List
import torch
from PIL import Image
from transformers import AutoTokenizer, AutoProcessor
from maga_transformer.config.gpt_init_model_parameters import \
GptInitModelParameters
from maga_transformer.distribute.worker_info import g_parallel_info
from maga_transformer.model_factory_register import register_model
from maga_transformer.models.qwen_v2 import QWenV2, QWenV2Weight
from maga_transformer.models.multimodal.multimodal_mixin import MultiModalMixin, BaseVitWeights
from maga_transformer.models.multimodal.multimodal_common import MultiModalEmbeddingInterface, mm_lock
from maga_transformer.utils.multimodal_util import MMUrlType
from maga_transformer.models.minicpmv.modeling_navit_siglip import SiglipVisionTransformer, SiglipVisionConfig
from maga_transformer.models.minicpmv.resampler import Resampler
from maga_transformer.models.multimodal.multimodal_mixin import BaseVitWeights, BaseMultiModalWeightInfo
from maga_transformer.utils.multimodal_util import MMUrlType, vit_emb_cache_, get_bytes_io_from_url
try:
from decord import VideoReader, cpu
except ModuleNotFoundError:
VideoReader = None
cpu = None
def encode_video(video_path, max_num_frames: int = 32):
def uniform_sample(l, n):
gap = len(l) / n
idxs = [int(i * gap + gap / 2) for i in range(n)]
return [l[i] for i in idxs]
vr = VideoReader(video_path, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > max_num_frames:
frame_idx = uniform_sample(frame_idx, max_num_frames)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
return frames
class ImageEmbeddingInterface(MultiModalEmbeddingInterface):
def __init__(self, config: GptInitModelParameters):
self.config = config
config = config.mm_related_params.config
self.vision_config = SiglipVisionConfig(**config)
self.processor = AutoProcessor.from_pretrained(config['ckpt_path'],
trust_remote_code=True)
self.vpm = SiglipVisionTransformer(self.vision_config)
self.embed_dim = config['llm_hidden_size']
self.query_num = config['query_num']
self.vision_dim = self.vision_config.hidden_size
self.resampler = Resampler(num_queries=self.query_num,
embed_dim=self.embed_dim,
num_heads=self.embed_dim // 128,
kv_dim=self.vision_dim,
adaptive=True)
@property
def _device(self):
return self.vpm.device
@torch.inference_mode()
def mm_embedding(self, url: str, mm_type: MMUrlType, **kwargs):
dtype = self._data_type
if g_parallel_info.tp_rank > 0:
return torch.Tensor([])
cached_res = vit_emb_cache_.check_cache(url)
if cached_res is None:
cached_url_res = get_bytes_io_from_url(url)
cached_url_res = self._mm_preprocess(cached_url_res, mm_type)
with mm_lock:
features = self.mm_process(cached_url_res,
mm_type=mm_type,
**kwargs)
if isinstance(features, list):
features = torch.stack(features).to(dtype).contiguous()
vit_emb_cache_.insert_cache(url, features)
return (features, None)
else:
return (cached_res, None)
def _mm_preprocess(self, data, type, **kwargs):
if type == MMUrlType.IMAGE:
return Image.open(data).convert("RGB")
elif type == MMUrlType.VIDEO:
return encode_video(data)
@torch.inference_mode()
def mm_process(self, mm_input, **kwargs):
mm_type = kwargs.get("mm_type")
if mm_type == MMUrlType.DEFAULT:
if isinstance(mm_input, list):
return self.image_embedding(mm_input)
else:
return self.image_embedding([mm_input])
elif mm_type == MMUrlType.IMAGE:
if isinstance(mm_input, list):
raise Exception("expect single image input, but get a list")
return self.image_embedding([mm_input])
elif mm_type == MMUrlType.VIDEO:
if not isinstance(mm_input, list):
raise Exception("expect video input, but get a single image")
return self.image_embedding(mm_input)
else:
raise Exception("unknown mm url type")
@torch.no_grad()
def image_embedding(self, images: List[Any]) -> List[torch.Tensor]:
data = self.processor.image_processor(images, return_tensors="pt")
dtype = self._data_type
tgt_sizes = data['tgt_sizes']
pixel_values_list = data['pixel_values']
vision_hidden_states = []
all_pixel_values = []
img_cnt = []
for pixel_values in pixel_values_list:
img_cnt.append(len(pixel_values))
all_pixel_values.extend([
i.flatten(end_dim=1).permute(1, 0).to(self._device)
for i in pixel_values
])
assert all_pixel_values
# exist image
if all_pixel_values:
tgt_sizes = [
tgt_size for tgt_size in tgt_sizes
if isinstance(tgt_size, torch.Tensor)
]
tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values, batch_first=True, padding_value=0.0)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(
B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches),
dtype=torch.bool,
device=self._device)
for i in range(B):
patch_attn_mask[i,
0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True
vision_batch_size = 16
all_pixel_values = all_pixel_values.type(dtype)
if B > vision_batch_size:
hs = []
for i in range(0, B, vision_batch_size):
start_idx = i
end_idx = i + vision_batch_size
tmp_hs = self.vpm(all_pixel_values[start_idx:end_idx],
patch_attention_mask=patch_attn_mask[
start_idx:end_idx],
tgt_sizes=tgt_sizes[start_idx:end_idx]
).last_hidden_state
hs.append(tmp_hs)
vision_embedding = torch.cat(hs, dim=0)
else:
vision_embedding = self.vpm(
all_pixel_values,
patch_attention_mask=patch_attn_mask,
tgt_sizes=tgt_sizes).last_hidden_state
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
start = 0
for pixel_values in pixel_values_list:
img_cnt = len(pixel_values)
if img_cnt > 0:
for i in range(img_cnt):
vision_hidden_states.append(vision_embedding[start +
i])
start += img_cnt
else:
vision_hidden_states.append([])
# print('embedding:', vision_hidden_states)
# print('embedding shape:', [v.shape for v in vision_hidden_states])
return vision_hidden_states
class MiniCPMVVitWeight(BaseVitWeights):
def _set_weight_prefix(self):
self._ckpt_prefix = ""
self._ft_prefix = "self.mm_part."
class MiniCPMVWeightInfo(QWenV2Weight, BaseMultiModalWeightInfo):
def __init__(self, config, tp_size, tp_rank):
QWenV2Weight.__init__(self, config, tp_size, tp_rank, prefix="llm.")
BaseMultiModalWeightInfo.__init__(self, config)
def _get_weight_info(self):
weights = super()._get_weight_info()
weights = self._get_vit_info(weights)
return weights
class MiniCPMV(QWenV2, MultiModalMixin):
def __init__(self, config: GptInitModelParameters):
QWenV2.__init__(self, config)
self.config.mm_sep_tokens = [
[self.tokenizer.im_start_id, self.tokenizer.im_end_id],
[self.tokenizer.slice_start_id, self.tokenizer.slice_end_id]
]
def _init_multimodal(self, config: GptInitModelParameters):
self.mm_part = ImageEmbeddingInterface(config)
config.mm_related_params.vit_weights = MiniCPMVVitWeight({
"vpm":
self.mm_part.vpm,
"resampler":
self.mm_part.resampler
})
@staticmethod
def get_weight_cls():
return MiniCPMVWeightInfo
@classmethod
def get_tokenizer(cls, config: GptInitModelParameters):
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path,
verbose=False,
trust_remote_code=True,
use_fast=True)
return tokenizer
@classmethod
def _create_config(cls, ckpt_path: str):
config = GptInitModelParameters(head_num=0,
head_num_kv=0,
size_per_head=0,
layer_num=0,
inter_size=0,
vocab_size=0,
max_seq_len=8192,
ckpt_path=ckpt_path,
rotary_embedding_dim=128,
rotary_embedding_style=1,
activation_type='SiGLU',
has_pre_decoder_layernorm=False,
has_post_decoder_layernorm=True,
norm_type='rmsnorm')
config_path = os.path.join(ckpt_path, 'config.json')
if os.path.exists(config_path):
with open(config_path) as reader:
content = reader.read()
config_json = json.loads(content)
QWenV2._from_config_json(config, config_json)
MiniCPMV._init_vit_params(config, config_json)
else:
raise Exception("no config.json found")
return config
@staticmethod
def _init_vit_params(config: GptInitModelParameters,
config_json: Dict[str, Any]):
config.mm_related_params.config = config_json["vision_config"]
config.mm_related_params.config["llm_hidden_size"] = config_json[
"hidden_size"]
config.mm_related_params.config["query_num"] = config_json["query_num"]
config.mm_related_params.config["ckpt_path"] = config.ckpt_path
register_model('minicpmv', MiniCPMV, ["MiniCPMV"])