in LLM/mlx_language_model.py [0:0]
def process(self, prompt):
logger.debug("infering language model...")
language_code = None
if isinstance(prompt, tuple):
prompt, language_code = prompt
if language_code[-5:] == "-auto":
language_code = language_code[:-5]
prompt = f"Please reply to my message in {WHISPER_LANGUAGE_TO_LLM_LANGUAGE[language_code]}. " + prompt
self.chat.append({"role": self.user_role, "content": prompt})
# Remove system messages if using a Gemma model
if "gemma" in self.model_name.lower():
chat_messages = [
msg for msg in self.chat.to_list() if msg["role"] != "system"
]
else:
chat_messages = self.chat.to_list()
prompt = self.tokenizer.apply_chat_template(
chat_messages, tokenize=False, add_generation_prompt=True
)
output = ""
curr_output = ""
for t in stream_generate(
self.model,
self.tokenizer,
prompt,
max_tokens=self.gen_kwargs["max_new_tokens"],
):
output += t.text
curr_output += t.text
if curr_output.endswith((".", "?", "!", "<|end|>")):
yield (curr_output.replace("<|end|>", ""), language_code)
curr_output = ""
generated_text = output.replace("<|end|>", "")
torch.mps.empty_cache()
self.chat.append({"role": "assistant", "content": generated_text})