maga_transformer/models/llava_weight.py (18 lines of code) (raw):

from maga_transformer.utils.model_weight import W from maga_transformer.models.llama_weight import LlamaWeightInfo from maga_transformer.config.gpt_init_model_parameters import GptInitModelParameters from maga_transformer.models.multimodal.multimodal_mixin import BaseMultiModalWeightInfo from maga_transformer.model_loader.model_weight_info import ModelWeightInfo class LlavaWeightInfo(LlamaWeightInfo, BaseMultiModalWeightInfo): def __init__(self, config: GptInitModelParameters, tp_size: int, tp_rank: int): LlamaWeightInfo.__init__(self, config, tp_size, tp_rank) BaseMultiModalWeightInfo.__init__(self, config) def _get_weight_info(self): llava_weight = ModelWeightInfo(layer_weights=[], weights=[]) llava_weight = super()._get_weight_info() # for llava-next for weight in llava_weight.layer_weights: if weight.name == W.attn_o_b: llava_weight.layer_weights.remove(weight) break llava_weight = self._get_vit_info(llava_weight) return llava_weight