pkg/agent/conversation.go (406 lines of code) (raw):

// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package agent import ( "context" _ "embed" "encoding/json" "fmt" "html/template" "io" "os" "sort" "strings" "time" "github.com/GoogleCloudPlatform/kubectl-ai/gollm" "github.com/GoogleCloudPlatform/kubectl-ai/pkg/journal" "github.com/GoogleCloudPlatform/kubectl-ai/pkg/tools" "github.com/GoogleCloudPlatform/kubectl-ai/pkg/ui" "k8s.io/klog/v2" ) //go:embed systemprompt_template_default.txt var defaultSystemPromptTemplate string type Conversation struct { LLM gollm.Client // PromptTemplateFile allows specifying a custom template file PromptTemplateFile string Model string RemoveWorkDir bool MaxIterations int Kubeconfig string SkipPermissions bool Tools tools.Tools EnableToolUseShim bool // Recorder captures events for diagnostics Recorder journal.Recorder // doc is the document which renders the conversation doc *ui.Document llmChat gollm.Chat workDir string } func (s *Conversation) Init(ctx context.Context, doc *ui.Document) error { log := klog.FromContext(ctx) // Create a temporary working directory workDir, err := os.MkdirTemp("", "agent-workdir-*") if err != nil { log.Error(err, "Failed to create temporary working directory") return err } log.Info("Created temporary working directory", "workDir", workDir) systemPrompt, err := s.generatePrompt(ctx, defaultSystemPromptTemplate, PromptData{ Tools: s.Tools, EnableToolUseShim: s.EnableToolUseShim, }) if err != nil { return fmt.Errorf("generating system prompt: %w", err) } // Start a new chat session s.llmChat = gollm.NewRetryChat( s.LLM.StartChat(systemPrompt, s.Model), gollm.RetryConfig{ MaxAttempts: 3, InitialBackoff: 10 * time.Second, MaxBackoff: 60 * time.Second, BackoffFactor: 2, Jitter: true, }, ) if !s.EnableToolUseShim { var functionDefinitions []*gollm.FunctionDefinition for _, tool := range s.Tools.AllTools() { functionDefinitions = append(functionDefinitions, tool.FunctionDefinition()) } // Sort function definitions to help KV cache reuse sort.Slice(functionDefinitions, func(i, j int) bool { return functionDefinitions[i].Name < functionDefinitions[j].Name }) if err := s.llmChat.SetFunctionDefinitions(functionDefinitions); err != nil { return fmt.Errorf("setting function definitions: %w", err) } } s.workDir = workDir s.doc = doc return nil } func (c *Conversation) Close() error { if c.workDir != "" { if c.RemoveWorkDir { if err := os.RemoveAll(c.workDir); err != nil { klog.Warningf("error cleaning up directory %q: %v", c.workDir, err) } } } return nil } // RunOneRound executes a chat-based agentic loop with the LLM using function calling. func (a *Conversation) RunOneRound(ctx context.Context, query string) error { log := klog.FromContext(ctx) log.Info("Starting chat loop for query:", "query", query) // currChatContent tracks chat content that needs to be sent // to the LLM in each iteration of the agentic loop below var currChatContent []any // Set the initial message to start the conversation currChatContent = []any{query} currentIteration := 0 maxIterations := a.MaxIterations for currentIteration < maxIterations { log.Info("Starting iteration", "iteration", currentIteration) a.Recorder.Write(ctx, &journal.Event{ Timestamp: time.Now(), Action: "llm-chat", Payload: []any{currChatContent}, }) stream, err := a.llmChat.SendStreaming(ctx, currChatContent...) if err != nil { return err } // Clear our "response" now that we sent the last response currChatContent = nil if a.EnableToolUseShim { // convert the candidate response into a gollm.ChatResponse stream, err = candidateToShimCandidate(stream) if err != nil { return err } } // Process each part of the response // only applicable is not using tooluse shim var functionCalls []gollm.FunctionCall var agentTextBlock *ui.AgentTextBlock for response, err := range stream { if err != nil { return fmt.Errorf("reading streaming LLM response: %w", err) } if response == nil { // end of streaming response break } klog.Infof("response: %+v", response) a.Recorder.Write(ctx, &journal.Event{ Timestamp: time.Now(), Action: "llm-response", Payload: response, }) if len(response.Candidates()) == 0 { log.Error(nil, "No candidates in response") return fmt.Errorf("no candidates in LLM response") } candidate := response.Candidates()[0] for _, part := range candidate.Parts() { // Check if it's a text response if text, ok := part.AsText(); ok { log.Info("text response", "text", text) if agentTextBlock == nil { agentTextBlock = ui.NewAgentTextBlock() agentTextBlock.SetStreaming(true) a.doc.AddBlock(agentTextBlock) } agentTextBlock.AppendText(text) } // Check if it's a function call if calls, ok := part.AsFunctionCalls(); ok && len(calls) > 0 { log.Info("function calls", "calls", calls) functionCalls = append(functionCalls, calls...) } } } if agentTextBlock != nil { agentTextBlock.SetStreaming(false) } // TODO(droot): Run all function calls in parallel // (may have to specify in the prompt to make these function calls independent) for _, call := range functionCalls { toolCall, err := a.Tools.ParseToolInvocation(ctx, call.Name, call.Arguments) if err != nil { return fmt.Errorf("building tool call: %w", err) } s := toolCall.PrettyPrint() a.doc.AddBlock(ui.NewFunctionCallRequestBlock().SetText(fmt.Sprintf(" Running: %s\n", s))) // Ask for confirmation only if SkipPermissions is false AND the tool modifies resources. if !a.SkipPermissions && call.Arguments["modifies_resource"] != "no" { confirmationPrompt := ` Do you want to proceed ? 1) Yes 2) Yes, and don't ask me again 3) No` optionsBlock := ui.NewInputOptionBlock().SetPrompt(confirmationPrompt) optionsBlock.SetOptions([]string{"1", "2", "3"}) a.doc.AddBlock(optionsBlock) selectedChoice, err := optionsBlock.Observable().Wait() if err != nil { if err == io.EOF { // Use hit control-D, or was piping and we reached the end of stdin. // Not a "big" problem return nil } return fmt.Errorf("reading input: %w", err) } switch selectedChoice { case "1": // Proceed with the operation case "2": a.SkipPermissions = true case "3": a.doc.AddBlock(ui.NewAgentTextBlock().SetText("Operation was skipped.")) observation := fmt.Sprintf("User didn't approve running %q.\n", call.Name) currChatContent = append(currChatContent, observation) continue default: // This case should technically not be reachable due to AskForConfirmation loop err := fmt.Errorf("invalid confirmation choice: %q", selectedChoice) log.Error(err, "Invalid choice received from AskForConfirmation") a.doc.AddBlock(ui.NewErrorBlock().SetText("Invalid choice received. Cancelling operation.")) return err } } ctx := journal.ContextWithRecorder(ctx, a.Recorder) output, err := toolCall.InvokeTool(ctx, tools.InvokeToolOptions{ Kubeconfig: a.Kubeconfig, WorkDir: a.workDir, }) if err != nil { return fmt.Errorf("executing action: %w", err) } if a.EnableToolUseShim { observation := fmt.Sprintf("Result of running %q:\n%s", call.Name, output) currChatContent = append(currChatContent, observation) } else { result, err := tools.ToolResultToMap(output) if err != nil { return err } currChatContent = append(currChatContent, gollm.FunctionCallResult{ ID: call.ID, Name: call.Name, Result: result, }) } } // If no function calls were made, we're done if len(functionCalls) == 0 { log.Info("No function calls were made, so most likely the task is completed, so we're done.") return nil } currentIteration++ } // If we've reached the maximum number of iterations log.Info("Max iterations reached", "iterations", maxIterations) errorBlock := ui.NewErrorBlock().SetText(fmt.Sprintf("Sorry, couldn't complete the task after %d iterations.\n", maxIterations)) a.doc.AddBlock(errorBlock) return fmt.Errorf("max iterations reached") } // toResult converts an arbitrary result to a map[string]any func toResult(v any) (map[string]any, error) { b, err := json.Marshal(v) if err != nil { return nil, fmt.Errorf("converting result to json: %w", err) } m := make(map[string]any) if err := json.Unmarshal(b, &m); err != nil { return nil, fmt.Errorf("converting json result to map: %w", err) } return m, nil } // generateFromTemplate generates a prompt for LLM. It uses the prompt from the provides template file or default. func (a *Conversation) generatePrompt(_ context.Context, defaultPromptTemplate string, data PromptData) (string, error) { promptTemplate := defaultPromptTemplate if a.PromptTemplateFile != "" { content, err := os.ReadFile(a.PromptTemplateFile) if err != nil { return "", fmt.Errorf("error reading template file: %v", err) } promptTemplate = string(content) } tmpl, err := template.New("promptTemplate").Parse(promptTemplate) if err != nil { return "", fmt.Errorf("building template for prompt: %w", err) } var result strings.Builder err = tmpl.Execute(&result, &data) if err != nil { return "", fmt.Errorf("evaluating template for prompt: %w", err) } return result.String(), nil } // PromptData represents the structure of the data to be filled into the template. type PromptData struct { Query string Tools tools.Tools EnableToolUseShim bool } func (a *PromptData) ToolsAsJSON() string { var toolDefinitions []*gollm.FunctionDefinition for _, tool := range a.Tools.AllTools() { toolDefinitions = append(toolDefinitions, tool.FunctionDefinition()) } json, err := json.MarshalIndent(toolDefinitions, "", " ") if err != nil { return "" } return string(json) } func (a *PromptData) ToolNames() string { return strings.Join(a.Tools.Names(), ", ") } type ReActResponse struct { Thought string `json:"thought"` Answer string `json:"answer,omitempty"` Action *Action `json:"action,omitempty"` } type Action struct { Name string `json:"name"` Reason string `json:"reason"` Command string `json:"command"` ModifiesResource string `json:"modifies_resource"` } func extractJSON(s string) (string, bool) { const jsonBlockMarker = "```json" first := strings.Index(s, jsonBlockMarker) last := strings.LastIndex(s, "```") if first == -1 || last == -1 || first == last { return "", false } data := s[first+len(jsonBlockMarker) : last] return data, true } // parseReActResponse parses the LLM response into a ReActResponse struct // This function assumes the input contains exactly one JSON code block // formatted with ```json and ``` markers. The JSON block is expected to // contain a valid ReActResponse object. func parseReActResponse(input string) (*ReActResponse, error) { cleaned, found := extractJSON(input) if !found { return nil, fmt.Errorf("no JSON code block found in %q", cleaned) } cleaned = strings.ReplaceAll(cleaned, "\n", "") cleaned = strings.TrimSpace(cleaned) var reActResp ReActResponse if err := json.Unmarshal([]byte(cleaned), &reActResp); err != nil { return nil, fmt.Errorf("parsing JSON %q: %w", cleaned, err) } return &reActResp, nil } // toMap converts the value to a map, going via JSON func toMap(v any) (map[string]any, error) { j, err := json.Marshal(v) if err != nil { return nil, fmt.Errorf("converting %T to json: %w", v, err) } m := make(map[string]any) if err := json.Unmarshal(j, &m); err != nil { return nil, fmt.Errorf("converting json to map: %w", err) } return m, nil } func candidateToShimCandidate(iterator gollm.ChatResponseIterator) (gollm.ChatResponseIterator, error) { return func(yield func(gollm.ChatResponse, error) bool) { buffer := "" for response, err := range iterator { if err != nil { yield(nil, err) return } if len(response.Candidates()) == 0 { yield(nil, fmt.Errorf("no candidates in LLM response")) return } candidate := response.Candidates()[0] for _, part := range candidate.Parts() { if text, ok := part.AsText(); ok { buffer += text klog.Infof("text is %q", text) } else { yield(nil, fmt.Errorf("no text part found in candidate")) return } } if _, found := extractJSON(buffer); found { break } } if buffer == "" { yield(nil, nil) return } parsedReActResp, err := parseReActResponse(buffer) if err != nil { yield(nil, fmt.Errorf("parsing ReAct response %q: %w", buffer, err)) return } buffer = "" // TODO: any trailing text? yield(&ShimResponse{candidate: parsedReActResp}, nil) }, nil } type ShimResponse struct { candidate *ReActResponse } func (r *ShimResponse) UsageMetadata() any { return nil } func (r *ShimResponse) Candidates() []gollm.Candidate { return []gollm.Candidate{&ShimCandidate{candidate: r.candidate}} } type ShimCandidate struct { candidate *ReActResponse } func (c *ShimCandidate) String() string { return fmt.Sprintf("Thought: %s\nAnswer: %s\nAction: %s", c.candidate.Thought, c.candidate.Answer, c.candidate.Action) } func (c *ShimCandidate) Parts() []gollm.Part { var parts []gollm.Part if c.candidate.Thought != "" { parts = append(parts, &ShimPart{text: c.candidate.Thought}) } if c.candidate.Answer != "" { parts = append(parts, &ShimPart{text: c.candidate.Answer}) } if c.candidate.Action != nil { parts = append(parts, &ShimPart{action: c.candidate.Action}) } return parts } type ShimPart struct { text string action *Action } func (p *ShimPart) AsText() (string, bool) { return p.text, p.text != "" } func (p *ShimPart) AsFunctionCalls() ([]gollm.FunctionCall, bool) { if p.action != nil { functionCallArgs, err := toMap(p.action) if err != nil { return nil, false } delete(functionCallArgs, "name") // passed separately // delete(functionCallArgs, "reason") // delete(functionCallArgs, "modifies_resource") return []gollm.FunctionCall{ { Name: p.action.Name, Arguments: functionCallArgs, }, }, true } return nil, false }