in deepseek_vl2/models/modeling_deepseek_vl_v2.py [0:0]
def __init__(self, config: DeepseekVLV2Config):
super().__init__(config)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
# ----------- vision encoder ------------
vision_config = config.vision_config
self.vision = VisionTransformer(
img_size=vision_config.image_size,
patch_size=vision_config.patch_size,
embed_dim=vision_config.width,
depth=vision_config.layers,
num_heads=vision_config.heads,
mlp_ratio=vision_config.mlp_ratio,
class_token=vision_config.class_token,
global_pool=vision_config.global_pool,
ignore_head=vision_config.ignore_head,
weight_init=vision_config.weight_init,
num_classes=0,
deterministic=vision_config.deterministic,
num_recomputing_layers=vision_config.num_recomputing_layers
)
# ----------- vl projector ------------
projector_config = config.projector_config
self.projector = MlpProjector(projector_config)
# image token format 形式
# FIXME 目前tile tag & global_view_pos的默认取值都是之前的实验策略;后续应当去掉默认取值,改为没有取值就raise error
self.tile_tag = config.tile_tag
self.global_view_pos = config.global_view_pos
# 用于format image token sequence的特殊token
embed_std = 1 / torch.sqrt(torch.tensor(projector_config.n_embed, dtype=torch.float32))
if self.tile_tag == "2D":
# <|view_separator|>, <|\n|>
self.image_newline = nn.Parameter(torch.randn(projector_config.n_embed) * embed_std)
# fix the typo: view_seperater
self.view_seperator = nn.Parameter(torch.randn(projector_config.n_embed) * embed_std)
elif self.tile_tag == "1D":
# <|tile_x|>, <|tile_global|>
candidate_resolutions = config.candidate_resolutions
if len(candidate_resolutions) == 0:
raise ValueError(
f"len(candidate_resolutions) should be larger than 0, but got {len(candidate_resolutions)}")
tile_variants_num = len(candidate_resolutions)
self.tile_indicators = nn.Parameter(
torch.randn(size=(tile_variants_num + 1, config.aligner.params.n_embed)) * embed_std
)
else:
raise ValueError(f"tile tag should be either 1D or 2D, but got {self.tile_tag}")
# ----------- language model ------------
language_config = config.language_config
self.language = DeepseekV2ForCausalLM(language_config)