megatron_patch/model/qwen2_5_vl/model.py (191 lines of code) (raw):

# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import logging from collections import namedtuple from typing import List import torch from megatron.core import InferenceParams, parallel_state from megatron.core.transformer import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig from .transformer_config import Qwen2VLTransformerConfig from megatron.core.packed_seq_params import PackedSeqParams from .visionmodel import Qwen2_5VisionModel from megatron_patch.model.qwen2_vl.gpt_model import GPTModel # Note: This is under development and may be missing features. class Qwen2_5VLModel(MegatronModule): """Qwen2.5VL multi-modal model. Args: language_transformer_config (TransformerConfig): Transformer config for the language model. language_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the language model. language_vocab_size (int): Language model vocabulary size. language_max_sequence_length (int): Language model maximum sequence length. This is used for positional embedding. vision_transformer_config (TransformerConfig): Transformer config for the vision model. vision_transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers of the vision model. drop_vision_class_token (bool): Drop vision class token(s) before input to the language model. vision_projection_config (TransformerConfig): Config for the projection from vision model outputs to language model inputs. vision_projection_layer_spec (ModuleSpec): Specifies the module to use for the vision projection. vision_projection_type (str): Type of the vision projection to use. Default is a 2-layer MLP. allow_missing_vision_projection_checkpoint (bool): Allow vision projection weights to be missing when loading a checkpoint. Default False. parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks. This is typically True for training and False for inference. language_position_embedding_type (str): Position embedding type to use in the language model. Default learned absolute. language_rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings in the language model. Defaults to 1.0. pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). Defaults to True. post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline parallelism). Defaults to True. add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the encoder will live on only a subset of the pipeline stages (specifically, only the first stage). add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the decoder will live on only a subset of the pipeline stages (specifically, every stage after the first one). img_h (int): The height of each image that the ViT will see. img_w (int): The width of each image that the ViT will see. patch_dim (int): The size of each patch side. img_embedding_idx (int): Index in the language_embeddings tensor where image_embeddings should be inserted. Defaults to 0. """ def __init__( self, language_transformer_config: Qwen2VLTransformerConfig, language_transformer_layer_spec: ModuleSpec, language_vocab_size: int, language_max_sequence_length: int, vision_transformer_config: TransformerConfig, vision_transformer_layer_spec: ModuleSpec, drop_vision_class_token: bool, vision_projection_config: TransformerConfig, vision_projection_layer_spec: ModuleSpec, vision_projection_type: str = "mlp", allow_missing_vision_projection_checkpoint: bool = False, parallel_output: bool = True, language_position_embedding_type: str = 'rope', language_rotary_percent: float = 1.0, pre_process: bool = True, post_process: bool = True, add_encoder: bool = True, add_decoder: bool = True, language_rotary_base: int = 10000, fp16_lm_cross_entropy: bool = False, language_share_embeddings_and_output_weights: bool=False ) -> None: super().__init__(config=language_transformer_config) logging.getLogger(__name__).warning( "Qwen2VL model is under development and may be missing features." ) self.pre_process = pre_process self.post_process = post_process self.add_encoder = add_encoder self.add_decoder = add_decoder self.encoder_hidden_state = None self.vision_model = None self.vision_projection = None self.language_model = None self.square_merge_size = vision_projection_config.ffn_hidden_size // vision_transformer_config.hidden_size # This attribute is needed to check if an all-reduce is required # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. self.share_embeddings_and_output_weights = False if self.pre_process: self.vision_model = Qwen2_5VisionModel( vision_transformer_config, vision_transformer_layer_spec, vision_projection_config, vision_projection_layer_spec, projection_type=vision_projection_type, pre_process=True, post_process=True ) self.language_model = GPTModel( config=language_transformer_config, transformer_layer_spec=language_transformer_layer_spec, vocab_size=language_vocab_size, max_sequence_length=language_max_sequence_length, parallel_output=parallel_output, position_embedding_type=language_position_embedding_type, rotary_percent=language_rotary_percent, pre_process=self.pre_process, post_process=self.post_process, rotary_base=language_rotary_base, mrope_section=language_transformer_config.mrope_section, fp16_lm_cross_entropy=fp16_lm_cross_entropy, share_embeddings_and_output_weights=language_share_embeddings_and_output_weights ) self.share_embeddings_and_output_weights = ( self.language_model.share_embeddings_and_output_weights ) def shared_embedding_or_output_weight(self): """This is a convenience method to surface the language model's word embeddings, which is necessary for `finalize_model_grads._allreduce_word_embedding_grads`.""" if self.add_decoder: return self.language_model.shared_embedding_or_output_weight() return None def set_input_tensor(self, input_tensor) -> None: # This is usually handled in schedules.py but some inference code still # gives us non-lists or None if not isinstance(input_tensor, list): input_tensor = [input_tensor] assert len(input_tensor) == 1, 'input_tensor should only be length 1 for Qwen2VL' if self.pre_process: self.encoder_hidden_state = input_tensor[0] else: self.language_model.set_input_tensor(input_tensor[0]) def freeze( self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool ): """Freeze model modules. Make specific modules non-trainable by setting requires_grad to False for the module's parameters. Args: freeze_language_model (bool): Freeze the language model module. freeze_vision_model (bool): Freeze the vision model module. freeze_vision_projection (bool): Freeze the vision projection module. """ modules = [] if freeze_language_model and self.language_model is not None: modules.append(self.language_model) if freeze_vision_model and self.vision_model is not None: modules.append(self.vision_model) if freeze_vision_projection and self.vision_projection is not None: modules.append(self.vision_projection) for module in modules: for param in module.parameters(): param.requires_grad = False def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, vision_data: torch.Tensor = None, vision_grid_thw: torch.Tensor = None, video_start_index: int = -1, image_input_mask: torch.Tensor = None, video_input_mask: torch.Tensor = None, attention_mask: torch.Tensor = None, labels: torch.Tensor = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, extra_block_kwargs: dict = None, ) -> torch.Tensor: """Forward function of the Qwen2VL model. Args: image_data (torch.Tensor): input image of shape [total_thw_size, n_features]. input_ids (torch.Tensor): input text ids [batch, text_seq_len]. position_ids (torch.Tensor): input text position ids [batch, text_seq_len]. attention_mask (torch.Tensor): attention mask for the language model [batch, 1, combined_seq_len, combined_seq_len]. labels (torch.Tensor): Optional target text labels [batch, combined_seq_len]. inference_params (InferenceParams): Inference-time parameters including KV cache. video_start_index: 0 -- all video len(video_seq) -- all image others -- mixture *_input_mask: should not be None in the first PP stage Returns: output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size]. """ use_inference_kv_cache = ( inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict ) if use_inference_kv_cache: raise NotImplementedError() if self.pre_process: vision_embeds = None if vision_grid_thw.shape[0] > 0: vision_embeds = self.vision_model( vision_data=vision_data, # If None, vision model should use intermediate outputs (EPP > 1) grid_thw=vision_grid_thw # should provided in each EPP stage ) # If running inference, the language model KV cache will be updated for image token positions. # Here we store the image tokens sequence length, which can be used as an offset to the KV cache later. if inference_params is not None: raise NotImplementedError() # inference_params.key_value_memory_dict["image_tokens_count"] = ( # vision_embeddings.shape[0] # ) # If running inference, we can skip image token computation if they were computed already earlier for this sample. if use_inference_kv_cache: language_embeddings: torch.Tensor = self.language_model.embedding( input_ids=input_ids, position_ids=None # NOTE: disable ) # [text_seq_len, b, h_language] # NOTE: why not cat here? is it the combined embeddings useless? combined_embeddings = language_embeddings elif vision_embeds is not None: if video_start_index == 0: image_embeds = None video_embeds = vision_embeds elif video_start_index == vision_embeds.shape[0]: image_embeds = vision_embeds video_embeds = None elif 0 < video_start_index < vision_embeds.shape[0]: image_embeds = vision_embeds[:video_start_index] video_embeds = vision_embeds[video_start_index:] else: raise ValueError(f"Expect video token start index in range [0, {vision_embeds.shape[0]}], but got {video_start_index}") if image_embeds is not None: image_input_mask = image_input_mask.T # shape [seqlen, mbs] if video_embeds is not None: video_input_mask = video_input_mask.T combined_embeddings = self.language_model.embedding( input_ids=input_ids, position_ids=None, # NOTE: disable image_input_mask=image_input_mask, video_input_mask=video_input_mask, image_embeds=image_embeds, video_embeds=video_embeds ) # [text_seq_len, b, h_language] else: combined_embeddings = self.language_model.embedding( input_ids=input_ids, position_ids=None # NOTE: disable ) # [text_seq_len, b, h_language] else: combined_embeddings = None output = self.language_model( input_ids=None, position_ids=position_ids, # None in encoder attention_mask=attention_mask, # None in encoder decoder_input=combined_embeddings, # only not None in the first decoder PP stage labels=labels, # only not None in the last decoder PP stage inference_params=inference_params, # currently always None packed_seq_params=packed_seq_params, # currently always None **(extra_block_kwargs or {}), ) return output def _load_state_dict_hook_ignore_param_names( param_names: List[str], module: torch.nn.Module, incompatible_keys: namedtuple ): """Hook to ignore missing keys during checkpoint loading. By default, this should not be used to avoid accidentally missing weights in checkpoint loading. Example use case: Use this for the vision projection if you want to load a checkpoint that contains vision and language model weights but not the vision projection weights. Args: param_names (list of str): Parameter names allowed to be missing when calling load_state_dict. module (torch.nn.Module): The torch module this hook applies to. Unused here but required by the torch API. incompatible_keys (namedtuple): Namedtuple with fields missing_keys and unexpected_keys, which collect the missing and unexpected keys when calling load_state_dict on this torch module, respectively. """ for param_name in param_names: if param_name in incompatible_keys.missing_keys: logging.getLogger(__name__).warning( f"{param_name} being removed from incompatible_keys.missing_keys in QWen2VLModel" ) incompatible_keys.missing_keys.remove(param_name)