chatlearn/models/vllm_module_v2.py (422 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """VLLM module v2""" import gc import inspect import os import torch from transformers import AutoTokenizer, AutoConfig from vllm import SamplingParams from vllm.config import LoadFormat from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import RayWorkerWrapper from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion from chatlearn.utils.global_vars import set_vllm_actors from chatlearn.utils.vllm_import_helper import parallel_state from chatlearn.utils.vllm_import_helper import get_block_manager_cls from chatlearn.utils.vllm_import_helper import get_pipeline_model_parallel_rank from chatlearn.utils.vllm_import_helper import TextTokensPrompt from chatlearn.utils.vllm_utils import initialize_vllm from .torch_module import TorchModule try: from .vllm.inference import InferenceMemoryManager except ImportError: InferenceMemoryManager = None # pylint: disable=unexpected-keyword-arg class VLLMModuleV2(TorchModule, RayWorkerWrapper): """VLLMModuleV2""" def __init__(self, *args, **kwargs): TorchModule.__init__(self, *args) # avoid overwrite methods methods_class1 = {method[0] for method in inspect.getmembers(TorchModule, predicate=inspect.isfunction)} methods_class2 = {method[0] for method in inspect.getmembers(RayWorkerWrapper, predicate=inspect.isfunction)} common_methods = methods_class1.intersection(methods_class2) # common method is '__init__' assert common_methods == {'__init__'}, \ f"Expected only '__init__' as common method for TorchModule and RayWorkerWrapper, but got {common_methods}" self.local_rank = 0 if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called else: if 'vllm_actor_type' in kwargs and 'worker' == kwargs['vllm_actor_type']: vllm_config = self.init_engine_args() RayWorkerWrapper.__init__(self, vllm_config=vllm_config) # pylint: disable=non-parent-init-called os.environ['VLLM_HOST_IP'] = self.get_address() self.tokenizer = None self._model = None self.llm = None self.model_config = AutoConfig.from_pretrained(self.model_args['tokenizer']) self.set_vllm_pp_layer_partition() self._metric_prefix = 'vllm_inference' def add_extra_args(self, parser): """ Add extra arguments for vllm. Args ---- parser : ArgumentParser Add extra arguments. """ group = parser.add_argument_group(title='vLLM extra arguments') group.add_argument('--distributed-backend', default='nccl', choices=['nccl', 'gloo'], help='Which backend to use for distributed training.') group.add_argument('--distributed-timeout-minutes', type=int, default=10, help='Timeout minutes for torch.distributed.') return parser def init_engine_args(self): dtype = self.model_args.get("dtype", "bfloat16") if self.model_args.get("fp16", False): dtype = "float16" load_format = self.model_args.get("vllm_load_format", LoadFormat.DUMMY) if load_format == LoadFormat.DUMMY: model_loader_extra_config = self.model_args else: model_loader_extra_config = None if self.model_args.get("apply_replica_id_to_seed", True): seed = self.model_args.get("seed", 0) + self.replica_id else: seed = self.model_args.get("seed", 0) from vllm.engine.arg_utils import AsyncEngineArgs # pylint: disable=import-outside-toplevel from vllm.usage.usage_lib import UsageContext # pylint: disable=import-outside-toplevel self.engine_args = AsyncEngineArgs( model=self.model_args['tokenizer'], tokenizer=self.model_args['tokenizer'], max_seq_len_to_capture=self.model_args.get("max_seq_len_to_capture", 32768), seed=seed, # load model: 'dummy' for megatron ckpt or mock weight; others for hf ckpt. load_format=load_format, model_loader_extra_config=model_loader_extra_config, # parallelism strategy tensor_parallel_size=self.module_args.tensor_model_parallel_size, pipeline_parallel_size=self.module_args.pipeline_model_parallel_size, dtype=dtype, # scheduling strategy max_num_seqs=self.module_args.generation_batch_size, max_num_batched_tokens = self.model_args.get("max_num_batched_tokens", None), num_scheduler_steps=self.model_args.get("num_scheduler_steps", 1), gpu_memory_utilization=self.model_args.get("gpu_memory_utilization", 0.90), # logger disable_log_requests=self.model_args.get("disable_log_requests", True), disable_log_stats=self.model_args.get("disable_log_stats", True), trust_remote_code=True, enforce_eager=self.model_args.get("enforce_eager", True), disable_custom_all_reduce=True, distributed_executor_backend="ray", preemption_mode=self.model_args.get("preemption_mode", 'recompute') , # swap, recompute swap_space=self.model_args.get("swap_space", 16)) return self.engine_args.create_engine_config(usage_context=UsageContext.ENGINE_CONTEXT) def init(self): """ :meta private: """ parallel_state.set_custom_all_reduce(False) initialize_vllm(extra_args_provider=self.add_extra_args, ignore_unknown_args=True, args_dict=self.model_args) def setup(self): """Set up tokenizer.""" super().setup() tokenizer = AutoTokenizer.from_pretrained(self.model_args['tokenizer']) tokenizer.tokenizer = tokenizer self.tokenizer = tokenizer def setup_vllm(self, workers): if self.llm is not None: # for evaluator return # setup vllm engine in rank 0 os.environ['VLLM_HOST_IP'] = self.get_address() set_vllm_actors(workers) dtype = self.model_args.get("dtype", "bfloat16") if self.model_args.get("fp16", False): dtype = "float16" load_format = self.model_args.get("vllm_load_format", LoadFormat.DUMMY) if load_format == LoadFormat.DUMMY: model_loader_extra_config = self.model_args else: model_loader_extra_config = None if self.model_args.get("apply_replica_id_to_seed", True): seed = self.model_args.get("seed", 0) + self.replica_id else: seed = self.model_args.get("seed", 0) self.llm = LLM( model=self.model_args['tokenizer'], tokenizer=self.model_args['tokenizer'], max_seq_len_to_capture=self.model_args.get("max_seq_len_to_capture", 32768), seed=seed, # load model: 'dummy' for megatron ckpt or mock weight; others for hf ckpt. load_format=load_format, model_loader_extra_config=model_loader_extra_config, # parallelism strategy tensor_parallel_size=self.module_args.tensor_model_parallel_size, pipeline_parallel_size=self.module_args.pipeline_model_parallel_size, dtype=dtype, # scheduling strategy max_num_seqs=self.module_args.generation_batch_size, max_num_batched_tokens = self.model_args.get("max_num_batched_tokens", None), num_scheduler_steps=self.model_args.get("num_scheduler_steps", 1), gpu_memory_utilization=self.model_args.get("gpu_memory_utilization", 0.90), # logger disable_log_requests=self.model_args.get("disable_log_requests", True), disable_log_stats=self.model_args.get("disable_log_stats", True), trust_remote_code=True, enforce_eager=self.model_args.get("enforce_eager", False), disable_custom_all_reduce=True, distributed_executor_backend="ray", preemption_mode=self.model_args.get("preemption_mode", 'recompute') , # swap, recompute swap_space=self.model_args.get("swap_space", 16)) self.llm.llm_engine.model_executor._run_workers("init_memory_manager") self.offload_for_workers() self.empty_cuda_graph_for_workers() self.empty_cache_for_workers() def dump_parameters(self, dump_path_root): self.onload_for_workers() self.llm.llm_engine.model_executor._run_workers("worker_dump_parameters", dump_path_root=dump_path_root) def worker_dump_parameters(self, dump_path_root): tp_rank = self.tensor_parallel_rank() model = self.model if isinstance(model, list): model = model[0] dir_path = os.path.join(dump_path_root, str(tp_rank)) if not os.path.exists(dir_path): os.makedirs(dir_path) self._logger.info(f"dump parameters to {dir_path}") for name, param in self.named_parameters.items(): pt_file = os.path.join(dir_path, name) torch.save(param.data.clone(), pt_file) def init_memory_manager(self): if self.module_args.offload_weights: if InferenceMemoryManager is None: raise Exception("Import InferenceMemoryManager failed, you may need to set right Megatron path first.") self._memory_manager = InferenceMemoryManager( self.model, self.runtime_args.bucket_size_mb_in_memory_manager, ) def set_vllm_pp_layer_partition(self): pipeline_world_size = self.module_args.pipeline_model_parallel_size num_layers = self.model_config.num_hidden_layers remainder = num_layers % pipeline_world_size if not self.model_args.get("allow_padding_num_layers", None): assert remainder == 0, \ f"expect num_layers % pipeline_model_size == 0 when VLLM_PP_LAYER_PARTITION is not set. \ while num_layers = {num_layers} pipeline_model_size = {pipeline_world_size}" return if remainder > 0: assert not self.model_args.get("standalone_embedding_stage", False), \ "not support standalone embedding stage if allow_padding_num_layers is true" # pad num_layers to make num_layers % pipeline_model_parallel_size == 0 num_layers_with_padding = num_layers - remainder + pipeline_world_size else: num_layers_with_padding = num_layers num_layers_without_padding = num_layers num_layers = num_layers_with_padding num_layers_per_stage_with_padding = ( num_layers // pipeline_world_size) # Each stage gets a contiguous set of layers. if self.model_args.get("pipeline_layers", None) is not None: rank_sizes = self.model_args.get("pipeline_layers", None) assert isinstance(rank_sizes, list) and all(isinstance(ele, int) for ele in rank_sizes), \ f"pipeline_layers expected to be list, and num layer of each stage to be integer, while {rank_sizes}." else: rank_sizes = [num_layers_per_stage_with_padding] * pipeline_world_size num_padding = num_layers - num_layers_without_padding if num_padding > 0: assert num_padding == 2, \ "Support num_padding_lsyers == 2 when applies inbalanced pp. Please set `args.pipeline_layers` for VLLMModule." for _index in range(-1, num_padding - 1): rank_sizes[_index] -= 1 assert len(rank_sizes) == pipeline_world_size # set env variable VLLM_PP_LAYER_PARTITION vllm_pp_layer_partition = ",".join([str(ele) for ele in rank_sizes]) if os.getenv("VLLM_PP_LAYER_PARTITION", None) is not None: env_vllm_pp_layer_partition = os.getenv("VLLM_PP_LAYER_PARTITION", None) if vllm_pp_layer_partition != env_vllm_pp_layer_partition: self._logger.warning( f"expect VLLM_PP_LAYER_PARTITION to be {vllm_pp_layer_partition}, while {env_vllm_pp_layer_partition}") os.environ["VLLM_PP_LAYER_PARTITION"] = vllm_pp_layer_partition self._logger.info(f"Set VLLM_PP_LAYER_PARTITION={vllm_pp_layer_partition}") def _get_sampling_params(self, is_eval): temperature = 0.0 if not self.model_args.get("use_beam_search", False): temperature = self.model_args.get("eval_temperature", 1.0) if is_eval else self.model_args.get( "temperature", 1.0) top_p = self.model_args.get("eval_top_p", 1.0) if is_eval else self.model_args.get("top_p", 1.0) top_k = self.model_args.get("eval_top_k", -1) if is_eval else self.model_args.get("top_k", -1) min_p = self.model_args.get("eval_min_p", 0.0) if is_eval else self.model_args.get("min_p", 0.0) presence_penalty = self.model_args.get("eval_presence_penalty", 0.0) if is_eval else self.model_args.get( "presence_penalty", 0.0) frequency_penalty = self.model_args.get("eval_frequency_penalty", 0.0) if is_eval else self.model_args.get( "frequency_penalty", 0.0) repetition_penalty = self.model_args.get("eval_repetition_penalty", 1.0) if is_eval else self.model_args.get( "repetition_penalty", 1.0) stop = self.model_args.get("stop_token_list", None) if stop is not None and isinstance(stop, str): stop = stop.split(";") sampling_params = SamplingParams( n=self.model_args.get("n", 1), presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, repetition_penalty=repetition_penalty, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, ignore_eos=self.model_args.get("ignore_eos", False), stop=stop, logprobs=self.model_args.get("logprobs", 1), detokenize=self.model_args.get("detokenize", False), prompt_logprobs=self.model_args.get("prompt_logprobs", None), skip_special_tokens=self.model_args.get('skip_special_tokens', True) ) # VLLMVersion.v_0_3_0, VLLMVersion.v_0_5_1 if hasattr(sampling_params, 'use_beam_search'): sampling_params.use_beam_search = self.model_args.get("use_beam_search", False) return sampling_params def update_weights_from_ipc_handles(self, reduce_data): for name, reduced in reduce_data.items(): rebuild_func, rebuild_args = reduced reconstructed_tensor = rebuild_func(*rebuild_args) self.model.load_weights([(name, reconstructed_tensor)]) def _convert_v1_inputs(self, prompts, prompt_token_ids): num_requests = len(prompts) assert num_requests == len(prompt_token_ids), \ ("The lengths of prompts and prompt_token_ids must be the same.") inputs = [] for i in range(num_requests): if prompts[i] is None: assert isinstance(prompt_token_ids[i], List[int]), \ f"Expect prompt_token_ids[{i}] is List[int] when prompt is None, while {prompt_token_ids[i]}." if prompt_token_ids[i] is None: assert isinstance(prompts[i], str), \ f"Expect prompts[{i}] is a string when prompt_token_ids is None, while {prompts[i]}." item = TextTokensPrompt( prompt=prompts[i], prompt_token_ids=prompt_token_ids[i]) inputs.append(item) return inputs def preprocess_inputs(self, query, is_eval): prompt_key = self.model_args.get("vllm_prompt_key", "prompt") input_ids_key = self.model_args.get("vllm_input_ids_key", "input_ids") prompts = query[prompt_key] prompts_token_ids = query[input_ids_key] seq_len = self.model_args.get("seq_length") parsed_prompts = [] sampling_params = [] for i, prompt in enumerate(prompts): prompt_token_ids = prompts_token_ids[i] if 'sampling_param' in query: sampling_param = query['sampling_param'][i] else: sampling_param = self._get_sampling_params(is_eval) if not self.model_args.get("new_token_limit", False): max_tokens = seq_len - len(prompt_token_ids) else: max_tokens = self.model_args.get("max_new_tokens") assert max_tokens < seq_len, "max_new_tokens must less than seq length." sampling_param.max_tokens = max_tokens item = self._convert_v1_inputs( prompts=[prompt], prompt_token_ids=[prompt_token_ids], )[0] parsed_prompts.append(item) sampling_params.append(sampling_param) return parsed_prompts, sampling_params def run_vllm(self, parsed_prompts, sampling_params): outputs = self.llm.generate( parsed_prompts, sampling_params, use_tqdm=True ) return outputs def generate_vllm(self, query, is_eval, iteration=0, is_first_run=True): # resume from stage checkpoint. outputs = self.load_stage_outputs(is_eval, iteration) if outputs is not None: return outputs if is_first_run: # using for multi-round generate self.reinit_cache_engine() parsed_prompts, sampling_params = self.preprocess_inputs(query, is_eval) outputs = [] if os.getenv("SKIP_GENERATION", None) is None: outputs = self.run_vllm(parsed_prompts, sampling_params) # save stage outputs for resume. self.save_stage_outputs(is_eval, outputs, iteration) return outputs def is_last_rank(self): return True def num_layers(self): """ :meta private: """ return self.llm.llm_engine.model_config.hf_config.num_hidden_layers def peak_memory(self): """ :meta private: """ self._peak_memory = max(self._peak_memory, torch.cuda.max_memory_allocated() / (1024 ** 3)) return self._peak_memory @property def data_parallel_size(self): """ :meta private: """ return 1 @property def data_parallel_rank(self): """ :meta private: """ return 0 @property def model(self): if self._model is None: assert self.worker is not None, \ "please set env variables `VLLM_USE_RAY_SPMD_WORKER=1` and `VLLM_USE_RAY_COMPILED_DAG=1` first." self._model = self.worker.model_runner.model return self._model def tensor_parallel_rank(self): """ :meta private: """ return parallel_state.get_tensor_model_parallel_rank() def pipeline_parallel_rank(self): """ :meta private: """ return get_pipeline_model_parallel_rank() def tensor_model_parallel_size(self): return self.tensor_and_expert_model_parallel_size() def expert_model_parallel_size(self): return 1 def tensor_and_expert_model_parallel_size(self): """ get tensor_and_expert_model_parallel_size :meta private: """ # vLLM not supported to enable expert parallel size # thus: tensor_and_expert_model_parallel_size = tensor_parallel_size return parallel_state.get_tensor_model_parallel_world_size() def model_setup_for_workers(self): self.llm.llm_engine.model_executor._run_workers("model_setup") # pylint: disable=unused-argument def offload_for_workers(self, to_onload_weights=None, to_build_grad_buffers=None, to_onload_main_weights=None, to_onload_optimizer_states=None): """ call offload for all workers """ self.llm.llm_engine.model_executor._run_workers("offload") def onload_for_workers(self, to_onload_weights=None, to_build_grad_buffers=None, to_onload_main_weights=None, to_onload_optimizer_states=None): """ call onload for all workers """ self.llm.llm_engine.model_executor._run_workers("onload") def empty_cache_for_workers(self): """ call empty cache for all workers """ self.llm.llm_engine.model_executor._run_workers("empty_cache") def empty_cuda_graph_for_workers(self): """ call empty cuda_graph for all workers """ self.llm.llm_engine.model_executor._run_workers("empty_cuda_graph") def offload_weights(self): """ offload weights """ if self.module_args.offload_weights: self._memory_manager.offload_weights() def onload_weights(self): """ onload weights """ if self.module_args.offload_weights: self._memory_manager.onload_weights() def empty_cache(self): if self.worker.gpu_cache is not None: for ele in self.worker.gpu_cache: # pylint: disable=unused-variable ele = None self.worker.gpu_cache = None # pylint: disable=access-member-before-definition if hasattr(self.worker, "cache_engine") and self.worker.cache_engine is not None: for c_e in self.worker.cache_engine: c_e.cpu_cache = None c_e.gpu_cache = None self.worker.cache_engine = None self.clear_cache() def clear_cache(self): if not self.timers("gc").started_: self.timers("gc").start() gc.collect() self.timers("gc").stop() super().empty_cache() def empty_cuda_graph(self): if self.worker.model_runner.graph_runners is not None: len_graph_runners = len(self.worker.model_runner.graph_runners) for graph_runner in self.worker.model_runner.graph_runners: for _, runner in graph_runner.items(): runner.input_buffers = {} runner.output_buffers = {} runner._graph = None for i in range(len_graph_runners): self.worker.model_runner.graph_runners[i] = {} self.worker.model_runner.graph_memory_pool = None def reset_block_manager(self): if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: version = "selfattn" if (self.llm.llm_engine.scheduler_config.embedding_mode or self.llm.llm_engine.cache_config.is_attention_free): version = "placeholder" else: version = "selfattn" if (self.llm.llm_engine.scheduler_config.runner_type == "pooling" or self.llm.llm_engine.cache_config.is_attention_free): version = "placeholder" num_gpu_blocks = self.llm.llm_engine.cache_config.num_gpu_blocks if num_gpu_blocks: num_gpu_blocks //= self.module_args.pipeline_model_parallel_size num_cpu_blocks = self.llm.llm_engine.cache_config.num_cpu_blocks if num_cpu_blocks: num_cpu_blocks //= self.module_args.pipeline_model_parallel_size BlockSpaceManagerImpl = get_block_manager_cls(version) for scheduler in self.llm.llm_engine.scheduler: # pylint: disable=not-an-iterable scheduler.block_manager = BlockSpaceManagerImpl( # pylint: disable=abstract-class-instantiated block_size=self.llm.llm_engine.cache_config.block_size, num_gpu_blocks=self.llm.llm_engine.cache_config.num_gpu_blocks, num_cpu_blocks=self.llm.llm_engine.cache_config.num_cpu_blocks, sliding_window=self.llm.llm_engine.cache_config.sliding_window, enable_caching=self.llm.llm_engine.cache_config.enable_prefix_caching) def reinit_cache_engine(self): # reinit cache engine self.llm.llm_engine.model_executor._run_workers("clear_cache") self.llm.llm_engine._initialize_kv_caches() # reset block menager self.reset_block_manager() self.llm.llm_engine.model_executor._run_workers("clear_cache")