streamaccumulator.go (139 lines of code) (raw):

package openai import "github.com/openai/openai-go/shared/constant" // Helper to accumulate chunks from a stream type ChatCompletionAccumulator struct { // The up-to-date accumulation of model's responses ChatCompletion choiceChatCompletionStates []chatCompletionResponseState justFinished chatCompletionResponseState } type FinishedChatCompletionToolCall struct { ChatCompletionMessageToolCallFunction Index int Id string } type chatCompletionResponseState struct { state chatCompletionResponseStateEnum index int } type chatCompletionResponseStateEnum int const ( emptyResponseState chatCompletionResponseStateEnum = iota contentResponseState refusalResponseState toolResponseState finishedResponseState ) // AddChunk incorporates a chunk into the accumulation. Chunks must be added in order. // Returns false if the chunk could not be successfully accumulated. // // The ChatCompletion field JSON does not get accumulated. func (acc *ChatCompletionAccumulator) AddChunk(chunk ChatCompletionChunk) bool { acc.justFinished = chatCompletionResponseState{} if !acc.accumulateDelta(chunk) { return false } // only chunks with choices can cause finished events if len(chunk.Choices) == 0 { return true } chunkIndex := int(chunk.Choices[0].Index) acc.choiceChatCompletionStates = expandToFit(acc.choiceChatCompletionStates, chunkIndex) acc.justFinished = acc.choiceChatCompletionStates[chunkIndex].update(chunk) return true } // JustFinishedRefusal retrieves the chat completion refusal when it is known to have just been completed. // The content is "just completed" when the last added chunk no longer contains a content // delta. If the content is just completed, the content is returned and the boolean is true. Otherwise, // an empty string is returned and the boolean will be false. func (acc *ChatCompletionAccumulator) JustFinishedContent() (content string, ok bool) { if acc.justFinished.state == contentResponseState { return acc.Choices[0].Message.Content, true } return "", false } // JustFinishedRefusal retrieves the chat completion refusal when it is known to have just been completed. // The refusal is "just completed" when the last added chunk no longer contains a refusal // delta. If the refusal is just completed, the refusal is returned and the boolean is true. Otherwise, // an empty string is returned and the boolean will be false. func (acc *ChatCompletionAccumulator) JustFinishedRefusal() (refusal string, ok bool) { if acc.justFinished.state == refusalResponseState { return acc.Choices[0].Message.Refusal, true } return "", false } // JustFinishedToolCall retrieves a tool call when it is known to have just been completed. // A tool call is "just completed" when the last added chunk no longer contains a tool call // delta or contains a delta for a different tool call. If the tool call is just completed, // a FinishedChatCompletionToolCall is returned and the boolean is true. Otherwise, an empty // tool call is returned and the boolean will be false. // // You cannot rely on this with a stream that has ParallelToolCalls enabled. func (acc *ChatCompletionAccumulator) JustFinishedToolCall() (toolcall FinishedChatCompletionToolCall, ok bool) { if acc.justFinished.state == toolResponseState { f := acc.Choices[0].Message.ToolCalls[acc.justFinished.index].Function id := acc.Choices[0].Message.ToolCalls[acc.justFinished.index].ID return FinishedChatCompletionToolCall{ Id: id, Index: acc.justFinished.index, ChatCompletionMessageToolCallFunction: ChatCompletionMessageToolCallFunction{ Name: f.Name, Arguments: f.Arguments, }, }, true } return FinishedChatCompletionToolCall{}, false } // Concatenates a ChatCompletionChunk onto a ChatCompletion. Returns false and // does nothing if a mismatch is detected. // // Ignores the JSON field func (cc *ChatCompletion) accumulateDelta(chunk ChatCompletionChunk) bool { if len(cc.ID) == 0 { cc.ID = chunk.ID } else if cc.ID != chunk.ID { return false } for _, delta := range chunk.Choices { cc.Choices = expandToFit(cc.Choices, int(delta.Index)) choice := &cc.Choices[delta.Index] choice.Index = delta.Index choice.FinishReason = delta.FinishReason if delta.Delta.Role != "" { choice.Message.Role = constant.Assistant(delta.Delta.Role) } choice.Message.Content += delta.Delta.Content choice.Message.Refusal += delta.Delta.Refusal for j := range delta.Delta.ToolCalls { deltaTool := &delta.Delta.ToolCalls[j] choice.Message.ToolCalls = expandToFit(choice.Message.ToolCalls, int(deltaTool.Index)) tool := &choice.Message.ToolCalls[deltaTool.Index] if deltaTool.ID != "" { tool.ID = deltaTool.ID } if deltaTool.Type != "" { tool.Type = constant.Function(deltaTool.Type) } tool.Function.Name += deltaTool.Function.Name tool.Function.Arguments += deltaTool.Function.Arguments } choice.Logprobs.Content = append(choice.Logprobs.Content, delta.Logprobs.Content...) choice.Logprobs.Refusal = append(choice.Logprobs.Refusal, delta.Logprobs.Refusal...) } cc.Usage.CompletionTokens += chunk.Usage.CompletionTokens cc.Usage.PromptTokens += chunk.Usage.PromptTokens cc.Usage.TotalTokens += chunk.Usage.TotalTokens cc.Model = chunk.Model cc.Created = chunk.Created cc.SystemFingerprint = chunk.SystemFingerprint cc.ServiceTier = ChatCompletionServiceTier(chunk.ServiceTier) if chunk.Object == chunk.Object.Default() { cc.Object = cc.Object.Default() } return true } // Updates the internal response state and returns the previous state if // the state changed. This ensures that JustFinished events only fire once. func (prev *chatCompletionResponseState) update(chunk ChatCompletionChunk) (justFinished chatCompletionResponseState) { delta := chunk.Choices[0].Delta new := chatCompletionResponseState{} switch { case delta.JSON.Content.IsPresent(): new.state = contentResponseState case delta.JSON.Refusal.IsPresent(): new.state = refusalResponseState case delta.JSON.ToolCalls.IsPresent(): new.state = toolResponseState new.index = int(delta.ToolCalls[0].Index) default: new.state = finishedResponseState } if *prev != new { justFinished = *prev } *prev = new return } func expandToFit[T any](slice []T, index int) []T { if index < len(slice) { return slice } if index < cap(slice) { return slice[:index+1] } newSlice := make([]T, index+1) copy(newSlice, slice) return newSlice }