chatlearn/models/vllm_module_v2.py (422 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""VLLM module v2"""
import gc
import inspect
import os
import torch
from transformers import AutoTokenizer, AutoConfig
from vllm import SamplingParams
from vllm.config import LoadFormat
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import RayWorkerWrapper
from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion
from chatlearn.utils.global_vars import set_vllm_actors
from chatlearn.utils.vllm_import_helper import parallel_state
from chatlearn.utils.vllm_import_helper import get_block_manager_cls
from chatlearn.utils.vllm_import_helper import get_pipeline_model_parallel_rank
from chatlearn.utils.vllm_import_helper import TextTokensPrompt
from chatlearn.utils.vllm_utils import initialize_vllm
from .torch_module import TorchModule
try:
from .vllm.inference import InferenceMemoryManager
except ImportError:
InferenceMemoryManager = None
# pylint: disable=unexpected-keyword-arg
class VLLMModuleV2(TorchModule, RayWorkerWrapper):
"""VLLMModuleV2"""
def __init__(self, *args, **kwargs):
TorchModule.__init__(self, *args)
# avoid overwrite methods
methods_class1 = {method[0] for method in inspect.getmembers(TorchModule, predicate=inspect.isfunction)}
methods_class2 = {method[0] for method in inspect.getmembers(RayWorkerWrapper, predicate=inspect.isfunction)}
common_methods = methods_class1.intersection(methods_class2)
# common method is '__init__'
assert common_methods == {'__init__'}, \
f"Expected only '__init__' as common method for TorchModule and RayWorkerWrapper, but got {common_methods}"
self.local_rank = 0
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs:
RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called
else:
if 'vllm_actor_type' in kwargs and 'worker' == kwargs['vllm_actor_type']:
vllm_config = self.init_engine_args()
RayWorkerWrapper.__init__(self, vllm_config=vllm_config) # pylint: disable=non-parent-init-called
os.environ['VLLM_HOST_IP'] = self.get_address()
self.tokenizer = None
self._model = None
self.llm = None
self.model_config = AutoConfig.from_pretrained(self.model_args['tokenizer'])
self.set_vllm_pp_layer_partition()
self._metric_prefix = 'vllm_inference'
def add_extra_args(self, parser):
"""
Add extra arguments for vllm.
Args
----
parser : ArgumentParser
Add extra arguments.
"""
group = parser.add_argument_group(title='vLLM extra arguments')
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.')
group.add_argument('--distributed-timeout-minutes', type=int, default=10,
help='Timeout minutes for torch.distributed.')
return parser
def init_engine_args(self):
dtype = self.model_args.get("dtype", "bfloat16")
if self.model_args.get("fp16", False):
dtype = "float16"
load_format = self.model_args.get("vllm_load_format", LoadFormat.DUMMY)
if load_format == LoadFormat.DUMMY:
model_loader_extra_config = self.model_args
else:
model_loader_extra_config = None
if self.model_args.get("apply_replica_id_to_seed", True):
seed = self.model_args.get("seed", 0) + self.replica_id
else:
seed = self.model_args.get("seed", 0)
from vllm.engine.arg_utils import AsyncEngineArgs # pylint: disable=import-outside-toplevel
from vllm.usage.usage_lib import UsageContext # pylint: disable=import-outside-toplevel
self.engine_args = AsyncEngineArgs(
model=self.model_args['tokenizer'],
tokenizer=self.model_args['tokenizer'],
max_seq_len_to_capture=self.model_args.get("max_seq_len_to_capture", 32768),
seed=seed,
# load model: 'dummy' for megatron ckpt or mock weight; others for hf ckpt.
load_format=load_format,
model_loader_extra_config=model_loader_extra_config,
# parallelism strategy
tensor_parallel_size=self.module_args.tensor_model_parallel_size,
pipeline_parallel_size=self.module_args.pipeline_model_parallel_size,
dtype=dtype,
# scheduling strategy
max_num_seqs=self.module_args.generation_batch_size,
max_num_batched_tokens = self.model_args.get("max_num_batched_tokens", None),
num_scheduler_steps=self.model_args.get("num_scheduler_steps", 1),
gpu_memory_utilization=self.model_args.get("gpu_memory_utilization", 0.90),
# logger
disable_log_requests=self.model_args.get("disable_log_requests", True),
disable_log_stats=self.model_args.get("disable_log_stats", True),
trust_remote_code=True,
enforce_eager=self.model_args.get("enforce_eager", True),
disable_custom_all_reduce=True,
distributed_executor_backend="ray",
preemption_mode=self.model_args.get("preemption_mode", 'recompute') , # swap, recompute
swap_space=self.model_args.get("swap_space", 16))
return self.engine_args.create_engine_config(usage_context=UsageContext.ENGINE_CONTEXT)
def init(self):
"""
:meta private:
"""
parallel_state.set_custom_all_reduce(False)
initialize_vllm(extra_args_provider=self.add_extra_args,
ignore_unknown_args=True,
args_dict=self.model_args)
def setup(self):
"""Set up tokenizer."""
super().setup()
tokenizer = AutoTokenizer.from_pretrained(self.model_args['tokenizer'])
tokenizer.tokenizer = tokenizer
self.tokenizer = tokenizer
def setup_vllm(self, workers):
if self.llm is not None: # for evaluator
return
# setup vllm engine in rank 0
os.environ['VLLM_HOST_IP'] = self.get_address()
set_vllm_actors(workers)
dtype = self.model_args.get("dtype", "bfloat16")
if self.model_args.get("fp16", False):
dtype = "float16"
load_format = self.model_args.get("vllm_load_format", LoadFormat.DUMMY)
if load_format == LoadFormat.DUMMY:
model_loader_extra_config = self.model_args
else:
model_loader_extra_config = None
if self.model_args.get("apply_replica_id_to_seed", True):
seed = self.model_args.get("seed", 0) + self.replica_id
else:
seed = self.model_args.get("seed", 0)
self.llm = LLM(
model=self.model_args['tokenizer'],
tokenizer=self.model_args['tokenizer'],
max_seq_len_to_capture=self.model_args.get("max_seq_len_to_capture", 32768),
seed=seed,
# load model: 'dummy' for megatron ckpt or mock weight; others for hf ckpt.
load_format=load_format,
model_loader_extra_config=model_loader_extra_config,
# parallelism strategy
tensor_parallel_size=self.module_args.tensor_model_parallel_size,
pipeline_parallel_size=self.module_args.pipeline_model_parallel_size,
dtype=dtype,
# scheduling strategy
max_num_seqs=self.module_args.generation_batch_size,
max_num_batched_tokens = self.model_args.get("max_num_batched_tokens", None),
num_scheduler_steps=self.model_args.get("num_scheduler_steps", 1),
gpu_memory_utilization=self.model_args.get("gpu_memory_utilization", 0.90),
# logger
disable_log_requests=self.model_args.get("disable_log_requests", True),
disable_log_stats=self.model_args.get("disable_log_stats", True),
trust_remote_code=True,
enforce_eager=self.model_args.get("enforce_eager", False),
disable_custom_all_reduce=True,
distributed_executor_backend="ray",
preemption_mode=self.model_args.get("preemption_mode", 'recompute') , # swap, recompute
swap_space=self.model_args.get("swap_space", 16))
self.llm.llm_engine.model_executor._run_workers("init_memory_manager")
self.offload_for_workers()
self.empty_cuda_graph_for_workers()
self.empty_cache_for_workers()
def dump_parameters(self, dump_path_root):
self.onload_for_workers()
self.llm.llm_engine.model_executor._run_workers("worker_dump_parameters", dump_path_root=dump_path_root)
def worker_dump_parameters(self, dump_path_root):
tp_rank = self.tensor_parallel_rank()
model = self.model
if isinstance(model, list):
model = model[0]
dir_path = os.path.join(dump_path_root, str(tp_rank))
if not os.path.exists(dir_path):
os.makedirs(dir_path)
self._logger.info(f"dump parameters to {dir_path}")
for name, param in self.named_parameters.items():
pt_file = os.path.join(dir_path, name)
torch.save(param.data.clone(), pt_file)
def init_memory_manager(self):
if self.module_args.offload_weights:
if InferenceMemoryManager is None:
raise Exception("Import InferenceMemoryManager failed, you may need to set right Megatron path first.")
self._memory_manager = InferenceMemoryManager(
self.model,
self.runtime_args.bucket_size_mb_in_memory_manager,
)
def set_vllm_pp_layer_partition(self):
pipeline_world_size = self.module_args.pipeline_model_parallel_size
num_layers = self.model_config.num_hidden_layers
remainder = num_layers % pipeline_world_size
if not self.model_args.get("allow_padding_num_layers", None):
assert remainder == 0, \
f"expect num_layers % pipeline_model_size == 0 when VLLM_PP_LAYER_PARTITION is not set. \
while num_layers = {num_layers} pipeline_model_size = {pipeline_world_size}"
return
if remainder > 0:
assert not self.model_args.get("standalone_embedding_stage", False), \
"not support standalone embedding stage if allow_padding_num_layers is true"
# pad num_layers to make num_layers % pipeline_model_parallel_size == 0
num_layers_with_padding = num_layers - remainder + pipeline_world_size
else:
num_layers_with_padding = num_layers
num_layers_without_padding = num_layers
num_layers = num_layers_with_padding
num_layers_per_stage_with_padding = (
num_layers // pipeline_world_size)
# Each stage gets a contiguous set of layers.
if self.model_args.get("pipeline_layers", None) is not None:
rank_sizes = self.model_args.get("pipeline_layers", None)
assert isinstance(rank_sizes, list) and all(isinstance(ele, int) for ele in rank_sizes), \
f"pipeline_layers expected to be list, and num layer of each stage to be integer, while {rank_sizes}."
else:
rank_sizes = [num_layers_per_stage_with_padding] * pipeline_world_size
num_padding = num_layers - num_layers_without_padding
if num_padding > 0:
assert num_padding == 2, \
"Support num_padding_lsyers == 2 when applies inbalanced pp. Please set `args.pipeline_layers` for VLLMModule."
for _index in range(-1, num_padding - 1):
rank_sizes[_index] -= 1
assert len(rank_sizes) == pipeline_world_size
# set env variable VLLM_PP_LAYER_PARTITION
vllm_pp_layer_partition = ",".join([str(ele) for ele in rank_sizes])
if os.getenv("VLLM_PP_LAYER_PARTITION", None) is not None:
env_vllm_pp_layer_partition = os.getenv("VLLM_PP_LAYER_PARTITION", None)
if vllm_pp_layer_partition != env_vllm_pp_layer_partition:
self._logger.warning(
f"expect VLLM_PP_LAYER_PARTITION to be {vllm_pp_layer_partition}, while {env_vllm_pp_layer_partition}")
os.environ["VLLM_PP_LAYER_PARTITION"] = vllm_pp_layer_partition
self._logger.info(f"Set VLLM_PP_LAYER_PARTITION={vllm_pp_layer_partition}")
def _get_sampling_params(self, is_eval):
temperature = 0.0
if not self.model_args.get("use_beam_search", False):
temperature = self.model_args.get("eval_temperature", 1.0) if is_eval else self.model_args.get(
"temperature", 1.0)
top_p = self.model_args.get("eval_top_p", 1.0) if is_eval else self.model_args.get("top_p", 1.0)
top_k = self.model_args.get("eval_top_k", -1) if is_eval else self.model_args.get("top_k", -1)
min_p = self.model_args.get("eval_min_p", 0.0) if is_eval else self.model_args.get("min_p", 0.0)
presence_penalty = self.model_args.get("eval_presence_penalty", 0.0) if is_eval else self.model_args.get(
"presence_penalty", 0.0)
frequency_penalty = self.model_args.get("eval_frequency_penalty", 0.0) if is_eval else self.model_args.get(
"frequency_penalty", 0.0)
repetition_penalty = self.model_args.get("eval_repetition_penalty", 1.0) if is_eval else self.model_args.get(
"repetition_penalty", 1.0)
stop = self.model_args.get("stop_token_list", None)
if stop is not None and isinstance(stop, str):
stop = stop.split(";")
sampling_params = SamplingParams(
n=self.model_args.get("n", 1),
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
ignore_eos=self.model_args.get("ignore_eos", False),
stop=stop,
logprobs=self.model_args.get("logprobs", 1),
detokenize=self.model_args.get("detokenize", False),
prompt_logprobs=self.model_args.get("prompt_logprobs", None),
skip_special_tokens=self.model_args.get('skip_special_tokens', True)
)
# VLLMVersion.v_0_3_0, VLLMVersion.v_0_5_1
if hasattr(sampling_params, 'use_beam_search'):
sampling_params.use_beam_search = self.model_args.get("use_beam_search", False)
return sampling_params
def update_weights_from_ipc_handles(self, reduce_data):
for name, reduced in reduce_data.items():
rebuild_func, rebuild_args = reduced
reconstructed_tensor = rebuild_func(*rebuild_args)
self.model.load_weights([(name, reconstructed_tensor)])
def _convert_v1_inputs(self, prompts, prompt_token_ids):
num_requests = len(prompts)
assert num_requests == len(prompt_token_ids), \
("The lengths of prompts and prompt_token_ids must be the same.")
inputs = []
for i in range(num_requests):
if prompts[i] is None:
assert isinstance(prompt_token_ids[i], List[int]), \
f"Expect prompt_token_ids[{i}] is List[int] when prompt is None, while {prompt_token_ids[i]}."
if prompt_token_ids[i] is None:
assert isinstance(prompts[i], str), \
f"Expect prompts[{i}] is a string when prompt_token_ids is None, while {prompts[i]}."
item = TextTokensPrompt(
prompt=prompts[i],
prompt_token_ids=prompt_token_ids[i])
inputs.append(item)
return inputs
def preprocess_inputs(self, query, is_eval):
prompt_key = self.model_args.get("vllm_prompt_key", "prompt")
input_ids_key = self.model_args.get("vllm_input_ids_key", "input_ids")
prompts = query[prompt_key]
prompts_token_ids = query[input_ids_key]
seq_len = self.model_args.get("seq_length")
parsed_prompts = []
sampling_params = []
for i, prompt in enumerate(prompts):
prompt_token_ids = prompts_token_ids[i]
if 'sampling_param' in query:
sampling_param = query['sampling_param'][i]
else:
sampling_param = self._get_sampling_params(is_eval)
if not self.model_args.get("new_token_limit", False):
max_tokens = seq_len - len(prompt_token_ids)
else:
max_tokens = self.model_args.get("max_new_tokens")
assert max_tokens < seq_len, "max_new_tokens must less than seq length."
sampling_param.max_tokens = max_tokens
item = self._convert_v1_inputs(
prompts=[prompt],
prompt_token_ids=[prompt_token_ids],
)[0]
parsed_prompts.append(item)
sampling_params.append(sampling_param)
return parsed_prompts, sampling_params
def run_vllm(self, parsed_prompts, sampling_params):
outputs = self.llm.generate(
parsed_prompts,
sampling_params,
use_tqdm=True
)
return outputs
def generate_vllm(self, query, is_eval, iteration=0, is_first_run=True):
# resume from stage checkpoint.
outputs = self.load_stage_outputs(is_eval, iteration)
if outputs is not None:
return outputs
if is_first_run: # using for multi-round generate
self.reinit_cache_engine()
parsed_prompts, sampling_params = self.preprocess_inputs(query, is_eval)
outputs = []
if os.getenv("SKIP_GENERATION", None) is None:
outputs = self.run_vllm(parsed_prompts, sampling_params)
# save stage outputs for resume.
self.save_stage_outputs(is_eval, outputs, iteration)
return outputs
def is_last_rank(self):
return True
def num_layers(self):
"""
:meta private:
"""
return self.llm.llm_engine.model_config.hf_config.num_hidden_layers
def peak_memory(self):
"""
:meta private:
"""
self._peak_memory = max(self._peak_memory, torch.cuda.max_memory_allocated() / (1024 ** 3))
return self._peak_memory
@property
def data_parallel_size(self):
"""
:meta private:
"""
return 1
@property
def data_parallel_rank(self):
"""
:meta private:
"""
return 0
@property
def model(self):
if self._model is None:
assert self.worker is not None, \
"please set env variables `VLLM_USE_RAY_SPMD_WORKER=1` and `VLLM_USE_RAY_COMPILED_DAG=1` first."
self._model = self.worker.model_runner.model
return self._model
def tensor_parallel_rank(self):
"""
:meta private:
"""
return parallel_state.get_tensor_model_parallel_rank()
def pipeline_parallel_rank(self):
"""
:meta private:
"""
return get_pipeline_model_parallel_rank()
def tensor_model_parallel_size(self):
return self.tensor_and_expert_model_parallel_size()
def expert_model_parallel_size(self):
return 1
def tensor_and_expert_model_parallel_size(self):
"""
get tensor_and_expert_model_parallel_size
:meta private:
"""
# vLLM not supported to enable expert parallel size
# thus: tensor_and_expert_model_parallel_size = tensor_parallel_size
return parallel_state.get_tensor_model_parallel_world_size()
def model_setup_for_workers(self):
self.llm.llm_engine.model_executor._run_workers("model_setup")
# pylint: disable=unused-argument
def offload_for_workers(self,
to_onload_weights=None,
to_build_grad_buffers=None,
to_onload_main_weights=None,
to_onload_optimizer_states=None):
"""
call offload for all workers
"""
self.llm.llm_engine.model_executor._run_workers("offload")
def onload_for_workers(self,
to_onload_weights=None,
to_build_grad_buffers=None,
to_onload_main_weights=None,
to_onload_optimizer_states=None):
"""
call onload for all workers
"""
self.llm.llm_engine.model_executor._run_workers("onload")
def empty_cache_for_workers(self):
"""
call empty cache for all workers
"""
self.llm.llm_engine.model_executor._run_workers("empty_cache")
def empty_cuda_graph_for_workers(self):
"""
call empty cuda_graph for all workers
"""
self.llm.llm_engine.model_executor._run_workers("empty_cuda_graph")
def offload_weights(self):
"""
offload weights
"""
if self.module_args.offload_weights:
self._memory_manager.offload_weights()
def onload_weights(self):
"""
onload weights
"""
if self.module_args.offload_weights:
self._memory_manager.onload_weights()
def empty_cache(self):
if self.worker.gpu_cache is not None:
for ele in self.worker.gpu_cache: # pylint: disable=unused-variable
ele = None
self.worker.gpu_cache = None # pylint: disable=access-member-before-definition
if hasattr(self.worker, "cache_engine") and self.worker.cache_engine is not None:
for c_e in self.worker.cache_engine:
c_e.cpu_cache = None
c_e.gpu_cache = None
self.worker.cache_engine = None
self.clear_cache()
def clear_cache(self):
if not self.timers("gc").started_:
self.timers("gc").start()
gc.collect()
self.timers("gc").stop()
super().empty_cache()
def empty_cuda_graph(self):
if self.worker.model_runner.graph_runners is not None:
len_graph_runners = len(self.worker.model_runner.graph_runners)
for graph_runner in self.worker.model_runner.graph_runners:
for _, runner in graph_runner.items():
runner.input_buffers = {}
runner.output_buffers = {}
runner._graph = None
for i in range(len_graph_runners):
self.worker.model_runner.graph_runners[i] = {}
self.worker.model_runner.graph_memory_pool = None
def reset_block_manager(self):
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
version = "selfattn"
if (self.llm.llm_engine.scheduler_config.embedding_mode
or self.llm.llm_engine.cache_config.is_attention_free):
version = "placeholder"
else:
version = "selfattn"
if (self.llm.llm_engine.scheduler_config.runner_type == "pooling"
or self.llm.llm_engine.cache_config.is_attention_free):
version = "placeholder"
num_gpu_blocks = self.llm.llm_engine.cache_config.num_gpu_blocks
if num_gpu_blocks:
num_gpu_blocks //= self.module_args.pipeline_model_parallel_size
num_cpu_blocks = self.llm.llm_engine.cache_config.num_cpu_blocks
if num_cpu_blocks:
num_cpu_blocks //= self.module_args.pipeline_model_parallel_size
BlockSpaceManagerImpl = get_block_manager_cls(version)
for scheduler in self.llm.llm_engine.scheduler: # pylint: disable=not-an-iterable
scheduler.block_manager = BlockSpaceManagerImpl( # pylint: disable=abstract-class-instantiated
block_size=self.llm.llm_engine.cache_config.block_size,
num_gpu_blocks=self.llm.llm_engine.cache_config.num_gpu_blocks,
num_cpu_blocks=self.llm.llm_engine.cache_config.num_cpu_blocks,
sliding_window=self.llm.llm_engine.cache_config.sliding_window,
enable_caching=self.llm.llm_engine.cache_config.enable_prefix_caching)
def reinit_cache_engine(self):
# reinit cache engine
self.llm.llm_engine.model_executor._run_workers("clear_cache")
self.llm.llm_engine._initialize_kv_caches()
# reset block menager
self.reset_block_manager()
self.llm.llm_engine.model_executor._run_workers("clear_cache")