maga_transformer/models/llava.py (158 lines of code) (raw):

import os import json import torch import re from typing import List, Any, Dict, Tuple, Union from transformers import AutoConfig, CLIPVisionConfig, AutoTokenizer from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters from maga_transformer.models.llava_weight import LlavaWeightInfo from maga_transformer.models.llama import Llama from maga_transformer.models.multimodal.multimodal_mixin import MultiModalMixin, BaseVitWeights from maga_transformer.distribute.worker_info import g_parallel_info from maga_transformer.models.llava_vit import LlavaImageEmbedding from maga_transformer.utils.util import to_torch_dtype from maga_transformer.model_factory_register import register_model class LlavaTokenizer(object): def __init__(self, tokenzier_path: str, mm_use_im_patch_token: bool, mm_use_im_start_end: bool, special_token_ids: Dict[str, Any], special_tokens: Dict[str, Any], bos_id: int = 1): self.tokenizer = AutoTokenizer.from_pretrained(tokenzier_path) self.mm_use_im_patch_token = mm_use_im_patch_token self.mm_use_im_start_end = mm_use_im_start_end extra_tokens: List[str] = [] if self.mm_use_im_patch_token: extra_tokens.extend(["<im_patch>"]) if self.mm_use_im_start_end: extra_tokens.extend(["<im_start>", "<im_end>"]) self.tokenizer.add_tokens(extra_tokens, special_tokens=True) self.image_token_index: int = special_token_ids["image_token_index"] self.ignore_token_index: int = special_token_ids["ignore_token_index"] self.default_image_token = special_tokens["default_mm_token"] self.default_im_start_token = special_tokens["default_im_start_token"] self.default_im_end_token = special_tokens["default_im_end_token"] self.bos_id = bos_id def encode(self, s: str, **kwargs) -> List[int]: replace_token = self.default_image_token if self.mm_use_im_start_end: replace_token = self.default_im_start_token + replace_token + self.default_im_end_token s = s.replace(self.default_image_token, replace_token) prompt_chunks: List[List[int]] = [self.tokenizer.encode(chunk) for chunk in s.split(self.default_image_token)] images = len(prompt_chunks) - 1 def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] t: List[int] = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == self.bos_id: offset = 1 t.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [self.image_token_index] * (offset + 1)): t.extend(x[offset:]) return t def decode(self, t: List[int]) -> str: return self.tokenizer.decode(t) def apply_chat_template(self, messages, **kwargs): return self.tokenizer.apply_chat_template(messages, **kwargs) class Llava(Llama, MultiModalMixin): def _init_multimodal(self, config: GptInitModelParameters): self.mm_part = LlavaImageEmbedding(config) vit_weight_dict: Dict[str, Any] = {"mm_projector": self.mm_part.mm_projector} if config.mm_related_params.config["unfreeze_mm_vision_tower"] or \ "mm_vision_tower" in config.mm_related_params.config["mm_tunable_parts"]: vit_weight_dict["vision_tower"] = self.mm_part.vision_tower if "unpad" in config.mm_related_params.config.get("mm_patch_merge_type", "flat"): vit_weight_dict["image_newline"] = self.mm_part.image_newline config.mm_related_params.vit_weights = BaseVitWeights(vit_weight_dict, True) @staticmethod def multimodal_modify_prompt_plugin(prompt: Union[List[Dict[str, Any]], str], images: List[str], img_token: str, **kwargs: Any) -> Tuple[str, List[Any]]: prompt, mm_inputs = MultiModalMixin.multimodal_modify_prompt_plugin(prompt, images, img_token, **kwargs) if img_token in prompt: return prompt, mm_inputs else: return prompt + (img_token + "\n") * len(images), mm_inputs @staticmethod def _create_config(ckpt_path): config = GptInitModelParameters( head_num=0, size_per_head=0, layer_num=0, max_seq_len=0, vocab_size=0, ckpt_path=ckpt_path, activation_type="SiGLU", norm_type="rmsnorm", rotary_embedding_dim=128, rotary_embedding_style=1, has_post_decoder_layernorm=True ) # hugggingface config_path = os.path.join(ckpt_path, "config.json") param_path = os.path.join(ckpt_path, "params.json") if os.path.exists(config_path): with open(config_path) as reader: content = reader.read() content = content.replace("LlavaForCausalLM", "LLaVAForCausalLM") config_json = json.loads(content) Llava.from_huggingface(config, config_json) else: raise Exception("llava parameter from unkown source") return config @staticmethod def get_weight_cls(): return LlavaWeightInfo @staticmethod def from_huggingface(config: GptInitModelParameters, config_json: Dict[str, Any]): if "text_config" in config_json: text_config = config_json["text_config"] # if text_config.get("_name_or_path", "") != "": # text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict() Llama.from_huggingface(config, text_config) vision_config = config_json["vision_config"] config.mm_related_params.config["vision_config"] = CLIPVisionConfig(vision_config) else: Llama.from_huggingface(config, config_json) mm_related_params_list = [ ("mm_use_im_patch_token", False), ("mm_use_im_start_end", False), ("image_aspect_ratio", None), ("tune_mm_mlp_adapter", False), ("image_grid_pinpoints", []), ("mm_projector_type", "linear"), ("mm_patch_merge_type", "flat"), ("hidden_size", 0), ("mm_vision_select_layer", None), ("mm_vision_select_feature", "patch"), ("unfreeze_mm_vision_tower", False), ("mm_tunable_parts", ""), ("add_faster_video", False), ("mm_newline_position", "grid"), ("mm_spatial_pool_mode", "bilinear") ] for param_name, default_value in mm_related_params_list: config.mm_related_params.config[param_name] = config_json.get(param_name, default_value) config.mm_related_params.config["mm_hidden_size"] = config_json.get("mm_hidden_size", config_json["hidden_size"]) config.mm_related_params.special_token_ids.update({"ignore_token_index": -100, "image_token_index": -200}) config.mm_related_params.special_tokens.update({ "default_mm_token": "<image>", "default_im_start_token": "<im_start>", "default_im_end_token": "<im_end>" }) vis_tower_name = config_json.get("mm_vision_tower", config_json.get("vision_tower", None)) img_expand_match = re.search("patch(\d+)-(\d+)", vis_tower_name) if img_expand_match: patch_size = int(img_expand_match.group(1)) img_size = int(img_expand_match.group(2)) config.mm_related_params.config["patch_size"] = patch_size config.mm_related_params.config["image_size"] = img_size config.mm_related_params.config["vit_tower_path"] = vis_tower_name config.mm_sep_tokens = [[-200]] # image_token_index @classmethod def get_tokenizer(cls, config: GptInitModelParameters): return LlavaTokenizer(config.tokenizer_path, config.mm_related_params.config["mm_use_im_patch_token"], config.mm_related_params.config["mm_use_im_start_end"], config.mm_related_params.special_token_ids, config.mm_related_params.special_tokens, config.special_tokens.bos_token_id) register_model("llava", Llava, ["LlavaLlamaForCausalLM"])