modules/SwissArmyTransformer/sat/model/official/glm4v_model.py (137 lines of code) (raw):
from .eva_clip_model import EVA2CLIPModel
from .chatglm4_model import ChatGLM4Model
import json
import os
import torch
import torch.nn.functional as F
from sat.model.base_model import BaseMixin
import math
import torch.nn as nn
from sat import mpu
from sat.helpers import print_rank0
import torch.nn.init as init
from sat.training.model_io import extract_model_specific_args_to_dump
import argparse
from copy import deepcopy
def init_weights(module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class GLU(nn.Module):
def __init__(self, args, in_features):
super().__init__()
self.linear_proj = nn.Linear(in_features, args.hidden_size, bias=False)
self.norm1 = nn.LayerNorm(args.hidden_size)
self.act1 = nn.GELU()
self.act2 = nn.functional.silu
self.dense_h_to_4h = nn.Linear(args.hidden_size, args.inner_hidden_size, bias=False)
self.gate_proj = nn.Linear(args.hidden_size, args.inner_hidden_size, bias=False)
self.dense_4h_to_h = nn.Linear(args.inner_hidden_size, args.hidden_size, bias=False)
# self.norm2 = nn.LayerNorm(args.hidden_size)
def forward(self, x):
x = self.linear_proj(x)
x = self.act1(self.norm1(x))
x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
x = self.dense_4h_to_h(x)
# x = self.norm2(x)
return x
def override_dist_dtype_device_args(args, b={}):
if args.mode == 'inference':
minimal_args = argparse.Namespace(
world_size=args.world_size,
rank=args.rank,
local_rank=args.local_rank,
skip_init=args.skip_init,
use_gpu_initialization=args.use_gpu_initialization,
deepspeed=args.deepspeed,
bf16=args.bf16,
fp16=args.fp16,
mode=args.mode,
device=args.device
)
else:
minimal_args = argparse.Namespace(
world_size=args.world_size,
rank=args.rank,
local_rank=args.local_rank,
skip_init=args.skip_init,
use_gpu_initialization=args.use_gpu_initialization,
deepspeed=args.deepspeed,
bf16=args.bf16,
fp16=args.fp16,
mode=args.mode,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
device=args.device,
hidden_dropout=0.,
attention_dropout=0.
)
if hasattr(args, 'model_parallel_size'):
b['model_parallel_size'] = args.model_parallel_size
return argparse.Namespace(**deepcopy(b), **vars(minimal_args))
class ImageMixin(BaseMixin):
def __init__(self, args):
super().__init__()
# Option 1. if not loading from ckpt, using this code
if args.eva_args:
vit_args = override_dist_dtype_device_args(args, args.eva_args)
self.vit_model = EVA2CLIPModel(EVA2CLIPModel.get_args(**vars(vit_args)))
# ===============================================
# Option 2. if loading from vit checkpoint, use this code
else:
url = os.path.join(os.getenv("SAT_HOME"), 'eva-clip-4b-14-x-drop-last-layer')
print("loading vit checkpoint from", url)
vit_args = override_dist_dtype_device_args(args, args.eva_args)
self.vit_model, vit_args = EVA2CLIPModel.from_pretrained(url, vit_args, overwrite_args={'model_parallel_size': args.model_parallel_size} if args.model_parallel_size > 1 else {})
args.eva_args = extract_model_specific_args_to_dump(vit_args, self.vit_model)
print("loading finished", url)
args.proj_hidden_size = args.hidden_size if args.proj_hidden_size is None else args.proj_hidden_size
self.conv = nn.Conv2d(in_channels=self.vit_model.transformer.hidden_size, out_channels=args.proj_hidden_size, kernel_size=2, stride=2)
self.linear_proj = GLU(args, args.proj_hidden_size)
self.linear_proj.apply(init_weights)
self.image_length = args.image_length
self.boi = nn.Parameter(torch.ones(1, 1, args.hidden_size).float())
self.eoi = nn.Parameter(torch.ones(1, 1, args.hidden_size).float())
init.xavier_uniform_(self.boi)
init.xavier_uniform_(self.eoi)
def word_embedding_forward(self, input_ids, output_cross_layer, **kw_args):
vision_inputs = {}
for k in kw_args:
if k.startswith('vision_') and k != 'vision_expert_mask':
vision_inputs[k[7:]] = kw_args[k]
if input_ids.shape[1] == 1 or not vision_inputs:
word_embedding = self.transformer.word_embeddings(input_ids)
else:
if 'position_ids' not in vision_inputs:
vision_inputs['position_ids'] = None
image_emb = self.vit_model(**vision_inputs)[0]
b, s, e = image_emb.shape # (b, 6400, 1792)
grid_size = int(s**0.5)
image_emb = image_emb.view(b, grid_size, grid_size, e).permute(0,3,1,2) # (b, 1792, 80, 80)
image_emb = self.conv(image_emb) # (b, 4096, 40, 40)
image_emb = image_emb.flatten(2).transpose(1, 2) # (b, 1600, 4096)
image_emb = self.linear_proj(image_emb) # (b, 1600, 6656)
image_embed_mask = kw_args['image_embed_mask']
word_embedding = self.transformer.word_embeddings(input_ids).clone()
word_embedding[image_embed_mask.bool()] = torch.cat([self.boi.repeat(len(image_emb), 1, 1), image_emb, self.eoi.repeat(len(image_emb), 1, 1)], dim=1).reshape(-1, image_emb.shape[-1])
word_embedding = word_embedding.contiguous()
return word_embedding
class GLM4VModel(ChatGLM4Model):
def __init__(self, args, transformer=None, **kwargs):
super().__init__(args, transformer=transformer, **kwargs)
self.image_length = args.image_length
self.add_mixin("eva", ImageMixin(args))
@classmethod
def add_model_specific_args(cls, parser):
group = parser.add_argument_group('GLM4V', 'GLM4V Configurations')
group.add_argument('--image_length', type=int, default=256)
group.add_argument('--eva_args', type=json.loads, default={})
group.add_argument('--proj_hidden_size', type=int, default=None)
return super().add_model_specific_args(parser)
def forward(self, input_ids, **kwargs):
if input_ids.shape[1] > 1:
return super().forward(input_ids=input_ids, **kwargs)
if "vision_expert_mask" in kwargs:
kwargs.pop("vision_expert_mask")
if "image_embed_mask" in kwargs:
kwargs.pop("image_embed_mask")
return super().forward(input_ids=input_ids, **kwargs)