chatlearn/utils/vllm_import_helper.py (8 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""""Version compatibility for vLLM"""
from typing import List, TypedDict
from typing_extensions import NotRequired
from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion
# pylint: disable=unused-import,import-outside-toplevel,wrong-import-position,wrong-import-order
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
# imports for vllm-030
from vllm.core.block_manager import BlockSpaceManager
from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils import parallel_state
from vllm.model_executor.parallel_utils.communication_op import tensor_model_parallel_all_gather
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
from vllm.model_executor.weight_utils import initialize_dummy_weights
elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1:
# imports for vllm-051
from vllm.core.interfaces import BlockSpaceManager
from vllm.distributed import parallel_state
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
from vllm.distributed.parallel_state import init_world_group
from vllm.distributed.parallel_state import initialize_model_parallel
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.llm_engine import _load_generation_config_dict
from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs import TextTokensPrompt
from vllm.model_executor.model_loader.utils import set_default_torch_dtype as _set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.detokenizer import Detokenizer
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]:
# imports for vllm-063/-66
from vllm.core.interfaces import BlockSpaceManager
from vllm.distributed import parallel_state
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
from vllm.distributed.parallel_state import init_world_group
from vllm.distributed.parallel_state import initialize_model_parallel
from vllm.distributed.utils import get_pp_indices
from vllm.engine.async_llm_engine import _AsyncLLMEngine as LLMEngine
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
from vllm.engine.llm_engine import _load_generation_config_dict
from vllm.engine.llm_engine import SchedulerContext, SchedulerOutputState
from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.preprocess import InputPreprocessor
from vllm.model_executor.model_loader.utils import set_default_torch_dtype as _set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
from vllm.model_executor.models.qwen2_moe import Qwen2MoeForCausalLM
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.detokenizer import Detokenizer
class TextTokensPrompt(TypedDict):
"""It is assumed that :attr:`prompt` is consistent with
:attr:`prompt_token_ids`. This is currently used in
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
prompt: str
"""The prompt text."""
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.llm import LLM
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.qwen import QWenLMHeadModel
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM
from vllm.sampling_params import SamplingParams
from vllm.utils import Counter
from vllm.worker.worker import Worker
def get_block_manager_cls(version):
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
return BlockSpaceManager
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]:
return BlockSpaceManager.get_block_space_manager_class(version)
def get_model_architecture(config):
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
from vllm.model_executor.model_loader import _get_model_architecture as get_model_architecture_v1
return get_model_architecture_v1(config)
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]:
from vllm.model_executor.model_loader.utils import get_model_architecture as get_model_architecture_v2
return get_model_architecture_v2(config)[0]
def get_pipeline_model_parallel_rank():
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
return parallel_state.get_pipeline_model_parallel_rank()
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]:
return parallel_state.get_pp_group().rank_in_group
def get_pipeline_model_parallel_world_size():
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
return parallel_state.get_pipeline_model_parallel_world_size()
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]:
return parallel_state.get_pp_group().world_size