maga_transformer/models/qwen_vl_weight.py (14 lines of code) (raw):
from maga_transformer.models.qwen import QWenWeight
from maga_transformer.models.multimodal.multimodal_mixin import BaseVitWeights, BaseMultiModalWeightInfo
class QwenVLVitWeight(BaseVitWeights):
def _set_weight_prefix(self):
self._ckpt_prefix = "transformer.visual."
self._ft_prefix = "self.mm_part.vit."
class QWenVLWeightInfo(QWenWeight, BaseMultiModalWeightInfo):
def __init__(self, config, tp_size, tp_rank):
QWenWeight.__init__(self, config, tp_size, tp_rank)
BaseMultiModalWeightInfo.__init__(self, config)
def _get_weight_info(self):
qwen_vl_weight = super()._get_weight_info()
qwen_vl_weight = self._get_vit_info(qwen_vl_weight)
return qwen_vl_weight