gollm/gemini.go (463 lines of code) (raw):

// Copyright 2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package gollm import ( "bytes" "context" "encoding/json" "errors" "fmt" "iter" "net" "net/http" "net/url" "os" "os/exec" "strings" "google.golang.org/genai" "k8s.io/klog/v2" ) func init() { RegisterProvider("gemini", geminiFactory) RegisterProvider("vertexai", vertexaiViaGeminiFactory) } func geminiFactory(ctx context.Context, u *url.URL) (Client, error) { opt := GeminiAPIClientOptions{} return NewGeminiAPIClient(ctx, opt) } // GeminiAPIClientOptions are the options for the Gemini API client. type GeminiAPIClientOptions struct { // API Key for GenAI. Required for BackendGeminiAPI. APIKey string } // NewGeminiAPIClient builds a client for the Gemini API. func NewGeminiAPIClient(ctx context.Context, opt GeminiAPIClientOptions) (*GoogleAIClient, error) { apiKey := opt.APIKey if apiKey == "" { apiKey = os.Getenv("GEMINI_API_KEY") } if apiKey == "" { return nil, fmt.Errorf("GEMINI_API_KEY environment variable not set") } cc := &genai.ClientConfig{ APIKey: apiKey, Backend: genai.BackendGeminiAPI, } client, err := genai.NewClient(ctx, cc) if err != nil { return nil, fmt.Errorf("building gemini client: %w", err) } return &GoogleAIClient{ client: client, }, nil } // VertexAIClientOptions are the options for using the VertexAPI. type VertexAIClientOptions struct { // GCP Project ID for Vertex AI. Required for BackendVertexAI. Project string // GCP Location/Region for Vertex AI. Required for BackendVertexAI. See https://cloud.google.com/vertex-ai/docs/general/locations Location string } func vertexaiViaGeminiFactory(ctx context.Context, u *url.URL) (Client, error) { opt := VertexAIClientOptions{} return NewVertexAIClient(ctx, opt) } // findDefaultGCPProject gets the default GCP project ID from gcloud func findDefaultGCPProject(ctx context.Context) (string, error) { log := klog.FromContext(ctx) // First check env vars // GOOGLE_CLOUD_PROJECT is the default for the genai library and a GCP convention projectID := "" for _, env := range []string{"GOOGLE_CLOUD_PROJECT"} { if v := os.Getenv(env); v != "" { projectID = v log.Info("got project for vertex client from env var", "project", projectID, "env", env) return projectID, nil } } // Now check default project in gcloud { cmd := exec.CommandContext(ctx, "gcloud", "config", "get", "project") var stdout bytes.Buffer cmd.Stdout = &stdout if err := cmd.Run(); err != nil { return "", fmt.Errorf("cannot get project (using gcloud config get project): %w", err) } projectID = strings.TrimSpace(stdout.String()) if projectID != "" { log.Info("got project from gcloud config", "project", projectID) return projectID, nil } } return "", fmt.Errorf("project was not set in gcloud config (or GOOGLE_CLOUD_PROJECT env var)") } // NewVertexAIClient builds a client for the vertexai API. func NewVertexAIClient(ctx context.Context, opt VertexAIClientOptions) (*GoogleAIClient, error) { log := klog.FromContext(ctx) cc := &genai.ClientConfig{ // Project ID is loaded from the GOOGLE_CLOUD_PROJECT environment variable // Location/Region is loaded from either GOOGLE_CLOUD_LOCATION or GOOGLE_CLOUD_REGION environment variable Backend: genai.BackendVertexAI, Project: opt.Project, Location: opt.Location, } // ProjectID is required if cc.Project == "" { projectID, err := findDefaultGCPProject(ctx) if err != nil { return nil, fmt.Errorf("finding default GCP project ID: %w", err) } cc.Project = projectID } // Location is also required if cc.Location == "" { location := "" // Check well-known env vars for _, env := range []string{"GOOGLE_CLOUD_LOCATION", "GOOGLE_CLOUD_REGION"} { if v := os.Getenv(env); v != "" { location = v log.Info("got location for vertex client from env var", "location", location, "env", env) break } } // Fallback to us-central1 if location == "" { location = "us-central1" log.Info("defaulted location for vertex client", "location", opt.Location) } cc.Location = location } client, err := genai.NewClient(ctx, cc) if err != nil { return nil, fmt.Errorf("building gemini client: %w", err) } return &GoogleAIClient{ client: client, }, nil } // GoogleAIClient is a client for the google AI APIs. // It implements the Client interface. type GoogleAIClient struct { client *genai.Client // responseSchema will constrain the output to match the given schema responseSchema *genai.Schema } var _ Client = &GoogleAIClient{} // ListModels lists the models available in the Gemini API. func (c *GoogleAIClient) ListModels(ctx context.Context) (modelNames []string, err error) { for model, err := range c.client.Models.All(ctx) { if err != nil { return nil, fmt.Errorf("error listing models: %w", err) } modelNames = append(modelNames, strings.TrimPrefix(model.Name, "models/")) } return modelNames, nil } // Close frees the resources used by the client. func (c *GoogleAIClient) Close() error { return nil } // SetResponseSchema constrains LLM responses to match the provided schema. // Calling with nil will clear the current schema. func (c *GoogleAIClient) SetResponseSchema(responseSchema *Schema) error { if responseSchema == nil { c.responseSchema = nil return nil } geminiSchema, err := toGeminiSchema(responseSchema) if err != nil { return err } c.responseSchema = geminiSchema return nil } func (c *GoogleAIClient) GenerateCompletion(ctx context.Context, request *CompletionRequest) (CompletionResponse, error) { log := klog.FromContext(ctx) var config *genai.GenerateContentConfig if c.responseSchema != nil { config = &genai.GenerateContentConfig{ ResponseSchema: c.responseSchema, ResponseMIMEType: "application/json", } } content := []*genai.Content{ {Role: "user", Parts: []*genai.Part{{Text: request.Prompt}}}, } log.Info("sending GenerateContent request to gemini", "content", content) result, err := c.client.Models.GenerateContent(ctx, request.Model, content, config) if err != nil { return nil, err } return &GeminiCompletionResponse{geminiResponse: result, text: result.Text()}, nil } // StartChat starts a new chat with the model. func (c *GoogleAIClient) StartChat(systemPrompt string, model string) Chat { // Some values that are recommended by aistudio temperature := float32(1.0) topK := float32(40) topP := float32(0.95) maxOutputTokens := int32(8192) chat := &GeminiChat{ model: model, client: c.client, genConfig: &genai.GenerateContentConfig{ SystemInstruction: &genai.Content{ Parts: []*genai.Part{ {Text: systemPrompt}, }, }, Temperature: &temperature, TopK: &topK, TopP: &topP, MaxOutputTokens: maxOutputTokens, ResponseMIMEType: "text/plain", }, history: []*genai.Content{}, } if chat.model == "gemma-3-27b-it" { // Note: gemma-3-27b-it does not allow system prompt // xref: https://discuss.ai.google.dev/t/gemma-3-missing-features-despite-announcement/71692 // TODO: remove this hack once gemma-3-27b-it supports system prompt chat.genConfig.SystemInstruction = nil chat.history = []*genai.Content{ {Role: "user", Parts: []*genai.Part{{Text: systemPrompt}}}, } } if c.responseSchema != nil { chat.genConfig.ResponseSchema = c.responseSchema chat.genConfig.ResponseMIMEType = "application/json" } return chat } // GeminiChat is a chat with the model. // It implements the Chat interface. type GeminiChat struct { model string client *genai.Client history []*genai.Content genConfig *genai.GenerateContentConfig } // SetFunctionDefinitions sets the function definitions for the chat. // This allows the LLM to call user-defined functions. func (c *GeminiChat) SetFunctionDefinitions(functionDefinitions []*FunctionDefinition) error { var genaiFunctionDeclarations []*genai.FunctionDeclaration for _, functionDefinition := range functionDefinitions { parameters, err := toGeminiSchema(functionDefinition.Parameters) if err != nil { return err } genaiFunctionDeclarations = append(genaiFunctionDeclarations, &genai.FunctionDeclaration{ Name: functionDefinition.Name, Description: functionDefinition.Description, Parameters: parameters, }) } c.genConfig.Tools = []*genai.Tool{ { FunctionDeclarations: genaiFunctionDeclarations, }, } return nil } // toGeminiSchema converts our generic Schema to a genai.Schema func toGeminiSchema(schema *Schema) (*genai.Schema, error) { ret := &genai.Schema{ Description: schema.Description, Required: schema.Required, } switch schema.Type { case TypeObject: ret.Type = genai.TypeObject case TypeString: ret.Type = genai.TypeString case TypeBoolean: ret.Type = genai.TypeBoolean case TypeInteger: ret.Type = genai.TypeInteger case TypeArray: ret.Type = genai.TypeArray default: return nil, fmt.Errorf("type %q not handled by genai.Schema", schema.Type) } if schema.Properties != nil { ret.Properties = make(map[string]*genai.Schema) for k, v := range schema.Properties { geminiValue, err := toGeminiSchema(v) if err != nil { return nil, err } ret.Properties[k] = geminiValue } } if schema.Items != nil { geminiValue, err := toGeminiSchema(schema.Items) if err != nil { return nil, err } ret.Items = geminiValue } return ret, nil } func (c *GeminiChat) partsToGemini(contents ...any) ([]*genai.Part, error) { var parts []*genai.Part for _, content := range contents { switch v := content.(type) { case string: parts = append(parts, genai.NewPartFromText(v)) case FunctionCallResult: parts = append(parts, &genai.Part{ FunctionResponse: &genai.FunctionResponse{ ID: v.ID, Name: v.Name, Response: v.Result, }, }) default: return nil, fmt.Errorf("unexpected type of content: %T", content) } } return parts, nil } // SendMessage sends a message to the model. // It returns a ChatResponse object containing the response from the model. func (c *GeminiChat) Send(ctx context.Context, contents ...any) (ChatResponse, error) { log := klog.FromContext(ctx) log.V(1).Info("sending LLM request", "user", contents) parts, err := c.partsToGemini(contents...) if err != nil { return nil, err } genaiContent := &genai.Content{ Role: "user", Parts: parts, } c.history = append(c.history, genaiContent) result, err := c.client.Models.GenerateContent(ctx, c.model, c.history, c.genConfig) if err != nil { return nil, fmt.Errorf("failed to generate content: %w", err) } if result == nil || len(result.Candidates) == 0 { return nil, fmt.Errorf("no response from Gemini") } c.history = append(c.history, result.Candidates[0].Content) geminiResponse := result log.V(1).Info("got LLM response", "response", geminiResponse) return &GeminiChatResponse{geminiResponse: geminiResponse}, nil } func (c *GeminiChat) SendStreaming(ctx context.Context, contents ...any) (ChatResponseIterator, error) { log := klog.FromContext(ctx) log.V(1).Info("sending LLM streaming request", "user", contents) parts, err := c.partsToGemini(contents...) if err != nil { return nil, err } genaiContent := &genai.Content{ Role: "user", Parts: parts, } c.history = append(c.history, genaiContent) stream := c.client.Models.GenerateContentStream(ctx, c.model, c.history, c.genConfig) return func(yield func(ChatResponse, error) bool) { next, stop := iter.Pull2(stream) defer stop() for { geminiResponse, err, ok := next() if !ok { return } var response *GeminiChatResponse if geminiResponse != nil { response = &GeminiChatResponse{geminiResponse: geminiResponse} if len(geminiResponse.Candidates) > 0 { // TODO: Should we try to coalesce parts when we have a streaming response? c.history = append(c.history, geminiResponse.Candidates[0].Content) } } if !yield(response, err) { return } } }, nil } // GeminiChatResponse is a response from the Gemini API. // It implements the ChatResponse interface. type GeminiChatResponse struct { geminiResponse *genai.GenerateContentResponse } var _ ChatResponse = &GeminiChatResponse{} func (r *GeminiChatResponse) MarshalJSON() ([]byte, error) { formatted := RecordChatResponse{ Raw: r.geminiResponse, } return json.Marshal(&formatted) } // String returns a string representation of the response. func (r *GeminiChatResponse) String() string { return r.geminiResponse.Text() } // UsageMetadata returns the usage metadata for the response. func (r *GeminiChatResponse) UsageMetadata() any { return r.geminiResponse.UsageMetadata } // Candidates returns the candidates for the response. func (r *GeminiChatResponse) Candidates() []Candidate { var candidates []Candidate for _, candidate := range r.geminiResponse.Candidates { candidates = append(candidates, &GeminiCandidate{candidate: candidate}) } return candidates } // GeminiCandidate is a candidate for the response. // It implements the Candidate interface. type GeminiCandidate struct { candidate *genai.Candidate } // String returns a string representation of the response. func (r *GeminiCandidate) String() string { var response strings.Builder response.WriteString("[") for i, parts := range r.Parts() { if i > 0 { response.WriteString(", ") } text, ok := parts.AsText() if ok { response.WriteString(text) } functionCalls, ok := parts.AsFunctionCalls() if ok { response.WriteString("functionCalls=[") for _, functionCall := range functionCalls { response.WriteString(fmt.Sprintf("%q(args=%v)", functionCall.Name, functionCall.Arguments)) } response.WriteString("]}") } } response.WriteString("]}") return response.String() } // Parts returns the parts of the candidate. func (r *GeminiCandidate) Parts() []Part { var parts []Part if r.candidate.Content != nil { for _, part := range r.candidate.Content.Parts { parts = append(parts, &GeminiPart{part: *part}) } } return parts } // GeminiPart is a part of a candidate. // It implements the Part interface. type GeminiPart struct { part genai.Part } // AsText returns the text of the part. func (p *GeminiPart) AsText() (string, bool) { if p.part.Text != "" { return p.part.Text, true } return "", false } // AsFunctionCalls returns the function calls of the part. func (p *GeminiPart) AsFunctionCalls() ([]FunctionCall, bool) { if p.part.FunctionCall != nil { return []FunctionCall{ { ID: p.part.FunctionCall.ID, Name: p.part.FunctionCall.Name, Arguments: p.part.FunctionCall.Args, }, }, true } return nil, false } type GeminiCompletionResponse struct { geminiResponse *genai.GenerateContentResponse text string } var _ CompletionResponse = &GeminiCompletionResponse{} func (r *GeminiCompletionResponse) MarshalJSON() ([]byte, error) { formatted := RecordCompletionResponse{ Text: r.text, Raw: r.geminiResponse, } return json.Marshal(&formatted) } func (r *GeminiCompletionResponse) Response() string { return r.text } func (r *GeminiCompletionResponse) UsageMetadata() any { return r.geminiResponse.UsageMetadata } func (r *GeminiCompletionResponse) String() string { return fmt.Sprintf("{text=%q}", r.text) } func (c *GeminiChat) IsRetryableError(err error) bool { if err == nil { return false } var apiErr genai.APIError if errors.As(err, &apiErr) { switch apiErr.Code { case http.StatusConflict, http.StatusTooManyRequests, http.StatusInternalServerError, http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: return true default: return false } } var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { return true } // Add other error checks specific to LLM clients if needed // e.g., if errors.Is(err, specificLLMRateLimitError) { return true } return false }