k8s-bench/main.go (456 lines of code) (raw):

// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "context" "encoding/json" "flag" "fmt" "os" "path/filepath" "sort" "strings" "time" "github.com/GoogleCloudPlatform/kubectl-ai/k8s-bench/pkg/model" "sigs.k8s.io/yaml" ) type Task struct { Setup string `json:"setup,omitempty"` Verifier string `json:"verifier,omitempty"` Cleanup string `json:"cleanup,omitempty"` Difficulty string `json:"difficulty"` Disabled bool `json:"disabled,omitempty"` Expect []Expectation `json:"expect,omitempty"` Script []ScriptStep `json:"script,omitempty"` // Isolation can be set to automatically create an isolated cluster // TODO: support namespaces also Isolation IsolationMode `json:"isolation,omitempty"` } type IsolationMode string const ( // IsolationModeCluster will create a cluster for the task evaluation. IsolationModeCluster IsolationMode = "cluster" ) type ScriptStep struct { Prompt string `json:"prompt"` } type Expectation struct { Contains string `json:"contains,omitempty"` } type EvalConfig struct { LLMConfigs []model.LLMConfig KubeConfig string TasksDir string TaskPattern string AgentBin string OutputDir string } type AnalyzeConfig struct { InputDir string OutputFormat string IgnoreToolUseShim bool } func expandPath(path string) (string, error) { if strings.HasPrefix(path, "~/") { home, err := os.UserHomeDir() if err != nil { return "", err } path = filepath.Join(home, path[2:]) } return filepath.Clean(os.ExpandEnv(path)), nil } func main() { // Print top-level usage if help is requested directly if len(os.Args) > 1 && (os.Args[1] == "--help" || os.Args[1] == "-h") { printUsage() return } ctx := context.Background() if err := run(ctx); err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) os.Exit(1) } } // Define custom usage text to show subcommands func printUsage() { fmt.Fprintf(os.Stderr, "Usage: %s <command> [options]\n\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Commands:\n") fmt.Fprintf(os.Stderr, " run Run evaluation benchmarks\n") fmt.Fprintf(os.Stderr, " analyze Analyze results from previous benchmark runs\n\n") fmt.Fprintf(os.Stderr, "Run '%s <command> --help' for more information on a command.\n", os.Args[0]) } type Strings []string func (f *Strings) String() string { return strings.Join(*f, ",") } func (f *Strings) Set(s string) error { *f = append(*f, s) return nil } func run(ctx context.Context) error { // No need to check for help flags here anymore // Default to "run" subcommand if no arguments provided subCommand := "run" if len(os.Args) > 1 && !strings.HasPrefix(os.Args[1], "-") { subCommand = os.Args[1] // Shift the arguments os.Args = append(os.Args[:1], os.Args[2:]...) } switch subCommand { case "run": return runEvals(ctx) case "analyze": return runAnalyze() default: printUsage() return fmt.Errorf("unknown subcommand: %s, valid options are 'run' or 'analyze'", subCommand) } } func runEvals(ctx context.Context) error { config := EvalConfig{ TasksDir: "./tasks", } // Set custom usage for 'run' subcommand flag.Usage = func() { fmt.Fprintf(os.Stderr, "Usage: %s run [options]\n\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Run K8s-bench evaluation benchmarks.\n\n") fmt.Fprintf(os.Stderr, "Options:\n") flag.PrintDefaults() } llmProvider := "gemini" modelList := "" defaultKubeConfig := "~/.kube/config" enableToolUseShim := true quiet := true flag.StringVar(&config.TasksDir, "tasks-dir", config.TasksDir, "Directory containing evaluation tasks") flag.StringVar(&config.KubeConfig, "kubeconfig", config.KubeConfig, "Path to kubeconfig file") flag.StringVar(&config.TaskPattern, "task-pattern", config.TaskPattern, "Pattern to filter tasks (e.g. 'pod' or 'redis')") flag.StringVar(&config.AgentBin, "agent-bin", config.AgentBin, "Path to kubernetes agent binary") flag.StringVar(&llmProvider, "llm-provider", llmProvider, "Specific LLM provider to evaluate (e.g. 'gemini' or 'ollama')") flag.StringVar(&modelList, "models", modelList, "Comma-separated list of models to evaluate (e.g. 'gemini-1.0,gemini-2.0')") flag.BoolVar(&enableToolUseShim, "enable-tool-use-shim", enableToolUseShim, "Enable tool use shim") flag.BoolVar(&quiet, "quiet", quiet, "Quiet mode (non-interactive mode)") flag.StringVar(&config.OutputDir, "output-dir", config.OutputDir, "Directory to write results to") flag.Parse() if config.KubeConfig == "" { config.KubeConfig = defaultKubeConfig } expandedKubeconfig, err := expandPath(config.KubeConfig) if err != nil { return fmt.Errorf("failed to expand kubeconfig path %q: %w", config.KubeConfig, err) } config.KubeConfig = expandedKubeconfig defaultModels := map[string][]string{ "gemini": {"gemini-2.5-pro-preview-03-25"}, } models := defaultModels if modelList != "" { if llmProvider == "" { return fmt.Errorf("--llm-provider is required when --models is specified") } modelSlice := strings.Split(modelList, ",") models = map[string][]string{ llmProvider: modelSlice, } } for llmProviderID, models := range models { var toolUseShimStr string if enableToolUseShim { toolUseShimStr = "shim_enabled" } else { toolUseShimStr = "shim_disabled" } for _, modelID := range models { id := fmt.Sprintf("%s-%s-%s", toolUseShimStr, llmProviderID, modelID) config.LLMConfigs = append(config.LLMConfigs, model.LLMConfig{ ID: id, ProviderID: llmProviderID, ModelID: modelID, EnableToolUseShim: enableToolUseShim, Quiet: quiet, }) } } if err := runEvaluation(ctx, config); err != nil { return fmt.Errorf("running evaluation: %w", err) } return nil } func runAnalyze() error { config := AnalyzeConfig{ InputDir: "", OutputFormat: "markdown", } // Set custom usage for 'analyze' subcommand flag.Usage = func() { fmt.Fprintf(os.Stderr, "Usage: %s analyze --input-dir <directory> [options]\n\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Analyze results from previous K8s-bench runs.\n\n") fmt.Fprintf(os.Stderr, "Options:\n") flag.PrintDefaults() } var resultsFilePath string flag.StringVar(&config.InputDir, "input-dir", config.InputDir, "Directory containing evaluation results (required)") flag.StringVar(&config.OutputFormat, "output-format", config.OutputFormat, "Output format (markdown or json)") flag.BoolVar(&config.IgnoreToolUseShim, "ignore-tool-use-shim", true, "Ignore tool use shim") flag.StringVar(&resultsFilePath, "results-filepath", "", "Optional file path to write results to") flag.Parse() // Check if input-dir is provided if config.InputDir == "" { flag.Usage() return fmt.Errorf("--input-dir is required") } // Check if output format is valid if config.OutputFormat != "markdown" && config.OutputFormat != "json" { return fmt.Errorf("invalid output format: %s, valid options are 'markdown' or 'json'", config.OutputFormat) } // Check if input directory exists if _, err := os.Stat(config.InputDir); os.IsNotExist(err) { return fmt.Errorf("input directory does not exist: %s", config.InputDir) } allResults, err := collectResults(config.InputDir) if err != nil { return fmt.Errorf("collecting results: %w", err) } // Format and output results if config.OutputFormat == "markdown" { if err := printMarkdownResults(config, allResults, resultsFilePath); err != nil { return fmt.Errorf("printing markdown results: %w", err) } } else { if err := printJSONResults(allResults, resultsFilePath); err != nil { return fmt.Errorf("printing JSON results: %w", err) } } return nil } func collectResults(inputDir string) ([]model.TaskResult, error) { var allResults []model.TaskResult // Walk through the directory structure to find all results.yaml files err := filepath.Walk(inputDir, func(path string, info os.FileInfo, err error) error { if err != nil { return err } // Only process results.yaml files if !info.IsDir() && info.Name() == "results.yaml" { // Read and parse the results file data, err := os.ReadFile(path) if err != nil { return fmt.Errorf("reading file %s: %w", path, err) } var result model.TaskResult if err := yaml.Unmarshal(data, &result); err != nil { return fmt.Errorf("parsing yaml from %s: %w", path, err) } allResults = append(allResults, result) } return nil }) if err != nil { return nil, err } return allResults, nil } func printMarkdownResults(config AnalyzeConfig, results []model.TaskResult, resultsFilePath string) error { // Create a buffer to hold the output var buffer strings.Builder buffer.WriteString("# K8s-bench Evaluation Results\n\n") allModels := make(map[string]bool) // Track all unique models for _, result := range results { allModels[result.LLMConfig.ModelID] = true } // Convert allModels map to a sorted slice models := make([]string, 0, len(allModels)) for model := range allModels { models = append(models, model) } sort.Strings(models) // Overall summary across all results totalCount := len(results) overallSuccessCount := 0 overallFailCount := 0 for _, result := range results { if strings.Contains(strings.ToLower(result.Result), "success") { overallSuccessCount++ } else { overallFailCount++ } } // --- Model Performance Summary --- buffer.WriteString("## Model Performance Summary\n\n") if config.IgnoreToolUseShim { // Simplified table ignoring shim status buffer.WriteString("| Model | Success | Fail |\n") buffer.WriteString("|-------|---------|------|\n") for _, model := range models { successCount := 0 failCount := 0 for _, result := range results { if result.LLMConfig.ModelID == model { if strings.Contains(strings.ToLower(result.Result), "success") { successCount++ } else { failCount++ } } } buffer.WriteString(fmt.Sprintf("| %s | %d | %d |\n", model, successCount, failCount)) } // Overall totals row buffer.WriteString("| **Total** |") buffer.WriteString(fmt.Sprintf(" %d | %d |\n\n", overallSuccessCount, overallFailCount)) } else { // Original table grouped by tool use shim status resultsByToolUseShim := make(map[string][]model.TaskResult) for _, result := range results { var toolUseShimStr string if result.LLMConfig.EnableToolUseShim { toolUseShimStr = "shim_enabled" } else { toolUseShimStr = "shim_disabled" } resultsByToolUseShim[toolUseShimStr] = append(resultsByToolUseShim[toolUseShimStr], result) } toolUseShimStrs := make([]string, 0, len(resultsByToolUseShim)) for toolUseShimStr := range resultsByToolUseShim { toolUseShimStrs = append(toolUseShimStrs, toolUseShimStr) } sort.Strings(toolUseShimStrs) // Create header row with success/fail columns for each toolUseShimStr buffer.WriteString("| Model |") for _, toolUseShimStr := range toolUseShimStrs { buffer.WriteString(fmt.Sprintf(" %s Success | %s Fail |", toolUseShimStr, toolUseShimStr)) } buffer.WriteString("\n|-------|") for range toolUseShimStrs { buffer.WriteString("------------|-----------|") } buffer.WriteString("\n") // Add a row for each model with success/fail counts for each strategy for _, model := range models { buffer.WriteString(fmt.Sprintf("| %s |", model)) for _, toolUseShimStr := range toolUseShimStrs { successCount := 0 failCount := 0 // Count success/fail for this model and toolUseShimStr for _, result := range resultsByToolUseShim[toolUseShimStr] { if result.LLMConfig.ModelID == model { if strings.Contains(strings.ToLower(result.Result), "success") { successCount++ } else { failCount++ } } } buffer.WriteString(fmt.Sprintf(" %d | %d |", successCount, failCount)) } buffer.WriteString("\n") } // Add a row showing overall totals for each toolUseShimStr buffer.WriteString("| **Total** |") for _, toolUseShimStr := range toolUseShimStrs { successCount := 0 failCount := 0 for _, result := range resultsByToolUseShim[toolUseShimStr] { if strings.Contains(strings.ToLower(result.Result), "success") { successCount++ } else { failCount++ } } buffer.WriteString(fmt.Sprintf(" %d | %d |", successCount, failCount)) } buffer.WriteString("\n\n") } // --- Overall Summary --- buffer.WriteString("## Overall Summary\n\n") buffer.WriteString(fmt.Sprintf("- Total Runs: %d\n", totalCount)) buffer.WriteString(fmt.Sprintf("- Overall Success: %d (%d%%)\n", overallSuccessCount, calculatePercentage(overallSuccessCount, totalCount))) buffer.WriteString(fmt.Sprintf("- Overall Fail: %d (%d%%)\n\n", overallFailCount, calculatePercentage(overallFailCount, totalCount))) // --- Detailed Results --- if config.IgnoreToolUseShim { // Group results by model for detailed view resultsByModel := make(map[string][]model.TaskResult) for _, result := range results { resultsByModel[result.LLMConfig.ModelID] = append(resultsByModel[result.LLMConfig.ModelID], result) } for _, model := range models { buffer.WriteString(fmt.Sprintf("## Model: %s\n\n", model)) buffer.WriteString("| Task | Provider | Result |\n") buffer.WriteString("|------|----------|--------|\n") modelSuccessCount := 0 modelFailCount := 0 modelResults := resultsByModel[model] modelTotalCount := len(modelResults) // Sort results within the model group for consistent output (e.g., by Task) sort.Slice(modelResults, func(i, j int) bool { return modelResults[i].Task < modelResults[j].Task }) for _, result := range modelResults { resultEmoji := "❌" // Default to failure if strings.Contains(strings.ToLower(result.Result), "success") { resultEmoji = "✅" modelSuccessCount++ } else { modelFailCount++ } buffer.WriteString(fmt.Sprintf("| %s | %s | %s %s |\n", result.Task, result.LLMConfig.ProviderID, resultEmoji, result.Result)) } // Add summary for this model buffer.WriteString(fmt.Sprintf("\n**%s Summary**\n\n", model)) buffer.WriteString(fmt.Sprintf("- Total: %d\n", modelTotalCount)) buffer.WriteString(fmt.Sprintf("- Success: %d (%d%%)\n", modelSuccessCount, calculatePercentage(modelSuccessCount, modelTotalCount))) buffer.WriteString(fmt.Sprintf("- Fail: %d (%d%%)\n\n", modelFailCount, calculatePercentage(modelFailCount, modelTotalCount))) } } else { // Original detailed results grouped by tool use shim status resultsByToolUseShim := make(map[string][]model.TaskResult) for _, result := range results { var toolUseShimStr string if result.LLMConfig.EnableToolUseShim { toolUseShimStr = "shim_enabled" } else { toolUseShimStr = "shim_disabled" } resultsByToolUseShim[toolUseShimStr] = append(resultsByToolUseShim[toolUseShimStr], result) } toolUseShimStrs := make([]string, 0, len(resultsByToolUseShim)) for toolUseShimStr := range resultsByToolUseShim { toolUseShimStrs = append(toolUseShimStrs, toolUseShimStr) } sort.Strings(toolUseShimStrs) for _, toolUseShimStr := range toolUseShimStrs { toolUseShimStrResults := resultsByToolUseShim[toolUseShimStr] // Print a header for this toolUseShimStr buffer.WriteString(fmt.Sprintf("## Tool Use: %s\n\n", toolUseShimStr)) // Create the table header buffer.WriteString("| Task | Provider | Model | Result |\n") buffer.WriteString("|------|----------|-------|--------|\n") // Track success and failure counts for this strategy successCount := 0 failCount := 0 totalCount := len(toolUseShimStrResults) // Sort results within the group for consistent output (e.g., by Task) sort.Slice(toolUseShimStrResults, func(i, j int) bool { if toolUseShimStrResults[i].LLMConfig.ModelID != toolUseShimStrResults[j].LLMConfig.ModelID { return toolUseShimStrResults[i].LLMConfig.ModelID < toolUseShimStrResults[j].LLMConfig.ModelID } return toolUseShimStrResults[i].Task < toolUseShimStrResults[j].Task }) // Add each result as a row in the table for _, result := range toolUseShimStrResults { resultEmoji := "❌" // Default to failure if strings.Contains(strings.ToLower(result.Result), "success") { resultEmoji = "✅" successCount++ } else { failCount++ } buffer.WriteString(fmt.Sprintf("| %s | %s | %s | %s %s |\n", result.Task, result.LLMConfig.ProviderID, result.LLMConfig.ModelID, resultEmoji, result.Result)) } // Add summary for this toolUseShimStr buffer.WriteString(fmt.Sprintf("\n**%s Summary**\n\n", toolUseShimStr)) buffer.WriteString(fmt.Sprintf("- Total: %d\n", totalCount)) buffer.WriteString(fmt.Sprintf("- Success: %d (%d%%)\n", successCount, calculatePercentage(successCount, totalCount))) buffer.WriteString(fmt.Sprintf("- Fail: %d (%d%%)\n\n", failCount, calculatePercentage(failCount, totalCount))) } } // --- Footer --- buffer.WriteString("---\n\n") buffer.WriteString(fmt.Sprintf("_Report generated on %s_\n", time.Now().Format("January 2, 2006 at 3:04 PM"))) // Get the final output output := buffer.String() // Write to file if path is provided, otherwise print to stdout if resultsFilePath != "" { if err := os.WriteFile(resultsFilePath, []byte(output), 0644); err != nil { return fmt.Errorf("writing to file %q: %w", resultsFilePath, err) } fmt.Printf("Results written to %s\n", resultsFilePath) } else { // Print to stdout only if no file path is specified fmt.Print(output) } return nil } func calculatePercentage(part, total int) int { if total == 0 { return 0 } return int((float64(part) / float64(total)) * 100) } func printJSONResults(results []model.TaskResult, resultsFilePath string) error { // Convert the results to JSON jsonData, err := json.MarshalIndent(results, "", " ") if err != nil { return fmt.Errorf("marshaling results to JSON: %w", err) } // Write to file if path is provided, otherwise print to stdout if resultsFilePath != "" { if err := os.WriteFile(resultsFilePath, jsonData, 0644); err != nil { return fmt.Errorf("writing to file %q: %w", resultsFilePath, err) } fmt.Printf("Results written to %s\n", resultsFilePath) } else { // Print to stdout only if no file path is specified fmt.Println(string(jsonData)) } return nil }