in chatlearn/models/vllm_module.py [0:0]
def _add_request_internal(self, prompt_list, prompt_token_id_list, is_eval=False):
if self._need_to_reset_scheduler:
self._reset_scheduler()
self.reset_vllm()
# sampling params
temperature = 0.0
if not self.model_args.get("use_beam_search"):
temperature = self.model_args.get("eval_temperature", 1.0) if is_eval else self.model_args.get("temperature", 1.0)
top_p = self.model_args.get("eval_top_p", 1.0) if is_eval else self.model_args.get("top_p", 1.0)
top_k = self.model_args.get("eval_top_k", -1) if is_eval else self.model_args.get("top_k", -1)
min_p = self.model_args.get("eval_min_p", 0.0) if is_eval else self.model_args.get("min_p", 0.0)
presence_penalty = self.model_args.get("eval_presence_penalty", 0.0) if is_eval else self.model_args.get("presence_penalty", 0.0)
frequency_penalty = self.model_args.get("eval_frequency_penalty", 0.0) if is_eval else self.model_args.get("frequency_penalty", 0.0)
repetition_penalty = self.model_args.get("eval_repetition_penalty", 1.0) if is_eval else self.model_args.get("repetition_penalty", 1.0)
stop = self.model_args.get("stop_token_list", None)
if isinstance(stop, str):
stop = stop.split(";")
seq_len = self.model_args.get("seq_length")
for prompt, prompt_token_ids in zip(prompt_list, prompt_token_id_list):
request_id = next(self.request_counter)
if self.model_args.get("new_token_limit", False):
max_tokens = self.model_args.get("max_new_tokens")
assert max_tokens < seq_len, "max_new_tokens must less than seq length."
prompt_token_ids = prompt_token_ids \
if len(prompt_token_ids) <= seq_len-max_tokens \
else prompt_token_ids[:seq_len-max_tokens]
else:
if len(prompt_token_ids) >= seq_len:
prompt_token_ids = prompt_token_ids[:seq_len-1]
max_tokens = seq_len - len(prompt_token_ids)
if CURRENT_VLLM_VERSION in [VLLMVersion.v_0_3_0, VLLMVersion.v_0_5_1]:
sampling_params = SamplingParams(
n=self.model_args.get("n"),
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
use_beam_search=self.model_args.get("use_beam_search"),
ignore_eos=self.model_args.get("ignore_eos"),
stop=stop,
max_tokens=max_tokens,
logprobs=1,
prompt_logprobs=self.model_args.get("prompt_logprobs", None),
skip_special_tokens=self.model_args.get('skip_special_tokens', True)
)
elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
sampling_params = SamplingParams(
n=self.model_args.get("n"),
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
ignore_eos=self.model_args.get("ignore_eos"),
stop=stop,
max_tokens=max_tokens,
logprobs=1,
prompt_logprobs=self.model_args.get("prompt_logprobs", None),
skip_special_tokens=self.model_args.get('skip_special_tokens', True)
)
else:
raise RuntimeError(f"Unsupported vllm version {CURRENT_VLLM_VERSION}, expect one of {list(VLLMVersion)}")
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0:
self.add_request(
request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids
)
elif CURRENT_VLLM_VERSION in \
[VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]:
inputs = self.convert_v1_inputs(
prompts=[prompt],
prompt_token_ids=[prompt_token_ids],
)[0]
self.add_request(
request_id,
inputs,
sampling_params
)
self.outputs = []
self.num_requests = self.get_num_unfinished_requests()
self._reset_metrics_stats_args()
self.pbar = tqdm(total=self.num_requests, desc=f"Processed prompts (replica {self.replica_id+1}/{self._num_replica})")
self._need_to_reset_scheduler = True