megatron_patch/model/qwen_vl/visual.py (296 lines of code) (raw):

# Copyright (c) 2023 Alibaba PAI Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict import math import requests from functools import partial from PIL import Image from typing import Callable, Optional, List import numpy as np import torch from torch import nn from torch.nn import functional as F from torch.nn.init import trunc_normal_ from torchvision import transforms from torchvision.transforms import InterpolationMode def get_abs_pos(abs_pos, tgt_size): """ This function resamples absolute positional embeddings `abs_pos` to match a target size `tgt_size`. If the target size is different from the source size, it performs interpolation; otherwise, returns the input. Args: abs_pos (torch.Tensor): A tensor containing absolute positional embeddings. tgt_size (int): The target sequence length after resampling. Returns: torch.Tensor: A tensor containing the resampled positional embeddings. """ src_size = int(math.sqrt(abs_pos.size(0))) tgt_size = int(math.sqrt(tgt_size)) dtype = abs_pos.dtype if src_size != tgt_size: return F.interpolate( abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), size=(tgt_size, tgt_size), mode="bicubic", align_corners=False, ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) else: return abs_pos # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ Generate a 2D sine-cosine positional embedding. Args: embed_dim (int): The dimension of the embedding. grid_size (int): The height and width of the 2D grid. cls_token (bool): If True, an additional position for the class token is included. Returns: np.ndarray: A numpy array with shape [grid_size*grid_size, embed_dim] if `cls_token` is False, otherwise [1+grid_size*grid_size, embed_dim]. """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): """ Generate a 2D sine-cosine positional embedding from a provided grid of coordinates. Args: embed_dim (int): The dimension of the embedding. grid (np.ndarray): An array containing the grid coordinates with shape (num_positions, 2). Returns: np.ndarray: A numpy array with the generated positional embedding with shape (num_positions, embed_dim). """ assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ Generate a 1D sine-cosine positional embedding from a provided grid of coordinates. Args: embed_dim (int): The dimension of the embedding. pos (np.ndarray): An array containing the grid coordinates with shape (num_positions, 1). Returns: np.ndarray: A numpy array with the generated positional embedding. """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) omega /= embed_dim / 2. omega = 1. / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class Resampler(nn.Module): """ A 2D perceiver-resampler network with one cross attention layers by (grid_size**2) learnable queries and 2d sincos pos_emb Attributes: num_queries (int): The number of learnable queries, which determines the output sequence length. embed_dim (int): The embedding dimension of the queries and the input key-value pairs. num_heads (int): The number of attention heads. pos_embed (nn.Parameter): A tensor containing the fixed 2D positional embeddings. query (nn.Parameter): A tensor containing the learnable query embeddings. attn (nn.MultiheadAttention): The multi-head attention module. ln_q (nn.LayerNorm): Layer normalization applied to the queries. ln_kv (nn.LayerNorm): Layer normalization applied to the input key-value pairs. """ def __init__( self, grid_size, embed_dim, num_heads, kv_dim=None, norm_layer=nn.LayerNorm ): super().__init__() self.num_queries = grid_size ** 2 self.embed_dim = embed_dim self.num_heads = num_heads self.pos_embed = nn.Parameter( torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float() ).requires_grad_(False) self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) trunc_normal_(self.query, std=.02) if kv_dim is not None and kv_dim != embed_dim: self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) else: self.kv_proj = nn.Identity() self.attn = nn.MultiheadAttention(embed_dim, num_heads) self.ln_q = norm_layer(embed_dim) self.ln_kv = norm_layer(embed_dim) def forward(self, x, attn_mask=None): """ The forward pass of the Resampler module. Args: x (torch.Tensor): The input key-value pairs. attn_mask (torch.Tensor, optional): An optional attention mask. Returns: torch.Tensor: The resampled output. """ pos_embed = get_abs_pos(self.pos_embed, x.size(1)) x = self.kv_proj(x) x = self.ln_kv(x).permute(1, 0, 2) N = x.shape[1] q = self.ln_q(self.query) out = self.attn( self._repeat(q, N) + self.pos_embed.unsqueeze(1), x + pos_embed.unsqueeze(1), x, attn_mask=attn_mask)[0] return out.permute(1, 0, 2) def _repeat(self, query, N: int): return query.unsqueeze(1).repeat(1, N, 1) class VisualAttention(nn.Module): """self-attention layer class. Self-attention layer takes input with size [s, b, h] and returns output of the same size. """ def __init__(self, embed_dim, num_heads, bias=True, kdim=None, vdim=None): super(VisualAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads # Per attention head and per partition values. assert embed_dim % num_heads == 0 self.hidden_size_per_attention_head = embed_dim // num_heads self.num_attention_heads_per_partition = num_heads self.hidden_size_per_partition = embed_dim # Strided linear layer. assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently' self.in_proj = nn.Linear(embed_dim, 3 * embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) def forward(self, query, key, value, attn_mask = None): # query/key/value: [sq, b, h] sq, b, _ = query.size() assert torch.allclose(query, key), 'Only Support Self-Attention Currently' sk = sq mixed_x_layer = self.in_proj(query) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] query_layer, key_layer, value_layer = mixed_x_layer.split( self.hidden_size_per_attention_head, dim=-1) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(sq, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head).transpose(0, 1) # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view(sk, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head).transpose(0, 1) q_scaled = query_layer / self.norm_factor if attn_mask is not None: attention_probs = torch.baddbmm(attn_mask, q_scaled, key_layer.transpose(-2, -1)) else: attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1)) attention_probs = attention_probs.softmax(dim=-1) value_layer = value_layer.view(sk, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head).transpose(0, 1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer) # change view [b, np, sq, hn] context_layer = context_layer.view(b, self.num_attention_heads_per_partition, sq, self.hidden_size_per_attention_head) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) output = self.out_proj(context_layer) return output class VisualAttentionBlock(nn.Module): def __init__( self, d_model: int, n_head: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU, norm_layer: Callable = nn.LayerNorm, is_cross_attention: bool = False, ): super().__init__() self.ln_1 = norm_layer(d_model) if is_cross_attention: self.ln_1_kv = norm_layer(d_model) self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.attn = VisualAttention(d_model, n_head) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model)) ])) def attention( self, q_x: torch.Tensor, k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ): k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None return self.attn(q_x, k_x, v_x, attn_mask=attn_mask) def forward( self, q_x: torch.Tensor, k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, ): k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) x = x + self.mlp(self.ln_2(x)) return x class TransformerBlock(nn.Module): def __init__( self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU, norm_layer: Callable = nn.LayerNorm, ): super().__init__() self.width = width self.layers = layers self.resblocks = nn.ModuleList([ VisualAttentionBlock( width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer) for _ in range(layers) ]) def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].mlp.c_fc.weight.dtype def get_cast_device(self) -> torch.device: return self.resblocks[0].mlp.c_fc.weight.device def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): for r in self.resblocks: x = r(x, attn_mask=attn_mask) return x class VisionTransformer(nn.Module): """ A Vision Transformer (ViT) class for image classification tasks. Attributes: image_size (int): The size of the input images (assumed square). patch_size (int): The size of each image patch. width (int): The dimensionality of the token embeddings. layers (int): The number of transformer blocks. heads (int): The number of attention heads in each block. mlp_ratio (float): Determines the size of the MLP as a ratio of the embedding dimension. n_queries (int): The number of queries for the attention pooling. output_dim (int): The dimensionality of the output token embeddings. positional_embedding (torch.nn.Parameter): The learnable positional embeddings. conv1 (torch.nn.Conv2d): The convolutional layer used to obtain patch embeddings. transformer (TransformerBlock): The sequence of transformer blocks. attn_pool (Resampler): The attention pooling layer. ln_post (torch.nn.LayerNorm): The final layer normalization layer. proj (torch.nn.Parameter): The projection matrix for the output embeddings. """ def __init__( self, image_size: int, patch_size: int, width: int, layers: int, heads: int, mlp_ratio: float, n_queries: int = 256, output_dim: int = 512, **kwargs ): super().__init__() image_height, image_width = self.image_size = (image_size, image_size) patch_height, patch_width = self.patch_size = (patch_size, patch_size) self.grid_size = (image_height // patch_height, image_width // patch_width) self.output_dim = output_dim mean = (0.48145466, 0.4578275, 0.40821073) std = (0.26862954, 0.26130258, 0.27577711) self.image_transform = transforms.Compose([ transforms.Resize( (image_size, image_size), interpolation=InterpolationMode.BICUBIC ), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) # class embeddings and positional embeddings scale = width ** -0.5 self.positional_embedding = nn.Parameter(scale * torch.randn(256, width)) norm_layer = partial(nn.LayerNorm, eps=1e-6) act_layer = nn.GELU self.ln_pre = norm_layer(width) self.transformer = TransformerBlock( width, layers, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer, ) self.attn_pool = Resampler( grid_size=int(math.sqrt(n_queries)), embed_dim=output_dim, num_heads=output_dim // 128, kv_dim=width, norm_layer=norm_layer, ) self.ln_post = norm_layer(output_dim) self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim)) def forward(self, x: torch.Tensor): x = x.to( dtype=self.transformer.get_cast_dtype(), device=self.transformer.get_cast_device(), ) # to patches x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = x + get_abs_pos(self.positional_embedding, x.size(1)) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD x = self.attn_pool(x) x = self.ln_post(x) x = x @ self.proj return x def encode(self, image_paths: List[str]): images = [] for image_path in image_paths: if image_path.startswith("http://") or image_path.startswith("https://"): image = Image.open(requests.get(image_path, stream=True).raw) else: image = Image.open(image_path) image = image.convert("RGB") images.append(self.image_transform(image)) images = torch.stack(images, dim=0) return self(images)