in router/src/lib.rs [941:1036]
fn try_into_generate(self, infer: &Infer) -> Result<(GenerateRequest, bool), InferError> {
let ChatRequest {
model,
max_tokens,
messages,
seed,
stop,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
presence_penalty,
frequency_penalty,
top_p,
top_logprobs,
..
} = self;
let repetition_penalty = presence_penalty.map(|x| x + 2.0);
let max_new_tokens = max_tokens;
let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);
let stop = stop.unwrap_or_default();
// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(0.0) => (false, None),
other => (true, other),
};
if response_format.is_some() && tools.is_some() {
return Err(InferError::ToolError(
"Grammar and tools are mutually exclusive".into(),
));
}
let (inputs, grammar, using_tools) = match response_format {
Some(format) => {
let inputs = infer.apply_chat_template(messages, None)?;
(inputs, Some(format), false)
}
None => {
if let Some(tools) = tools {
match ToolGrammar::apply(tools, tool_choice)? {
Some((updated_tools, tool_schema)) => {
let grammar = GrammarType::Json(serde_json::json!(tool_schema));
let inputs: String = infer.apply_chat_template(
messages,
Some((updated_tools, tool_prompt)),
)?;
(inputs, Some(grammar), true)
}
None => {
// same as if no response_format or tools are set
let inputs = infer.apply_chat_template(messages, None)?;
(inputs, None, false)
}
}
} else {
// if no response_format or tools are set simply apply the chat template to generate inputs
let inputs = infer.apply_chat_template(messages, None)?;
(inputs, None, false)
}
}
};
Ok((
GenerateRequest {
inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters {
best_of: None,
temperature,
repetition_penalty,
frequency_penalty,
top_k: None,
top_p,
typical_p: None,
do_sample,
max_new_tokens,
return_full_text: None,
stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: false,
seed,
top_n_tokens: top_logprobs,
grammar,
adapter_id: model.filter(|m| *m != "tgi"),
},
},
using_tools,
))
}