maga_transformer/models/internvl_vit.py (555 lines of code) (raw):

from typing import List, Optional, Tuple, Union, Dict, Any from PIL import Image try: from decord import VideoReader, cpu except ModuleNotFoundError: VideoReader = None cpu = None import os import copy import logging import warnings import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint from einops import rearrange from timm.models.layers import DropPath from torch import nn from transformers.activations import ACT2FN from transformers.modeling_outputs import (BaseModelOutput, BaseModelOutputWithPooling) from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig from torchvision.transforms.functional import InterpolationMode import torchvision.transforms as T from maga_transformer.models.multimodal.multimodal_common import MultiModalEmbeddingInterface from maga_transformer.utils.multimodal_util import MMUrlType from maga_transformer.utils.flash_attn_utils import can_use_flash_attn from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters has_flash_attn = False try: if can_use_flash_attn(): from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import \ flash_attn_varlen_qkvpacked_func has_flash_attn = True except Exception as e: logging.info(f'initialize flash_attn failed, exception {e}, using default attention in internvl vit') IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) def build_transform(input_size): MEAN, STD = IMAGENET_MEAN, IMAGENET_STD transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=MEAN, std=STD) ]) return transform def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): best_ratio_diff = float('inf') best_ratio = (1, 1) area = width * height for ratio in target_ratios: target_aspect_ratio = ratio[0] / ratio[1] ratio_diff = abs(aspect_ratio - target_aspect_ratio) if ratio_diff < best_ratio_diff: best_ratio_diff = ratio_diff best_ratio = ratio elif ratio_diff == best_ratio_diff: if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: best_ratio = ratio return best_ratio def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): orig_width, orig_height = image.size aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio target_ratios = set( (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if i * j <= max_num and i * j >= min_num) target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) # find the closest aspect ratio to the target target_aspect_ratio = find_closest_aspect_ratio( aspect_ratio, target_ratios, orig_width, orig_height, image_size) # calculate the target width and height target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] for i in range(blocks): box = ( (i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size, ((i % (target_width // image_size)) + 1) * image_size, ((i // (target_width // image_size)) + 1) * image_size ) # split the image split_img = resized_img.crop(box) processed_images.append(split_img) assert len(processed_images) == blocks if use_thumbnail and len(processed_images) != 1: thumbnail_img = image.resize((image_size, image_size)) processed_images.append(thumbnail_img) return processed_images def pixel_shuffle(ps_version, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) # N, W, H * scale, C // scale --> N, H * scale, W, C // scale x = x.permute(0, 2, 1, 3).contiguous() # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) x = x.view(n, int(h * scale_factor), int(w * scale_factor), int(c / (scale_factor * scale_factor))) if ps_version == 'v1': warnings.warn("In ps_version 'v1', the height and width have not been swapped back, " 'which results in a transposed image.') else: x = x.permute(0, 2, 1, 3).contiguous() return x def get_index(bound, fps, max_frame, first_idx=0, num_segments=32): if bound: start, end = bound[0], bound[1] else: start, end = -100000, 100000 start_idx = max(first_idx, round(start * fps)) end_idx = min(round(end * fps), max_frame) seg_size = float(end_idx - start_idx) / num_segments frame_indices = np.array([ int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments) ]) return frame_indices def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32): vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) max_frame = len(vr) - 1 fps = float(vr.get_avg_fps()) img_list = [] frame_indices = get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments) for frame_index in frame_indices: img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB') img_list.append(img) return img_list class InternVLImageEmbedding(MultiModalEmbeddingInterface): def __init__(self, config: GptInitModelParameters): self.config = config config = config.mm_related_params.config self.select_layer = config["select_layer"] self.vision_model = InternVisionModel(InternVisionConfig(**config)) vit_hidden_size = config["hidden_size"] llm_hidden_size = config["llm_hidden_size"] self.downsample_ratio = config["downsample_ratio"] self.ps_version = config["ps_version"] self.mlp1 = nn.Sequential( nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size), nn.GELU(), nn.Linear(llm_hidden_size, llm_hidden_size) ) @property def _device(self): return self.vision_model.device @torch.inference_mode() def mm_process(self, mm_input, **kwargs): mm_type = kwargs.get("mm_type") if mm_type == MMUrlType.DEFAULT: if isinstance(mm_input, list): return self.image_embedding(mm_input, 1) else: return self.image_embedding([mm_input], 12)[0] elif mm_type == MMUrlType.IMAGE: if isinstance(mm_input, list): raise Exception("expect single image input, but get a list") return self.image_embedding([mm_input], 12)[0] elif mm_type == MMUrlType.VIDEO: if not isinstance(mm_input, list): raise Exception("expect video input, but get a single image") return self.image_embedding(mm_input, 1) else: raise Exception("unknown mm url type") def _mm_preprocess(self, data, **kwargs): mm_type = kwargs.get("mm_type") if mm_type == MMUrlType.DEFAULT: origin_data = copy.copy(data) try: return Image.open(data).convert("RGB") except Exception as e: try: return load_video(origin_data, num_segments=8, max_num=1) except Exception as e: raise Exception(str(e)) elif mm_type == MMUrlType.IMAGE: return Image.open(data).convert("RGB") elif mm_type == MMUrlType.VIDEO: return load_video(data, num_segments=8, max_num=1) else: raise Exception("unknown mm url type") @torch.no_grad() def image_embedding(self, images: List[Image.Image], max_num): # hugging face default value device = self._device config = self.config.mm_related_params.config input_size = config["image_size"] transform = build_transform(input_size=config["image_size"]) res = [] for image in images: now_images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) pixel_values = [transform(now_image) for now_image in now_images] pixel_values = torch.stack(pixel_values).to(device=device).to(self._data_type) if self.select_layer == -1: vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=False, return_dict=True).last_hidden_state else: vit_embeds = self.vision_model( pixel_values=pixel_values, output_hidden_states=True, return_dict=True).hidden_states[self.select_layer] vit_embeds = vit_embeds[:, 1:, :] h = w = int(vit_embeds.shape[1] ** 0.5) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) vit_embeds = pixel_shuffle(self.ps_version, vit_embeds, scale_factor=self.downsample_ratio) vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) vit_embeds = self.mlp1(vit_embeds) res.append(vit_embeds.reshape(-1, vit_embeds.shape[-1])) return torch.stack(res) class InternVisionConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to instantiate a vision encoder according to the specified arguments, defining the model architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. Args: num_channels (`int`, *optional*, defaults to 3): Number of color channels in the input images (e.g., 3 for RGB). patch_size (`int`, *optional*, defaults to 14): The size (resolution) of each patch. image_size (`int`, *optional*, defaults to 224): The size (resolution) of each image. qkv_bias (`bool`, *optional*, defaults to `False`): Whether to add a bias to the queries and values in the self-attention layers. hidden_size (`int`, *optional*, defaults to 3200): Dimensionality of the encoder layers and the pooler layer. num_attention_heads (`int`, *optional*, defaults to 25): Number of attention heads for each attention layer in the Transformer encoder. intermediate_size (`int`, *optional*, defaults to 12800): Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. qk_normalization (`bool`, *optional*, defaults to `True`): Whether to normalize the queries and keys in the self-attention layers. num_hidden_layers (`int`, *optional*, defaults to 48): Number of hidden layers in the Transformer encoder. use_flash_attn (`bool`, *optional*, defaults to `True`): Whether to use flash attention mechanism. hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults to 1e-6): The epsilon used by the layer normalization layers. dropout (`float`, *optional*, defaults to 0.0): The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. drop_path_rate (`float`, *optional*, defaults to 0.0): Dropout rate for stochastic depth. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. initializer_factor (`float`, *optional*, defaults to 0.1): A factor for layer scale. """ model_type = 'intern_vit_6b' def __init__( self, num_channels=3, patch_size=14, image_size=224, qkv_bias=False, hidden_size=3200, num_attention_heads=25, intermediate_size=12800, qk_normalization=True, num_hidden_layers=48, use_flash_attn=True, hidden_act='gelu', norm_type='rms_norm', layer_norm_eps=1e-6, dropout=0.0, drop_path_rate=0.0, attention_dropout=0.0, initializer_range=0.02, initializer_factor=0.1, **kwargs, ): super().__init__(**kwargs) self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.dropout = dropout self.drop_path_rate = drop_path_rate self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_channels = num_channels self.patch_size = patch_size self.image_size = image_size self.initializer_range = initializer_range self.initializer_factor = initializer_factor self.attention_dropout = attention_dropout self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act self.norm_type = norm_type self.qkv_bias = qkv_bias self.qk_normalization = qk_normalization self.use_flash_attn = use_flash_attn @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig': config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) if 'vision_config' in config_dict: config_dict = config_dict['vision_config'] if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type: logging.warning( f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.' ) return cls.from_dict(config_dict, **kwargs) class FlashAttention(nn.Module): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None): super().__init__() self.softmax_scale = softmax_scale self.dropout_p = attention_dropout def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, max_s=None, need_weights=False): """Implements the multihead softmax attention. Arguments --------- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None if unpadded: (nnz, 3, h, d) key_padding_mask: a bool tensor of shape (B, S) """ assert not need_weights assert qkv.dtype in [torch.float16, torch.bfloat16] assert qkv.is_cuda if cu_seqlens is None: batch_size = qkv.shape[0] seqlen = qkv.shape[1] if key_padding_mask is None: qkv = rearrange(qkv, 'b s ... -> (b s) ...') max_s = seqlen cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device) output = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal ) output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) else: nheads = qkv.shape[-2] x = rearrange(qkv, 'b s three h d -> b s (three h d)') x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) output_unpad = flash_attn_varlen_qkvpacked_func( x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal ) output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices, batch_size, seqlen), 'b s (h d) -> b s h d', h=nheads) else: assert max_s is not None output = flash_attn_varlen_qkvpacked_func( qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, softmax_scale=self.softmax_scale, causal=causal ) return output, None class InternRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) try: from apex.normalization import FusedRMSNorm InternRMSNorm = FusedRMSNorm # noqa logging.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm') except ImportError: # using the normal InternRMSNorm pass except Exception: logging.warning('discovered apex but it failed to load, falling back to InternRMSNorm') pass NORM2FN = { 'rms_norm': InternRMSNorm, 'layer_norm': nn.LayerNorm, } class InternVisionEmbeddings(nn.Module): def __init__(self, config: InternVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.image_size = config.image_size self.patch_size = config.patch_size self.class_embedding = nn.Parameter( torch.randn(1, 1, self.embed_dim), ) self.patch_embedding = nn.Conv2d( in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size ) self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_positions = self.num_patches + 1 self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) def _get_pos_embed(self, pos_embed, H, W): target_dtype = pos_embed.dtype pos_embed = pos_embed.float().reshape( 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False). \ reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype) return pos_embed def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: target_dtype = self.patch_embedding.weight.dtype patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height] batch_size, _, height, width = patch_embeds.shape patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) embeddings = torch.cat([class_embeds, patch_embeds], dim=1) position_embedding = torch.cat([ self.position_embedding[:, :1, :], self._get_pos_embed(self.position_embedding[:, 1:, :], height, width) ], dim=1) embeddings = embeddings + position_embedding.to(target_dtype) return embeddings class InternAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: InternVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.use_flash_attn = config.use_flash_attn and has_flash_attn if config.use_flash_attn and not has_flash_attn: logging.info('Warning: Flash Attention is not available, use_flash_attn is set to False.') self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:' f' {self.num_heads}).' ) self.scale = self.head_dim ** -0.5 self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias) self.attn_drop = nn.Dropout(config.attention_dropout) self.proj_drop = nn.Dropout(config.dropout) self.qk_normalization = config.qk_normalization if self.qk_normalization: self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) if self.use_flash_attn: self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout) self.proj = nn.Linear(self.embed_dim, self.embed_dim) def _naive_attn(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) if self.qk_normalization: B_, H_, N_, D_ = q.shape q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) attn = ((q * self.scale) @ k.transpose(-2, -1)) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x def _flash_attn(self, x, key_padding_mask=None, need_weights=False): qkv = self.qkv(x) qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) if self.qk_normalization: q, k, v = qkv.unbind(2) q = self.q_norm(q.flatten(-2, -1)).view(q.shape) k = self.k_norm(k.flatten(-2, -1)).view(k.shape) qkv = torch.stack([q, k, v], dim=2) context, _ = self.inner_attn( qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False ) outs = self.proj(rearrange(context, 'b s h d -> b s (h d)')) outs = self.proj_drop(outs) return outs def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states) return x class InternMLP(nn.Module): def __init__(self, config: InternVisionConfig): super().__init__() self.config = config self.act = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.fc2(hidden_states) return hidden_states class InternVisionEncoderLayer(nn.Module): def __init__(self, config: InternVisionConfig, drop_path_rate: float): super().__init__() self.embed_dim = config.hidden_size self.intermediate_size = config.intermediate_size self.norm_type = config.norm_type self.attn = InternAttention(config) self.mlp = InternMLP(config) self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() def forward( self, hidden_states: torch.Tensor, ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]: """ Args: hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)` """ hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1) hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2) return hidden_states class InternVisionEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`InternEncoderLayer`]. Args: config (`InternConfig`): The corresponding vision configuration for the `InternEncoder`. """ def __init__(self, config: InternVisionConfig): super().__init__() self.config = config # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)] self.layers = nn.ModuleList([ InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)]) self.gradient_checkpointing = True def forward( self, inputs_embeds, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Embedded representation of the inputs. Should be float, not int tokens. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_states = () if output_hidden_states else None hidden_states = inputs_embeds for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = torch.utils.checkpoint.checkpoint( encoder_layer, hidden_states) else: layer_outputs = encoder_layer( hidden_states, ) hidden_states = layer_outputs if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, encoder_states] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states ) class InternVisionModel(PreTrainedModel): main_input_name = 'pixel_values' _supports_flash_attn_2 = True config_class = InternVisionConfig _no_split_modules = ['InternVisionEncoderLayer'] def __init__(self, config: InternVisionConfig): super().__init__(config) self.config = config self.embeddings = InternVisionEmbeddings(config) self.encoder = InternVisionEncoder(config) def resize_pos_embeddings(self, old_size, new_size, patch_size): pos_emb = self.embeddings.position_embedding _, num_positions, embed_dim = pos_emb.shape cls_emb = pos_emb[:, :1, :] pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2) pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False) pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1) pos_emb = torch.cat([cls_emb, pos_emb], dim=1) self.embeddings.position_embedding = nn.Parameter(pos_emb) self.embeddings.image_size = new_size logging.info('Resized position embeddings from {} to {}'.format(old_size, new_size)) def get_input_embeddings(self): return self.embeddings def forward( self, pixel_values: Optional[torch.tensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, pixel_embeds: Optional[torch.tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None and pixel_embeds is None: raise ValueError('You have to specify pixel_values or pixel_embeds') if pixel_embeds is not None: hidden_states = pixel_embeds else: if len(pixel_values.shape) == 4: hidden_states = self.embeddings(pixel_values) else: raise ValueError(f'wrong pixel_values size: {pixel_values.shape}') encoder_outputs = self.encoder( inputs_embeds=hidden_states, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs.last_hidden_state pooled_output = last_hidden_state[:, 0, :] if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, )