megatron_patch/model/qwen_vl/visual.py (296 lines of code) (raw):
# Copyright (c) 2023 Alibaba PAI Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
import math
import requests
from functools import partial
from PIL import Image
from typing import Callable, Optional, List
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import trunc_normal_
from torchvision import transforms
from torchvision.transforms import InterpolationMode
def get_abs_pos(abs_pos, tgt_size):
"""
This function resamples absolute positional embeddings `abs_pos` to match a target size `tgt_size`.
If the target size is different from the source size, it performs interpolation; otherwise, returns the input.
Args:
abs_pos (torch.Tensor): A tensor containing absolute positional embeddings.
tgt_size (int): The target sequence length after resampling.
Returns:
torch.Tensor: A tensor containing the resampled positional embeddings.
"""
src_size = int(math.sqrt(abs_pos.size(0)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
Generate a 2D sine-cosine positional embedding.
Args:
embed_dim (int): The dimension of the embedding.
grid_size (int): The height and width of the 2D grid.
cls_token (bool): If True, an additional position for the class token is included.
Returns:
np.ndarray: A numpy array with shape [grid_size*grid_size, embed_dim] if `cls_token` is False,
otherwise [1+grid_size*grid_size, embed_dim].
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
"""
Generate a 2D sine-cosine positional embedding from a provided grid of coordinates.
Args:
embed_dim (int): The dimension of the embedding.
grid (np.ndarray): An array containing the grid coordinates with shape (num_positions, 2).
Returns:
np.ndarray: A numpy array with the generated positional embedding with shape (num_positions, embed_dim).
"""
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
Generate a 1D sine-cosine positional embedding from a provided grid of coordinates.
Args:
embed_dim (int): The dimension of the embedding.
pos (np.ndarray): An array containing the grid coordinates with shape (num_positions, 1).
Returns:
np.ndarray: A numpy array with the generated positional embedding.
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class Resampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Attributes:
num_queries (int): The number of learnable queries, which determines the output sequence length.
embed_dim (int): The embedding dimension of the queries and the input key-value pairs.
num_heads (int): The number of attention heads.
pos_embed (nn.Parameter): A tensor containing the fixed 2D positional embeddings.
query (nn.Parameter): A tensor containing the learnable query embeddings.
attn (nn.MultiheadAttention): The multi-head attention module.
ln_q (nn.LayerNorm): Layer normalization applied to the queries.
ln_kv (nn.LayerNorm): Layer normalization applied to the input key-value pairs.
"""
def __init__(
self,
grid_size,
embed_dim,
num_heads,
kv_dim=None,
norm_layer=nn.LayerNorm
):
super().__init__()
self.num_queries = grid_size ** 2
self.embed_dim = embed_dim
self.num_heads = num_heads
self.pos_embed = nn.Parameter(
torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
).requires_grad_(False)
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
else:
self.kv_proj = nn.Identity()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
def forward(self, x, attn_mask=None):
"""
The forward pass of the Resampler module.
Args:
x (torch.Tensor): The input key-value pairs.
attn_mask (torch.Tensor, optional): An optional attention mask.
Returns:
torch.Tensor: The resampled output.
"""
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
x = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask)[0]
return out.permute(1, 0, 2)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class VisualAttention(nn.Module):
"""self-attention layer class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(self, embed_dim, num_heads,
bias=True, kdim=None, vdim=None):
super(VisualAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
# Per attention head and per partition values.
assert embed_dim % num_heads == 0
self.hidden_size_per_attention_head = embed_dim // num_heads
self.num_attention_heads_per_partition = num_heads
self.hidden_size_per_partition = embed_dim
# Strided linear layer.
assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently'
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
def forward(self, query, key, value, attn_mask = None):
# query/key/value: [sq, b, h]
sq, b, _ = query.size()
assert torch.allclose(query, key), 'Only Support Self-Attention Currently'
sk = sq
mixed_x_layer = self.in_proj(query)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
query_layer, key_layer, value_layer = mixed_x_layer.split(
self.hidden_size_per_attention_head, dim=-1)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(sq,
b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(sk,
b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
q_scaled = query_layer / self.norm_factor
if attn_mask is not None:
attention_probs = torch.baddbmm(attn_mask, q_scaled, key_layer.transpose(-2, -1))
else:
attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
attention_probs = attention_probs.softmax(dim=-1)
value_layer = value_layer.view(sk,
b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer)
# change view [b, np, sq, hn]
context_layer = context_layer.view(b,
self.num_attention_heads_per_partition,
sq, self.hidden_size_per_attention_head)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.out_proj(context_layer)
return output
class VisualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
is_cross_attention: bool = False,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
if is_cross_attention:
self.ln_1_kv = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.attn = VisualAttention(d_model, n_head)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
def attention(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
return self.attn(q_x, k_x, v_x, attn_mask=attn_mask)
def forward(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
):
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
x = x + self.mlp(self.ln_2(x))
return x
class TransformerBlock(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList([
VisualAttentionBlock(
width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer)
for _ in range(layers)
])
def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype
def get_cast_device(self) -> torch.device:
return self.resblocks[0].mlp.c_fc.weight.device
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
for r in self.resblocks:
x = r(x, attn_mask=attn_mask)
return x
class VisionTransformer(nn.Module):
"""
A Vision Transformer (ViT) class for image classification tasks.
Attributes:
image_size (int): The size of the input images (assumed square).
patch_size (int): The size of each image patch.
width (int): The dimensionality of the token embeddings.
layers (int): The number of transformer blocks.
heads (int): The number of attention heads in each block.
mlp_ratio (float): Determines the size of the MLP as a ratio of the embedding dimension.
n_queries (int): The number of queries for the attention pooling.
output_dim (int): The dimensionality of the output token embeddings.
positional_embedding (torch.nn.Parameter): The learnable positional embeddings.
conv1 (torch.nn.Conv2d): The convolutional layer used to obtain patch embeddings.
transformer (TransformerBlock): The sequence of transformer blocks.
attn_pool (Resampler): The attention pooling layer.
ln_post (torch.nn.LayerNorm): The final layer normalization layer.
proj (torch.nn.Parameter): The projection matrix for the output embeddings.
"""
def __init__(
self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
n_queries: int = 256,
output_dim: int = 512,
**kwargs
):
super().__init__()
image_height, image_width = self.image_size = (image_size, image_size)
patch_height, patch_width = self.patch_size = (patch_size, patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width)
self.output_dim = output_dim
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
self.image_transform = transforms.Compose([
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width,
kernel_size=patch_size, stride=patch_size, bias=False)
# class embeddings and positional embeddings
scale = width ** -0.5
self.positional_embedding = nn.Parameter(scale * torch.randn(256, width))
norm_layer = partial(nn.LayerNorm, eps=1e-6)
act_layer = nn.GELU
self.ln_pre = norm_layer(width)
self.transformer = TransformerBlock(
width,
layers,
heads,
mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
)
self.attn_pool = Resampler(
grid_size=int(math.sqrt(n_queries)),
embed_dim=output_dim,
num_heads=output_dim // 128,
kv_dim=width,
norm_layer=norm_layer,
)
self.ln_post = norm_layer(output_dim)
self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
def forward(self, x: torch.Tensor):
x = x.to(
dtype=self.transformer.get_cast_dtype(),
device=self.transformer.get_cast_device(),
)
# to patches
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = x + get_abs_pos(self.positional_embedding, x.size(1))
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.attn_pool(x)
x = self.ln_post(x)
x = x @ self.proj
return x
def encode(self, image_paths: List[str]):
images = []
for image_path in image_paths:
if image_path.startswith("http://") or image_path.startswith("https://"):
image = Image.open(requests.get(image_path, stream=True).raw)
else:
image = Image.open(image_path)
image = image.convert("RGB")
images.append(self.image_transform(image))
images = torch.stack(images, dim=0)
return self(images)