modules/SwissArmyTransformer/sat/model/official/clip_model.py (114 lines of code) (raw):

import os import math from re import L import torch import torch.nn as nn import torch.nn.functional as F from sat.model.base_model import BaseMixin, BaseModel, non_conflict, load_checkpoint, get_model from sat.model.official.vit_model import ViTModel, ImagePatchEmbeddingMixin from sat.model.mixins import BaseMixin from sat import mpu from sat.model.transformer import LayerNorm from sat import update_args_with_file from sat.resources import auto_create """ CLIP model follows Siamese architecture. For image encoder, it is a ViTModel with 32x32 patch. For text encoder, it is a BaseModel with causal mask. """ class QuickGELUActivation(nn.Module): """ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs """ def forward(self, input): return input * torch.sigmoid(1.702 * input) class ImageMixin(BaseMixin): def __init__(self, vision_embed_dim, projection_dim, layernorm_epsilon): super().__init__() self.pre_layernorm = LayerNorm(vision_embed_dim, eps=layernorm_epsilon) self.visual_projection = nn.Linear(vision_embed_dim, projection_dim, bias=False) def layer_forward(self, hidden_states, mask, *args, **kw_args): ''' hidden_states: [batch, seq_len, hidden_size] mask: [(1, 1), seq_len, seq_len] ''' layer = self.transformer.layers[kw_args['layer_id']] if kw_args['layer_id'] == 0: hidden_states = self.pre_layernorm(hidden_states) output = layer(hidden_states, mask, *args, **kw_args) return output def final_forward(self, logits, **kw_args): return self.visual_projection(logits[:, 0]) class PatchMixin(ImagePatchEmbeddingMixin): def __init__(self, in_channels, hidden_size, property): super().__init__(in_channels, hidden_size, property) self.property = property self.proj = nn.Conv2d(in_channels, hidden_size, kernel_size=property.patch_size, stride=property.patch_size, bias=False) class ImageEncoder(ViTModel): def __init__(self, args, layernorm_epsilon=1e-5, activation_func=QuickGELUActivation()): super().__init__(args, layernorm_epsilon=layernorm_epsilon, activation_func=activation_func) self.del_mixin('cls') self.add_mixin('image_enc', ImageMixin(args.hidden_size, args.projection_dim, layernorm_epsilon)) self.del_mixin('patch_embedding') self.add_mixin('patch_embedding', PatchMixin(args.in_channels, args.hidden_size, self.property)) @classmethod def add_model_specific_args(cls, parser): group = parser.add_argument_group('CLIP-image', 'CLIP image encoder Configurations') group.add_argument('--projection-dim', type=int) return super().add_model_specific_args(parser) class TextMixin(BaseMixin): def __init__(self, text_embed_dim, projection_dim): super().__init__() self.text_projection = nn.Linear(text_embed_dim, projection_dim, bias=False) def final_forward(self, logits, **kw_args): return self.text_projection(logits[:, -1]) def layer_forward(self, hidden_states, mask, *args, **kw_args): # causal mask mask = mask - mask.triu(1) layer = self.transformer.layers[kw_args['layer_id']] output = layer(hidden_states, mask, *args, **kw_args) return output class TextEncoder(BaseModel): def __init__(self, args, layernorm_epsilon=1e-5, activation_func=QuickGELUActivation()): super().__init__(args, layernorm_epsilon=layernorm_epsilon, activation_func=activation_func) self.add_mixin('text_enc', TextMixin(args.hidden_size, args.projection_dim)) # @classmethod # def add_model_specific_args(cls, parser): # return super().add_model_specific_args(parser) import argparse class CLIP(nn.Module): def __init__(self, args, layernorm_epsilon=1e-5): super().__init__() self.image_encoder = ImageEncoder(args, layernorm_epsilon=layernorm_epsilon) text_args = argparse.Namespace(**vars(args)) override_attrs = ['vocab_size', 'num_layers', 'hidden_size', 'num_attention_heads', 'layernorm_order', 'max_sequence_length', 'inner_hidden_size', 'hidden_size_per_attention_head'] for name in override_attrs: text_attr = getattr(text_args, 'text_' + name, None) if text_attr is not None: # else use encoder-config setattr(text_args, name, text_attr) self.text_encoder = TextEncoder(text_args, layernorm_epsilon=layernorm_epsilon) self.logit_scale = nn.Parameter(torch.ones([]) * args.logit_scale_init_value) def encode_image(self, input_ids, position_ids, attention_mask=None, **kw_args): return self.image_encoder(input_ids, position_ids, attention_mask, **kw_args) def encode_text(self, input_ids, position_ids, attention_mask, **kw_args): return self.text_encoder(input_ids, position_ids, attention_mask, **kw_args) def reinit(self, mixin_names): # please use different mixin names for two encoders self.image_encoder.reinit(mixin_names) self.text_encoder.reinit(mixin_names) def forward(self, image_input_ids, image_position_ids, text_input_ids, text_position_ids, *, image_attention_mask=None, text_attention_mask=None, **kw_args): image_embeds, *image_mems = self.encode_image(image_input_ids, image_position_ids, attention_mask=image_attention_mask, **kw_args) text_embeds, *text_mems = self.encode_text(text_input_ids, text_position_ids, attention_mask=text_attention_mask, **kw_args) # normalized features image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) # cosine similarity as logits logit_scale = self.logit_scale.exp() logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale logits_per_image = logits_per_text.T return image_embeds, text_embeds, logits_per_text, logits_per_image @classmethod def add_model_specific_args(cls, parser): group = parser.add_argument_group('SiameseModel', 'CLIP') group.add_argument("--text-layernorm-order", type=str, default=None) group.add_argument("--text-num-layers", type=int, default=None) group.add_argument("--text-hidden-size", type=int, default=None) group.add_argument("--text-num-attention-heads", type=int, default=None) group.add_argument("--text-max-sequence-length", type=int, default=None) group.add_argument("--text-inner-hidden-size", type=int, default=None) group.add_argument("--text-hidden-size-per-attention-head", type=int, default=None) group.add_argument("--logit-scale-init-value", type=float, default=None) return parser @classmethod def from_pretrained(cls, args, name, *, path=None, url=None): model_path = auto_create(name, path=path, url=url) args = update_args_with_file(args, path=os.path.join(model_path, 'model_config.json')) model = get_model(args, cls) load_checkpoint(model, args, load_path=model_path) return model, args