def _init_args()

in chatlearn/models/vllm_module.py [0:0]


    def _init_args(self):
        if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
            self.set_vllm_pp_layer_partition()

        if self.model_args.get("apply_replica_id_to_seed", True):
            seed = self.model_args.get("seed", 0) + self.replica_id
        else:
            seed = self.model_args.get("seed", 0)

        engine_args = EngineArgs(
            model=self.model_args.get("tokenizer"),
            tokenizer=self.model_args.get("tokenizer"),
            tokenizer_mode=self.model_args.get("tokenizer_mode", "auto"),
            trust_remote_code=self.model_args.get("trust_remote_code", True),
            tensor_parallel_size=self.module_args.tensor_model_parallel_size,
            pipeline_parallel_size=self.module_args.pipeline_model_parallel_size,
            dtype=self.model_args.get("params_dtype", "auto"),
            quantization=self.model_args.get("quantization", None),
            revision=self.model_args.get("revision", None),
            tokenizer_revision=self.model_args.get("tokenizer_revision", None),
            seed=seed,
            gpu_memory_utilization=self.model_args.get("gpu_memory_utilization", 0.90),
            block_size=self.model_args.get("block_size"),
            swap_space=self.model_args.get("swap_space"),
            max_num_batched_tokens=self.model_args.get("max_num_batched_tokens"),
            max_num_seqs=self.model_args.get("micro_batch_size"),
            max_model_len=self.model_args.get("seq_length"),
            enforce_eager=self.model_args.get("enforce_eager", True),
            disable_custom_all_reduce=True
        )

        self.quant_config = None
        self.pipeline_layer_offset = None
        if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
            engine_args.max_paddings = self.model_args.get("max_paddings", 256)
            engine_args.max_context_len_to_capture = self.model_args.get("max_context_len_to_capture", 8192)
            self.model_config, self.cache_config, self.parallel_config, self.scheduler_config, self.lora_config = \
                engine_args.create_engine_configs()
            self.worker = Worker(
                self.model_config,
                self.parallel_config,
                self.scheduler_config,
                local_rank=0,
                rank=0,
                distributed_init_method=None,
                lora_config=self.lora_config,
                kv_cache_dtype=self.cache_config.cache_dtype,
                is_driver_worker=True,
            )
            self._init_tokenizer()
        elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
            engine_args.max_seq_len_to_capture = self.model_args.get("max_context_len_to_capture", 8192)
            if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
                engine_args.num_scheduler_steps = self.model_args.get("num_scheduler_steps", 1)
            engine_config = \
                engine_args.create_engine_config()
            self.cache_config = engine_config.cache_config
            self.device_config = engine_config.device_config
            self.load_config = engine_config.load_config
            self.lora_config = engine_config.lora_config
            self.model_config = engine_config.model_config
            self.parallel_config = engine_config.parallel_config
            self.scheduler_config = engine_config.scheduler_config

            self.generation_config_fields = _load_generation_config_dict(
                self.model_config)
            self.input_processor = INPUT_REGISTRY.create_input_processor(
                self.model_config)

            if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3 and self.scheduler_config.is_multi_step:
                from vllm.worker.multi_step_worker import MultiStepWorker
                self.worker = MultiStepWorker(
                    self.model_config,
                    self.parallel_config,
                    self.scheduler_config,
                    self.device_config,
                    self.cache_config,
                    self.load_config,
                    local_rank=0,
                    rank=0,
                    distributed_init_method=None,
                    lora_config=self.lora_config,
                    is_driver_worker=True,
                )
            else:
                self.worker = Worker(
                    self.model_config,
                    self.parallel_config,
                    self.scheduler_config,
                    self.device_config,
                    self.cache_config,
                    self.load_config,
                    local_rank=0,
                    rank=0,
                    distributed_init_method=None,
                    lora_config=self.lora_config,
                    is_driver_worker=True,
                )
            self.tokenizer = self._init_tokenizer()
            self.detokenizer = Detokenizer(self.tokenizer)