maga_transformer/models/qwen_vl.py (132 lines of code) (raw):

import torch import os import json from typing import List, Any, Tuple, Dict, Union from transformers import AutoTokenizer from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters from maga_transformer.distribute.worker_info import g_parallel_info from maga_transformer.models.qwen import QWen from maga_transformer.models.qwen_vl_weight import QWenVLWeightInfo, QwenVLVitWeight from maga_transformer.models.qwen_vl_vit import VisionTransformer as QWen_VL_ViT from maga_transformer.models.base_model import BaseModel, MultimodalInput from maga_transformer.models.multimodal.multimodal_mixin import MultiModalMixin from maga_transformer.models.multimodal.multimodal_common import ImageEmbeddingInterface from maga_transformer.model_factory_register import register_model from maga_transformer.utils.util import to_torch_dtype class QwenVLImageEmbedding(ImageEmbeddingInterface): def __init__(self, config: GptInitModelParameters): self.vit = QWen_VL_ViT(**config.mm_related_params.config) self.config = config @property def _device(self): return self.vit.device @torch.no_grad() def image_embedding(self, images: List[Any]) -> torch.Tensor: images = self.vit.encode(images, self._device, self._data_type) return images class QWen_VL(QWen, MultiModalMixin): def _init_multimodal(self, config: GptInitModelParameters): self.mm_part = QwenVLImageEmbedding(config) config.mm_related_params.vit_weights = QwenVLVitWeight({"vit": self.mm_part.vit}) def load(self, device: str): if os.environ.get("VIT_TRT", "0") == "1": weights_info = self.get_weight_cls()(self.config, g_parallel_info.tp_size, g_parallel_info.tp_rank) self.init_mm_trt( weights_info, self.config.ckpt_path, self.config.mm_related_params, device, to_torch_dtype(self.config.data_type) ) super().load(device=device) @staticmethod def multimodal_modify_prompt_plugin(prompt: Union[List[Dict[str, Any]], str], images: List[str], img_token: str, **kwargs: Any) -> Tuple[str, List[MultimodalInput]]: prompt, mm_inputs = MultiModalMixin.multimodal_modify_prompt_plugin(prompt, images, img_token, **kwargs) start_str = '<img>' end_str = '</img>' if img_token in prompt: split_prompts = prompt.split(img_token) if len(split_prompts) - 1 != len(images): raise Exception('num of ' + img_token + ' should equals to images num') res = split_prompts[0] idx = 0 for split_prompt in split_prompts[1:]: res = res + start_str + images[idx] + end_str + split_prompt idx = idx + 1 return res, mm_inputs else: prefix_prompt = '' if len(images) > 0: for i in range(len(images)): prefix_prompt += 'Picture {i}:'.format(i = i + 1) + start_str + images[i] + end_str + '\n' tmp_prompt = prompt while start_str in tmp_prompt: start_idx = tmp_prompt.find(start_str) end_idx = tmp_prompt.find(end_str) if end_idx < start_idx: raise Exception(f'unclosed tag <img> pair in {prompt}') images.append(tmp_prompt[start_idx + len(start_str): end_idx]) tmp_prompt = tmp_prompt[end_idx + len(end_str):] return prefix_prompt + prompt, [MultimodalInput(image) for image in images] @classmethod def _create_config(cls, ckpt_path: str): config = GptInitModelParameters( head_num=0, size_per_head=0, layer_num=0, max_seq_len=0, vocab_size=0 ) QWen_VL._common_config(config, ckpt_path) return config @staticmethod def _common_config(config: GptInitModelParameters, ckpt_path: str) -> GptInitModelParameters: QWen._common_config(config, ckpt_path) QWen._from_hf(config, ckpt_path) QWen_VL._load_vit_param(config, ckpt_path) return config @staticmethod def _load_vit_param(config: GptInitModelParameters, ckpt_path: str): config_path = os.path.join(ckpt_path, "config.json") if not os.path.exists(config_path): return with open(config_path) as reader: content = reader.read() config_json = json.loads(content) vit_config = config_json['visual'] config.mm_related_params.config.update(vit_config) config.mm_related_params.special_token_ids.update({ 'image_start_id': vit_config['image_start_id'], 'image_end_id': vit_config['image_start_id'] + 1, 'image_pad_id': vit_config['image_start_id'] + 2}) config.mm_related_params.special_tokens.update({'default_mm_token': '<img/>'}) config.mm_sep_tokens = [[vit_config['image_start_id'], vit_config['image_start_id'] + 1]] @classmethod def get_tokenizer(cls, config: GptInitModelParameters): return AutoTokenizer.from_pretrained(config.tokenizer_path, trust_remote_code=True) @staticmethod def get_weight_cls(): return QWenVLWeightInfo @staticmethod def eval_model_size(config: GptInitModelParameters): llm_size = BaseModel.eval_model_size(config) data_width = 4 llm_size += QWen_VL.eval_vit_param_count(config) * data_width return llm_size @staticmethod def eval_vit_param_count(config: GptInitModelParameters): vit_config = config.mm_related_params.config embed_dim = vit_config["output_dim"] width = vit_config["width"] layers = vit_config["layers"] patch_size = vit_config["patch_size"] mlp_ratio = vit_config["mlp_ratio"] mlp_width = int(mlp_ratio * width) llm_size = (3 * width * patch_size ** 2 + width * 2) llm_size += (layers * (width * 2 * 2 + width ** 2 * 4 + width * 4 + mlp_width * width * 2 + mlp_width + width)) llm_size += (width * embed_dim + embed_dim ** 2 + embed_dim + embed_dim * 2 * 3) return llm_size @staticmethod def eval_model_param_count(config: GptInitModelParameters): llm_param_count = BaseModel.eval_model_param_count(config) llm_param_count += QWen_VL.eval_vit_param_count(config) return llm_param_count register_model('qwen_vl', QWen_VL, ["QWenMLMHeadModel"])