transformers/llm/export/utils/vision.py (633 lines of code) (raw):

import math import torch import torch.nn.functional as F import numpy as np from .transformers import VisionRotary, Decoder from .spinner import spinner_run class Vision(torch.nn.Module): def __init__(self, visual, base): super().__init__() self.model_type = base.model_type self.visual = visual.eval() self.embed_ = base.embed self.tokenizer = base.tokenizer self.config = base.config self.hidden_size = base.hidden_size self.llm_config = base.llm_config self.rope_ratio = 1.0 # mllama self.cross_attention_states = None self.cross_attention_mask = None self.init_config() self.load() @staticmethod def get_vision(model_type): visual_models = { 'internvl_chat': InternVLVision, 'qwen': QwenVision, 'qwen2_vl': Qwen2Vision, 'qwen2_5_vl':Qwen2_5Vision, 'qwen2_5_omni': Qwen2_5OmniVision, 'mllama': MllamaVision } if model_type in visual_models: return visual_models[model_type] return None def init_config(self): from transformers.image_utils import (OPENAI_CLIP_MEAN, OPENAI_CLIP_STD) self.llm_config['is_visual'] = True image_mean = np.array(OPENAI_CLIP_MEAN) * 255.0 image_norm = 1 / (np.array(OPENAI_CLIP_STD) * 255.0) self.llm_config['image_mean'] = image_mean.tolist() self.llm_config['image_norm'] = image_norm.tolist() def export(self, onnx_path): raise NotImplementedError def load(self): raise NotImplementedError def str_to_ids(self, prompt): input_ids = self.tokenizer(prompt, return_tensors="pt")['input_ids'] return input_ids def forward(self, images): raise NotImplementedError def embed(self, input_ids, images = None, videos = None): raise NotImplementedError class InternVLVision(Vision): def __init__(self, visual, base): super().__init__(visual, base) self.quant_bit = 8 self.vision_model = visual self.mlp1 = base.model.mlp1 self.select_layer = base.model.select_layer def load(self): self.image_size = self.config.force_image_size self.downsample_ratio = self.config.downsample_ratio self.llm_config['is_visual'] = True self.llm_config['image_size'] = self.image_size # self.llm_config['vision_start'] = self.tokenizer.img_start_id # self.llm_config['vision_end'] = self.tokenizer.img_end_id # self.llm_config['image_pad'] = self.tokenizer.img_pad_id def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.shape[0], x.shape[1], x.shape[2], x.shape[3] # N, W, H, C --> N, W, H * scale, C // scale x = x.view(n, w, (h * scale_factor).int(), (c / scale_factor).int()) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) x = x.view(n, (h * scale_factor).int(), (w * scale_factor).int(), (c / (scale_factor * scale_factor)).int()) x = x.permute(0, 2, 1, 3).contiguous() return x def extract_feature(self, pixel_values): if self.select_layer == -1: vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=False, return_dict=True).last_hidden_state else: vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=True, return_dict=True).hidden_states[self.select_layer] vit_embeds = vit_embeds[:, 1:, :] h = w = (vit_embeds.shape[1] ** 0.5).int() vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds) # For mnn's embedding, the order is (seq, batch, hidden) vit_embeds = vit_embeds.permute(1, 0, 2) return vit_embeds def init_config(self): self.llm_config['is_visual'] = True IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] for i in range(3): IMAGENET_MEAN[i] = IMAGENET_MEAN[i] * 255.0 IMAGENET_STD[i] = 1.0 / IMAGENET_STD[i] / 255.0 self.llm_config['image_mean'] = IMAGENET_MEAN self.llm_config['image_norm'] = IMAGENET_STD self.llm_config['image_size_unit'] = 14 def export(self, onnx_path): input_images = torch.randn((1, 3, self.image_size, self.image_size), dtype=torch.float32) onnx_model = f'{onnx_path}/visual.onnx' torch.onnx.export(self, (input_images), onnx_model, input_names=['input_images'], output_names=['image_embeds'], dynamic_axes={ "input_images": { 0: "size", 2: "height", 3: "width"}, }, do_constant_folding=True, verbose=False, opset_version=15) return onnx_model def forward(self, images): return self.extract_feature(images) class QwenVision(Vision): def __init__(self, visual, base): self.quant_bit = 16 super().__init__(visual, base) def load(self): self.image_start_id = self.config.visual['image_start_id'] self.image_size = self.config.visual['image_size'] self.llm_config['is_visual'] = True self.llm_config['image_size'] = self.image_size self.llm_config['vision_start'] = self.tokenizer.img_start_id self.llm_config['vision_end'] = self.tokenizer.img_end_id self.llm_config['image_pad'] = self.tokenizer.img_pad_id @spinner_run(f'export visual to ') def export(self, onnx_path): input_images = torch.randn((1, 3, self.image_size, self.image_size)) onnx_model = f'{onnx_path}/visual.onnx' torch.onnx.export(self, (input_images), onnx_model, input_names=['input_images'], output_names=['image_embeds'], dynamic_axes={ "input_images": { 0: "size" }, }, do_constant_folding=True, verbose=False, opset_version=15) return onnx_model def forward(self, images): return self.visual(images).transpose(1, 0) def embed(self, input_ids, images = None, videos = None): if not torch.any(input_ids == self.image_start_id): return self.embed_(input_ids) bos_pos = torch.where(input_ids == self.image_start_id) eos_pos = torch.where(input_ids == self.image_start_id + 1) img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1) images = [] for i, a, b in img_pos: image = input_ids[i][a + 1 : b - 1].tolist() image = image[ : image.index(self.image_start_id + 2)] images.append(bytes(image).decode('utf-8')) images = self.visual.encode(images).transpose(1, 0) hidden_states = self.embed_(input_ids) for idx, (i, a, b) in enumerate(img_pos): hidden_states[a + 1 : b, i] = images[:, idx] return hidden_states class Qwen2Vision(Vision): def __init__(self, visual, base): self.quant_bit = 4 self.temporal_patch_size = 2 self.patch_size = 14 self.merge_size = 2 self.image_height = 420 self.image_width = 420 self.image_embeds = [] self.image_grid_thw = [] super().__init__(visual, base) def load(self): self.vision_start_id = self.config.vision_start_token_id self.vision_end_id = self.config.vision_end_token_id self.image_pad_id = self.config.image_token_id self.llm_config['image_size'] = self.image_height self.llm_config['vision_start'] = self.vision_start_id self.llm_config['vision_end'] = self.vision_end_id self.llm_config['image_pad'] = self.image_pad_id self.vision_start_token = '<|vision_start|>' self.vision_end_token = '<|vision_end|>' self.image_pad_token = '<|image_pad|>' # load model config = self.visual.config if hasattr(config, "embed_dim"): self.hidden_size = config.embed_dim else: self.hidden_size = config.hidden_size self.num_attention_heads = config.num_heads self.num_key_value_heads = config.num_heads self.head_dim = self.hidden_size // self.num_attention_heads self.rope_theta = 10000.0 self.rotary_dim = self.head_dim // 2 self.rotary = VisionRotary(self) self.model_map = { 'decoder': { 'self_attn': 'attn', 'mlp': 'mlp', 'input_layernorm': 'norm1', 'post_attention_layernorm': 'norm2' }, 'attention': { 'qkv_proj': 'qkv', 'o_proj': 'proj' } } self.patch_embed = self.visual.patch_embed self.blocks = [] for block in self.visual.blocks.children(): layer_id = len(self.blocks) self.blocks.append(Decoder(block, layer_id, self)) self.merger = self.visual.merger def str_to_ids(self, prompt): if '<img>' in prompt and '</img>' in prompt: import re import requests from PIL import Image pattern = r'(<img>.*?</img>)' parts = re.split(pattern, prompt) txt_prompt = '' for part in parts: if re.match(pattern, part): img_content = re.search(r'<img>(.*?)</img>', part).group(1) # find <hw></hw> in image_content match = re.search(r'<hw>(.*?)</hw>', img_content) if match: img_content = img_content[:match.start()] + img_content[match.end():] hw = match.group(1).split(',') self.image_height, self.image_width = int(hw[0]), int(hw[1]) if img_content.startswith('http://') or img_content.startswith('https://'): image_obj = Image.open(requests.get(img_content, stream=True).raw) else: image_obj = Image.open(img_content) img_pad_len = self.img_process(image_obj) img_pad_str = self.image_pad_token * img_pad_len img_str = f'{self.vision_start_token}{img_pad_str}{self.vision_end_token}' txt_prompt += img_str else: txt_prompt += part else: txt_prompt = prompt input_ids = self.tokenizer(txt_prompt, return_tensors="pt")['input_ids'] return input_ids def get_position_ids(self, input_ids, seq_len, token_len): if token_len: position_ids = torch.tensor([[seq_len - 1]] * 3, dtype=torch.int) return position_ids input_ids = input_ids.flatten() txt_len, vision_idx, cur_idx = 0, 0, 0 position_ids_list = [] for i, token in enumerate(input_ids): if token != self.image_pad_id: txt_len += 1 if token == self.vision_start_id: text_index = torch.arange(cur_idx, cur_idx + txt_len, dtype=torch.int) cur_idx += txt_len txt_len = 0 position_ids_list.append(torch.stack([text_index, text_index, text_index])) elif token == self.vision_end_id: t, h, w = self.image_grid_thw[vision_idx] h = h // self.merge_size w = w // self.merge_size t_index = torch.arange(t).view(-1, 1).expand(-1, h * w).flatten() h_index = torch.arange(h).view(1, -1, 1).expand(t, -1, w).flatten() w_index = torch.arange(w).view(1, 1, -1).expand(t, h, -1).flatten() position_ids_list.append(torch.stack([t_index, h_index, w_index]) + cur_idx) cur_idx += w vision_idx += 1 if txt_len > 0: text_index = torch.arange(cur_idx, cur_idx + txt_len, dtype=torch.int) position_ids_list.append(torch.stack([text_index, text_index, text_index])) position_ids = torch.cat(position_ids_list, dim=1) return position_ids def vision_position_ids(self, grid_thw): pos_ids = [] for t, h, w in grid_thw: llm_h, llm_w = h // self.merge_size, w // self.merge_size # compute pos_ids hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) hpos_ids = hpos_ids.reshape(llm_h, self.merge_size, llm_w, self.merge_size) hpos_ids = hpos_ids.permute(0, 2, 1, 3) hpos_ids = hpos_ids.flatten() wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) wpos_ids = wpos_ids.reshape(llm_h, self.merge_size, llm_w, self.merge_size) wpos_ids = wpos_ids.permute(0, 2, 1, 3) wpos_ids = wpos_ids.flatten() pos_ids.append(torch.stack([hpos_ids, wpos_ids])) position_ids = torch.cat(pos_ids, dim=0) return position_ids def vision_attention_mask(self, grid_thw, cu_window_seqlens = None): seq_len = grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2] if cu_window_seqlens is None: cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(dim=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) else: cu_seqlens = cu_window_seqlens attention_mask = torch.full([1, seq_len, seq_len], torch.finfo(torch.float32).min) for i in range(1, len(cu_seqlens)): attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 return attention_mask def vision_reshape(self, images): images = [images] * self.temporal_patch_size patches = torch.concat(images, axis=0) _, channel, height, width = patches.shape grid_t = patches.shape[0] // self.temporal_patch_size grid_h, grid_w = height // self.patch_size, width // self.patch_size patches = patches.reshape( grid_t, self.temporal_patch_size, channel, grid_h // self.merge_size, self.merge_size, self.patch_size, grid_w // self.merge_size, self.merge_size, self.patch_size, ) patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) flatten_patches = patches.reshape( grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size ) grid_thw = torch.tensor([[grid_t, grid_h, grid_w]]) self.image_grid_thw.append([grid_t, grid_h, grid_w]) return flatten_patches, grid_thw def images_forward(self, images): flatten_patches, grid_thw = self.vision_reshape(images) position_ids = self.vision_position_ids(grid_thw) attention_mask = self.vision_attention_mask(grid_thw) return self.forward(flatten_patches, position_ids, attention_mask) def forward(self, flatten_patches, position_ids, attention_mask): rotary_pos_emb = self.rotary(position_ids) hidden_states = self.patch_embed(flatten_patches) if rotary_pos_emb.dtype != hidden_states.dtype: rotary_pos_emb = rotary_pos_emb.to(hidden_states.dtype) for blk in self.blocks: hidden_states, _ = blk(hidden_states, rotary_pos_emb=rotary_pos_emb, attention_mask=attention_mask) image_embeds = self.merger(hidden_states) image_embeds = image_embeds.unsqueeze(1) return image_embeds def smart_resize(self, height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280): if height < factor or width < factor: raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") elif max(height, width) / min(height, width) > 200: raise ValueError( f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" ) h_bar = round(height / factor) * factor w_bar = round(width / factor) * factor if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = math.floor(height / beta / factor) * factor w_bar = math.floor(width / beta / factor) * factor elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = math.ceil(height * beta / factor) * factor w_bar = math.ceil(width * beta / factor) * factor return h_bar, w_bar def img_process(self, image): from transformers.image_transforms import ( convert_to_rgb, resize, rescale, normalize ) from transformers.image_utils import ( OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling, infer_channel_dimension_format, to_numpy_array ) image = convert_to_rgb(image) image = to_numpy_array(image) resized_height, resized_width = self.smart_resize(self.image_height, self.image_width) format = infer_channel_dimension_format(image) resample = PILImageResampling.BICUBIC image = resize(image, size=(resized_height, resized_width), resample=resample, input_data_format=format) image = rescale(image, scale=1 / 255.0, input_data_format=format) image = normalize(image=image, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_data_format=format) image = np.expand_dims(image, [0]) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image) image_embed = self.images_forward(image) self.image_embeds.append(image_embed) return image_embed.shape[0] def embed(self, input_ids, images = None, videos = None): input_embeds = self.embed_(input_ids) if self.image_embeds is not None and len(self.image_embeds) > 0: image_mask = (input_ids == self.image_pad_id).squeeze() input_embeds[image_mask] = torch.concat(self.image_embeds, dim=0).to(input_embeds.dtype) return input_embeds @spinner_run(f'export visual to ') def export(self, onnx_path): patch = torch.randn([900, 1176]) posision_ids = torch.zeros([2, 900], dtype=torch.int32) attention_mask = torch.zeros([1, 900, 900], dtype=torch.float) onnx_model = f'{onnx_path}/visual.onnx' torch.onnx.export(self, (patch, posision_ids, attention_mask), onnx_model, input_names=['patches', 'position_ids', 'attention_mask'], output_names=['image_embeds'], dynamic_axes={ "patches": { 0: "size" }, "position_ids": { 1: "size" }, "attention_mask": { 1: "size", 2: "size" } }, do_constant_folding=True, verbose=False, opset_version=15) return onnx_model class Qwen2_5Vision(Qwen2Vision): def __init__(self, visual, base): super().__init__(visual, base) self.merge_unit = self.merge_size * self.merge_size self.window_size = visual.window_size self.fullatt_block_indexes = visual.fullatt_block_indexes def get_window_index(self, grid_thw): window_index: list = [] cu_window_seqlens: list = [0] window_index_id = 0 vit_merger_window_size = self.window_size // self.merge_size // self.patch_size for grid_t, grid_h, grid_w in grid_thw: llm_grid_h, llm_grid_w = ( grid_h // self.merge_size, grid_w // self.merge_size, ) index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w) pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100) index_padded = index_padded.reshape( grid_t, num_windows_h, vit_merger_window_size, num_windows_w, vit_merger_window_size, ) index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( grid_t, num_windows_h * num_windows_w, vit_merger_window_size, vit_merger_window_size, ) seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) index_padded = index_padded.reshape(-1) index_new = index_padded[index_padded != -100] window_index.append(index_new + window_index_id) cu_seqlens_tmp = seqlens.cumsum(0) * self.merge_unit + cu_window_seqlens[-1] cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() window_index = torch.cat(window_index, dim=0) return window_index, cu_window_seqlens def images_forward(self, images): flatten_patches, grid_thw = self.vision_reshape(images) position_ids = self.vision_position_ids(grid_thw) window_index, cu_window_seqlens = self.get_window_index(grid_thw) normal_attention_mask = self.vision_attention_mask(grid_thw) fullatt_attention_mask = self.vision_attention_mask(grid_thw, cu_window_seqlens) attention_mask = torch.stack([normal_attention_mask, fullatt_attention_mask], dim=0) return self.forward(flatten_patches, position_ids, attention_mask, window_index) def forward(self, flatten_patches, position_ids, attention_mask, window_index): hidden_states = self.patch_embed(flatten_patches) seq_len, _ = hidden_states.size() position_ids = position_ids.reshape(2, seq_len // self.merge_unit, self.merge_unit) position_ids = position_ids[:, window_index, :] position_ids = position_ids.reshape(2, seq_len) rotary_pos_emb = self.rotary(position_ids) if rotary_pos_emb.dtype != hidden_states.dtype: rotary_pos_emb = rotary_pos_emb.to(hidden_states.dtype) hidden_states = hidden_states.reshape(seq_len // self.merge_unit, self.merge_unit, -1) hidden_states = hidden_states[window_index, :, :] hidden_states = hidden_states.reshape(seq_len, -1) for layer_num, blk in enumerate(self.blocks): if layer_num in self.fullatt_block_indexes: attention_mask_now = attention_mask[0] else: attention_mask_now = attention_mask[1] hidden_states, _ = blk(hidden_states, rotary_pos_emb=rotary_pos_emb, attention_mask=attention_mask_now) image_embeds = self.merger(hidden_states) reverse_indices = torch.argsort(window_index) image_embeds = image_embeds[reverse_indices, :] image_embeds = image_embeds.unsqueeze(1) return image_embeds @spinner_run(f'export visual to ') def export(self, onnx_path): patch = torch.randn([400, 1176]) posision_ids = torch.zeros([2, 400], dtype=torch.int32) attention_mask = torch.zeros([2, 1, 400, 400], dtype=torch.float) window_index = torch.arange(100, dtype=torch.int32) onnx_model = f'{onnx_path}/visual.onnx' torch.onnx.export(self, (patch, posision_ids, attention_mask, window_index), onnx_model, input_names=['patches', 'position_ids', 'attention_mask', 'window_index'], output_names=['image_embeds'], dynamic_axes={ "patches": { 0: "size" }, "position_ids": { 1: "size" }, "attention_mask": { 2: "size", 3: "size" }, "window_index": { 0: "size" } }, do_constant_folding=True, verbose=False, opset_version=15) return onnx_model class Qwen2_5OmniVision(Qwen2_5Vision): def __init__(self, visual, base): self.quant_bit = 8 self.temporal_patch_size = 2 self.patch_size = 14 self.merge_size = 2 self.image_height = 420 self.image_width = 420 self.image_embeds = None super().__init__(visual, base) def load(self): self.config = self.config.thinker_config self.vision_start_id = self.config.vision_start_token_id self.vision_end_id = self.config.vision_end_token_id self.image_pad_id = self.config.image_token_index self.llm_config['image_size'] = self.image_height self.llm_config['vision_start'] = self.vision_start_id self.llm_config['vision_end'] = self.vision_end_id self.llm_config['image_pad'] = self.image_pad_id self.vision_start_token = '<|vision_bos|>' self.vision_end_token = '<|vision_eos|>' self.image_pad_token = '<|IMAGE|>' # load model config = self.visual.config if hasattr(config, "embed_dim"): self.hidden_size = config.embed_dim else: self.hidden_size = config.hidden_size self.num_attention_heads = config.num_heads self.num_key_value_heads = config.num_heads self.head_dim = self.hidden_size // self.num_attention_heads self.rope_theta = 10000.0 self.rotary_dim = self.head_dim // 2 self.rotary = VisionRotary(self) self.model_map = { 'decoder': { 'self_attn': 'attn', 'mlp': 'mlp', 'input_layernorm': 'norm1', 'post_attention_layernorm': 'norm2' }, 'attention': { 'q_proj': 'q', 'k_proj': 'k', 'v_proj': 'v', 'o_proj': 'proj' } } self.patch_embed = self.visual.patch_embed self.blocks = [] for block in self.visual.blocks.children(): layer_id = len(self.blocks) self.blocks.append(Decoder(block, layer_id, self)) self.merger = self.visual.merger class MllamaVision(Vision): def __init__(self, visual, base): super().__init__(visual, base) self.multi_modal_projector = base.multi_modal_projector self.image_objs = [] def load(self): self.llm_config['is_visual'] = True self.llm_config['image_size'] = self.config.vision_config.image_size self.image_size = self.config.vision_config.image_size def str_to_ids(self, prompt): if '<img>' in prompt and '</img>' in prompt: import re import requests from PIL import Image pattern = r'(<img>.*?</img>)' parts = re.split(pattern, prompt) txt_prompt = '' for part in parts: if re.match(pattern, part): img_content = re.search(r'<img>(.*?)</img>', part).group(1) if img_content.startswith('http://') or img_content.startswith('https://'): self.image_objs.append(Image.open(requests.get(img_content, stream=True).raw)) else: self.image_objs.append(Image.open(img_content)) txt_prompt += '<|image|>' else: txt_prompt += part else: txt_prompt = prompt input_ids = self.tokenizer(txt_prompt, return_tensors="pt")['input_ids'] # image process for img in self.image_objs: self.img_process(img) return input_ids def img_process(self, image): self.image_size = 560 resized_height = self.image_size resized_width = self.image_size from transformers.image_transforms import ( convert_to_rgb, resize, rescale, normalize ) from transformers.image_utils import ( OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling, infer_channel_dimension_format, to_numpy_array ) image = convert_to_rgb(image) image = to_numpy_array(image) format = infer_channel_dimension_format(image) resample = PILImageResampling.BICUBIC image = resize(image, size=(resized_height, resized_width), resample=resample, input_data_format=format) image = rescale(image, scale=1 / 255.0, input_data_format=format) image = normalize(image=image, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_data_format=format) image = image.transpose(2, 0, 1) image = np.expand_dims(image, [0, 1, 2]) pad_val = np.zeros_like(image) image = np.concatenate([image, pad_val, pad_val, pad_val], axis=2) image = torch.from_numpy(image) self.cross_attention_states = self.forward(image) def forward(self, images): aspect_ratio_ids = torch.tensor([[1]]) aspect_ratio_mask = torch.tensor([[[1, 0, 0, 0]]]) vision_outputs = self.visual(images, aspect_ratio_ids, aspect_ratio_mask) cross_attention_states = vision_outputs[0] cross_attention_states = cross_attention_states.type(self.multi_modal_projector.weight.dtype) cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( -1, cross_attention_states.shape[-2], self.hidden_size) return cross_attention_states def embed(self, input_ids, images = None, videos = None): txt_embeds = self.embed_(input_ids) return txt_embeds