gollm/openai.go (276 lines of code) (raw):
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gollm
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/url"
"os"
openai "github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"k8s.io/klog/v2"
)
// Register the OpenAI provider factory on package initialization.
func init() {
if err := RegisterProvider("openai", newOpenAIClientFactory); err != nil {
klog.Fatalf("Failed to register openai provider: %v", err)
}
}
// newOpenAIClientFactory is the factory function for creating OpenAI clients.
func newOpenAIClientFactory(ctx context.Context, _ *url.URL) (Client, error) {
// The URL is not currently used for OpenAI config, relies on env vars.
return NewOpenAIClient(ctx)
}
// OpenAIClient implements the gollm.Client interface for OpenAI models.
type OpenAIClient struct {
client openai.Client
}
// Ensure OpenAIClient implements the Client interface.
var _ Client = &OpenAIClient{}
// NewOpenAIClient creates a new client for interacting with OpenAI.
// It reads the API key and optional endpoint from environment variables
// OPENAI_API_KEY and OPENAI_ENDPOINT.
func NewOpenAIClient(ctx context.Context) (*OpenAIClient, error) {
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
// The NewClient might handle this, but explicit check is safer
return nil, errors.New("OPENAI_API_KEY environment variable not set")
}
endpoint := os.Getenv("OPENAI_ENDPOINT")
if endpoint != "" {
klog.Infof("Using custom OpenAI endpoint: %s", endpoint)
return &OpenAIClient{
client: openai.NewClient(option.WithBaseURL(endpoint)),
}, nil
}
return &OpenAIClient{
client: openai.NewClient(),
}, nil
}
// Close cleans up any resources used by the client.
func (c *OpenAIClient) Close() error {
// No specific cleanup needed for the OpenAI client currently.
return nil
}
// StartChat starts a new chat session.
func (c *OpenAIClient) StartChat(systemPrompt, model string) Chat {
// Default to gpt-4o if no model is specified or if it doesn't look like a known OpenAI prefix
if model == "" {
model = "gpt-4o"
klog.V(1).Info("No model specified, defaulting to gpt-4o")
}
klog.V(1).Infof("Starting new OpenAI chat session with model: %s", model)
// Initialize history with system prompt if provided
history := []openai.ChatCompletionMessageParamUnion{}
if systemPrompt != "" {
history = append(history, openai.SystemMessage(systemPrompt))
}
return &openAIChatSession{
client: c.client, // Pass the client from OpenAIClient
history: history,
model: model,
// functionDefinitions and tools will be set later via SetFunctionDefinitions
}
}
// simpleCompletionResponse is a basic implementation of CompletionResponse.
type simpleCompletionResponse struct {
content string
}
// Response returns the completion content.
func (r *simpleCompletionResponse) Response() string {
return r.content
}
// UsageMetadata returns nil for now.
func (r *simpleCompletionResponse) UsageMetadata() any {
return nil
}
// GenerateCompletion sends a completion request to the OpenAI API.
func (c *OpenAIClient) GenerateCompletion(ctx context.Context, req *CompletionRequest) (CompletionResponse, error) {
klog.Infof("OpenAI GenerateCompletion called with model: %s", req.Model)
klog.V(1).Infof("Prompt:\n%s", req.Prompt)
// Use the Chat Completions API as shown in examples
chatReq := openai.ChatCompletionNewParams{
Model: openai.ChatModel(req.Model), // Use the model specified in the request
Messages: []openai.ChatCompletionMessageParamUnion{
// Assuming a simple user message structure for now
openai.UserMessage(req.Prompt),
},
}
completion, err := c.client.Chat.Completions.New(ctx, chatReq)
if err != nil {
return nil, fmt.Errorf("failed to generate OpenAI completion: %w", err)
}
// Check if there are choices and a message
if len(completion.Choices) == 0 || completion.Choices[0].Message.Content == "" {
return nil, errors.New("received an empty response from OpenAI")
}
// Return the content of the first choice
resp := &simpleCompletionResponse{
content: completion.Choices[0].Message.Content,
}
return resp, nil
}
// SetResponseSchema is not implemented yet.
func (c *OpenAIClient) SetResponseSchema(schema *Schema) error {
klog.Warning("OpenAIClient.SetResponseSchema is not implemented yet")
return nil
}
// ListModels is not implemented yet.
func (c *OpenAIClient) ListModels(ctx context.Context) ([]string, error) {
// TODO: Implement listing OpenAI models using c.client
klog.Warning("OpenAIClient.ListModels is not implemented yet")
return []string{}, nil
}
// --- Chat Session Implementation ---
type openAIChatSession struct {
client openai.Client
history []openai.ChatCompletionMessageParamUnion
model string
functionDefinitions []*FunctionDefinition // Stored in gollm format
tools []openai.ChatCompletionToolParam // Stored in OpenAI format
}
// Ensure openAIChatSession implements the Chat interface.
var _ Chat = (*openAIChatSession)(nil)
// SetFunctionDefinitions stores the function definitions and converts them to OpenAI format.
func (cs *openAIChatSession) SetFunctionDefinitions(defs []*FunctionDefinition) error {
cs.functionDefinitions = defs
cs.tools = nil // Clear previous tools
if len(defs) > 0 {
cs.tools = make([]openai.ChatCompletionToolParam, len(defs))
for i, gollmDef := range defs {
// Basic conversion, assuming schema is compatible or nil
var params openai.FunctionParameters
if gollmDef.Parameters != nil {
// NOTE: This assumes gollmDef.Parameters is directly marshalable to JSON
// that fits openai.FunctionParameters. May need refinement.
bytes, err := gollmDef.Parameters.ToRawSchema()
if err != nil {
return fmt.Errorf("failed to convert schema for function %s: %w", gollmDef.Name, err)
}
if err := json.Unmarshal(bytes, ¶ms); err != nil {
return fmt.Errorf("failed to unmarshal schema for function %s: %w", gollmDef.Name, err)
}
}
cs.tools[i] = openai.ChatCompletionToolParam{
Function: openai.FunctionDefinitionParam{
Name: gollmDef.Name,
Description: openai.String(gollmDef.Description),
Parameters: params,
},
}
}
}
klog.V(1).Infof("Set %d function definitions for OpenAI chat session", len(cs.functionDefinitions))
return nil
}
// Send sends the user message(s), appends to history, and gets the LLM response.
func (cs *openAIChatSession) Send(ctx context.Context, contents ...any) (ChatResponse, error) {
klog.V(1).InfoS("openAIChatSession.Send called", "model", cs.model, "history_len", len(cs.history))
// 1. Append user message(s) to history
for _, content := range contents {
switch c := content.(type) {
case string:
klog.V(2).Infof("Adding user message to history: %s", c)
cs.history = append(cs.history, openai.UserMessage(c))
case FunctionCallResult:
klog.V(2).Infof("Adding tool call result to history: Name=%s, ID=%s", c.Name, c.ID)
// Marshal the result map into a JSON string for the message content
resultJSON, err := json.Marshal(c.Result)
if err != nil {
klog.Errorf("Failed to marshal function call result: %v", err)
return nil, fmt.Errorf("failed to marshal function call result %q: %w", c.Name, err)
}
cs.history = append(cs.history, openai.ToolMessage(string(resultJSON), c.ID))
default:
// TODO: Handle other content types if necessary?
klog.Warningf("Unhandled content type in Send: %T", content)
return nil, fmt.Errorf("unhandled content type: %T", content)
}
}
// 2. Prepare the API request
chatReq := openai.ChatCompletionNewParams{
Model: openai.ChatModel(cs.model),
Messages: cs.history,
}
if len(cs.tools) > 0 {
chatReq.Tools = cs.tools
// chatReq.ToolChoice = openai.ToolChoiceAuto // Or specify if needed
}
// 3. Call the OpenAI API
klog.V(1).InfoS("Sending request to OpenAI Chat API", "model", cs.model, "messages", len(chatReq.Messages), "tools", len(chatReq.Tools))
completion, err := cs.client.Chat.Completions.New(ctx, chatReq)
if err != nil {
// TODO: Check if error is retryable using cs.IsRetryableError
klog.Errorf("OpenAI ChatCompletion API error: %v", err)
return nil, fmt.Errorf("OpenAI chat completion failed: %w", err)
}
klog.V(1).InfoS("Received response from OpenAI Chat API", "id", completion.ID, "choices", len(completion.Choices))
// 4. Process the response
if len(completion.Choices) == 0 {
klog.Warning("Received response with no choices from OpenAI")
return nil, errors.New("received empty response from OpenAI (no choices)")
}
// Add assistant's response (first choice) to history
assistantMsg := completion.Choices[0].Message
// Convert to param type before appending to history
cs.history = append(cs.history, assistantMsg.ToParam())
klog.V(2).InfoS("Added assistant message to history", "content_present", assistantMsg.Content != "", "tool_calls", len(assistantMsg.ToolCalls))
// Wrap the response
resp := &openAIChatResponse{
openaiCompletion: completion,
}
return resp, nil
}
// SendStreaming sends the user message(s) and returns an iterator for the LLM response stream.
// NOTE: Due to limitations in the openai-go v0.1.0-beta.10 library, this function
// simulates streaming by making a single non-streaming call and returning an iterator
// that yields the single response. This satisfies the agent's interface requirement.
func (cs *openAIChatSession) SendStreaming(ctx context.Context, contents ...any) (ChatResponseIterator, error) {
klog.V(1).InfoS("openAIChatSession.SendStreaming called (simulated)", "model", cs.model)
// Call the non-streaming Send method we implemented earlier
singleResponse, err := cs.Send(ctx, contents...)
if err != nil {
// Send already logs errors, just wrap it
return nil, fmt.Errorf("simulated streaming failed during non-streaming call: %w", err)
}
// Return an iterator function that yields the single response once.
var yielded bool
iteratorFunc := func(yield func(ChatResponse, error) bool) {
if !yielded {
yielded = true
// Yield the single response. If yield returns false, stop early.
if !yield(singleResponse, nil) {
return
}
}
// Subsequent calls do nothing, effectively ending the stream.
}
return iteratorFunc, nil
}
// IsRetryableError returns false for now.
func (cs *openAIChatSession) IsRetryableError(err error) bool {
// TODO: Implement actual retry logic if needed
return false
}
// --- Helper structs for ChatResponse interface ---
type openAIChatResponse struct {
openaiCompletion *openai.ChatCompletion
}
var _ ChatResponse = (*openAIChatResponse)(nil)
func (r *openAIChatResponse) UsageMetadata() any {
// Check if the main completion object and Usage exist
if r.openaiCompletion != nil && r.openaiCompletion.Usage.TotalTokens > 0 { // Check a field within Usage
return r.openaiCompletion.Usage
}
return nil
}
func (r *openAIChatResponse) Candidates() []Candidate {
if r.openaiCompletion == nil {
return nil
}
candidates := make([]Candidate, len(r.openaiCompletion.Choices))
for i, choice := range r.openaiCompletion.Choices {
candidates[i] = &openAICandidate{openaiChoice: &choice}
}
return candidates
}
type openAICandidate struct {
openaiChoice *openai.ChatCompletionChoice
}
var _ Candidate = (*openAICandidate)(nil)
func (c *openAICandidate) Parts() []Part {
// Check if the choice exists before accessing Message
if c.openaiChoice == nil {
return nil
}
// OpenAI message can have Content AND ToolCalls
var parts []Part
if c.openaiChoice.Message.Content != "" {
parts = append(parts, &openAIPart{content: c.openaiChoice.Message.Content})
}
if len(c.openaiChoice.Message.ToolCalls) > 0 {
parts = append(parts, &openAIPart{toolCalls: c.openaiChoice.Message.ToolCalls})
}
return parts
}
// String provides a simple string representation for logging/debugging.
func (c *openAICandidate) String() string {
if c.openaiChoice == nil {
return "<nil candidate>"
}
content := "<no content>"
if c.openaiChoice.Message.Content != "" {
content = c.openaiChoice.Message.Content
}
toolCalls := len(c.openaiChoice.Message.ToolCalls)
finishReason := string(c.openaiChoice.FinishReason)
return fmt.Sprintf("Candidate(FinishReason: %s, ToolCalls: %d, Content: %q)", finishReason, toolCalls, content)
}
type openAIPart struct {
content string
toolCalls []openai.ChatCompletionMessageToolCall // Correct type
}
var _ Part = (*openAIPart)(nil)
func (p *openAIPart) AsText() (string, bool) {
return p.content, p.content != ""
}
func (p *openAIPart) AsFunctionCalls() ([]FunctionCall, bool) {
if len(p.toolCalls) == 0 {
return nil, false
}
gollmCalls := make([]FunctionCall, len(p.toolCalls))
for i, tc := range p.toolCalls {
// Check if it's a function call by seeing if Function Name is populated
if tc.Function.Name == "" { // Adjusted check for function calls
klog.V(2).Infof("Skipping non-function tool call ID: %s", tc.ID)
continue
}
var args map[string]any
// Attempt to unmarshal arguments, ignore error for now if it fails
_ = json.Unmarshal([]byte(tc.Function.Arguments), &args)
gollmCalls[i] = FunctionCall{
ID: tc.ID, // Pass the Tool Call ID
Name: tc.Function.Name,
Arguments: args,
}
}
return gollmCalls, true
}