import math
import torch
import torch.nn.functional as F
import numpy as np

from .transformers import VisionRotary, Decoder
from .spinner import spinner_run

class Vision(torch.nn.Module):
    def __init__(self, visual, base):
        super().__init__()
        self.model_type = base.model_type
        self.visual = visual.eval()
        self.embed_ = base.embed
        self.tokenizer = base.tokenizer
        self.config = base.config
        self.hidden_size = base.hidden_size
        self.llm_config = base.llm_config
        self.rope_ratio = 1.0
        # mllama
        self.cross_attention_states = None
        self.cross_attention_mask = None
        self.init_config()
        self.load()

    @staticmethod
    def get_vision(model_type):
        visual_models = {
            'internvl_chat': InternVLVision,
            'qwen': QwenVision,
            'qwen2_vl': Qwen2Vision,
            'qwen2_5_vl':Qwen2_5Vision,
            'qwen2_5_omni': Qwen2_5OmniVision,
            'mllama': MllamaVision
        }
        if model_type in visual_models:
            return visual_models[model_type]
        return None

    def init_config(self):
        from transformers.image_utils import (OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)
        self.llm_config['is_visual'] = True
        image_mean = np.array(OPENAI_CLIP_MEAN) * 255.0
        image_norm = 1 / (np.array(OPENAI_CLIP_STD) * 255.0)
        self.llm_config['image_mean'] = image_mean.tolist()
        self.llm_config['image_norm'] = image_norm.tolist()

    def export(self, onnx_path):
        raise NotImplementedError

    def load(self):
        raise NotImplementedError

    def str_to_ids(self, prompt):
        input_ids = self.tokenizer(prompt, return_tensors="pt")['input_ids']
        return input_ids

    def forward(self, images):
        raise NotImplementedError

    def embed(self, input_ids, images = None, videos = None):
        raise NotImplementedError

class InternVLVision(Vision):
    def __init__(self, visual, base):
        super().__init__(visual, base)
        self.quant_bit = 8
        self.vision_model = visual
        self.mlp1 = base.model.mlp1
        self.select_layer = base.model.select_layer

    def load(self):
        self.image_size = self.config.force_image_size
        self.downsample_ratio = self.config.downsample_ratio
        self.llm_config['is_visual'] = True
        self.llm_config['image_size'] = self.image_size
        # self.llm_config['vision_start'] = self.tokenizer.img_start_id
        # self.llm_config['vision_end'] = self.tokenizer.img_end_id
        # self.llm_config['image_pad'] = self.tokenizer.img_pad_id
    def pixel_shuffle(self, x, scale_factor=0.5):
        n, w, h, c = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
        # N, W, H, C --> N, W, H * scale, C // scale
        x = x.view(n, w, (h * scale_factor).int(), (c / scale_factor).int())
        # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
        x = x.view(n, (h * scale_factor).int(), (w * scale_factor).int(),
                   (c / (scale_factor * scale_factor)).int())
        x = x.permute(0, 2, 1, 3).contiguous()
        return x
    def extract_feature(self, pixel_values):
        if self.select_layer == -1:
            vit_embeds = self.vision_model(
                pixel_values=pixel_values,
                output_hidden_states=False,
                return_dict=True).last_hidden_state
        else:
            vit_embeds = self.vision_model(
                pixel_values=pixel_values,
                output_hidden_states=True,
                return_dict=True).hidden_states[self.select_layer]
        vit_embeds = vit_embeds[:, 1:, :]

        h = w = (vit_embeds.shape[1] ** 0.5).int()
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
        vit_embeds = self.mlp1(vit_embeds)

        # For mnn's embedding, the order is (seq, batch, hidden)
        vit_embeds = vit_embeds.permute(1, 0, 2)
        return vit_embeds
    def init_config(self):
        self.llm_config['is_visual'] = True
        IMAGENET_MEAN = [0.485, 0.456, 0.406]
        IMAGENET_STD = [0.229, 0.224, 0.225]
        for i in range(3):
            IMAGENET_MEAN[i] = IMAGENET_MEAN[i] * 255.0
            IMAGENET_STD[i] = 1.0 / IMAGENET_STD[i] / 255.0
        self.llm_config['image_mean'] = IMAGENET_MEAN
        self.llm_config['image_norm'] = IMAGENET_STD
        self.llm_config['image_size_unit'] = 14
    def export(self, onnx_path):
        input_images = torch.randn((1, 3, self.image_size, self.image_size), dtype=torch.float32)
        onnx_model = f'{onnx_path}/visual.onnx'
        torch.onnx.export(self, (input_images),
                        onnx_model,
                        input_names=['input_images'],
                        output_names=['image_embeds'],
                        dynamic_axes={
                            "input_images": { 0: "size", 2: "height", 3: "width"},
                        },
                        do_constant_folding=True,
                        verbose=False,
                        opset_version=15)
        return onnx_model
    def forward(self, images):
        return self.extract_feature(images)

class QwenVision(Vision):
    def __init__(self, visual, base):
        self.quant_bit = 16
        super().__init__(visual, base)

    def load(self):
        self.image_start_id = self.config.visual['image_start_id']
        self.image_size = self.config.visual['image_size']
        self.llm_config['is_visual'] = True
        self.llm_config['image_size'] = self.image_size
        self.llm_config['vision_start'] = self.tokenizer.img_start_id
        self.llm_config['vision_end'] = self.tokenizer.img_end_id
        self.llm_config['image_pad'] = self.tokenizer.img_pad_id

    @spinner_run(f'export visual to ')
    def export(self, onnx_path):
        input_images = torch.randn((1, 3, self.image_size, self.image_size))
        onnx_model = f'{onnx_path}/visual.onnx'
        torch.onnx.export(self, (input_images),
                        onnx_model,
                        input_names=['input_images'],
                        output_names=['image_embeds'],
                        dynamic_axes={
                            "input_images": { 0: "size" },
                        },
                        do_constant_folding=True,
                        verbose=False,
                        opset_version=15)
        return onnx_model

    def forward(self, images):
        return self.visual(images).transpose(1, 0)

    def embed(self, input_ids, images = None, videos = None):
        if not torch.any(input_ids == self.image_start_id):
            return self.embed_(input_ids)
        bos_pos = torch.where(input_ids == self.image_start_id)
        eos_pos = torch.where(input_ids == self.image_start_id + 1)
        img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
        images = []
        for i, a, b in img_pos:
            image = input_ids[i][a + 1 : b - 1].tolist()
            image = image[ : image.index(self.image_start_id + 2)]
            images.append(bytes(image).decode('utf-8'))
        images = self.visual.encode(images).transpose(1, 0)
        hidden_states = self.embed_(input_ids)
        for idx, (i, a, b) in enumerate(img_pos):
            hidden_states[a + 1 : b, i] = images[:, idx]
        return hidden_states

class Qwen2Vision(Vision):
    def __init__(self, visual, base):
        self.quant_bit = 4
        self.temporal_patch_size = 2
        self.patch_size = 14
        self.merge_size = 2
        self.image_height = 420
        self.image_width = 420
        self.image_embeds = []
        self.image_grid_thw = []
        super().__init__(visual, base)

    def load(self):
        self.vision_start_id = self.config.vision_start_token_id
        self.vision_end_id = self.config.vision_end_token_id
        self.image_pad_id = self.config.image_token_id
        self.llm_config['image_size'] = self.image_height
        self.llm_config['vision_start'] = self.vision_start_id
        self.llm_config['vision_end'] = self.vision_end_id
        self.llm_config['image_pad'] = self.image_pad_id
        self.vision_start_token = '<|vision_start|>'
        self.vision_end_token = '<|vision_end|>'
        self.image_pad_token = '<|image_pad|>'
        # load model
        config = self.visual.config
        if hasattr(config, "embed_dim"):
            self.hidden_size = config.embed_dim
        else:
            self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_heads
        self.num_key_value_heads = config.num_heads
        self.head_dim = self.hidden_size // self.num_attention_heads
        self.rope_theta = 10000.0
        self.rotary_dim = self.head_dim // 2
        self.rotary = VisionRotary(self)
        self.model_map = {
            'decoder': {
                'self_attn': 'attn',
                'mlp': 'mlp',
                'input_layernorm': 'norm1',
                'post_attention_layernorm': 'norm2'
            },
            'attention': {
                'qkv_proj': 'qkv',
                'o_proj': 'proj'
            }
        }
        self.patch_embed = self.visual.patch_embed
        self.blocks = []
        for block in self.visual.blocks.children():
            layer_id = len(self.blocks)
            self.blocks.append(Decoder(block, layer_id, self))
        self.merger = self.visual.merger

    def str_to_ids(self, prompt):
        if '<img>' in prompt and '</img>' in prompt:
            import re
            import requests
            from PIL import Image
            pattern = r'(<img>.*?</img>)'
            parts = re.split(pattern, prompt)
            txt_prompt = ''
            for part in parts:
                if re.match(pattern, part):
                    img_content = re.search(r'<img>(.*?)</img>', part).group(1)
                    # find <hw></hw> in image_content
                    match = re.search(r'<hw>(.*?)</hw>', img_content)
                    if match:
                        img_content = img_content[:match.start()] + img_content[match.end():]
                        hw = match.group(1).split(',')
                        self.image_height, self.image_width = int(hw[0]), int(hw[1])
                    if img_content.startswith('http://') or img_content.startswith('https://'):
                        image_obj = Image.open(requests.get(img_content, stream=True).raw)
                    else:
                        image_obj = Image.open(img_content)
                    img_pad_len = self.img_process(image_obj)
                    img_pad_str = self.image_pad_token * img_pad_len
                    img_str = f'{self.vision_start_token}{img_pad_str}{self.vision_end_token}'
                    txt_prompt += img_str
                else:
                    txt_prompt += part
        else:
            txt_prompt = prompt
        input_ids = self.tokenizer(txt_prompt, return_tensors="pt")['input_ids']
        return input_ids

    def get_position_ids(self, input_ids, seq_len, token_len):
        if token_len:
            position_ids = torch.tensor([[seq_len - 1]] * 3, dtype=torch.int)
            return position_ids
        input_ids = input_ids.flatten()
        txt_len, vision_idx, cur_idx = 0, 0, 0
        position_ids_list = []
        for i, token in enumerate(input_ids):
            if token != self.image_pad_id:
                txt_len += 1
            if token == self.vision_start_id:
                text_index = torch.arange(cur_idx, cur_idx + txt_len, dtype=torch.int)
                cur_idx += txt_len
                txt_len = 0
                position_ids_list.append(torch.stack([text_index, text_index, text_index]))
            elif token == self.vision_end_id:
                t, h, w = self.image_grid_thw[vision_idx]
                h = h // self.merge_size
                w = w // self.merge_size
                t_index = torch.arange(t).view(-1, 1).expand(-1, h * w).flatten()
                h_index = torch.arange(h).view(1, -1, 1).expand(t, -1, w).flatten()
                w_index = torch.arange(w).view(1, 1, -1).expand(t, h, -1).flatten()
                position_ids_list.append(torch.stack([t_index, h_index, w_index]) + cur_idx)
                cur_idx += w
                vision_idx += 1
        if txt_len > 0:
            text_index = torch.arange(cur_idx, cur_idx + txt_len, dtype=torch.int)
            position_ids_list.append(torch.stack([text_index, text_index, text_index]))
        position_ids = torch.cat(position_ids_list, dim=1)
        return position_ids

    def vision_position_ids(self, grid_thw):
        pos_ids = []
        for t, h, w in grid_thw:
            llm_h, llm_w = h // self.merge_size, w // self.merge_size
            # compute pos_ids
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(llm_h, self.merge_size, llm_w, self.merge_size)
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(llm_h, self.merge_size, llm_w, self.merge_size)
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
            pos_ids.append(torch.stack([hpos_ids, wpos_ids]))
        position_ids = torch.cat(pos_ids, dim=0)
        return position_ids

    def vision_attention_mask(self, grid_thw, cu_window_seqlens = None):
        seq_len = grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]
        if cu_window_seqlens is None:
            cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(dim=0)
            cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
        else:
            cu_seqlens = cu_window_seqlens
        attention_mask = torch.full([1, seq_len, seq_len], torch.finfo(torch.float32).min)
        for i in range(1, len(cu_seqlens)):
            attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
        return attention_mask

    def vision_reshape(self, images):
        images = [images] * self.temporal_patch_size
        patches = torch.concat(images, axis=0)
        _, channel, height, width = patches.shape
        grid_t = patches.shape[0] // self.temporal_patch_size
        grid_h, grid_w = height // self.patch_size, width // self.patch_size
        patches = patches.reshape(
            grid_t,
            self.temporal_patch_size,
            channel,
            grid_h // self.merge_size,
            self.merge_size,
            self.patch_size,
            grid_w // self.merge_size,
            self.merge_size,
            self.patch_size,
        )
        patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
        flatten_patches = patches.reshape(
            grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
        )
        grid_thw = torch.tensor([[grid_t, grid_h, grid_w]])
        self.image_grid_thw.append([grid_t, grid_h, grid_w])
        return flatten_patches, grid_thw

    def images_forward(self, images):
        flatten_patches, grid_thw = self.vision_reshape(images)
        position_ids = self.vision_position_ids(grid_thw)
        attention_mask = self.vision_attention_mask(grid_thw)
        return self.forward(flatten_patches, position_ids, attention_mask)

    def forward(self, flatten_patches, position_ids, attention_mask):
        rotary_pos_emb = self.rotary(position_ids)
        hidden_states = self.patch_embed(flatten_patches)
        if rotary_pos_emb.dtype != hidden_states.dtype:
            rotary_pos_emb = rotary_pos_emb.to(hidden_states.dtype)
        for blk in self.blocks:
            hidden_states, _ = blk(hidden_states, rotary_pos_emb=rotary_pos_emb, attention_mask=attention_mask)
        image_embeds = self.merger(hidden_states)
        image_embeds = image_embeds.unsqueeze(1)
        return image_embeds

    def smart_resize(self, height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280):
        if height < factor or width < factor:
            raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
        elif max(height, width) / min(height, width) > 200:
            raise ValueError(
                f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
            )
        h_bar = round(height / factor) * factor
        w_bar = round(width / factor) * factor
        if h_bar * w_bar > max_pixels:
            beta = math.sqrt((height * width) / max_pixels)
            h_bar = math.floor(height / beta / factor) * factor
            w_bar = math.floor(width / beta / factor) * factor
        elif h_bar * w_bar < min_pixels:
            beta = math.sqrt(min_pixels / (height * width))
            h_bar = math.ceil(height * beta / factor) * factor
            w_bar = math.ceil(width * beta / factor) * factor
        return h_bar, w_bar

    def img_process(self, image):
        from transformers.image_transforms import (
            convert_to_rgb,
            resize,
            rescale,
            normalize
        )
        from transformers.image_utils import (
            OPENAI_CLIP_MEAN,
            OPENAI_CLIP_STD,
            PILImageResampling,
            infer_channel_dimension_format,
            to_numpy_array
        )
        image = convert_to_rgb(image)
        image = to_numpy_array(image)
        resized_height, resized_width = self.smart_resize(self.image_height, self.image_width)
        format = infer_channel_dimension_format(image)
        resample = PILImageResampling.BICUBIC
        image = resize(image, size=(resized_height, resized_width), resample=resample, input_data_format=format)
        image = rescale(image, scale=1 / 255.0, input_data_format=format)
        image = normalize(image=image, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_data_format=format)
        image = np.expand_dims(image, [0])
        image = image.transpose(0, 3, 1, 2)
        image = torch.from_numpy(image)
        image_embed = self.images_forward(image)
        self.image_embeds.append(image_embed)
        return image_embed.shape[0]

    def embed(self, input_ids, images = None, videos = None):
        input_embeds = self.embed_(input_ids)
        if self.image_embeds is not None and len(self.image_embeds) > 0:
            image_mask = (input_ids == self.image_pad_id).squeeze()
            input_embeds[image_mask] = torch.concat(self.image_embeds, dim=0).to(input_embeds.dtype)
        return input_embeds

    @spinner_run(f'export visual to ')
    def export(self, onnx_path):
        patch = torch.randn([900, 1176])
        posision_ids = torch.zeros([2, 900], dtype=torch.int32)
        attention_mask = torch.zeros([1, 900, 900], dtype=torch.float)
        onnx_model = f'{onnx_path}/visual.onnx'
        torch.onnx.export(self, (patch, posision_ids, attention_mask),
                        onnx_model,
                        input_names=['patches', 'position_ids', 'attention_mask'],
                        output_names=['image_embeds'],
                        dynamic_axes={
                            "patches": { 0: "size" },
                            "position_ids": { 1: "size" },
                            "attention_mask": { 1: "size", 2: "size" }
                        },
                        do_constant_folding=True,
                        verbose=False,
                        opset_version=15)
        return onnx_model

class Qwen2_5Vision(Qwen2Vision):
    def __init__(self, visual, base):
        super().__init__(visual, base)
        self.merge_unit = self.merge_size * self.merge_size
        self.window_size = visual.window_size
        self.fullatt_block_indexes = visual.fullatt_block_indexes

    def get_window_index(self, grid_thw):
        window_index: list = []
        cu_window_seqlens: list = [0]
        window_index_id = 0
        vit_merger_window_size = self.window_size // self.merge_size // self.patch_size

        for grid_t, grid_h, grid_w in grid_thw:
            llm_grid_h, llm_grid_w = (
                grid_h // self.merge_size,
                grid_w // self.merge_size,
            )
            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
            index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
            index_padded = index_padded.reshape(
                grid_t,
                num_windows_h,
                vit_merger_window_size,
                num_windows_w,
                vit_merger_window_size,
            )
            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
                grid_t,
                num_windows_h * num_windows_w,
                vit_merger_window_size,
                vit_merger_window_size,
            )
            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
            index_padded = index_padded.reshape(-1)
            index_new = index_padded[index_padded != -100]
            window_index.append(index_new + window_index_id)
            cu_seqlens_tmp = seqlens.cumsum(0) * self.merge_unit + cu_window_seqlens[-1]
            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
        window_index = torch.cat(window_index, dim=0)
        return window_index, cu_window_seqlens

    def images_forward(self, images):
        flatten_patches, grid_thw = self.vision_reshape(images)
        position_ids = self.vision_position_ids(grid_thw)
        window_index, cu_window_seqlens = self.get_window_index(grid_thw)
        normal_attention_mask = self.vision_attention_mask(grid_thw)
        fullatt_attention_mask = self.vision_attention_mask(grid_thw, cu_window_seqlens)
        attention_mask = torch.stack([normal_attention_mask, fullatt_attention_mask], dim=0)
        return self.forward(flatten_patches, position_ids, attention_mask, window_index)

    def forward(self, flatten_patches, position_ids, attention_mask, window_index):
        hidden_states = self.patch_embed(flatten_patches)
        seq_len, _ = hidden_states.size()
        position_ids = position_ids.reshape(2, seq_len // self.merge_unit, self.merge_unit)
        position_ids = position_ids[:, window_index, :]
        position_ids = position_ids.reshape(2, seq_len)
        rotary_pos_emb = self.rotary(position_ids)
        if rotary_pos_emb.dtype != hidden_states.dtype:
            rotary_pos_emb = rotary_pos_emb.to(hidden_states.dtype)
        hidden_states = hidden_states.reshape(seq_len // self.merge_unit, self.merge_unit, -1)
        hidden_states = hidden_states[window_index, :, :]
        hidden_states = hidden_states.reshape(seq_len, -1)
        for layer_num, blk in enumerate(self.blocks):
            if layer_num in self.fullatt_block_indexes:
                attention_mask_now = attention_mask[0]
            else:
                attention_mask_now = attention_mask[1]
            hidden_states, _ = blk(hidden_states, rotary_pos_emb=rotary_pos_emb, attention_mask=attention_mask_now)
        image_embeds = self.merger(hidden_states)
        reverse_indices = torch.argsort(window_index)
        image_embeds = image_embeds[reverse_indices, :]
        image_embeds = image_embeds.unsqueeze(1)
        return image_embeds

    @spinner_run(f'export visual to ')
    def export(self, onnx_path):
        patch = torch.randn([400, 1176])
        posision_ids = torch.zeros([2, 400], dtype=torch.int32)
        attention_mask = torch.zeros([2, 1, 400, 400], dtype=torch.float)
        window_index = torch.arange(100, dtype=torch.int32)
        onnx_model = f'{onnx_path}/visual.onnx'
        torch.onnx.export(self, (patch, posision_ids, attention_mask, window_index),
                        onnx_model,
                        input_names=['patches', 'position_ids', 'attention_mask', 'window_index'],
                        output_names=['image_embeds'],
                        dynamic_axes={
                            "patches": { 0: "size" },
                            "position_ids": { 1: "size" },
                            "attention_mask": { 2: "size", 3: "size" },
                            "window_index": { 0: "size" }
                        },
                        do_constant_folding=True,
                        verbose=False,
                        opset_version=15)
        return onnx_model

class Qwen2_5OmniVision(Qwen2_5Vision):
    def __init__(self, visual, base):
        self.quant_bit = 8
        self.temporal_patch_size = 2
        self.patch_size = 14
        self.merge_size = 2
        self.image_height = 420
        self.image_width = 420
        self.image_embeds = None
        super().__init__(visual, base)

    def load(self):
        self.config = self.config.thinker_config
        self.vision_start_id = self.config.vision_start_token_id
        self.vision_end_id = self.config.vision_end_token_id
        self.image_pad_id = self.config.image_token_index
        self.llm_config['image_size'] = self.image_height
        self.llm_config['vision_start'] = self.vision_start_id
        self.llm_config['vision_end'] = self.vision_end_id
        self.llm_config['image_pad'] = self.image_pad_id
        self.vision_start_token = '<|vision_bos|>'
        self.vision_end_token = '<|vision_eos|>'
        self.image_pad_token = '<|IMAGE|>'
        # load model
        config = self.visual.config
        if hasattr(config, "embed_dim"):
            self.hidden_size = config.embed_dim
        else:
            self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_heads
        self.num_key_value_heads = config.num_heads
        self.head_dim = self.hidden_size // self.num_attention_heads
        self.rope_theta = 10000.0
        self.rotary_dim = self.head_dim // 2
        self.rotary = VisionRotary(self)
        self.model_map = {
            'decoder': {
                'self_attn': 'attn',
                'mlp': 'mlp',
                'input_layernorm': 'norm1',
                'post_attention_layernorm': 'norm2'
            },
            'attention': {
                'q_proj': 'q',
                'k_proj': 'k',
                'v_proj': 'v',
                'o_proj': 'proj'
            }
        }
        self.patch_embed = self.visual.patch_embed
        self.blocks = []
        for block in self.visual.blocks.children():
            layer_id = len(self.blocks)
            self.blocks.append(Decoder(block, layer_id, self))
        self.merger = self.visual.merger

class MllamaVision(Vision):
    def __init__(self, visual, base):
        super().__init__(visual, base)
        self.multi_modal_projector = base.multi_modal_projector
        self.image_objs = []

    def load(self):
        self.llm_config['is_visual'] = True
        self.llm_config['image_size'] = self.config.vision_config.image_size
        self.image_size = self.config.vision_config.image_size

    def str_to_ids(self, prompt):
        if '<img>' in prompt and '</img>' in prompt:
            import re
            import requests
            from PIL import Image
            pattern = r'(<img>.*?</img>)'
            parts = re.split(pattern, prompt)
            txt_prompt = ''
            for part in parts:
                if re.match(pattern, part):
                    img_content = re.search(r'<img>(.*?)</img>', part).group(1)
                    if img_content.startswith('http://') or img_content.startswith('https://'):
                        self.image_objs.append(Image.open(requests.get(img_content, stream=True).raw))
                    else:
                        self.image_objs.append(Image.open(img_content))
                    txt_prompt += '<|image|>'
                else:
                    txt_prompt += part
        else:
            txt_prompt = prompt
        input_ids = self.tokenizer(txt_prompt, return_tensors="pt")['input_ids']
        # image process
        for img in self.image_objs:
            self.img_process(img)
        return input_ids

    def img_process(self, image):
        self.image_size = 560
        resized_height = self.image_size
        resized_width = self.image_size
        from transformers.image_transforms import (
            convert_to_rgb,
            resize,
            rescale,
            normalize
        )
        from transformers.image_utils import (
            OPENAI_CLIP_MEAN,
            OPENAI_CLIP_STD,
            PILImageResampling,
            infer_channel_dimension_format,
            to_numpy_array
        )
        image = convert_to_rgb(image)
        image = to_numpy_array(image)
        format = infer_channel_dimension_format(image)
        resample = PILImageResampling.BICUBIC
        image = resize(image, size=(resized_height, resized_width), resample=resample, input_data_format=format)
        image = rescale(image, scale=1 / 255.0, input_data_format=format)
        image = normalize(image=image, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_data_format=format)
        image = image.transpose(2, 0, 1)
        image = np.expand_dims(image, [0, 1, 2])
        pad_val = np.zeros_like(image)
        image = np.concatenate([image, pad_val, pad_val, pad_val], axis=2)
        image = torch.from_numpy(image)
        self.cross_attention_states = self.forward(image)

    def forward(self, images):
        aspect_ratio_ids = torch.tensor([[1]])
        aspect_ratio_mask = torch.tensor([[[1, 0, 0, 0]]])
        vision_outputs = self.visual(images, aspect_ratio_ids, aspect_ratio_mask)
        cross_attention_states = vision_outputs[0]
        cross_attention_states = cross_attention_states.type(self.multi_modal_projector.weight.dtype)
        cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
                -1, cross_attention_states.shape[-2], self.hidden_size)
        return cross_attention_states

    def embed(self, input_ids, images = None, videos = None):
        txt_embeds = self.embed_(input_ids)
        return txt_embeds