modules/SwissArmyTransformer/sat/model/official/eva2_model.py (112 lines of code) (raw):

import torch from sat.model.base_model import BaseModel from sat.model.mixins import BaseMixin import torch.nn as nn from .vit_model import ViTProperty from sat.ops.layernorm import LayerNorm class MaskedPatchEmbedMixin(BaseMixin): def __init__(self, in_channels, hidden_size, property): super(MaskedPatchEmbedMixin, self).__init__() self.property = property self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=property.patch_size, stride=property.patch_size) def word_embedding_forward(self, input_ids, **kwargs): """ Input: * input_ids with shape (batch_size, pre_len+post_len) * kwargs["image"] with shape (B, C, H, W) * kwargs["bool_masked_pos"] with shape (B, num_patches) Output: * (batch_size, pre_len+num_patches+post_len, hidden_size) """ images = kwargs["image"] embeddings = self.proj(images) embeddings = embeddings.flatten(2).transpose(1, 2) if kwargs.get("bool_masked_pos", None) is not None: batch_size, seq_len, _ = embeddings.size() mask_token = self.mask_token.expand(batch_size, seq_len, -1) w = kwargs["bool_masked_pos"].unsqueeze(-1).type_as(mask_token) embeddings = embeddings * (1 - w) + mask_token * w pre_word_embeddings = self.transformer.word_embeddings(input_ids[:,:self.property.pre_len]) post_word_embeddings = self.transformer.word_embeddings(input_ids[:,self.property.pre_len:self.property.pre_len+self.property.post_len]) embeddings = torch.cat([pre_word_embeddings, embeddings, post_word_embeddings], dim=1) return embeddings class EVA2FinalMixin(BaseMixin): def __init__(self, predict_feature_dim, hidden_size): super().__init__() self.lm_head = nn.Linear(hidden_size, predict_feature_dim) def final_forward(self, logits, **kwargs): logits = logits[:, 1:] if kwargs.get("bool_masked_pos", None) is not None: return self.lm_head(logits[kwargs["bool_masked_pos"]]) return self.lm_head(logits) class SwiGLUMixin(BaseMixin): def __init__(self, num_layers, in_features, hidden_features, act_layer=nn.SiLU, drop=0., eps=1e-6): super().__init__() # self.w1 = nn.Linear(in_features, hidden_features) self.w2 = nn.ModuleList([nn.Linear(in_features, hidden_features) for i in range(num_layers)]) self.act = act_layer() self.ffn_ln = nn.ModuleList([LayerNorm(hidden_features, eps=eps) for i in range(num_layers)]) # self.w3 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def mlp_forward(self, hidden_states, **kw_args): x = hidden_states origin = self.transformer.layers[kw_args['layer_id']].mlp x1 = origin.dense_h_to_4h(x) x2 = self.w2[kw_args['layer_id']](x) hidden = self.act(x1) * x2 x = self.ffn_ln[kw_args['layer_id']](hidden) x = origin.dense_4h_to_h(x) x = self.drop(x) return x # I don't know why the original eva2 model doesn't add bias to key, but adds bias to query and value. I just add bias to all of them here. from sat.model.position_embedding.vision_rotary_embeddings import VisionRotaryEmbeddingFast from sat.transformer_defaults import standard_attention from sat.mpu.utils import split_tensor_along_last_dim class EVA2AttnMixin(BaseMixin): def __init__(self, hidden_size, num_attention_heads, property): super().__init__() half_head_dim = hidden_size // num_attention_heads // 2 hw_seq_len = property.image_size[0] // property.patch_size self.rope = VisionRotaryEmbeddingFast( dim=half_head_dim, pt_seq_len=hw_seq_len, ) def attention_forward(self, hidden_states, mask, **kw_args): origin = self self = self.transformer.layers[kw_args['layer_id']].attention attention_fn = standard_attention if 'attention_fn' in self.hooks: attention_fn = self.hooks['attention_fn'] mixed_raw_layer = self.query_key_value(hidden_states) (mixed_query_layer, mixed_key_layer, mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) dropout_fn = self.attention_dropout if self.training else None query_layer = self._transpose_for_scores(mixed_query_layer) key_layer = self._transpose_for_scores(mixed_key_layer) value_layer = self._transpose_for_scores(mixed_value_layer) q_t = query_layer[:, :, 1:, :] ro_q_t = origin.rope(q_t) q = torch.cat((query_layer[:, :, :1, :], ro_q_t), -2).type_as(value_layer) k_t = key_layer[:, :, 1:, :] ro_k_t = origin.rope(k_t) k = torch.cat((key_layer[:, :, :1, :], ro_k_t), -2).type_as(value_layer) context_layer = attention_fn(q, k, value_layer, mask, dropout_fn, **kw_args) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) output = self.dense(context_layer) if self.training: output = self.output_dropout(output) return output class EVA2Model(BaseModel): def __init__(self, args, transformer=None, **kwargs): self.property = ViTProperty(args.image_size, args.patch_size, args.pre_len, args.post_len) args.max_sequence_length = self.property.seq_len super().__init__(args, transformer=transformer, **kwargs) self.add_mixin("patch_embedding", MaskedPatchEmbedMixin(args.in_channels, args.hidden_size, self.property)) # The old_property of ViTModel is not elegent. However, I don't have time to fix them (including vit, cait, deit, yolos). I can only discard it since eva model for now. # self.add_mixin("pos_embedding", InterpolatedPositionEmbeddingMixin(args.hidden_size, self.old_property, self.property)) self.add_mixin("eva2-final", EVA2FinalMixin(args.predict_feature_dim, args.hidden_size)) self.add_mixin("eva2-mlp", SwiGLUMixin(args.num_layers, args.hidden_size, args.inner_hidden_size, eps=args.layernorm_epsilon)) self.add_mixin("eva2-attn", EVA2AttnMixin(args.hidden_size, args.num_attention_heads, self.property)) def position_embedding_forward(self, position_ids, output_cross_layer, **kw_args): return self.transformer.position_embeddings.weight.unsqueeze(0) @classmethod def add_model_specific_args(cls, parser): group = parser.add_argument_group('EVA2', 'EVA2 Configurations') group.add_argument('--image-size', nargs='+', type=int, default=[224, 224]) group.add_argument('--pre-len', type=int, default=1) # [cls] by default group.add_argument('--post-len', type=int, default=0) # empty by default, but sometimes with special tokens, such as [det] in yolos. group.add_argument('--in-channels', type=int, default=3) group.add_argument('--patch-size', type=int, default=14) group.add_argument('--predict-feature-dim', type=int, default=768) return parser