in src/models/struxgpt_base.py [0:0]
def call_model_local(self, prompt: Union[str, List[str]], history=None, batched=False, bs=8,
template: Literal['default', 'llama2', 'qwen', 'mistral', 'intern2', 'plain']='default',
**kwargs) -> Union[str, List[str]]:
assert history is None, 'Unsupported now.'
system = kwargs.get('system', self.prompt_system)
if template == 'default':
template = self.model_type
def _format_example(prompt):
if template == 'llama2':
prompt = f'[INST] <<SYS>>\n{system}\n<</SYS>>\n\n{prompt} [/INST]'
elif template == 'mistral':
prompt = f'[INST] {prompt} [/INST]'
elif template in ['intern2', 'qwen']:
prompt = f"<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
elif template == 'plain':
pass
else:
raise NotImplementedError(template)
return prompt
temperature = kwargs.get('temperature', 0.0)
max_tokens = kwargs.get('maxOutputLength', 128)
show_progress = kwargs.get('show_progress', True)
progress_desc = kwargs.get('progress_desc', None)
stop = kwargs.get('stop', ['</s>', '<|im_end|>'])
if batched:
assert self.use_vllm and not isinstance(prompt, str)
sampling_kwargs = SamplingParams(temperature=temperature, stop=stop, max_tokens=max_tokens)
pred = []
pbar = range(0, len(prompt), bs)
if show_progress:
pbar = tqdm(pbar, desc=progress_desc)
for i in pbar:
questions = [_format_example(qry) for qry in prompt[i:i+bs]]
outputs = self.model.generate(questions, sampling_kwargs, use_tqdm=False)
pred.extend([output.outputs[0].text for output in outputs])
else:
assert isinstance(prompt, str)
prompt = _format_example(prompt)
if self.use_vllm:
sampling_kwargs = SamplingParams(temperature=temperature, stop=stop, max_tokens=max_tokens)
outputs = self.model.generate([prompt], sampling_kwargs, use_tqdm=False)
pred = [output.outputs[0].text for output in outputs][0]
else:
inputs = self.tokenizer(prompt, truncation=False, return_tensors="pt").to('cuda:0')
context_length = inputs.input_ids.shape[-1]
eos_token_id = [self.tokenizer.eos_token_id]
for stop_word in stop:
eos_token_id.extend(self.tokenizer.encode(stop_word))
output = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
eos_token_id=eos_token_id,
)[0]
pred = self.tokenizer.decode(output[context_length:], skip_special_tokens=True)
return pred