lib/AssistantStream.ts (740 lines of code) (raw):

import { ImageFile, ImageFileContentBlock, Message, MessageContent, MessageContentDelta, Messages, Text, TextContentBlock, TextDelta, } from "../resources/beta/threads/messages.ts"; import * as Core from "../core.ts"; import { RequestOptions } from "../core.ts"; import { Run, RunCreateParamsBase, RunCreateParamsStreaming, Runs, RunSubmitToolOutputsParamsBase, RunSubmitToolOutputsParamsStreaming, } from "../resources/beta/threads/runs/runs.ts"; import { type ReadableStream } from "../_shims/mod.ts"; import { Stream } from "../streaming.ts"; import { APIUserAbortError, OpenAIError } from "../error.ts"; import { AssistantStreamEvent, MessageStreamEvent, RunStepStreamEvent, RunStreamEvent, } from "../resources/beta/assistants.ts"; import { RunStep, RunStepDelta, ToolCall, ToolCallDelta, } from "../resources/beta/threads/runs/steps.ts"; import { ThreadCreateAndRunParamsBase, Threads, } from "../resources/beta/threads/threads.ts"; import { BaseEvents, EventStream } from "./EventStream.ts"; export type MessageDelta = Messages.MessageDelta; export interface AssistantStreamEvents extends BaseEvents { run: (run: Run) => void; //New event structure messageCreated: (message: Message) => void; messageDelta: (message: MessageDelta, snapshot: Message) => void; messageDone: (message: Message) => void; runStepCreated: (runStep: RunStep) => void; runStepDelta: (delta: RunStepDelta, snapshot: Runs.RunStep) => void; runStepDone: (runStep: Runs.RunStep, snapshot: Runs.RunStep) => void; toolCallCreated: (toolCall: ToolCall) => void; toolCallDelta: (delta: ToolCallDelta, snapshot: ToolCall) => void; toolCallDone: (toolCall: ToolCall) => void; textCreated: (content: Text) => void; textDelta: (delta: TextDelta, snapshot: Text) => void; textDone: (content: Text, snapshot: Message) => void; //No created or delta as this is not streamed imageFileDone: (content: ImageFile, snapshot: Message) => void; event: (event: AssistantStreamEvent) => void; } export type ThreadCreateAndRunParamsBaseStream = & Omit<ThreadCreateAndRunParamsBase, "stream"> & { stream?: true; }; export type RunCreateParamsBaseStream = Omit<RunCreateParamsBase, "stream"> & { stream?: true; }; export type RunSubmitToolOutputsParamsStream = & Omit<RunSubmitToolOutputsParamsBase, "stream"> & { stream?: true; }; export class AssistantStream extends EventStream<AssistantStreamEvents> implements AsyncIterable<AssistantStreamEvent> { //Track all events in a single list for reference #events: AssistantStreamEvent[] = []; //Used to accumulate deltas //We are accumulating many types so the value here is not strict #runStepSnapshots: { [id: string]: Runs.RunStep } = {}; #messageSnapshots: { [id: string]: Message } = {}; #messageSnapshot: Message | undefined; #finalRun: Run | undefined; #currentContentIndex: number | undefined; #currentContent: MessageContent | undefined; #currentToolCallIndex: number | undefined; #currentToolCall: ToolCall | undefined; //For current snapshot methods #currentEvent: AssistantStreamEvent | undefined; #currentRunSnapshot: Run | undefined; #currentRunStepSnapshot: Runs.RunStep | undefined; [Symbol.asyncIterator](): AsyncIterator<AssistantStreamEvent> { const pushQueue: AssistantStreamEvent[] = []; const readQueue: { resolve: (chunk: AssistantStreamEvent | undefined) => void; reject: (err: unknown) => void; }[] = []; let done = false; //Catch all for passing along all events this.on("event", (event) => { const reader = readQueue.shift(); if (reader) { reader.resolve(event); } else { pushQueue.push(event); } }); this.on("end", () => { done = true; for (const reader of readQueue) { reader.resolve(undefined); } readQueue.length = 0; }); this.on("abort", (err) => { done = true; for (const reader of readQueue) { reader.reject(err); } readQueue.length = 0; }); this.on("error", (err) => { done = true; for (const reader of readQueue) { reader.reject(err); } readQueue.length = 0; }); return { next: async (): Promise<IteratorResult<AssistantStreamEvent>> => { if (!pushQueue.length) { if (done) { return { value: undefined, done: true }; } return new Promise<AssistantStreamEvent | undefined>(( resolve, reject, ) => readQueue.push({ resolve, reject })).then(( chunk, ) => (chunk ? { value: chunk, done: false } : { value: undefined, done: true }) ); } const chunk = pushQueue.shift()!; return { value: chunk, done: false }; }, return: async () => { this.abort(); return { value: undefined, done: true }; }, }; } static fromReadableStream(stream: ReadableStream): AssistantStream { const runner = new AssistantStream(); runner._run(() => runner._fromReadableStream(stream)); return runner; } protected async _fromReadableStream( readableStream: ReadableStream, options?: Core.RequestOptions, ): Promise<Run> { const signal = options?.signal; if (signal) { if (signal.aborted) this.controller.abort(); signal.addEventListener("abort", () => this.controller.abort()); } this._connected(); const stream = Stream.fromReadableStream<AssistantStreamEvent>( readableStream, this.controller, ); for await (const event of stream) { this.#addEvent(event); } if (stream.controller.signal?.aborted) { throw new APIUserAbortError(); } return this._addRun(this.#endRequest()); } toReadableStream(): ReadableStream { const stream = new Stream( this[Symbol.asyncIterator].bind(this), this.controller, ); return stream.toReadableStream(); } static createToolAssistantStream( threadId: string, runId: string, runs: Runs, params: RunSubmitToolOutputsParamsStream, options: RequestOptions | undefined, ) { const runner = new AssistantStream(); runner._run(() => runner._runToolAssistantStream(threadId, runId, runs, params, { ...options, headers: { ...options?.headers, "X-Stainless-Helper-Method": "stream" }, }) ); return runner; } protected async _createToolAssistantStream( run: Runs, threadId: string, runId: string, params: RunSubmitToolOutputsParamsStream, options?: Core.RequestOptions, ): Promise<Run> { const signal = options?.signal; if (signal) { if (signal.aborted) this.controller.abort(); signal.addEventListener("abort", () => this.controller.abort()); } const body: RunSubmitToolOutputsParamsStreaming = { ...params, stream: true, }; const stream = await run.submitToolOutputs(threadId, runId, body, { ...options, signal: this.controller.signal, }); this._connected(); for await (const event of stream) { this.#addEvent(event); } if (stream.controller.signal?.aborted) { throw new APIUserAbortError(); } return this._addRun(this.#endRequest()); } static createThreadAssistantStream( params: ThreadCreateAndRunParamsBaseStream, thread: Threads, options?: RequestOptions, ) { const runner = new AssistantStream(); runner._run(() => runner._threadAssistantStream(params, thread, { ...options, headers: { ...options?.headers, "X-Stainless-Helper-Method": "stream" }, }) ); return runner; } static createAssistantStream( threadId: string, runs: Runs, params: RunCreateParamsBaseStream, options?: RequestOptions, ) { const runner = new AssistantStream(); runner._run(() => runner._runAssistantStream(threadId, runs, params, { ...options, headers: { ...options?.headers, "X-Stainless-Helper-Method": "stream" }, }) ); return runner; } currentEvent(): AssistantStreamEvent | undefined { return this.#currentEvent; } currentRun(): Run | undefined { return this.#currentRunSnapshot; } currentMessageSnapshot(): Message | undefined { return this.#messageSnapshot; } currentRunStepSnapshot(): Runs.RunStep | undefined { return this.#currentRunStepSnapshot; } async finalRunSteps(): Promise<Runs.RunStep[]> { await this.done(); return Object.values(this.#runStepSnapshots); } async finalMessages(): Promise<Message[]> { await this.done(); return Object.values(this.#messageSnapshots); } async finalRun(): Promise<Run> { await this.done(); if (!this.#finalRun) throw Error("Final run was not received."); return this.#finalRun; } protected async _createThreadAssistantStream( thread: Threads, params: ThreadCreateAndRunParamsBase, options?: Core.RequestOptions, ): Promise<Run> { const signal = options?.signal; if (signal) { if (signal.aborted) this.controller.abort(); signal.addEventListener("abort", () => this.controller.abort()); } const body: RunCreateParamsStreaming = { ...params, stream: true }; const stream = await thread.createAndRun(body, { ...options, signal: this.controller.signal, }); this._connected(); for await (const event of stream) { this.#addEvent(event); } if (stream.controller.signal?.aborted) { throw new APIUserAbortError(); } return this._addRun(this.#endRequest()); } protected async _createAssistantStream( run: Runs, threadId: string, params: RunCreateParamsBase, options?: Core.RequestOptions, ): Promise<Run> { const signal = options?.signal; if (signal) { if (signal.aborted) this.controller.abort(); signal.addEventListener("abort", () => this.controller.abort()); } const body: RunCreateParamsStreaming = { ...params, stream: true }; const stream = await run.create(threadId, body, { ...options, signal: this.controller.signal, }); this._connected(); for await (const event of stream) { this.#addEvent(event); } if (stream.controller.signal?.aborted) { throw new APIUserAbortError(); } return this._addRun(this.#endRequest()); } #addEvent(event: AssistantStreamEvent) { if (this.ended) return; this.#currentEvent = event; this.#handleEvent(event); switch (event.event) { case "thread.created": //No action on this event. break; case "thread.run.created": case "thread.run.queued": case "thread.run.in_progress": case "thread.run.requires_action": case "thread.run.completed": case "thread.run.failed": case "thread.run.cancelling": case "thread.run.cancelled": case "thread.run.expired": this.#handleRun(event); break; case "thread.run.step.created": case "thread.run.step.in_progress": case "thread.run.step.delta": case "thread.run.step.completed": case "thread.run.step.failed": case "thread.run.step.cancelled": case "thread.run.step.expired": this.#handleRunStep(event); break; case "thread.message.created": case "thread.message.in_progress": case "thread.message.delta": case "thread.message.completed": case "thread.message.incomplete": this.#handleMessage(event); break; case "error": //This is included for completeness, but errors are processed in the SSE event processing so this should not occur throw new Error( "Encountered an error event in event processing - errors should be processed earlier", ); } } #endRequest(): Run { if (this.ended) { throw new OpenAIError(`stream has ended, this shouldn't happen`); } if (!this.#finalRun) throw Error("Final run has not been received"); return this.#finalRun; } #handleMessage(this: AssistantStream, event: MessageStreamEvent) { const [accumulatedMessage, newContent] = this.#accumulateMessage( event, this.#messageSnapshot, ); this.#messageSnapshot = accumulatedMessage; this.#messageSnapshots[accumulatedMessage.id] = accumulatedMessage; for (const content of newContent) { const snapshotContent = accumulatedMessage.content[content.index]; if (snapshotContent?.type == "text") { this._emit("textCreated", snapshotContent.text); } } switch (event.event) { case "thread.message.created": this._emit("messageCreated", event.data); break; case "thread.message.in_progress": break; case "thread.message.delta": this._emit("messageDelta", event.data.delta, accumulatedMessage); if (event.data.delta.content) { for (const content of event.data.delta.content) { //If it is text delta, emit a text delta event if (content.type == "text" && content.text) { let textDelta = content.text; let snapshot = accumulatedMessage.content[content.index]; if (snapshot && snapshot.type == "text") { this._emit("textDelta", textDelta, snapshot.text); } else { throw Error( "The snapshot associated with this text delta is not text or missing", ); } } if (content.index != this.#currentContentIndex) { //See if we have in progress content if (this.#currentContent) { switch (this.#currentContent.type) { case "text": this._emit( "textDone", this.#currentContent.text, this.#messageSnapshot, ); break; case "image_file": this._emit( "imageFileDone", this.#currentContent.image_file, this.#messageSnapshot, ); break; } } this.#currentContentIndex = content.index; } this.#currentContent = accumulatedMessage.content[content.index]; } } break; case "thread.message.completed": case "thread.message.incomplete": //We emit the latest content we were working on on completion (including incomplete) if (this.#currentContentIndex !== undefined) { const currentContent = event.data.content[this.#currentContentIndex]; if (currentContent) { switch (currentContent.type) { case "image_file": this._emit( "imageFileDone", currentContent.image_file, this.#messageSnapshot, ); break; case "text": this._emit( "textDone", currentContent.text, this.#messageSnapshot, ); break; } } } if (this.#messageSnapshot) { this._emit("messageDone", event.data); } this.#messageSnapshot = undefined; } } #handleRunStep(this: AssistantStream, event: RunStepStreamEvent) { const accumulatedRunStep = this.#accumulateRunStep(event); this.#currentRunStepSnapshot = accumulatedRunStep; switch (event.event) { case "thread.run.step.created": this._emit("runStepCreated", event.data); break; case "thread.run.step.delta": const delta = event.data.delta; if ( delta.step_details && delta.step_details.type == "tool_calls" && delta.step_details.tool_calls && accumulatedRunStep.step_details.type == "tool_calls" ) { for (const toolCall of delta.step_details.tool_calls) { if (toolCall.index == this.#currentToolCallIndex) { this._emit( "toolCallDelta", toolCall, accumulatedRunStep.step_details .tool_calls[toolCall.index] as ToolCall, ); } else { if (this.#currentToolCall) { this._emit("toolCallDone", this.#currentToolCall); } this.#currentToolCallIndex = toolCall.index; this.#currentToolCall = accumulatedRunStep.step_details.tool_calls[toolCall.index]; if (this.#currentToolCall) { this._emit("toolCallCreated", this.#currentToolCall); } } } } this._emit("runStepDelta", event.data.delta, accumulatedRunStep); break; case "thread.run.step.completed": case "thread.run.step.failed": case "thread.run.step.cancelled": case "thread.run.step.expired": this.#currentRunStepSnapshot = undefined; const details = event.data.step_details; if (details.type == "tool_calls") { if (this.#currentToolCall) { this._emit("toolCallDone", this.#currentToolCall as ToolCall); this.#currentToolCall = undefined; } } this._emit("runStepDone", event.data, accumulatedRunStep); break; case "thread.run.step.in_progress": break; } } #handleEvent(this: AssistantStream, event: AssistantStreamEvent) { this.#events.push(event); this._emit("event", event); } #accumulateRunStep(event: RunStepStreamEvent): Runs.RunStep { switch (event.event) { case "thread.run.step.created": this.#runStepSnapshots[event.data.id] = event.data; return event.data; case "thread.run.step.delta": let snapshot = this.#runStepSnapshots[event.data.id] as Runs.RunStep; if (!snapshot) { throw Error("Received a RunStepDelta before creation of a snapshot"); } let data = event.data; if (data.delta) { const accumulated = AssistantStream.accumulateDelta( snapshot, data.delta, ) as Runs.RunStep; this.#runStepSnapshots[event.data.id] = accumulated; } return this.#runStepSnapshots[event.data.id] as Runs.RunStep; case "thread.run.step.completed": case "thread.run.step.failed": case "thread.run.step.cancelled": case "thread.run.step.expired": case "thread.run.step.in_progress": this.#runStepSnapshots[event.data.id] = event.data; break; } if (this.#runStepSnapshots[event.data.id]) { return this.#runStepSnapshots[event.data.id] as Runs.RunStep; } throw new Error("No snapshot available"); } #accumulateMessage( event: AssistantStreamEvent, snapshot: Message | undefined, ): [Message, MessageContentDelta[]] { let newContent: MessageContentDelta[] = []; switch (event.event) { case "thread.message.created": //On creation the snapshot is just the initial message return [event.data, newContent]; case "thread.message.delta": if (!snapshot) { throw Error( "Received a delta with no existing snapshot (there should be one from message creation)", ); } let data = event.data; //If this delta does not have content, nothing to process if (data.delta.content) { for (const contentElement of data.delta.content) { if (contentElement.index in snapshot.content) { let currentContent = snapshot.content[contentElement.index]; snapshot.content[contentElement.index] = this.#accumulateContent( contentElement, currentContent, ); } else { snapshot.content[contentElement.index] = contentElement as MessageContent; // This is a new element newContent.push(contentElement); } } } return [snapshot, newContent]; case "thread.message.in_progress": case "thread.message.completed": case "thread.message.incomplete": //No changes on other thread events if (snapshot) { return [snapshot, newContent]; } else { throw Error( "Received thread message event with no existing snapshot", ); } } throw Error("Tried to accumulate a non-message event"); } #accumulateContent( contentElement: MessageContentDelta, currentContent: MessageContent | undefined, ): TextContentBlock | ImageFileContentBlock { return AssistantStream.accumulateDelta( currentContent as unknown as Record<any, any>, contentElement, ) as | TextContentBlock | ImageFileContentBlock; } static accumulateDelta( acc: Record<string, any>, delta: Record<string, any>, ): Record<string, any> { for (const [key, deltaValue] of Object.entries(delta)) { if (!acc.hasOwnProperty(key)) { acc[key] = deltaValue; continue; } let accValue = acc[key]; if (accValue === null || accValue === undefined) { acc[key] = deltaValue; continue; } // We don't accumulate these special properties if (key === "index" || key === "type") { acc[key] = deltaValue; continue; } // Type-specific accumulation logic if (typeof accValue === "string" && typeof deltaValue === "string") { accValue += deltaValue; } else if ( typeof accValue === "number" && typeof deltaValue === "number" ) { accValue += deltaValue; } else if (Core.isObj(accValue) && Core.isObj(deltaValue)) { accValue = this.accumulateDelta( accValue as Record<string, any>, deltaValue as Record<string, any>, ); } else if (Array.isArray(accValue) && Array.isArray(deltaValue)) { if ( accValue.every((x) => typeof x === "string" || typeof x === "number") ) { accValue.push(...deltaValue); // Use spread syntax for efficient addition continue; } for (const deltaEntry of deltaValue) { if (!Core.isObj(deltaEntry)) { throw new Error( `Expected array delta entry to be an object but got: ${deltaEntry}`, ); } const index = deltaEntry["index"]; if (index == null) { console.error(deltaEntry); throw new Error( "Expected array delta entry to have an `index` property", ); } if (typeof index !== "number") { throw new Error( `Expected array delta entry \`index\` property to be a number but got ${index}`, ); } const accEntry = accValue[index]; if (accEntry == null) { accValue.push(deltaEntry); } else { accValue[index] = this.accumulateDelta(accEntry, deltaEntry); } } continue; } else { throw Error( `Unhandled record type: ${key}, deltaValue: ${deltaValue}, accValue: ${accValue}`, ); } acc[key] = accValue; } return acc; } #handleRun(this: AssistantStream, event: RunStreamEvent) { this.#currentRunSnapshot = event.data; switch (event.event) { case "thread.run.created": break; case "thread.run.queued": break; case "thread.run.in_progress": break; case "thread.run.requires_action": case "thread.run.cancelled": case "thread.run.failed": case "thread.run.completed": case "thread.run.expired": this.#finalRun = event.data; if (this.#currentToolCall) { this._emit("toolCallDone", this.#currentToolCall); this.#currentToolCall = undefined; } break; case "thread.run.cancelling": break; } } protected _addRun(run: Run): Run { return run; } protected async _threadAssistantStream( params: ThreadCreateAndRunParamsBase, thread: Threads, options?: Core.RequestOptions, ): Promise<Run> { return await this._createThreadAssistantStream(thread, params, options); } protected async _runAssistantStream( threadId: string, runs: Runs, params: RunCreateParamsBase, options?: Core.RequestOptions, ): Promise<Run> { return await this._createAssistantStream(runs, threadId, params, options); } protected async _runToolAssistantStream( threadId: string, runId: string, runs: Runs, params: RunSubmitToolOutputsParamsStream, options?: Core.RequestOptions, ): Promise<Run> { return await this._createToolAssistantStream( runs, threadId, runId, params, options, ); } }