lib/ChatCompletionStream.ts (757 lines of code) (raw):

import * as Core from "../core.ts"; import { APIUserAbortError, ContentFilterFinishReasonError, LengthFinishReasonError, OpenAIError, } from "../error.ts"; import { type ChatCompletion, type ChatCompletionChunk, type ChatCompletionCreateParams, type ChatCompletionCreateParamsBase, type ChatCompletionCreateParamsStreaming, ChatCompletionTokenLogprob, } from "../resources/chat/completions.ts"; import { AbstractChatCompletionRunner, type AbstractChatCompletionRunnerEvents, } from "./AbstractChatCompletionRunner.ts"; import { type ReadableStream } from "../_shims/mod.ts"; import { Stream } from "../streaming.ts"; import OpenAI from "../mod.ts"; import { ParsedChatCompletion } from "../resources/beta/chat/completions.ts"; import { AutoParseableResponseFormat, hasAutoParseableInput, isAutoParsableResponseFormat, isAutoParsableTool, maybeParseChatCompletion, shouldParseToolCall, } from "./parser.ts"; import { partialParse } from "../_vendor/partial-json-parser/parser.ts"; export interface ContentDeltaEvent { delta: string; snapshot: string; parsed: unknown | null; } export interface ContentDoneEvent<ParsedT = null> { content: string; parsed: ParsedT | null; } export interface RefusalDeltaEvent { delta: string; snapshot: string; } export interface RefusalDoneEvent { refusal: string; } export interface FunctionToolCallArgumentsDeltaEvent { name: string; index: number; arguments: string; parsed_arguments: unknown; arguments_delta: string; } export interface FunctionToolCallArgumentsDoneEvent { name: string; index: number; arguments: string; parsed_arguments: unknown; } export interface LogProbsContentDeltaEvent { content: Array<ChatCompletionTokenLogprob>; snapshot: Array<ChatCompletionTokenLogprob>; } export interface LogProbsContentDoneEvent { content: Array<ChatCompletionTokenLogprob>; } export interface LogProbsRefusalDeltaEvent { refusal: Array<ChatCompletionTokenLogprob>; snapshot: Array<ChatCompletionTokenLogprob>; } export interface LogProbsRefusalDoneEvent { refusal: Array<ChatCompletionTokenLogprob>; } export interface ChatCompletionStreamEvents<ParsedT = null> extends AbstractChatCompletionRunnerEvents { content: (contentDelta: string, contentSnapshot: string) => void; chunk: (chunk: ChatCompletionChunk, snapshot: ChatCompletionSnapshot) => void; "content.delta": (props: ContentDeltaEvent) => void; "content.done": (props: ContentDoneEvent<ParsedT>) => void; "refusal.delta": (props: RefusalDeltaEvent) => void; "refusal.done": (props: RefusalDoneEvent) => void; "tool_calls.function.arguments.delta": ( props: FunctionToolCallArgumentsDeltaEvent, ) => void; "tool_calls.function.arguments.done": ( props: FunctionToolCallArgumentsDoneEvent, ) => void; "logprobs.content.delta": (props: LogProbsContentDeltaEvent) => void; "logprobs.content.done": (props: LogProbsContentDoneEvent) => void; "logprobs.refusal.delta": (props: LogProbsRefusalDeltaEvent) => void; "logprobs.refusal.done": (props: LogProbsRefusalDoneEvent) => void; } export type ChatCompletionStreamParams = & Omit<ChatCompletionCreateParamsBase, "stream"> & { stream?: true; }; interface ChoiceEventState { content_done: boolean; refusal_done: boolean; logprobs_content_done: boolean; logprobs_refusal_done: boolean; current_tool_call_index: number | null; done_tool_calls: Set<number>; } export class ChatCompletionStream<ParsedT = null> extends AbstractChatCompletionRunner< ChatCompletionStreamEvents<ParsedT>, ParsedT > implements AsyncIterable<ChatCompletionChunk> { #params: ChatCompletionCreateParams | null; #choiceEventStates: ChoiceEventState[]; #currentChatCompletionSnapshot: ChatCompletionSnapshot | undefined; constructor(params: ChatCompletionCreateParams | null) { super(); this.#params = params; this.#choiceEventStates = []; } get currentChatCompletionSnapshot(): ChatCompletionSnapshot | undefined { return this.#currentChatCompletionSnapshot; } /** * Intended for use on the frontend, consuming a stream produced with * `.toReadableStream()` on the backend. * * Note that messages sent to the model do not appear in `.on('message')` * in this context. */ static fromReadableStream( stream: ReadableStream, ): ChatCompletionStream<null> { const runner = new ChatCompletionStream(null); runner._run(() => runner._fromReadableStream(stream)); return runner; } static createChatCompletion<ParsedT>( client: OpenAI, params: ChatCompletionStreamParams, options?: Core.RequestOptions, ): ChatCompletionStream<ParsedT> { const runner = new ChatCompletionStream<ParsedT>( params as ChatCompletionCreateParamsStreaming, ); runner._run(() => runner._runChatCompletion( client, { ...params, stream: true }, { ...options, headers: { ...options?.headers, "X-Stainless-Helper-Method": "stream", }, }, ) ); return runner; } #beginRequest() { if (this.ended) return; this.#currentChatCompletionSnapshot = undefined; } #getChoiceEventState( choice: ChatCompletionSnapshot.Choice, ): ChoiceEventState { let state = this.#choiceEventStates[choice.index]; if (state) { return state; } state = { content_done: false, refusal_done: false, logprobs_content_done: false, logprobs_refusal_done: false, done_tool_calls: new Set(), current_tool_call_index: null, }; this.#choiceEventStates[choice.index] = state; return state; } #addChunk(this: ChatCompletionStream<ParsedT>, chunk: ChatCompletionChunk) { if (this.ended) return; const completion = this.#accumulateChatCompletion(chunk); this._emit("chunk", chunk, completion); for (const choice of chunk.choices) { const choiceSnapshot = completion.choices[choice.index]!; if ( choice.delta.content != null && choiceSnapshot.message?.role === "assistant" && choiceSnapshot.message?.content ) { this._emit( "content", choice.delta.content, choiceSnapshot.message.content, ); this._emit("content.delta", { delta: choice.delta.content, snapshot: choiceSnapshot.message.content, parsed: choiceSnapshot.message.parsed, }); } if ( choice.delta.refusal != null && choiceSnapshot.message?.role === "assistant" && choiceSnapshot.message?.refusal ) { this._emit("refusal.delta", { delta: choice.delta.refusal, snapshot: choiceSnapshot.message.refusal, }); } if ( choice.logprobs?.content != null && choiceSnapshot.message?.role === "assistant" ) { this._emit("logprobs.content.delta", { content: choice.logprobs?.content, snapshot: choiceSnapshot.logprobs?.content ?? [], }); } if ( choice.logprobs?.refusal != null && choiceSnapshot.message?.role === "assistant" ) { this._emit("logprobs.refusal.delta", { refusal: choice.logprobs?.refusal, snapshot: choiceSnapshot.logprobs?.refusal ?? [], }); } const state = this.#getChoiceEventState(choiceSnapshot); if (choiceSnapshot.finish_reason) { this.#emitContentDoneEvents(choiceSnapshot); if (state.current_tool_call_index != null) { this.#emitToolCallDoneEvent( choiceSnapshot, state.current_tool_call_index, ); } } for (const toolCall of choice.delta.tool_calls ?? []) { if (state.current_tool_call_index !== toolCall.index) { this.#emitContentDoneEvents(choiceSnapshot); // new tool call started, the previous one is done if (state.current_tool_call_index != null) { this.#emitToolCallDoneEvent( choiceSnapshot, state.current_tool_call_index, ); } } state.current_tool_call_index = toolCall.index; } for (const toolCallDelta of choice.delta.tool_calls ?? []) { const toolCallSnapshot = choiceSnapshot.message.tool_calls ?.[toolCallDelta.index]; if (!toolCallSnapshot?.type) { continue; } if (toolCallSnapshot?.type === "function") { this._emit("tool_calls.function.arguments.delta", { name: toolCallSnapshot.function?.name, index: toolCallDelta.index, arguments: toolCallSnapshot.function.arguments, parsed_arguments: toolCallSnapshot.function.parsed_arguments, arguments_delta: toolCallDelta.function?.arguments ?? "", }); } else { assertNever(toolCallSnapshot?.type); } } } } #emitToolCallDoneEvent( choiceSnapshot: ChatCompletionSnapshot.Choice, toolCallIndex: number, ) { const state = this.#getChoiceEventState(choiceSnapshot); if (state.done_tool_calls.has(toolCallIndex)) { // we've already fired the done event return; } const toolCallSnapshot = choiceSnapshot.message.tool_calls?.[toolCallIndex]; if (!toolCallSnapshot) { throw new Error("no tool call snapshot"); } if (!toolCallSnapshot.type) { throw new Error("tool call snapshot missing `type`"); } if (toolCallSnapshot.type === "function") { const inputTool = this.#params?.tools?.find( (tool) => tool.type === "function" && tool.function.name === toolCallSnapshot.function.name, ); this._emit("tool_calls.function.arguments.done", { name: toolCallSnapshot.function.name, index: toolCallIndex, arguments: toolCallSnapshot.function.arguments, parsed_arguments: isAutoParsableTool(inputTool) ? inputTool.$parseRaw(toolCallSnapshot.function.arguments) : inputTool?.function.strict ? JSON.parse(toolCallSnapshot.function.arguments) : null, }); } else { assertNever(toolCallSnapshot.type); } } #emitContentDoneEvents(choiceSnapshot: ChatCompletionSnapshot.Choice) { const state = this.#getChoiceEventState(choiceSnapshot); if (choiceSnapshot.message.content && !state.content_done) { state.content_done = true; const responseFormat = this.#getAutoParseableResponseFormat(); this._emit("content.done", { content: choiceSnapshot.message.content, parsed: responseFormat ? responseFormat.$parseRaw(choiceSnapshot.message.content) : (null as any), }); } if (choiceSnapshot.message.refusal && !state.refusal_done) { state.refusal_done = true; this._emit("refusal.done", { refusal: choiceSnapshot.message.refusal }); } if (choiceSnapshot.logprobs?.content && !state.logprobs_content_done) { state.logprobs_content_done = true; this._emit("logprobs.content.done", { content: choiceSnapshot.logprobs.content, }); } if (choiceSnapshot.logprobs?.refusal && !state.logprobs_refusal_done) { state.logprobs_refusal_done = true; this._emit("logprobs.refusal.done", { refusal: choiceSnapshot.logprobs.refusal, }); } } #endRequest(): ParsedChatCompletion<ParsedT> { if (this.ended) { throw new OpenAIError(`stream has ended, this shouldn't happen`); } const snapshot = this.#currentChatCompletionSnapshot; if (!snapshot) { throw new OpenAIError(`request ended without sending any chunks`); } this.#currentChatCompletionSnapshot = undefined; this.#choiceEventStates = []; return finalizeChatCompletion(snapshot, this.#params); } protected override async _createChatCompletion( client: OpenAI, params: ChatCompletionCreateParams, options?: Core.RequestOptions, ): Promise<ParsedChatCompletion<ParsedT>> { super._createChatCompletion; const signal = options?.signal; if (signal) { if (signal.aborted) this.controller.abort(); signal.addEventListener("abort", () => this.controller.abort()); } this.#beginRequest(); const stream = await client.chat.completions.create( { ...params, stream: true }, { ...options, signal: this.controller.signal }, ); this._connected(); for await (const chunk of stream) { this.#addChunk(chunk); } if (stream.controller.signal?.aborted) { throw new APIUserAbortError(); } return this._addChatCompletion(this.#endRequest()); } protected async _fromReadableStream( readableStream: ReadableStream, options?: Core.RequestOptions, ): Promise<ChatCompletion> { const signal = options?.signal; if (signal) { if (signal.aborted) this.controller.abort(); signal.addEventListener("abort", () => this.controller.abort()); } this.#beginRequest(); this._connected(); const stream = Stream.fromReadableStream<ChatCompletionChunk>( readableStream, this.controller, ); let chatId; for await (const chunk of stream) { if (chatId && chatId !== chunk.id) { // A new request has been made. this._addChatCompletion(this.#endRequest()); } this.#addChunk(chunk); chatId = chunk.id; } if (stream.controller.signal?.aborted) { throw new APIUserAbortError(); } return this._addChatCompletion(this.#endRequest()); } #getAutoParseableResponseFormat(): | AutoParseableResponseFormat<ParsedT> | null { const responseFormat = this.#params?.response_format; if (isAutoParsableResponseFormat<ParsedT>(responseFormat)) { return responseFormat; } return null; } #accumulateChatCompletion( chunk: ChatCompletionChunk, ): ChatCompletionSnapshot { let snapshot = this.#currentChatCompletionSnapshot; const { choices, ...rest } = chunk; if (!snapshot) { snapshot = this.#currentChatCompletionSnapshot = { ...rest, choices: [], }; } else { Object.assign(snapshot, rest); } for ( const { delta, finish_reason, index, logprobs = null, ...other } of chunk .choices ) { let choice = snapshot.choices[index]; if (!choice) { choice = snapshot.choices[index] = { finish_reason, index, message: {}, logprobs, ...other, }; } if (logprobs) { if (!choice.logprobs) { choice.logprobs = Object.assign({}, logprobs); } else { const { content, refusal, ...rest } = logprobs; assertIsEmpty(rest); Object.assign(choice.logprobs, rest); if (content) { choice.logprobs.content ??= []; choice.logprobs.content.push(...content); } if (refusal) { choice.logprobs.refusal ??= []; choice.logprobs.refusal.push(...refusal); } } } if (finish_reason) { choice.finish_reason = finish_reason; if (this.#params && hasAutoParseableInput(this.#params)) { if (finish_reason === "length") { throw new LengthFinishReasonError(); } if (finish_reason === "content_filter") { throw new ContentFilterFinishReasonError(); } } } Object.assign(choice, other); if (!delta) continue; // Shouldn't happen; just in case. const { content, refusal, function_call, role, tool_calls, ...rest } = delta; assertIsEmpty(rest); Object.assign(choice.message, rest); if (refusal) { choice.message.refusal = (choice.message.refusal || "") + refusal; } if (role) choice.message.role = role; if (function_call) { if (!choice.message.function_call) { choice.message.function_call = function_call; } else { if (function_call.name) { choice.message.function_call.name = function_call.name; } if (function_call.arguments) { choice.message.function_call.arguments ??= ""; choice.message.function_call.arguments += function_call.arguments; } } } if (content) { choice.message.content = (choice.message.content || "") + content; if (!choice.message.refusal && this.#getAutoParseableResponseFormat()) { choice.message.parsed = partialParse(choice.message.content); } } if (tool_calls) { if (!choice.message.tool_calls) choice.message.tool_calls = []; for (const { index, id, type, function: fn, ...rest } of tool_calls) { const tool_call = (choice.message.tool_calls[index] ??= {} as ChatCompletionSnapshot.Choice.Message.ToolCall); Object.assign(tool_call, rest); if (id) tool_call.id = id; if (type) tool_call.type = type; if (fn) tool_call.function ??= { name: fn.name ?? "", arguments: "" }; if (fn?.name) tool_call.function!.name = fn.name; if (fn?.arguments) { tool_call.function!.arguments += fn.arguments; if (shouldParseToolCall(this.#params, tool_call)) { tool_call.function!.parsed_arguments = partialParse( tool_call.function!.arguments, ); } } } } } return snapshot; } [Symbol.asyncIterator]( this: ChatCompletionStream<ParsedT>, ): AsyncIterator<ChatCompletionChunk> { const pushQueue: ChatCompletionChunk[] = []; const readQueue: { resolve: (chunk: ChatCompletionChunk | undefined) => void; reject: (err: unknown) => void; }[] = []; let done = false; this.on("chunk", (chunk) => { const reader = readQueue.shift(); if (reader) { reader.resolve(chunk); } else { pushQueue.push(chunk); } }); this.on("end", () => { done = true; for (const reader of readQueue) { reader.resolve(undefined); } readQueue.length = 0; }); this.on("abort", (err) => { done = true; for (const reader of readQueue) { reader.reject(err); } readQueue.length = 0; }); this.on("error", (err) => { done = true; for (const reader of readQueue) { reader.reject(err); } readQueue.length = 0; }); return { next: async (): Promise<IteratorResult<ChatCompletionChunk>> => { if (!pushQueue.length) { if (done) { return { value: undefined, done: true }; } return new Promise<ChatCompletionChunk | undefined>(( resolve, reject, ) => readQueue.push({ resolve, reject })).then(( chunk, ) => (chunk ? { value: chunk, done: false } : { value: undefined, done: true }) ); } const chunk = pushQueue.shift()!; return { value: chunk, done: false }; }, return: async () => { this.abort(); return { value: undefined, done: true }; }, }; } toReadableStream(): ReadableStream { const stream = new Stream( this[Symbol.asyncIterator].bind(this), this.controller, ); return stream.toReadableStream(); } } function finalizeChatCompletion<ParsedT>( snapshot: ChatCompletionSnapshot, params: ChatCompletionCreateParams | null, ): ParsedChatCompletion<ParsedT> { const { id, choices, created, model, system_fingerprint, ...rest } = snapshot; const completion: ChatCompletion = { ...rest, id, choices: choices.map( ( { message, finish_reason, index, logprobs, ...choiceRest }, ): ChatCompletion.Choice => { if (!finish_reason) { throw new OpenAIError(`missing finish_reason for choice ${index}`); } const { content = null, function_call, tool_calls, ...messageRest } = message; const role = message.role as "assistant"; // this is what we expect; in theory it could be different which would make our types a slight lie but would be fine. if (!role) { throw new OpenAIError(`missing role for choice ${index}`); } if (function_call) { const { arguments: args, name } = function_call; if (args == null) { throw new OpenAIError( `missing function_call.arguments for choice ${index}`, ); } if (!name) { throw new OpenAIError( `missing function_call.name for choice ${index}`, ); } return { ...choiceRest, message: { content, function_call: { arguments: args, name }, role, refusal: message.refusal ?? null, }, finish_reason, index, logprobs, }; } if (tool_calls) { return { ...choiceRest, index, finish_reason, logprobs, message: { ...messageRest, role, content, refusal: message.refusal ?? null, tool_calls: tool_calls.map((tool_call, i) => { const { function: fn, type, id, ...toolRest } = tool_call; const { arguments: args, name, ...fnRest } = fn || {}; if (id == null) { throw new OpenAIError( `missing choices[${index}].tool_calls[${i}].id\n${ str(snapshot) }`, ); } if (type == null) { throw new OpenAIError( `missing choices[${index}].tool_calls[${i}].type\n${ str(snapshot) }`, ); } if (name == null) { throw new OpenAIError( `missing choices[${index}].tool_calls[${i}].function.name\n${ str(snapshot) }`, ); } if (args == null) { throw new OpenAIError( `missing choices[${index}].tool_calls[${i}].function.arguments\n${ str(snapshot) }`, ); } return { ...toolRest, id, type, function: { ...fnRest, name, arguments: args }, }; }), }, }; } return { ...choiceRest, message: { ...messageRest, content, role, refusal: message.refusal ?? null, }, finish_reason, index, logprobs, }; }, ), created, model, object: "chat.completion", ...(system_fingerprint ? { system_fingerprint } : {}), }; return maybeParseChatCompletion(completion, params); } function str(x: unknown) { return JSON.stringify(x); } /** * Represents a streamed chunk of a chat completion response returned by model, * based on the provided input. */ export interface ChatCompletionSnapshot { /** * A unique identifier for the chat completion. */ id: string; /** * A list of chat completion choices. Can be more than one if `n` is greater * than 1. */ choices: Array<ChatCompletionSnapshot.Choice>; /** * The Unix timestamp (in seconds) of when the chat completion was created. */ created: number; /** * The model to generate the completion. */ model: string; // Note we do not include an "object" type on the snapshot, // because the object is not a valid "chat.completion" until finalized. // object: 'chat.completion'; /** * This fingerprint represents the backend configuration that the model runs with. * * Can be used in conjunction with the `seed` request parameter to understand when * backend changes have been made that might impact determinism. */ system_fingerprint?: string; } export namespace ChatCompletionSnapshot { export interface Choice { /** * A chat completion delta generated by streamed model responses. */ message: Choice.Message; /** * The reason the model stopped generating tokens. This will be `stop` if the model * hit a natural stop point or a provided stop sequence, `length` if the maximum * number of tokens specified in the request was reached, `content_filter` if * content was omitted due to a flag from our content filters, or `function_call` * if the model called a function. */ finish_reason: ChatCompletion.Choice["finish_reason"] | null; /** * Log probability information for the choice. */ logprobs: ChatCompletion.Choice.Logprobs | null; /** * The index of the choice in the list of choices. */ index: number; } export namespace Choice { /** * A chat completion delta generated by streamed model responses. */ export interface Message { /** * The contents of the chunk message. */ content?: string | null; refusal?: string | null; parsed?: unknown | null; /** * The name and arguments of a function that should be called, as generated by the * model. */ function_call?: Message.FunctionCall; tool_calls?: Array<Message.ToolCall>; /** * The role of the author of this message. */ role?: "system" | "user" | "assistant" | "function" | "tool"; } export namespace Message { export interface ToolCall { /** * The ID of the tool call. */ id: string; function: ToolCall.Function; /** * The type of the tool. */ type: "function"; } export namespace ToolCall { export interface Function { /** * The arguments to call the function with, as generated by the model in JSON * format. Note that the model does not always generate valid JSON, and may * hallucinate parameters not defined by your function schema. Validate the * arguments in your code before calling your function. */ arguments: string; parsed_arguments?: unknown; /** * The name of the function to call. */ name: string; } } /** * The name and arguments of a function that should be called, as generated by the * model. */ export interface FunctionCall { /** * The arguments to call the function with, as generated by the model in JSON * format. Note that the model does not always generate valid JSON, and may * hallucinate parameters not defined by your function schema. Validate the * arguments in your code before calling your function. */ arguments?: string; /** * The name of the function to call. */ name?: string; } } } } type AssertIsEmpty<T extends {}> = keyof T extends never ? T : never; /** * Ensures the given argument is an empty object, useful for * asserting that all known properties on an object have been * destructured. */ function assertIsEmpty<T extends {}>( obj: AssertIsEmpty<T>, ): asserts obj is AssertIsEmpty<T> { return; } function assertNever(_x: never) {}