chatlearn/models/vllm/vllm_model.py (63 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """vllm-based model""" import torch from torch import nn from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion from chatlearn.utils.vllm_import_helper import LlamaForCausalLM from chatlearn.utils.vllm_import_helper import QWenLMHeadModel from chatlearn.utils.vllm_import_helper import Qwen2ForCausalLM from chatlearn.utils.vllm_import_helper import get_model_architecture from chatlearn.utils.utils import get_use_legacy_models from chatlearn.utils.vllm_utils import ( convert_llama_state_dict_from_megatron_to_vllm, convert_llama_state_dict_from_mcore_to_vllm, convert_qwen_state_dict_from_megatron_to_vllm, load_checkpoint ) # additional imports for vLLM-0.6.3 try: from chatlearn.utils.vllm_import_helper import Qwen2MoeForCausalLM except ImportError: Qwen2MoeForCausalLM = None print("Cannot import Qwen2MoeForCausalLM for vllm 0.6.3, please install vllm 0.6.3 first.") class VLLMModel(nn.Module): """VLLM based Model""" def __init__(self, config, model_args, cache_config): super().__init__() self.config = config self.model_args = model_args self.model_class = get_model_architecture(config) if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0: self.model = self.model_class(config.hf_config) elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1: self.model = self.model_class(config.hf_config, cache_config=cache_config) elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: model_class_name = getattr(config.hf_config, "architectures", []) assert model_class_name, f"architectures should be set in model config, while {model_class_name}" if model_class_name[0] == "QWenLMHeadModel": # None for multimodal config. self.model = self.model_class(config.hf_config, None, cache_config=cache_config) else: self.model = self.model_class(config.hf_config, cache_config=cache_config) else: raise RuntimeError(f"unsupport vllm version, supported version list: [0.3.0, 0.5.1, 0.6.3], while {CURRENT_VLLM_VERSION}") def load_weights(self): torch.distributed.barrier() load_checkpoint(self, None, None) torch.distributed.barrier() def load_state_dict(self, state_dict, strict=True, assign=False): # pylint: disable=unused-argument qwen_version = None if isinstance(self.model, LlamaForCausalLM): use_legacy_models = get_use_legacy_models(self.model_args) if use_legacy_models: convert_state_dict_internal = convert_llama_state_dict_from_megatron_to_vllm else: convert_state_dict_internal = convert_llama_state_dict_from_mcore_to_vllm elif isinstance(self.model, QWenLMHeadModel): qwen_version = 1.0 convert_state_dict_internal = convert_qwen_state_dict_from_megatron_to_vllm elif isinstance(self.model, Qwen2ForCausalLM) or (Qwen2MoeForCausalLM is not None and isinstance(self.model, Qwen2MoeForCausalLM)): qwen_version = 2.0 convert_state_dict_internal = convert_qwen_state_dict_from_megatron_to_vllm else: raise RuntimeError(f"Unsupported model for vllm backend. \ support [LlamaForCausalLM, QWenLMHeadModel, Qwen2ForCausalLM, Qwen2MoeForCausalLM] only, while {self.model_class}") state_dict = convert_state_dict_internal(self.model_args, self.config.hf_config, qwen_version=qwen_version) super().load_state_dict(state_dict, strict=strict) def forward(self, *args, **kwargs): return self.model(*args, **kwargs)