maga_transformer/models/chat_glm_v4_vision.py (52 lines of code) (raw):

import os import torch from maga_transformer.config.gpt_init_model_parameters import \ GptInitModelParameters from maga_transformer.distribute.worker_info import g_parallel_info from maga_transformer.model_factory_register import register_model from maga_transformer.models.chat_glm_v4 import ChatGlmV4 from maga_transformer.models.chat_glm_v4_vision_weight import ( ChatGlmV4VisionVitWeights, ChatGlmV4VisionWeightInfo) from maga_transformer.models.eva2clip_vit import EVA2CLIPImageEmbedding from maga_transformer.models.multimodal.multimodal_mixin import MultiModalMixin from maga_transformer.utils.util import get_config_from_path, to_torch_dtype class ChatGlmV4VisionImageEmbedding(EVA2CLIPImageEmbedding): @torch.inference_mode() def mm_process(self, mm_input, **kwargs): embeddings = self.image_embedding([mm_input])[0] pos_ids = [1] * embeddings.shape[0] pos_ids[0] = 0 pos_ids[-1] = 2 return embeddings, torch.tensor(pos_ids, dtype=torch.int32) class ChatGlmV4Vision(ChatGlmV4, MultiModalMixin): def _init_multimodal(self, config: GptInitModelParameters): self.mm_part = ChatGlmV4VisionImageEmbedding(config) config.mm_related_params.vit_weights = ChatGlmV4VisionVitWeights( {"vit": self.mm_part.vit} ) def load(self, device: str): if os.environ.get("VIT_TRT", "0") == "1": weights_info = self.get_weight_cls()(self.config, g_parallel_info.tp_size, g_parallel_info.tp_rank) self.init_mm_trt( weights_info, self.config.ckpt_path, self.config.mm_related_params, device, to_torch_dtype(self.config.data_type) ) super().load(device=device) @classmethod def _create_config(cls, ckpt_path: str): config = ChatGlmV4._create_config(ckpt_path) config_dict = get_config_from_path(ckpt_path) vit_config = config_dict["vision_config"] config.mm_related_params.config.update(vit_config) config.build_position_ids = True # use initial hidden size for linear_proj and conv layer in eva2clip config.mm_related_params.config['use_vision_hidden_size'] = False config.mm_related_params.config["boi_token_id"] = config_dict.get("boi_token_id", 0) config.mm_related_params.config["eoi_token_id"] = config_dict.get("eoi_token_id", 0) config.mm_sep_tokens = [[config_dict.get("boi_token_id", 0), config_dict.get("eoi_token_id", 0)]] config.include_sep_tokens = True config.mm_position_ids_style = 1 return config @staticmethod def get_weight_cls(): return ChatGlmV4VisionWeightInfo register_model("chatglm4v", ChatGlmV4Vision, [], ["THUDM/glm-4v-9b"])