export async function endpointCohere()

in src/lib/server/endpoints/cohere/endpointCohere.ts [21:163]


export async function endpointCohere(
	input: z.input<typeof endpointCohereParametersSchema>
): Promise<Endpoint> {
	const { apiKey, clientName, model, raw, forceSingleStep } =
		endpointCohereParametersSchema.parse(input);

	let cohere: CohereClient;

	try {
		cohere = new (await import("cohere-ai")).CohereClient({
			token: apiKey,
			clientName,
		});
	} catch (e) {
		throw new Error("Failed to import cohere-ai", { cause: e });
	}

	return async ({ messages, preprompt, generateSettings, continueMessage, tools, toolResults }) => {
		let system = preprompt;
		if (messages?.[0]?.from === "system") {
			system = messages[0].content;
		}

		// Tools must use [A-z_] for their names and directly_answer is banned
		// It's safe to convert the tool names because we treat - and _ the same
		tools = tools
			?.filter((tool) => !toolHasName("directly_answer", tool))
			.map((tool) => ({ ...tool, name: tool.name.replaceAll("-", "_") }));

		const parameters = { ...model.parameters, ...generateSettings };

		return (async function* () {
			let stream;
			let tokenId = 0;

			if (raw) {
				const prompt = await buildPrompt({
					messages,
					model,
					preprompt: system,
					continueMessage,
					tools,
					toolResults,
				});

				stream = await cohere.chatStream({
					forceSingleStep,
					message: prompt,
					rawPrompting: true,
					model: model.id ?? model.name,
					p: parameters?.top_p,
					k: parameters?.top_k,
					maxTokens: parameters?.max_new_tokens,
					temperature: parameters?.temperature,
					stopSequences: parameters?.stop,
					frequencyPenalty: parameters?.frequency_penalty,
				});
			} else {
				const formattedMessages = messages
					.filter((message) => message.from !== "system")
					.map((message) => ({
						role: message.from === "user" ? "USER" : "CHATBOT",
						message: message.content,
					})) satisfies Cohere.Message[];

				stream = await cohere
					.chatStream({
						forceSingleStep,
						model: model.id ?? model.name,
						chatHistory: formattedMessages.slice(0, -1),
						message: formattedMessages[formattedMessages.length - 1].message,
						preamble: system,
						p: parameters?.top_p,
						k: parameters?.top_k,
						maxTokens: parameters?.max_new_tokens,
						temperature: parameters?.temperature,
						stopSequences: parameters?.stop,
						frequencyPenalty: parameters?.frequency_penalty,
						tools,
						toolResults:
							toolResults?.length && toolResults?.length > 0
								? toolResults?.map((toolResult) => {
										if (toolResult.status === ToolResultStatus.Error) {
											return { call: toolResult.call, outputs: [{ error: toolResult.message }] };
										}
										return { call: toolResult.call, outputs: toolResult.outputs };
									})
								: undefined,
					})
					.catch(async (err) => {
						if (!err.body) throw err;

						// Decode the error message and throw
						const message = await convertStreamToBuffer(err.body).catch(() => {
							throw err;
						});
						throw Error(message, { cause: err });
					});
			}

			for await (const output of stream) {
				if (output.eventType === "text-generation") {
					yield {
						token: {
							id: tokenId++,
							text: output.text,
							logprob: 0,
							special: false,
						},
						generated_text: null,
						details: null,
					} satisfies TextGenerationStreamOutput;
				} else if (output.eventType === "tool-calls-generation") {
					yield {
						token: {
							id: tokenId++,
							text: "",
							logprob: 0,
							special: true,
							toolCalls: output.toolCalls as ToolCall[],
						},
						generated_text: null,
						details: null,
					};
				} else if (output.eventType === "stream-end") {
					if (["ERROR", "ERROR_TOXIC", "ERROR_LIMIT"].includes(output.finishReason)) {
						throw new Error(output.finishReason);
					}
					yield {
						token: {
							id: tokenId++,
							text: "",
							logprob: 0,
							special: true,
						},
						generated_text: output.response.text,
						details: null,
					};
				}
			}
		})();
	};
}