in chatlearn/models/vllm_module.py [0:0]
def _init_args(self):
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
self.set_vllm_pp_layer_partition()
if self.model_args.get("apply_replica_id_to_seed", True):
seed = self.model_args.get("seed", 0) + self.replica_id
else:
seed = self.model_args.get("seed", 0)
engine_args = EngineArgs(
model=self.model_args.get("tokenizer"),
tokenizer=self.model_args.get("tokenizer"),
tokenizer_mode=self.model_args.get("tokenizer_mode", "auto"),
trust_remote_code=self.model_args.get("trust_remote_code", True),
tensor_parallel_size=self.module_args.tensor_model_parallel_size,
pipeline_parallel_size=self.module_args.pipeline_model_parallel_size,
dtype=self.model_args.get("params_dtype", "auto"),
quantization=self.model_args.get("quantization", None),
revision=self.model_args.get("revision", None),
tokenizer_revision=self.model_args.get("tokenizer_revision", None),
seed=seed,
gpu_memory_utilization=self.model_args.get("gpu_memory_utilization", 0.90),
block_size=self.model_args.get("block_size"),
swap_space=self.model_args.get("swap_space"),
max_num_batched_tokens=self.model_args.get("max_num_batched_tokens"),
max_num_seqs=self.model_args.get("micro_batch_size"),
max_model_len=self.model_args.get("seq_length"),
enforce_eager=self.model_args.get("enforce_eager", True),
disable_custom_all_reduce=True
)
self.quant_config = None
self.pipeline_layer_offset = None
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
engine_args.max_paddings = self.model_args.get("max_paddings", 256)
engine_args.max_context_len_to_capture = self.model_args.get("max_context_len_to_capture", 8192)
self.model_config, self.cache_config, self.parallel_config, self.scheduler_config, self.lora_config = \
engine_args.create_engine_configs()
self.worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
local_rank=0,
rank=0,
distributed_init_method=None,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
self._init_tokenizer()
elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
engine_args.max_seq_len_to_capture = self.model_args.get("max_context_len_to_capture", 8192)
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
engine_args.num_scheduler_steps = self.model_args.get("num_scheduler_steps", 1)
engine_config = \
engine_args.create_engine_config()
self.cache_config = engine_config.cache_config
self.device_config = engine_config.device_config
self.load_config = engine_config.load_config
self.lora_config = engine_config.lora_config
self.model_config = engine_config.model_config
self.parallel_config = engine_config.parallel_config
self.scheduler_config = engine_config.scheduler_config
self.generation_config_fields = _load_generation_config_dict(
self.model_config)
self.input_processor = INPUT_REGISTRY.create_input_processor(
self.model_config)
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3 and self.scheduler_config.is_multi_step:
from vllm.worker.multi_step_worker import MultiStepWorker
self.worker = MultiStepWorker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
self.load_config,
local_rank=0,
rank=0,
distributed_init_method=None,
lora_config=self.lora_config,
is_driver_worker=True,
)
else:
self.worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
self.load_config,
local_rank=0,
rank=0,
distributed_init_method=None,
lora_config=self.lora_config,
is_driver_worker=True,
)
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)