main.go (326 lines of code) (raw):

// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "bufio" "bytes" "context" "flag" "fmt" "io" "log" "os" "os/signal" "path/filepath" "strings" "syscall" "github.com/GoogleCloudPlatform/kubectl-ai/gollm" "github.com/GoogleCloudPlatform/kubectl-ai/pkg/agent" "github.com/GoogleCloudPlatform/kubectl-ai/pkg/journal" "github.com/GoogleCloudPlatform/kubectl-ai/pkg/tools" "github.com/GoogleCloudPlatform/kubectl-ai/pkg/ui" "k8s.io/klog/v2" "sigs.k8s.io/yaml" ) // Using the defaults from goreleaser as per https://goreleaser.com/cookbooks/using-main.version/ var ( version = "dev" commit = "none" date = "unknown" ) // models var geminiModels = []string{ "gemini-2.5-pro-preview-03-25", "gemini-2.0-flash", } func main() { ctx := context.Background() sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) go func() { sig := <-sigCh fmt.Fprintf(os.Stderr, "Received signal, shutting down... %s\n", sig) klog.Flush() os.Exit(0) }() if err := run(ctx); err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } } type Options struct { ProviderID string `json:"llmProvider,omitempty"` ModelID string `json:"model,omitempty"` // SkipPermissions is a flag to skip asking for confirmation before executing kubectl commands // that modifies resources in the cluster. SkipPermissions bool `json:"skipPermissions,omitempty"` // EnableToolUseShim is a flag to enable tool use shim. // TODO(droot): figure out a better way to discover if the model supports tool use // and set this automatically. EnableToolUseShim bool `json:"enableToolUseShim,omitempty"` // Quiet flag indicates if the agent should run in non-interactive mode. // It requires a query to be provided as a positional argument. Quiet bool `json:"quiet,omitempty"` MCPServer bool } func (o *Options) InitDefaults() { o.ProviderID = "gemini" o.ModelID = geminiModels[0] // by default, confirm before executing kubectl commands that modify resources in the cluster. o.SkipPermissions = false o.MCPServer = false // We now default to our strongest model (gemini-2.5-pro-exp-03-25) which supports tool use natively. // so we don't need shim. o.EnableToolUseShim = false } func (o *Options) LoadConfiguration(b []byte) error { if err := yaml.Unmarshal(b, &o); err != nil { return fmt.Errorf("parsing configuration: %w", err) } return nil } func (o *Options) LoadConfigurationFile() error { configPaths := []string{ "{CONFIG}/kubectl-ai/config.yaml", "{HOME}/.config/kubectl-ai/config.yaml", } for _, configPath := range configPaths { // Try to load configuration tokens := strings.Split(configPath, "/") for i, token := range tokens { if token == "{CONFIG}" { configDir, err := os.UserConfigDir() if err != nil { return fmt.Errorf("getting user config directory: %w", err) } tokens[i] = configDir } if token == "{HOME}" { homeDir, err := os.UserHomeDir() if err != nil { return fmt.Errorf("getting user home directory: %w", err) } tokens[i] = homeDir } } configPath = filepath.Join(tokens...) configBytes, err := os.ReadFile(configPath) if err != nil { if os.IsNotExist(err) { // ignore } else { fmt.Fprintf(os.Stderr, "warning: could not load defaults from %q: %v\n", configPath, err) } } if len(configBytes) > 0 { if err := o.LoadConfiguration(configBytes); err != nil { fmt.Fprintf(os.Stderr, "warning: error loading configuration from %q: %v\n", configPath, err) } } } return nil } func run(ctx context.Context) error { // Command line flags var opt Options opt.InitDefaults() if err := opt.LoadConfigurationFile(); err != nil { return fmt.Errorf("loading configuration file: %w", err) } maxIterations := flag.Int("max-iterations", 20, "maximum number of iterations agent will try before giving up") kubeconfig := flag.String("kubeconfig", "", "path to the kubeconfig file") promptTemplateFile := flag.String("prompt-template-file", "", "path to custom prompt template file") tracePath := flag.String("trace-path", "trace.log", "path to the trace file") removeWorkDir := flag.Bool("remove-workdir", false, "remove the temporary working directory after execution") flag.StringVar(&opt.ProviderID, "llm-provider", opt.ProviderID, "language model provider") flag.StringVar(&opt.ModelID, "model", opt.ModelID, "language model e.g. gemini-2.0-flash-thinking-exp-01-21, gemini-2.0-flash") flag.BoolVar(&opt.SkipPermissions, "skip-permissions", opt.SkipPermissions, "(dangerous) skip asking for confirmation before executing kubectl commands that modify resources") flag.BoolVar(&opt.MCPServer, "mcp-server", opt.MCPServer, "run in MCP server mode") flag.BoolVar(&opt.EnableToolUseShim, "enable-tool-use-shim", opt.EnableToolUseShim, "enable tool use shim") flag.BoolVar(&opt.Quiet, "quiet", opt.Quiet, "run in non-interactive mode, requires a query to be provided as a positional argument") // add commandline flags for logging klog.InitFlags(nil) flag.Set("logtostderr", "false") // disable logging to stderr flag.Set("log_file", filepath.Join(os.TempDir(), "kubectl-ai.log")) flag.Parse() defer klog.Flush() // Do this early, before the third-party code logs anything. redirectStdLogToKlog() // Handle kubeconfig with priority: command-line arg > env var > default path kubeconfigPath := *kubeconfig if kubeconfigPath == "" { // Check environment variable kubeconfigPath = os.Getenv("KUBECONFIG") if kubeconfigPath == "" { // Use default path homeDir, err := os.UserHomeDir() if err != nil { return fmt.Errorf("error getting user home directory: %w", err) } kubeconfigPath = filepath.Join(homeDir, ".kube", "config") } } if opt.MCPServer { workDir := filepath.Join(os.TempDir(), "kubectl-ai-mcp") if err := os.MkdirAll(workDir, 0755); err != nil { return fmt.Errorf("error creating work directory: %w", err) } mcpServer, err := newKubectlMCPServer(ctx, kubeconfigPath, tools.Default(), workDir) if err != nil { return fmt.Errorf("creating mcp server: %w", err) } return mcpServer.Serve(ctx) } // Check for positional arguments (after all flags are parsed) args := flag.Args() var queryFromCmd string // Check if stdin has data (is not a terminal) stdinInfo, _ := os.Stdin.Stat() stdinHasData := (stdinInfo.Mode() & os.ModeCharDevice) == 0 // Handle positional arguments and stdin if len(args) > 1 { return fmt.Errorf("only one positional argument (query) is allowed") } else if stdinHasData { // Read from stdin scanner := bufio.NewScanner(os.Stdin) var queryBuilder strings.Builder // If we have a positional argument, use it as a prefix if len(args) == 1 { queryBuilder.WriteString(args[0]) queryBuilder.WriteString("\n") } // Read the rest from stdin for scanner.Scan() { queryBuilder.WriteString(scanner.Text()) queryBuilder.WriteString("\n") } if err := scanner.Err(); err != nil { return fmt.Errorf("error reading from stdin: %w", err) } queryFromCmd = strings.TrimSpace(queryBuilder.String()) if queryFromCmd == "" { return fmt.Errorf("no query provided from stdin") } } else if len(args) == 1 { // Just use the positional argument as the query queryFromCmd = args[0] } klog.Info("Application started", "pid", os.Getpid()) llmClient, err := gollm.NewClient(ctx, opt.ProviderID) if err != nil { return fmt.Errorf("creating llm client: %w", err) } defer llmClient.Close() var recorder journal.Recorder if *tracePath != "" { fileRecorder, err := journal.NewFileRecorder(*tracePath) if err != nil { return fmt.Errorf("creating trace recorder: %w", err) } defer fileRecorder.Close() recorder = fileRecorder } else { // Ensure we always have a recorder, to avoid nil checks recorder = &journal.LogRecorder{} defer recorder.Close() } doc := ui.NewDocument() // since stdin is already consumed, we use TTY for taking input from user useTTYForInput := stdinHasData u, err := ui.NewTerminalUI(doc, recorder, useTTYForInput) if err != nil { return err } conversation := &agent.Conversation{ Model: opt.ModelID, Kubeconfig: kubeconfigPath, LLM: llmClient, MaxIterations: *maxIterations, PromptTemplateFile: *promptTemplateFile, Tools: tools.Default(), Recorder: recorder, RemoveWorkDir: *removeWorkDir, SkipPermissions: opt.SkipPermissions, EnableToolUseShim: opt.EnableToolUseShim, } err = conversation.Init(ctx, doc) if err != nil { return fmt.Errorf("starting conversation: %w", err) } defer conversation.Close() chatSession := session{ model: opt.ModelID, doc: doc, ui: u, conversation: conversation, LLM: llmClient, } if opt.Quiet { if queryFromCmd == "" { return fmt.Errorf("quiet mode requires a query to be provided as a positional argument") } return chatSession.answerQuery(ctx, queryFromCmd) } return chatSession.repl(ctx, queryFromCmd) } // session represents the user chat session (interactive/non-interactive both) type session struct { model string ui ui.UI doc *ui.Document conversation *agent.Conversation availableModels []string LLM gollm.Client } // repl is a read-eval-print loop for the chat session. func (s *session) repl(ctx context.Context, initialQuery string) error { query := initialQuery if query == "" { s.doc.AddBlock(ui.NewAgentTextBlock().SetText("Hey there, what can I help you with today?")) } for { if query == "" { input := ui.NewInputTextBlock() s.doc.AddBlock(input) userInput, err := input.Observable().Wait() if err != nil { if err == io.EOF { // Use hit control-D, or was piping and we reached the end of stdin. // Not a "big" problem return nil } return fmt.Errorf("reading input: %w", err) } query = strings.TrimSpace(userInput) } switch { case query == "": continue case query == "reset": err := s.conversation.Init(ctx, s.doc) if err != nil { return err } case query == "clear": s.ui.ClearScreen() case query == "exit" || query == "quit": // s.ui.RenderOutput(ctx, "Allright...bye.\n") return nil default: if err := s.answerQuery(ctx, query); err != nil { errorBlock := &ui.ErrorBlock{} errorBlock.SetText(fmt.Sprintf("Error: %v\n", err)) s.doc.AddBlock(errorBlock) } } // Reset query to empty string so that we prompt for input again query = "" } } func (s *session) listModels(ctx context.Context) ([]string, error) { if s.availableModels == nil { modelNames, err := s.LLM.ListModels(ctx) if err != nil { return nil, fmt.Errorf("listing models: %w", err) } s.availableModels = modelNames } return s.availableModels, nil } func (s *session) answerQuery(ctx context.Context, query string) error { switch { case query == "model": infoBlock := &ui.AgentTextBlock{} infoBlock.AppendText(fmt.Sprintf("Current model is `%s`\n", s.model)) s.doc.AddBlock(infoBlock) case query == "version": infoBlock := &ui.AgentTextBlock{} infoBlock.AppendText(fmt.Sprintf("Version: `%s`\n", version)) s.doc.AddBlock(infoBlock) case query == "models": models, err := s.listModels(ctx) if err != nil { return fmt.Errorf("listing models: %w", err) } infoBlock := &ui.AgentTextBlock{} infoBlock.AppendText("\n Available models:\n") infoBlock.AppendText(strings.Join(models, "\n")) s.doc.AddBlock(infoBlock) default: return s.conversation.RunOneRound(ctx, query) } return nil } // Redirect standard log output to our custom klog writer // This is primarily to suppress warning messages from // genai library https://github.com/googleapis/go-genai/blob/6ac4afc0168762dc3b7a4d940fc463cc1854f366/types.go#L1633 func redirectStdLogToKlog() { log.SetOutput(klogWriter{}) // Disable standard log's prefixes (date, time, file info) // because klog will add its own more detailed prefix. log.SetFlags(0) } // Define a custom writer that forwards messages to klog.Warning type klogWriter struct{} // Implement the io.Writer interface func (writer klogWriter) Write(data []byte) (n int, err error) { // We trim the trailing newline because klog adds its own. message := string(bytes.TrimSuffix(data, []byte("\n"))) klog.Warning(message) return len(data), nil }