chatlearn/models/vllm_module.py (798 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""VLLM module"""
import gc
from typing import List, Tuple
import math
import os
import time
import torch
from tqdm import tqdm
from chatlearn.models.vllm.vllm_model import VLLMModel
from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion
from chatlearn.utils.dist_utils import broadcast_var_object_dict
from chatlearn.utils.vllm_import_helper import get_block_manager_cls
from chatlearn.utils.vllm_import_helper import get_pipeline_model_parallel_rank
from chatlearn.utils.vllm_import_helper import Scheduler
from chatlearn.utils.vllm_import_helper import EngineArgs
from chatlearn.utils.vllm_import_helper import LLM
from chatlearn.utils.vllm_import_helper import LLMEngine
from chatlearn.utils.vllm_import_helper import parallel_state
from chatlearn.utils.vllm_import_helper import SamplingParams
from chatlearn.utils.vllm_import_helper import Counter
from chatlearn.utils.vllm_import_helper import Worker
# additional imports for vLLM-0.5.1/0.6.3
try:
from chatlearn.utils.vllm_import_helper import Detokenizer
from chatlearn.utils.vllm_import_helper import ExecuteModelRequest
from chatlearn.utils.vllm_import_helper import INPUT_REGISTRY
from chatlearn.utils.vllm_import_helper import _load_generation_config_dict
from chatlearn.utils.vllm_import_helper import SequenceGroupOutputProcessor
from chatlearn.utils.vllm_import_helper import StopChecker
from chatlearn.utils.vllm_import_helper import TextTokensPrompt
except ImportError:
print("Cannot import addtional module for vllm 0.5.1 or 0.6.3, please install vllm 0.5.1/0.6.3 first.")
# additional imports for vLLM-0.6.3
try:
from chatlearn.utils.vllm_import_helper import InputPreprocessor
from chatlearn.utils.vllm_import_helper import SchedulerContext, SchedulerOutputState
except ImportError:
print("Cannot import addtional module for vllm 0.6.3, please install vllm 0.6.3 first.")
from chatlearn.utils.vllm_utils import initialize_vllm
from chatlearn.utils.vllm_utils import get_model, print_rank_0
from .torch_module import TorchModule
try:
from .megatron.memory_manager import InferenceMemoryManager
except ImportError:
InferenceMemoryManager = None
_LOGGING_INTERVAL_SEC = 5.0
# pylint: disable=import-outside-toplevel,unexpected-keyword-arg,no-value-for-parameter,too-many-function-args
class VLLMModule(TorchModule, LLMEngine, LLM):
"""VLLMModule is the class for vLLM models.
Args
----
name : str
model name
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.log_stats = False
# inference only
if self.model_args.get("micro_batch_size") != self.module_args.generation_batch_size:
self._logger.info(f"{self.name} Overwrite micro_batch_size with generation_batch_size {self.module_args.generation_batch_size}")
self.model_args["micro_batch_size"] = self.module_args.generation_batch_size
# parallel size
self.model_args["pipeline_model_parallel_size"] = self.module_args.pipeline_model_parallel_size
self.model_args["tensor_model_parallel_size"] = self.module_args.tensor_model_parallel_size
# precision
if self.model_args.get("fp16", False):
assert not self.model_args.get("bf16", False)
self.model_args["params_dtype"] = torch.half
if self.model_args.get("bf16", False):
assert not self.model_args.get("fp16", False)
self.model_args["params_dtype"] = torch.bfloat16
# To save gpu memory, we set `prompt_logprobs=None` default. If need to evaluate loss on prompts, please set prompt_logprobs=1
if self.model_args.get("loss_on_prompts", False) and self.model_args.get("prompt_logprobs", None) is None:
raise RuntimeError("expect loss_on_prompts to be false for memory reduction, or set prompt_logprobs in sampling_params to be `1`.")
self.scheduler = None
self._need_to_reset_scheduler = True
self._log_metrics = self.model_args.get("log_metrics", False)
self._init_args()
def _init_args(self):
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
self.set_vllm_pp_layer_partition()
if self.model_args.get("apply_replica_id_to_seed", True):
seed = self.model_args.get("seed", 0) + self.replica_id
else:
seed = self.model_args.get("seed", 0)
engine_args = EngineArgs(
model=self.model_args.get("tokenizer"),
tokenizer=self.model_args.get("tokenizer"),
tokenizer_mode=self.model_args.get("tokenizer_mode", "auto"),
trust_remote_code=self.model_args.get("trust_remote_code", True),
tensor_parallel_size=self.module_args.tensor_model_parallel_size,
pipeline_parallel_size=self.module_args.pipeline_model_parallel_size,
dtype=self.model_args.get("params_dtype", "auto"),
quantization=self.model_args.get("quantization", None),
revision=self.model_args.get("revision", None),
tokenizer_revision=self.model_args.get("tokenizer_revision", None),
seed=seed,
gpu_memory_utilization=self.model_args.get("gpu_memory_utilization", 0.90),
block_size=self.model_args.get("block_size"),
swap_space=self.model_args.get("swap_space"),
max_num_batched_tokens=self.model_args.get("max_num_batched_tokens"),
max_num_seqs=self.model_args.get("micro_batch_size"),
max_model_len=self.model_args.get("seq_length"),
enforce_eager=self.model_args.get("enforce_eager", True),
disable_custom_all_reduce=True
)
self.quant_config = None
self.pipeline_layer_offset = None
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
engine_args.max_paddings = self.model_args.get("max_paddings", 256)
engine_args.max_context_len_to_capture = self.model_args.get("max_context_len_to_capture", 8192)
self.model_config, self.cache_config, self.parallel_config, self.scheduler_config, self.lora_config = \
engine_args.create_engine_configs()
self.worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
local_rank=0,
rank=0,
distributed_init_method=None,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
self._init_tokenizer()
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
engine_args.max_seq_len_to_capture = self.model_args.get("max_context_len_to_capture", 8192)
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
engine_args.num_scheduler_steps = self.model_args.get("num_scheduler_steps", 1)
engine_config = \
engine_args.create_engine_config()
self.cache_config = engine_config.cache_config
self.device_config = engine_config.device_config
self.load_config = engine_config.load_config
self.lora_config = engine_config.lora_config
self.model_config = engine_config.model_config
self.parallel_config = engine_config.parallel_config
self.scheduler_config = engine_config.scheduler_config
self.generation_config_fields = _load_generation_config_dict(
self.model_config)
self.input_processor = INPUT_REGISTRY.create_input_processor(
self.model_config)
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3 and self.scheduler_config.is_multi_step:
from vllm.worker.multi_step_worker import MultiStepWorker
self.worker = MultiStepWorker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
self.load_config,
local_rank=0,
rank=0,
distributed_init_method=None,
lora_config=self.lora_config,
is_driver_worker=True,
)
else:
self.worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
self.load_config,
local_rank=0,
rank=0,
distributed_init_method=None,
lora_config=self.lora_config,
is_driver_worker=True,
)
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
def setup(self):
"""Set up model and load checkpoint"""
model = [get_model(self.model_provider, self.model_args)]
if self.model_args["load"] is None:
print_rank_0(f"Warning: Using random parameter for {self.name} model.")
assert len(model) == 1, "Above condition should have caught this"
self.model = model[0]
def model_provider(self):
"""Build the model."""
print_rank_0('building vLLM model ...')
model = VLLMModel(self.model_config, self.model_args, self.cache_config)
return model
def _reset_metrics_stats_args(self):
self.start_time = None
# Logging.
self.last_stats_time = 0.0
self.forward_count = 0
self.num_done_requests = 0
self.num_processed_prompt = 0
self.num_generated_tokens = 0
self.action_length = 0
self.action_max_length = float("-inf")
self.action_min_length = float("inf")
self.batch_size_stats = 0.0
self.gpu_cache_usage = 0.0
self.cpu_cache_usage = 0.0
self.max_prompt_length_static_batching = [
0 for _ in range(math.ceil(self.num_requests/self.scheduler_config.max_num_seqs))]
self.max_output_length_static_batching = [
0 for _ in range(math.ceil(self.num_requests/self.scheduler_config.max_num_seqs))]
def reset_vllm(self):
self.request_counter = Counter()
self.log_stats = self.model_args.get("log_stats", False)
# Logging.
self.last_logging_time = 0.0
# List of (timestamp, num_tokens)
self.num_prompt_tokens: List[Tuple[float, int]] = []
# List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = []
self.sliding_window = self.cache_config.sliding_window
def add_extra_args(self, parser):
"""
Add extra arguments for vllm.
Args
----
parser : ArgumentParser
Add extra arguments.
"""
group = parser.add_argument_group(title='vLLM extra arguments')
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.')
group.add_argument('--distributed-timeout-minutes', type=int, default=10,
help='Timeout minutes for torch.distributed.')
return parser
def init(self):
"""
:meta private:
"""
if CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
parallel_state.set_custom_all_reduce(not self.parallel_config.disable_custom_all_reduce)
initialize_vllm(extra_args_provider=self.add_extra_args,
ignore_unknown_args=True,
args_dict=self.model_args)
self.parallel_config.rank = torch.distributed.get_rank()
def build_scheduler(self):
self.seq_counter = Counter()
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
if self.scheduler is None:
self.scheduler = Scheduler(self.scheduler_config, self.cache_config, None)
else:
BlockSpaceManagerImpl = get_block_manager_cls(None)
self.scheduler.block_manager = BlockSpaceManagerImpl( # pylint: disable=abstract-class-instantiated
block_size=self.cache_config.block_size,
num_gpu_blocks=self.cache_config.num_gpu_blocks,
num_cpu_blocks=self.cache_config.num_cpu_blocks,
sliding_window=self.cache_config.sliding_window)
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
if self.scheduler is None:
self.scheduler = [
Scheduler(self.scheduler_config, self.cache_config, None,
self.parallel_config.pipeline_parallel_size)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
def get_tokenizer_for_seq(sequence):
tokenizer_group = self.get_tokenizer_group()
assert tokenizer_group, ("tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False")
return tokenizer_group.get_lora_tokenizer(sequence.lora_request)
tokenizer_for_seq = get_tokenizer_for_seq if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3 \
else self.get_tokenizer_for_seq
self.output_processor = (
SequenceGroupOutputProcessor.create_output_processor(
self.scheduler_config,
self.detokenizer,
self.scheduler,
self.seq_counter,
tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
tokenizer_for_seq,
),
))
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer)
self.cached_scheduler_outputs = [
SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.scheduler_contexts = [
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.multi_step_stream_outputs)
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.use_cached_outputs = False
self.process_request_outputs_callback = None
self.tracer = None
else:
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
version = "selfattn"
if (self.scheduler_config.embedding_mode
or self.cache_config.is_attention_free):
version = "placeholder"
else:
version = "v1"
if self.scheduler_config.use_v2_block_manager:
version = "v2"
if self.scheduler_config.embedding_mode:
version = "embedding"
BlockSpaceManagerImpl = get_block_manager_cls(version)
num_gpu_blocks = self.cache_config.num_gpu_blocks
if num_gpu_blocks:
num_gpu_blocks //= self.pipeline_model_parallel_size()
num_cpu_blocks = self.cache_config.num_cpu_blocks
if num_cpu_blocks:
num_cpu_blocks //= self.pipeline_model_parallel_size()
for scheduler in self.scheduler:
scheduler.block_manager = BlockSpaceManagerImpl( # pylint: disable=abstract-class-instantiated
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching)
def _reset_scheduler(self):
# reset scheduler
scheduler_list = self.scheduler if isinstance(self.scheduler, list) else [self.scheduler]
for scheduler in scheduler_list:
scheduler.block_manager.reset()
def reinit_cache_engine(self):
# reinit cache engine
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
self.worker.init_cache_engine(cache_config=self.cache_config)
self.worker.warm_up_model()
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
self.worker.initialize_cache(self.cache_config.num_gpu_blocks, self.cache_config.num_cpu_blocks)
def empty_cache(self):
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
self.worker.gpu_cache = None # pylint: disable=access-member-before-definition
self.worker.cache_engine.cpu_cache = None
self.worker.cache_engine.gpu_cache = None
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
if self.worker.gpu_cache is not None:
for ele in self.worker.gpu_cache: # pylint: disable=unused-variable
ele = None
self.worker.gpu_cache = None # pylint: disable=access-member-before-definition
if hasattr(self.worker, "cache_engine") and self.worker.cache_engine is not None:
for c_e in self.worker.cache_engine:
c_e.cpu_cache = None
c_e.gpu_cache = None
self.worker.cache_engine = None
self.clear_cache()
def clear_cache(self):
if not self.timers("gc").started_:
self.timers("gc").start()
gc.collect()
self.timers("gc").stop()
super().empty_cache()
def profile_cache_blocks(self):
"""Profiles the memory usage and initializes the KV cache."""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
self.clear_cache()
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
num_gpu_blocks, num_cpu_blocks = self.worker.profile_num_available_blocks(
self.cache_config.block_size,
self.cache_config.gpu_memory_utilization,
self.cache_config.swap_space_bytes,
self.cache_config.cache_dtype
)
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
num_gpu_blocks, num_cpu_blocks = self.worker.determine_num_available_blocks()
else:
raise RuntimeError(f"Unsupported vllm version {CURRENT_VLLM_VERSION}, expect one of {list(VLLMVersion)}")
self._need_to_reset_scheduler = False
self.clear_cache()
return num_gpu_blocks, num_cpu_blocks
def set_cache_config(self, num_gpu_blocks, num_cpu_blocks):
# debug log.
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
self._logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
self._need_to_reset_scheduler = False
def convert_v1_inputs(self, prompts, prompt_token_ids):
num_requests = len(prompts)
assert num_requests == len(prompt_token_ids), \
("The lengths of prompts and prompt_token_ids must be the same.")
inputs = []
for i in range(num_requests):
if prompts[i] is None:
assert isinstance(prompt_token_ids[i], List[int]), \
f"Expect prompt_token_ids[{i}] is List[int] when prompt is None, while {prompt_token_ids[i]}."
if prompt_token_ids[i] is None:
assert isinstance(prompts[i], str), \
f"Expect prompts[{i}] is a string when prompt_token_ids is None, while {prompts[i]}."
item = TextTokensPrompt(
prompt=prompts[i],
prompt_token_ids=prompt_token_ids[i])
inputs.append(item)
return inputs
def _add_request(self, data, is_eval=False): # pylint: disable=arguments-differ
prompt_key = self.model_args.get("vllm_prompt_key", "prompt")
input_ids_key = self.model_args.get("vllm_input_ids_key", "input_ids")
return self._add_request_internal(data[prompt_key], data[input_ids_key], is_eval=is_eval)
def _add_request_internal(self, prompt_list, prompt_token_id_list, is_eval=False):
if self._need_to_reset_scheduler:
self._reset_scheduler()
self.reset_vllm()
# sampling params
temperature = 0.0
if not self.model_args.get("use_beam_search"):
temperature = self.model_args.get("eval_temperature", 1.0) if is_eval else self.model_args.get("temperature", 1.0)
top_p = self.model_args.get("eval_top_p", 1.0) if is_eval else self.model_args.get("top_p", 1.0)
top_k = self.model_args.get("eval_top_k", -1) if is_eval else self.model_args.get("top_k", -1)
min_p = self.model_args.get("eval_min_p", 0.0) if is_eval else self.model_args.get("min_p", 0.0)
presence_penalty = self.model_args.get("eval_presence_penalty", 0.0) if is_eval else self.model_args.get("presence_penalty", 0.0)
frequency_penalty = self.model_args.get("eval_frequency_penalty", 0.0) if is_eval else self.model_args.get("frequency_penalty", 0.0)
repetition_penalty = self.model_args.get("eval_repetition_penalty", 1.0) if is_eval else self.model_args.get("repetition_penalty", 1.0)
stop = self.model_args.get("stop_token_list", None)
if isinstance(stop, str):
stop = stop.split(";")
seq_len = self.model_args.get("seq_length")
for prompt, prompt_token_ids in zip(prompt_list, prompt_token_id_list):
request_id = next(self.request_counter)
if self.model_args.get("new_token_limit", False):
max_tokens = self.model_args.get("max_new_tokens")
assert max_tokens < seq_len, "max_new_tokens must less than seq length."
prompt_token_ids = prompt_token_ids \
if len(prompt_token_ids) <= seq_len-max_tokens \
else prompt_token_ids[:seq_len-max_tokens]
else:
if len(prompt_token_ids) >= seq_len:
prompt_token_ids = prompt_token_ids[:seq_len-1]
max_tokens = seq_len - len(prompt_token_ids)
if CURRENT_VLLM_VERSION in [VLLMVersion.v_0_3_0, VLLMVersion.v_0_5_1]:
sampling_params = SamplingParams(
n=self.model_args.get("n"),
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
use_beam_search=self.model_args.get("use_beam_search"),
ignore_eos=self.model_args.get("ignore_eos"),
stop=stop,
max_tokens=max_tokens,
logprobs=1,
prompt_logprobs=self.model_args.get("prompt_logprobs", None),
skip_special_tokens=self.model_args.get('skip_special_tokens', True)
)
elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
sampling_params = SamplingParams(
n=self.model_args.get("n"),
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
ignore_eos=self.model_args.get("ignore_eos"),
stop=stop,
max_tokens=max_tokens,
logprobs=1,
prompt_logprobs=self.model_args.get("prompt_logprobs", None),
skip_special_tokens=self.model_args.get('skip_special_tokens', True)
)
else:
raise RuntimeError(f"Unsupported vllm version {CURRENT_VLLM_VERSION}, expect one of {list(VLLMVersion)}")
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
self.add_request(
request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids
)
elif CURRENT_VLLM_VERSION in \
[VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
inputs = self.convert_v1_inputs(
prompts=[prompt],
prompt_token_ids=[prompt_token_ids],
)[0]
self.add_request(
request_id,
inputs,
sampling_params
)
self.outputs = []
self.num_requests = self.get_num_unfinished_requests()
self._reset_metrics_stats_args()
self.pbar = tqdm(total=self.num_requests, desc=f"Processed prompts (replica {self.replica_id+1}/{self._num_replica})")
self._need_to_reset_scheduler = True
def model_setup(self):
"""
:meta private:
"""
super().model_setup()
# TODO: we may need to let setup return model, optimizer and opt_param_scheduler
if self.trainable:
assert hasattr(self, "model")
assert hasattr(self, "optimizer")
assert hasattr(self, "opt_param_scheduler")
self.model.eval()
else:
assert hasattr(self, "model")
self.model.eval()
self.worker.model_runner.model = self.model.model
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
from vllm.worker.multi_step_worker import MultiStepWorker
if isinstance(self.worker, MultiStepWorker):
self.worker.model_runner._base_model_runner.model = self.model.model
if CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
self.worker.device = torch.device(f"cuda:{torch.cuda.current_device()}")
self.worker.init_gpu_memory = torch.cuda.mem_get_info()[0]
if self.module_args.offload_weights:
if InferenceMemoryManager is None:
raise Exception("Import InferenceMemoryManager failed, you may need to set right Megatron path first.")
self._memory_manager = InferenceMemoryManager(
self.model,
self.runtime_args.bucket_size_mb_in_memory_manager,
)
self.offload()
def pipeline_model_parallel_size(self):
"""
get pipeline_model_parallel_size
:meta private:
"""
return self.parallel_config.pipeline_parallel_size
def tensor_model_parallel_size(self):
"""
get tensor_model_parallel_size
:meta private:
"""
return self.parallel_config.tensor_parallel_size
def tensor_and_expert_model_parallel_size(self):
"""
get tensor_and_expert_model_parallel_size
:meta private:
"""
# vLLM not supported to enable expert parallel size
# thus: tensor_and_expert_model_parallel_size = tensor_parallel_size
return self.parallel_config.tensor_parallel_size
@property
def data_parallel_size(self):
"""
:meta private:
"""
return 1
@property
def data_parallel_rank(self):
"""
:meta private:
"""
return 0
def tensor_parallel_rank(self):
"""
:meta private:
"""
return parallel_state.get_tensor_model_parallel_rank()
def pipeline_parallel_rank(self):
"""
:meta private:
"""
return get_pipeline_model_parallel_rank()
def num_layers(self):
"""
:meta private:
"""
return self.model_config.hf_config.num_hidden_layers
def generate_vllm(self, query, is_eval, iteration=0): # pylint: disable=unused-argument
num_gpu_blocks, num_cpu_blocks = self.profile_cache_blocks()
num_blocks = torch.tensor([num_gpu_blocks, num_cpu_blocks], device='cuda')
torch.distributed.all_reduce(num_blocks, op=torch.distributed.ReduceOp.MIN)
min_gpu_blocks = num_blocks[0].item()
min_cpu_blocks = num_blocks[1].item()
self.set_cache_config(min_gpu_blocks, min_cpu_blocks)
if self.is_last_rank():
self.build_scheduler()
self.reinit_cache_engine()
# add requests of current episode to vllm scheduler
if self.is_last_rank():
self._add_request(query, is_eval=is_eval)
step_outputs = True
while step_outputs:
schedule_query = None
if self.is_last_rank():
# support multi step schedule.
virtual_engine = 0
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = False
ctx = self.scheduler_contexts[virtual_engine]
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
# Schedule iteration
scheduler_outputs = self.schedule()
seq_group_metadata_list = scheduler_outputs["seq_group_metadata_list"]
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
if (self.scheduler_config.is_multi_step
and scheduler_outputs["num_lookahead_slots"] > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
schedule_query = scheduler_outputs
if len(scheduler_outputs) == 0:
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
schedule_query = broadcast_var_object_dict(schedule_query, torch.distributed.get_world_size()-1)
output = self.execute_step(schedule_query)
if self.is_last_rank():
step_outputs = bool(output)
signal_tensor = torch.tensor(step_outputs, device='cuda')
torch.distributed.broadcast(signal_tensor, torch.distributed.get_world_size()-1)
else:
signal_tensor = torch.tensor(True, device='cuda')
torch.distributed.broadcast(signal_tensor, torch.distributed.get_world_size()-1)
step_outputs = signal_tensor.item()
if self.is_last_rank():
self.outputs = sorted(self.outputs, key=lambda x: int(x.request_id))
return self.outputs
def schedule(self):
if self.start_time is None:
self.start_time = time.monotonic()
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
for scheduler in self.scheduler:
self.seq_group_metadata_list, self.scheduler_outputs, _ = scheduler.schedule()
if self.seq_group_metadata_list:
break
else:
self.seq_group_metadata_list, self.scheduler_outputs = self.scheduler[0].schedule()
if self.scheduler_outputs.is_empty():
return {}
data = {
"seq_group_metadata_list" : self.seq_group_metadata_list,
"blocks_to_swap_in" : self.scheduler_outputs.blocks_to_swap_in,
"blocks_to_swap_out" : self.scheduler_outputs.blocks_to_swap_out,
"blocks_to_copy" : self.scheduler_outputs.blocks_to_copy
}
if CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
finished_requests_ids = self.scheduler[0].get_and_reset_finished_requests_ids()
data.update({
"num_lookahead_slots": self.scheduler_outputs.num_lookahead_slots,
"running_queue_size": self.scheduler_outputs.running_queue_size,
"finished_requests_ids": finished_requests_ids
})
return data
def process_model_outputs(self, output, seq_group_metadata_list=None):
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
step_outputs = self._process_model_outputs(output, self.scheduler_outputs)
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1:
step_outputs = self._process_model_outputs(
output, self.scheduler_outputs.scheduled_seq_groups,
self.scheduler_outputs.ignored_seq_groups, self.seq_group_metadata_list)
else:
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
# if self.is_last_rank():virtual_engine
virtual_engine = 0
allow_async_output_proc = False
ctx = self.scheduler_contexts[virtual_engine]
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, output)
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
for seq_group in seq_group_metadata_list:
seq_group.finish_step()
if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps.
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[0] = SchedulerOutputState()
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
# Add results to the output_queue
ctx.append_output(outputs=output,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=self.scheduler_outputs,
is_async=allow_async_output_proc,
is_last_step=True,
is_first_step_output=is_first_step_output)
self._process_model_outputs(ctx=ctx)
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0
step_outputs = ctx.request_outputs
else:
# Multi-step case
step_outputs = ctx.request_outputs
else:
raise RuntimeError(f"Unsupported vllm version {CURRENT_VLLM_VERSION}, expect one of {list(VLLMVersion)}")
done = 0
for out in step_outputs:
if out.finished:
self.outputs.append(out)
done += 1
self.pbar.update(1)
self.num_requests -= done
if self.num_requests <= 0:
self.pbar.close()
if self._log_metrics:
self.log_metrics_stats(done)
return self.num_requests
@torch.inference_mode()
def execute_step(self, data):
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
output = self.worker.execute_model(
data["seq_group_metadata_list"],
data["blocks_to_swap_in"],
data["blocks_to_swap_out"],
data["blocks_to_copy"]
)
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1:
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=data["seq_group_metadata_list"],
blocks_to_swap_in=data["blocks_to_swap_in"],
blocks_to_swap_out=data["blocks_to_swap_out"],
blocks_to_copy=data["blocks_to_copy"],
num_lookahead_slots=data["num_lookahead_slots"],
running_queue_size=data["running_queue_size"],
finished_requests_ids=data["finished_requests_ids"]
)
output = self.worker.execute_model(execute_model_req=execute_model_req)
else:
if len(data) > 0:
# For llm_engine, there is no pipeline parallel support, so the engine
# used is always 0.
virtual_engine = 0
# These are cached outputs from previous iterations. None if on first
# iteration
seq_group_metadata_list = data["seq_group_metadata_list"]
allow_async_output_proc = False
assert seq_group_metadata_list is not None
finished_requests_ids = data["finished_requests_ids"]
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids = None
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=data["blocks_to_swap_in"],
blocks_to_swap_out=data["blocks_to_swap_out"],
blocks_to_copy=data["blocks_to_copy"],
num_lookahead_slots=data["num_lookahead_slots"],
running_queue_size=data["running_queue_size"],
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
output = self.worker.execute_model(execute_model_req=execute_model_req)
else:
# No outputs in this case
output = []
else:
raise RuntimeError(f"Unsupported vllm version {CURRENT_VLLM_VERSION}, expect one of {list(VLLMVersion)}")
if self.is_last_rank() and hasattr(self, "scheduler_outputs"):
return self.process_model_outputs(output, seq_group_metadata_list=data["seq_group_metadata_list"])
return output
def offload_weights(self):
"""
offload weights
"""
if self.module_args.offload_weights:
self._memory_manager.offload_weights()
def onload_weights(self):
"""
onload weights
"""
if self.module_args.offload_weights:
self._memory_manager.onload_weights()
def set_vllm_pp_layer_partition(self):
pipeline_world_size = self.module_args.pipeline_model_parallel_size
num_layers = self.model_args.get("num_layers")
remainder = num_layers % pipeline_world_size
if not self.model_args.get("allow_padding_num_layers", None):
assert remainder == 0, \
f"expect num_layers % pipeline_model_size == 0 when VLLM_PP_LAYER_PARTITION is not set. \
while num_layers = {num_layers} pipeline_model_size = {pipeline_world_size}"
return
if remainder > 0:
assert not self.model_args.get("standalone_embedding_stage", False), \
"not support standalone embedding stage if allow_padding_num_layers is true"
# pad num_layers to make num_layers % pipeline_model_parallel_size == 0
num_layers_with_padding = num_layers - remainder + pipeline_world_size
else:
num_layers_with_padding = num_layers
num_layers_without_padding = num_layers
num_layers = num_layers_with_padding
num_layers_per_stage_with_padding = (
num_layers // pipeline_world_size)
# Each stage gets a contiguous set of layers.
if self.model_args.get("pipeline_layers", None) is not None:
rank_sizes = self.model_args.get("pipeline_layers", None)
assert isinstance(rank_sizes, list) and all(isinstance(ele, int) for ele in rank_sizes), \
f"pipeline_layers expected to be list, and num layer of each stage to be integer, while {rank_sizes}."
else:
rank_sizes = [num_layers_per_stage_with_padding] * pipeline_world_size
num_padding = num_layers - num_layers_without_padding
if num_padding > 0:
assert num_padding == 2, \
"Support num_padding_lsyers == 2 when applies inbalanced pp. Please set `args.pipeline_layers` for VLLMModule."
for _index in range(-1, num_padding - 1):
rank_sizes[_index] -= 1
assert len(rank_sizes) == pipeline_world_size
# set env variable VLLM_PP_LAYER_PARTITION
vllm_pp_layer_partition = ",".join([str(ele) for ele in rank_sizes])
if os.getenv("VLLM_PP_LAYER_PARTITION", None) is not None:
env_vllm_pp_layer_partition = os.getenv("VLLM_PP_LAYER_PARTITION", None)
if vllm_pp_layer_partition != env_vllm_pp_layer_partition:
self._logger.warning(
f"expect VLLM_PP_LAYER_PARTITION to be {vllm_pp_layer_partition}, while {env_vllm_pp_layer_partition}")
os.environ["VLLM_PP_LAYER_PARTITION"] = vllm_pp_layer_partition
def log_metrics_stats(self, num_done_requests):
now = time.monotonic()
self.num_done_requests += num_done_requests
scheduler_list = self.scheduler if isinstance(self.scheduler, list) else [self.scheduler]
avg_request_throughput = self.num_done_requests / (now - self.start_time)
if self.scheduler_outputs.prompt_run:
self.num_processed_prompt += self.scheduler_outputs.num_batched_tokens
else:
self.num_generated_tokens += self.scheduler_outputs.num_batched_tokens
avg_generation_throughput = self.num_generated_tokens / (now - self.start_time)
avg_prompt_throughput = self.num_processed_prompt / (now - self.start_time)
self.forward_count += 1
total_num_gpu_blocks = self.cache_config.num_gpu_blocks
num_free_gpu_blocks = sum(
scheduler.block_manager.get_num_free_gpu_blocks() for scheduler in scheduler_list)
num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
self.gpu_cache_usage += num_used_gpu_blocks / total_num_gpu_blocks
avg_gpu_cache_usage = self.gpu_cache_usage / self.forward_count
total_num_cpu_blocks = self.cache_config.num_cpu_blocks
if total_num_cpu_blocks > 0:
num_free_cpu_blocks = sum(
scheduler.block_manager.get_num_free_cpu_blocks() for scheduler in scheduler_list)
num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
else:
cpu_cache_usage = 0.0
self.cpu_cache_usage += cpu_cache_usage
avg_cpu_cache_usage = self.cpu_cache_usage / self.forward_count
for idx in range(self.num_done_requests - num_done_requests, self.num_done_requests):
output = self.outputs[idx]
prompt_length = len(output.prompt_token_ids)
output_length = len(output.outputs[0].token_ids)
batch_index = int(output.request_id / self.scheduler_config.max_num_seqs)
self.max_prompt_length_static_batching[batch_index] = max(
self.max_prompt_length_static_batching[batch_index], prompt_length)
self.max_output_length_static_batching[batch_index] = max(
self.max_output_length_static_batching[batch_index], output_length)
self.action_length += output_length
self.action_max_length = max(self.action_max_length, output_length)
self.action_min_length = min(self.action_min_length, output_length)
action_length_mean = float(self.action_length / self.num_done_requests) if self.num_done_requests else 0.0
for scheduler in scheduler_list:
self.batch_size_stats += len(scheduler.running)
avg_batch_size = self.batch_size_stats / self.forward_count
if not self.num_requests or (now - self.last_stats_time >= _LOGGING_INTERVAL_SEC):
self.last_stats_time = now
message = ""
if not self.num_requests:
batch_size = [self.scheduler_config.max_num_seqs \
for _ in range(math.ceil(self.num_done_requests / self.scheduler_config.max_num_seqs))]
if self.num_done_requests % self.scheduler_config.max_num_seqs:
batch_size[-1] = self.num_done_requests % self.scheduler_config.max_num_seqs
num_prompt_tokens_static_batching = sum( # pylint: disable=consider-using-generator
[prompt_len * bs for prompt_len, bs in zip(self.max_prompt_length_static_batching, batch_size)])
num_output_tokens_static_batching = sum( # pylint: disable=consider-using-generator
[output_length * bs for output_length, bs in zip(self.max_output_length_static_batching, batch_size)])
message = f"num_processed_prompts_continuous_batching: {self.num_processed_prompt}, " \
f"num_processed_prompts_static_batching: {num_prompt_tokens_static_batching}, " \
f"num_processed_prompts_continuous_batching/num_processed_prompts_static_batching: \
{self.num_processed_prompt/num_prompt_tokens_static_batching:.1f}, " \
f"num_output_tokens_continuous_batching: {self.num_generated_tokens}, " \
f"num_output_tokens_static_batching: {num_output_tokens_static_batching}, " \
f"num_output_tokens_continuous_batching/num_output_tokens_static_batching: \
{self.num_generated_tokens/num_output_tokens_static_batching:.1f}, " \
self._logger.info(f"allready generate responses for {self.num_done_requests} reqs, "
f"avg_request_throughput: {avg_request_throughput:.1f} reqs/s, "
f"avg_prompt_throughput: {avg_prompt_throughput:.1f} tokens/s, "
f"avg_generation_throughput: {avg_generation_throughput:.1f} tokens/s, "
f"avg_batch_size: {avg_batch_size:.1f} reqs, "
f"avg_gpu_cache_usage: {avg_gpu_cache_usage * 100:.1f}%, "
f"avg_cpu_cache_usage {avg_cpu_cache_usage * 100:.1f}%, "
f"action_length_mean: {action_length_mean:.1f}, "
f"action_max_length: {self.action_max_length if self.num_done_requests else 'inf'}, "
f"action_min_length: {self.action_min_length if self.num_done_requests else '-inf'}, "
f"{message}")
# pylint: enable=import-outside-toplevel,unexpected-keyword-arg,no-value-for-parameter,too-many-function-args