def _add_request_internal()

in chatlearn/models/vllm_module.py [0:0]


    def _add_request_internal(self, prompt_list, prompt_token_id_list, is_eval=False):
        if self._need_to_reset_scheduler:
            self._reset_scheduler()
        self.reset_vllm()

        # sampling params
        temperature = 0.0
        if not self.model_args.get("use_beam_search"):
            temperature = self.model_args.get("eval_temperature", 1.0) if is_eval else self.model_args.get("temperature", 1.0)
        top_p = self.model_args.get("eval_top_p", 1.0) if is_eval else self.model_args.get("top_p", 1.0)
        top_k = self.model_args.get("eval_top_k", -1) if is_eval else self.model_args.get("top_k", -1)
        min_p = self.model_args.get("eval_min_p", 0.0) if is_eval else self.model_args.get("min_p", 0.0)
        presence_penalty = self.model_args.get("eval_presence_penalty", 0.0) if is_eval else self.model_args.get("presence_penalty", 0.0)
        frequency_penalty = self.model_args.get("eval_frequency_penalty", 0.0) if is_eval else self.model_args.get("frequency_penalty", 0.0)
        repetition_penalty = self.model_args.get("eval_repetition_penalty", 1.0) if is_eval else self.model_args.get("repetition_penalty", 1.0)

        stop = self.model_args.get("stop_token_list", None)
        if isinstance(stop, str):
            stop = stop.split(";")
        seq_len = self.model_args.get("seq_length")
        for prompt, prompt_token_ids in zip(prompt_list, prompt_token_id_list):
            request_id = next(self.request_counter)
            if self.model_args.get("new_token_limit", False):
                max_tokens = self.model_args.get("max_new_tokens")
                assert max_tokens < seq_len, "max_new_tokens must less than seq length."
                prompt_token_ids = prompt_token_ids \
                    if len(prompt_token_ids) <= seq_len-max_tokens \
                    else prompt_token_ids[:seq_len-max_tokens]
            else:
                if len(prompt_token_ids) >= seq_len:
                    prompt_token_ids = prompt_token_ids[:seq_len-1]
                max_tokens = seq_len - len(prompt_token_ids)

            if CURRENT_VLLM_VERSION in [VLLMVersion.v_0_3_0, VLLMVersion.v_0_5_1]:
                sampling_params = SamplingParams(
                    n=self.model_args.get("n"),
                    presence_penalty=presence_penalty,
                    frequency_penalty=frequency_penalty,
                    repetition_penalty=repetition_penalty,
                    temperature=temperature,
                    top_p=top_p,
                    top_k=top_k,
                    min_p=min_p,
                    use_beam_search=self.model_args.get("use_beam_search"),
                    ignore_eos=self.model_args.get("ignore_eos"),
                    stop=stop,
                    max_tokens=max_tokens,
                    logprobs=1,
                    prompt_logprobs=self.model_args.get("prompt_logprobs", None),
                    skip_special_tokens=self.model_args.get('skip_special_tokens', True)
                )
            elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
                sampling_params = SamplingParams(
                    n=self.model_args.get("n"),
                    presence_penalty=presence_penalty,
                    frequency_penalty=frequency_penalty,
                    repetition_penalty=repetition_penalty,
                    temperature=temperature,
                    top_p=top_p,
                    top_k=top_k,
                    min_p=min_p,
                    ignore_eos=self.model_args.get("ignore_eos"),
                    stop=stop,
                    max_tokens=max_tokens,
                    logprobs=1,
                    prompt_logprobs=self.model_args.get("prompt_logprobs", None),
                    skip_special_tokens=self.model_args.get('skip_special_tokens', True)
                )
            else:
                raise RuntimeError(f"Unsupported vllm version {CURRENT_VLLM_VERSION}, expect one of {list(VLLMVersion)}")

            if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
                self.add_request(
                    request_id,
                    prompt,
                    sampling_params,
                    prompt_token_ids=prompt_token_ids
                )
            elif CURRENT_VLLM_VERSION in \
                    [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
                inputs = self.convert_v1_inputs(
                    prompts=[prompt],
                    prompt_token_ids=[prompt_token_ids],
                )[0]
                self.add_request(
                    request_id,
                    inputs,
                    sampling_params
                )

        self.outputs = []
        self.num_requests = self.get_num_unfinished_requests()
        self._reset_metrics_stats_args()
        self.pbar = tqdm(total=self.num_requests, desc=f"Processed prompts (replica {self.replica_id+1}/{self._num_replica})")
        self._need_to_reset_scheduler = True