export async function endpointInferenceClient()

in src/lib/server/endpoints/inference-client/endpointInferenceClient.ts [113:342]


export async function endpointInferenceClient(
	input: z.input<typeof endpointInferenceClientParametersSchema>
): Promise<Endpoint> {
	const { model, provider, modelName, baseURL, multimodal, customHeaders } =
		endpointInferenceClientParametersSchema.parse(input);

	if (!!provider && !!baseURL) {
		throw new Error("provider and baseURL cannot both be provided");
	}
	const client = baseURL
		? new InferenceClient(config.HF_TOKEN, { endpointUrl: baseURL })
		: new InferenceClient(config.HF_TOKEN);

	const imageProcessor = multimodal.image ? makeImageProcessor(multimodal.image) : undefined;

	async function prepareFiles(files: MessageFile[], conversationId?: Conversation["_id"]) {
		if (!imageProcessor) {
			return [];
		}
		const processedFiles = await Promise.all(
			files
				.filter((file) => file.mime.startsWith("image/"))
				.map(async (file) => {
					if (file.type === "hash" && conversationId) {
						file = await downloadFile(file.value, conversationId);
					}
					return imageProcessor(file);
				})
		);
		return processedFiles.map((file) => ({
			type: "image_url" as const,
			image_url: {
				url: `data:${file.mime};base64,${file.image.toString("base64")}`,
			},
		}));
	}

	return async ({ messages, generateSettings, tools, toolResults, preprompt, conversationId }) => {
		/* eslint-disable @typescript-eslint/no-explicit-any */
		let messagesArray = (await Promise.all(
			messages.map(async (message) => {
				return {
					role: message.from,
					content: [
						...(await prepareFiles(message.files ?? [], conversationId)),
						{ type: "text" as const, text: message.content },
					],
				};
			})
		)) as any[];

		if (
			!model.systemRoleSupported &&
			messagesArray.length > 0 &&
			messagesArray[0]?.role === "system"
		) {
			messagesArray[0].role = "user";
		} else if (messagesArray[0].role !== "system") {
			messagesArray.unshift({
				role: "system",
				content: preprompt ?? "",
			});
		}

		if (toolResults && toolResults.length > 0) {
			messagesArray = [
				...messagesArray,
				{
					role: "assistant",
					content: [
						{
							type: "text" as const,
							text: "",
						},
					],
					tool_calls: toolResults.map((toolResult) => ({
						type: "function",
						function: {
							name: toolResult.call.name,
							arguments: JSON.stringify(toolResult.call.parameters),
						},
						id: toolResult?.call?.toolId || uuidv4(),
					})),
				},
				...toolResults.map((toolResult) => ({
					role: model.systemRoleSupported ? "tool" : "user",
					content: [
						{
							type: "text" as const,
							text: JSON.stringify(toolResult),
						},
					],
					tool_call_id: toolResult?.call?.toolId || uuidv4(),
				})),
			];
		}

		messagesArray = messagesArray.reduce((acc: typeof messagesArray, current) => {
			if (acc.length === 0 || current.role !== acc[acc.length - 1].role) {
				acc.push(current);
			} else {
				const prevMessage = acc[acc.length - 1];

				prevMessage.content = [
					...prevMessage.content.filter((item: any) => item.type !== "text"),
					...current.content.filter((item: any) => item.type !== "text"),
					{
						type: "text" as const,
						text: [
							...prevMessage.content.filter((item: any) => item.type === "text"),
							...current.content.filter((item: any) => item.type === "text"),
						]
							.map((item: any) => item.text)
							.join("\n")
							.replace(/^\n/, ""),
					},
				];

				prevMessage.files = [...(prevMessage?.files ?? []), ...(current?.files ?? [])];

				prevMessage.tool_calls = [
					...(prevMessage?.tool_calls ?? []),
					...(current?.tool_calls ?? []),
				];
			}
			return acc;
		}, []);
		const toolCallChoices = createChatCompletionToolsArray(tools);
		const stream = client.chatCompletionStream(
			{
				...model.parameters,
				...generateSettings,
				model: modelName ?? model.id ?? model.name,
				provider: baseURL ? undefined : provider || ("hf-inference" as const),
				messages: messagesArray,
				...(toolCallChoices.length > 0 ? { tools: toolCallChoices, tool_choice: "auto" } : {}),
				toolResults,
			},
			{
				fetch: async (url, options) => {
					return fetch(url, {
						...options,
						headers: {
							...options?.headers,
							"X-Use-Cache": "false",
							"ChatUI-Conversation-ID": conversationId?.toString() ?? "",
							...customHeaders,
						},
					});
				},
			}
		);

		let tokenId = 0;
		let generated_text = "";
		const finalToolCalls: DeltaToolCall[] = [];

		async function* convertStream(): AsyncGenerator<
			TextGenerationStreamOutputWithToolsAndWebSources,
			void,
			void
		> {
			for await (const chunk of stream) {
				const token = chunk.choices?.[0]?.delta?.content || "";

				generated_text += token;

				const toolCalls = chunk.choices?.[0]?.delta?.tool_calls ?? [];

				for (const toolCall of toolCalls) {
					const index = toolCall.index ?? 0;

					if (!finalToolCalls[index]) {
						finalToolCalls[index] = toolCall;
					} else {
						if (finalToolCalls[index].function.arguments === undefined) {
							finalToolCalls[index].function.arguments = "";
						}
						if (toolCall.function.arguments) {
							finalToolCalls[index].function.arguments += toolCall.function.arguments;
						}
					}
				}

				yield {
					token: {
						id: tokenId++,
						text: token,
						logprob: 0,
						special: false,
					},
					details: null,
					generated_text: null,
				};
			}

			let mappedToolCalls: ToolCall[] | undefined;
			try {
				if (finalToolCalls.length === 0) {
					mappedToolCalls = undefined;
				} else {
					// Ensure finalToolCalls is an array
					const toolCallsArray = Array.isArray(finalToolCalls) ? finalToolCalls : [finalToolCalls];

					mappedToolCalls = toolCallsArray.map((tc) => ({
						id: tc.id,
						name: tc.function.name ?? "",
						parameters: JSON.parse(jsonrepair(tc.function.arguments || "{}")),
					}));
				}
			} catch (e) {
				logger.error(e, "error mapping tool calls");
			}

			yield {
				token: {
					id: tokenId++,
					text: "",
					logprob: 0,
					special: true,
					toolCalls: mappedToolCalls,
				},
				generated_text,
				details: null,
			};
		}

		return convertStream();
	};
}