transformers/llm/export/utils/vision.py (633 lines of code) (raw):
import math
import torch
import torch.nn.functional as F
import numpy as np
from .transformers import VisionRotary, Decoder
from .spinner import spinner_run
class Vision(torch.nn.Module):
def __init__(self, visual, base):
super().__init__()
self.model_type = base.model_type
self.visual = visual.eval()
self.embed_ = base.embed
self.tokenizer = base.tokenizer
self.config = base.config
self.hidden_size = base.hidden_size
self.llm_config = base.llm_config
self.rope_ratio = 1.0
# mllama
self.cross_attention_states = None
self.cross_attention_mask = None
self.init_config()
self.load()
@staticmethod
def get_vision(model_type):
visual_models = {
'internvl_chat': InternVLVision,
'qwen': QwenVision,
'qwen2_vl': Qwen2Vision,
'qwen2_5_vl':Qwen2_5Vision,
'qwen2_5_omni': Qwen2_5OmniVision,
'mllama': MllamaVision
}
if model_type in visual_models:
return visual_models[model_type]
return None
def init_config(self):
from transformers.image_utils import (OPENAI_CLIP_MEAN, OPENAI_CLIP_STD)
self.llm_config['is_visual'] = True
image_mean = np.array(OPENAI_CLIP_MEAN) * 255.0
image_norm = 1 / (np.array(OPENAI_CLIP_STD) * 255.0)
self.llm_config['image_mean'] = image_mean.tolist()
self.llm_config['image_norm'] = image_norm.tolist()
def export(self, onnx_path):
raise NotImplementedError
def load(self):
raise NotImplementedError
def str_to_ids(self, prompt):
input_ids = self.tokenizer(prompt, return_tensors="pt")['input_ids']
return input_ids
def forward(self, images):
raise NotImplementedError
def embed(self, input_ids, images = None, videos = None):
raise NotImplementedError
class InternVLVision(Vision):
def __init__(self, visual, base):
super().__init__(visual, base)
self.quant_bit = 8
self.vision_model = visual
self.mlp1 = base.model.mlp1
self.select_layer = base.model.select_layer
def load(self):
self.image_size = self.config.force_image_size
self.downsample_ratio = self.config.downsample_ratio
self.llm_config['is_visual'] = True
self.llm_config['image_size'] = self.image_size
# self.llm_config['vision_start'] = self.tokenizer.img_start_id
# self.llm_config['vision_end'] = self.tokenizer.img_end_id
# self.llm_config['image_pad'] = self.tokenizer.img_pad_id
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.shape[0], x.shape[1], x.shape[2], x.shape[3]
# N, W, H, C --> N, W, H * scale, C // scale
x = x.view(n, w, (h * scale_factor).int(), (c / scale_factor).int())
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
x = x.permute(0, 2, 1, 3).contiguous()
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
x = x.view(n, (h * scale_factor).int(), (w * scale_factor).int(),
(c / (scale_factor * scale_factor)).int())
x = x.permute(0, 2, 1, 3).contiguous()
return x
def extract_feature(self, pixel_values):
if self.select_layer == -1:
vit_embeds = self.vision_model(
pixel_values=pixel_values,
output_hidden_states=False,
return_dict=True).last_hidden_state
else:
vit_embeds = self.vision_model(
pixel_values=pixel_values,
output_hidden_states=True,
return_dict=True).hidden_states[self.select_layer]
vit_embeds = vit_embeds[:, 1:, :]
h = w = (vit_embeds.shape[1] ** 0.5).int()
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
vit_embeds = self.mlp1(vit_embeds)
# For mnn's embedding, the order is (seq, batch, hidden)
vit_embeds = vit_embeds.permute(1, 0, 2)
return vit_embeds
def init_config(self):
self.llm_config['is_visual'] = True
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
for i in range(3):
IMAGENET_MEAN[i] = IMAGENET_MEAN[i] * 255.0
IMAGENET_STD[i] = 1.0 / IMAGENET_STD[i] / 255.0
self.llm_config['image_mean'] = IMAGENET_MEAN
self.llm_config['image_norm'] = IMAGENET_STD
self.llm_config['image_size_unit'] = 14
def export(self, onnx_path):
input_images = torch.randn((1, 3, self.image_size, self.image_size), dtype=torch.float32)
onnx_model = f'{onnx_path}/visual.onnx'
torch.onnx.export(self, (input_images),
onnx_model,
input_names=['input_images'],
output_names=['image_embeds'],
dynamic_axes={
"input_images": { 0: "size", 2: "height", 3: "width"},
},
do_constant_folding=True,
verbose=False,
opset_version=15)
return onnx_model
def forward(self, images):
return self.extract_feature(images)
class QwenVision(Vision):
def __init__(self, visual, base):
self.quant_bit = 16
super().__init__(visual, base)
def load(self):
self.image_start_id = self.config.visual['image_start_id']
self.image_size = self.config.visual['image_size']
self.llm_config['is_visual'] = True
self.llm_config['image_size'] = self.image_size
self.llm_config['vision_start'] = self.tokenizer.img_start_id
self.llm_config['vision_end'] = self.tokenizer.img_end_id
self.llm_config['image_pad'] = self.tokenizer.img_pad_id
@spinner_run(f'export visual to ')
def export(self, onnx_path):
input_images = torch.randn((1, 3, self.image_size, self.image_size))
onnx_model = f'{onnx_path}/visual.onnx'
torch.onnx.export(self, (input_images),
onnx_model,
input_names=['input_images'],
output_names=['image_embeds'],
dynamic_axes={
"input_images": { 0: "size" },
},
do_constant_folding=True,
verbose=False,
opset_version=15)
return onnx_model
def forward(self, images):
return self.visual(images).transpose(1, 0)
def embed(self, input_ids, images = None, videos = None):
if not torch.any(input_ids == self.image_start_id):
return self.embed_(input_ids)
bos_pos = torch.where(input_ids == self.image_start_id)
eos_pos = torch.where(input_ids == self.image_start_id + 1)
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
images = []
for i, a, b in img_pos:
image = input_ids[i][a + 1 : b - 1].tolist()
image = image[ : image.index(self.image_start_id + 2)]
images.append(bytes(image).decode('utf-8'))
images = self.visual.encode(images).transpose(1, 0)
hidden_states = self.embed_(input_ids)
for idx, (i, a, b) in enumerate(img_pos):
hidden_states[a + 1 : b, i] = images[:, idx]
return hidden_states
class Qwen2Vision(Vision):
def __init__(self, visual, base):
self.quant_bit = 4
self.temporal_patch_size = 2
self.patch_size = 14
self.merge_size = 2
self.image_height = 420
self.image_width = 420
self.image_embeds = []
self.image_grid_thw = []
super().__init__(visual, base)
def load(self):
self.vision_start_id = self.config.vision_start_token_id
self.vision_end_id = self.config.vision_end_token_id
self.image_pad_id = self.config.image_token_id
self.llm_config['image_size'] = self.image_height
self.llm_config['vision_start'] = self.vision_start_id
self.llm_config['vision_end'] = self.vision_end_id
self.llm_config['image_pad'] = self.image_pad_id
self.vision_start_token = '<|vision_start|>'
self.vision_end_token = '<|vision_end|>'
self.image_pad_token = '<|image_pad|>'
# load model
config = self.visual.config
if hasattr(config, "embed_dim"):
self.hidden_size = config.embed_dim
else:
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_heads
self.num_key_value_heads = config.num_heads
self.head_dim = self.hidden_size // self.num_attention_heads
self.rope_theta = 10000.0
self.rotary_dim = self.head_dim // 2
self.rotary = VisionRotary(self)
self.model_map = {
'decoder': {
'self_attn': 'attn',
'mlp': 'mlp',
'input_layernorm': 'norm1',
'post_attention_layernorm': 'norm2'
},
'attention': {
'qkv_proj': 'qkv',
'o_proj': 'proj'
}
}
self.patch_embed = self.visual.patch_embed
self.blocks = []
for block in self.visual.blocks.children():
layer_id = len(self.blocks)
self.blocks.append(Decoder(block, layer_id, self))
self.merger = self.visual.merger
def str_to_ids(self, prompt):
if '<img>' in prompt and '</img>' in prompt:
import re
import requests
from PIL import Image
pattern = r'(<img>.*?</img>)'
parts = re.split(pattern, prompt)
txt_prompt = ''
for part in parts:
if re.match(pattern, part):
img_content = re.search(r'<img>(.*?)</img>', part).group(1)
# find <hw></hw> in image_content
match = re.search(r'<hw>(.*?)</hw>', img_content)
if match:
img_content = img_content[:match.start()] + img_content[match.end():]
hw = match.group(1).split(',')
self.image_height, self.image_width = int(hw[0]), int(hw[1])
if img_content.startswith('http://') or img_content.startswith('https://'):
image_obj = Image.open(requests.get(img_content, stream=True).raw)
else:
image_obj = Image.open(img_content)
img_pad_len = self.img_process(image_obj)
img_pad_str = self.image_pad_token * img_pad_len
img_str = f'{self.vision_start_token}{img_pad_str}{self.vision_end_token}'
txt_prompt += img_str
else:
txt_prompt += part
else:
txt_prompt = prompt
input_ids = self.tokenizer(txt_prompt, return_tensors="pt")['input_ids']
return input_ids
def get_position_ids(self, input_ids, seq_len, token_len):
if token_len:
position_ids = torch.tensor([[seq_len - 1]] * 3, dtype=torch.int)
return position_ids
input_ids = input_ids.flatten()
txt_len, vision_idx, cur_idx = 0, 0, 0
position_ids_list = []
for i, token in enumerate(input_ids):
if token != self.image_pad_id:
txt_len += 1
if token == self.vision_start_id:
text_index = torch.arange(cur_idx, cur_idx + txt_len, dtype=torch.int)
cur_idx += txt_len
txt_len = 0
position_ids_list.append(torch.stack([text_index, text_index, text_index]))
elif token == self.vision_end_id:
t, h, w = self.image_grid_thw[vision_idx]
h = h // self.merge_size
w = w // self.merge_size
t_index = torch.arange(t).view(-1, 1).expand(-1, h * w).flatten()
h_index = torch.arange(h).view(1, -1, 1).expand(t, -1, w).flatten()
w_index = torch.arange(w).view(1, 1, -1).expand(t, h, -1).flatten()
position_ids_list.append(torch.stack([t_index, h_index, w_index]) + cur_idx)
cur_idx += w
vision_idx += 1
if txt_len > 0:
text_index = torch.arange(cur_idx, cur_idx + txt_len, dtype=torch.int)
position_ids_list.append(torch.stack([text_index, text_index, text_index]))
position_ids = torch.cat(position_ids_list, dim=1)
return position_ids
def vision_position_ids(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
llm_h, llm_w = h // self.merge_size, w // self.merge_size
# compute pos_ids
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(llm_h, self.merge_size, llm_w, self.merge_size)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(llm_h, self.merge_size, llm_w, self.merge_size)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids]))
position_ids = torch.cat(pos_ids, dim=0)
return position_ids
def vision_attention_mask(self, grid_thw, cu_window_seqlens = None):
seq_len = grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]
if cu_window_seqlens is None:
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(dim=0)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
else:
cu_seqlens = cu_window_seqlens
attention_mask = torch.full([1, seq_len, seq_len], torch.finfo(torch.float32).min)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
def vision_reshape(self, images):
images = [images] * self.temporal_patch_size
patches = torch.concat(images, axis=0)
_, channel, height, width = patches.shape
grid_t = patches.shape[0] // self.temporal_patch_size
grid_h, grid_w = height // self.patch_size, width // self.patch_size
patches = patches.reshape(
grid_t,
self.temporal_patch_size,
channel,
grid_h // self.merge_size,
self.merge_size,
self.patch_size,
grid_w // self.merge_size,
self.merge_size,
self.patch_size,
)
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
)
grid_thw = torch.tensor([[grid_t, grid_h, grid_w]])
self.image_grid_thw.append([grid_t, grid_h, grid_w])
return flatten_patches, grid_thw
def images_forward(self, images):
flatten_patches, grid_thw = self.vision_reshape(images)
position_ids = self.vision_position_ids(grid_thw)
attention_mask = self.vision_attention_mask(grid_thw)
return self.forward(flatten_patches, position_ids, attention_mask)
def forward(self, flatten_patches, position_ids, attention_mask):
rotary_pos_emb = self.rotary(position_ids)
hidden_states = self.patch_embed(flatten_patches)
if rotary_pos_emb.dtype != hidden_states.dtype:
rotary_pos_emb = rotary_pos_emb.to(hidden_states.dtype)
for blk in self.blocks:
hidden_states, _ = blk(hidden_states, rotary_pos_emb=rotary_pos_emb, attention_mask=attention_mask)
image_embeds = self.merger(hidden_states)
image_embeds = image_embeds.unsqueeze(1)
return image_embeds
def smart_resize(self, height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280):
if height < factor or width < factor:
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
elif max(height, width) / min(height, width) > 200:
raise ValueError(
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
)
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = math.floor(height / beta / factor) * factor
w_bar = math.floor(width / beta / factor) * factor
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
return h_bar, w_bar
def img_process(self, image):
from transformers.image_transforms import (
convert_to_rgb,
resize,
rescale,
normalize
)
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
PILImageResampling,
infer_channel_dimension_format,
to_numpy_array
)
image = convert_to_rgb(image)
image = to_numpy_array(image)
resized_height, resized_width = self.smart_resize(self.image_height, self.image_width)
format = infer_channel_dimension_format(image)
resample = PILImageResampling.BICUBIC
image = resize(image, size=(resized_height, resized_width), resample=resample, input_data_format=format)
image = rescale(image, scale=1 / 255.0, input_data_format=format)
image = normalize(image=image, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_data_format=format)
image = np.expand_dims(image, [0])
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
image_embed = self.images_forward(image)
self.image_embeds.append(image_embed)
return image_embed.shape[0]
def embed(self, input_ids, images = None, videos = None):
input_embeds = self.embed_(input_ids)
if self.image_embeds is not None and len(self.image_embeds) > 0:
image_mask = (input_ids == self.image_pad_id).squeeze()
input_embeds[image_mask] = torch.concat(self.image_embeds, dim=0).to(input_embeds.dtype)
return input_embeds
@spinner_run(f'export visual to ')
def export(self, onnx_path):
patch = torch.randn([900, 1176])
posision_ids = torch.zeros([2, 900], dtype=torch.int32)
attention_mask = torch.zeros([1, 900, 900], dtype=torch.float)
onnx_model = f'{onnx_path}/visual.onnx'
torch.onnx.export(self, (patch, posision_ids, attention_mask),
onnx_model,
input_names=['patches', 'position_ids', 'attention_mask'],
output_names=['image_embeds'],
dynamic_axes={
"patches": { 0: "size" },
"position_ids": { 1: "size" },
"attention_mask": { 1: "size", 2: "size" }
},
do_constant_folding=True,
verbose=False,
opset_version=15)
return onnx_model
class Qwen2_5Vision(Qwen2Vision):
def __init__(self, visual, base):
super().__init__(visual, base)
self.merge_unit = self.merge_size * self.merge_size
self.window_size = visual.window_size
self.fullatt_block_indexes = visual.fullatt_block_indexes
def get_window_index(self, grid_thw):
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = self.window_size // self.merge_size // self.patch_size
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (
grid_h // self.merge_size,
grid_w // self.merge_size,
)
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
index_padded = index_padded.reshape(
grid_t,
num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size,
)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t,
num_windows_h * num_windows_w,
vit_merger_window_size,
vit_merger_window_size,
)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = seqlens.cumsum(0) * self.merge_unit + cu_window_seqlens[-1]
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
def images_forward(self, images):
flatten_patches, grid_thw = self.vision_reshape(images)
position_ids = self.vision_position_ids(grid_thw)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
normal_attention_mask = self.vision_attention_mask(grid_thw)
fullatt_attention_mask = self.vision_attention_mask(grid_thw, cu_window_seqlens)
attention_mask = torch.stack([normal_attention_mask, fullatt_attention_mask], dim=0)
return self.forward(flatten_patches, position_ids, attention_mask, window_index)
def forward(self, flatten_patches, position_ids, attention_mask, window_index):
hidden_states = self.patch_embed(flatten_patches)
seq_len, _ = hidden_states.size()
position_ids = position_ids.reshape(2, seq_len // self.merge_unit, self.merge_unit)
position_ids = position_ids[:, window_index, :]
position_ids = position_ids.reshape(2, seq_len)
rotary_pos_emb = self.rotary(position_ids)
if rotary_pos_emb.dtype != hidden_states.dtype:
rotary_pos_emb = rotary_pos_emb.to(hidden_states.dtype)
hidden_states = hidden_states.reshape(seq_len // self.merge_unit, self.merge_unit, -1)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
attention_mask_now = attention_mask[0]
else:
attention_mask_now = attention_mask[1]
hidden_states, _ = blk(hidden_states, rotary_pos_emb=rotary_pos_emb, attention_mask=attention_mask_now)
image_embeds = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)
image_embeds = image_embeds[reverse_indices, :]
image_embeds = image_embeds.unsqueeze(1)
return image_embeds
@spinner_run(f'export visual to ')
def export(self, onnx_path):
patch = torch.randn([400, 1176])
posision_ids = torch.zeros([2, 400], dtype=torch.int32)
attention_mask = torch.zeros([2, 1, 400, 400], dtype=torch.float)
window_index = torch.arange(100, dtype=torch.int32)
onnx_model = f'{onnx_path}/visual.onnx'
torch.onnx.export(self, (patch, posision_ids, attention_mask, window_index),
onnx_model,
input_names=['patches', 'position_ids', 'attention_mask', 'window_index'],
output_names=['image_embeds'],
dynamic_axes={
"patches": { 0: "size" },
"position_ids": { 1: "size" },
"attention_mask": { 2: "size", 3: "size" },
"window_index": { 0: "size" }
},
do_constant_folding=True,
verbose=False,
opset_version=15)
return onnx_model
class Qwen2_5OmniVision(Qwen2_5Vision):
def __init__(self, visual, base):
self.quant_bit = 8
self.temporal_patch_size = 2
self.patch_size = 14
self.merge_size = 2
self.image_height = 420
self.image_width = 420
self.image_embeds = None
super().__init__(visual, base)
def load(self):
self.config = self.config.thinker_config
self.vision_start_id = self.config.vision_start_token_id
self.vision_end_id = self.config.vision_end_token_id
self.image_pad_id = self.config.image_token_index
self.llm_config['image_size'] = self.image_height
self.llm_config['vision_start'] = self.vision_start_id
self.llm_config['vision_end'] = self.vision_end_id
self.llm_config['image_pad'] = self.image_pad_id
self.vision_start_token = '<|vision_bos|>'
self.vision_end_token = '<|vision_eos|>'
self.image_pad_token = '<|IMAGE|>'
# load model
config = self.visual.config
if hasattr(config, "embed_dim"):
self.hidden_size = config.embed_dim
else:
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_heads
self.num_key_value_heads = config.num_heads
self.head_dim = self.hidden_size // self.num_attention_heads
self.rope_theta = 10000.0
self.rotary_dim = self.head_dim // 2
self.rotary = VisionRotary(self)
self.model_map = {
'decoder': {
'self_attn': 'attn',
'mlp': 'mlp',
'input_layernorm': 'norm1',
'post_attention_layernorm': 'norm2'
},
'attention': {
'q_proj': 'q',
'k_proj': 'k',
'v_proj': 'v',
'o_proj': 'proj'
}
}
self.patch_embed = self.visual.patch_embed
self.blocks = []
for block in self.visual.blocks.children():
layer_id = len(self.blocks)
self.blocks.append(Decoder(block, layer_id, self))
self.merger = self.visual.merger
class MllamaVision(Vision):
def __init__(self, visual, base):
super().__init__(visual, base)
self.multi_modal_projector = base.multi_modal_projector
self.image_objs = []
def load(self):
self.llm_config['is_visual'] = True
self.llm_config['image_size'] = self.config.vision_config.image_size
self.image_size = self.config.vision_config.image_size
def str_to_ids(self, prompt):
if '<img>' in prompt and '</img>' in prompt:
import re
import requests
from PIL import Image
pattern = r'(<img>.*?</img>)'
parts = re.split(pattern, prompt)
txt_prompt = ''
for part in parts:
if re.match(pattern, part):
img_content = re.search(r'<img>(.*?)</img>', part).group(1)
if img_content.startswith('http://') or img_content.startswith('https://'):
self.image_objs.append(Image.open(requests.get(img_content, stream=True).raw))
else:
self.image_objs.append(Image.open(img_content))
txt_prompt += '<|image|>'
else:
txt_prompt += part
else:
txt_prompt = prompt
input_ids = self.tokenizer(txt_prompt, return_tensors="pt")['input_ids']
# image process
for img in self.image_objs:
self.img_process(img)
return input_ids
def img_process(self, image):
self.image_size = 560
resized_height = self.image_size
resized_width = self.image_size
from transformers.image_transforms import (
convert_to_rgb,
resize,
rescale,
normalize
)
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
PILImageResampling,
infer_channel_dimension_format,
to_numpy_array
)
image = convert_to_rgb(image)
image = to_numpy_array(image)
format = infer_channel_dimension_format(image)
resample = PILImageResampling.BICUBIC
image = resize(image, size=(resized_height, resized_width), resample=resample, input_data_format=format)
image = rescale(image, scale=1 / 255.0, input_data_format=format)
image = normalize(image=image, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, input_data_format=format)
image = image.transpose(2, 0, 1)
image = np.expand_dims(image, [0, 1, 2])
pad_val = np.zeros_like(image)
image = np.concatenate([image, pad_val, pad_val, pad_val], axis=2)
image = torch.from_numpy(image)
self.cross_attention_states = self.forward(image)
def forward(self, images):
aspect_ratio_ids = torch.tensor([[1]])
aspect_ratio_mask = torch.tensor([[[1, 0, 0, 0]]])
vision_outputs = self.visual(images, aspect_ratio_ids, aspect_ratio_mask)
cross_attention_states = vision_outputs[0]
cross_attention_states = cross_attention_states.type(self.multi_modal_projector.weight.dtype)
cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
-1, cross_attention_states.shape[-2], self.hidden_size)
return cross_attention_states
def embed(self, input_ids, images = None, videos = None):
txt_embeds = self.embed_(input_ids)
return txt_embeds