cmd/root.go (102 lines of code) (raw):
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package cmd
import (
"fmt"
"log"
"os"
"strings"
"github.com/GoogleCloudPlatform/db-context-enrichment/internal/config"
"github.com/GoogleCloudPlatform/db-context-enrichment/internal/database"
_ "github.com/GoogleCloudPlatform/db-context-enrichment/internal/database/mysql"
_ "github.com/GoogleCloudPlatform/db-context-enrichment/internal/database/postgres"
_ "github.com/GoogleCloudPlatform/db-context-enrichment/internal/database/sqlserver"
"github.com/spf13/cobra"
)
var (
dryRun bool
updateExistingMode string
geminiAPIKey string
// Database connection flags
dialect string
host string
port int
username string
password string
dbName string
cloudSQLInstanceConnectionName string
cloudSQLUsePrivateIP bool
)
var rootCmd = &cobra.Command{
Use: "db_schema_enricher",
Short: "A tool to enrich database schema with metadata",
Long: `db_schema_enricher is a CLI tool that helps enrich database schemas
with metadata like column descriptions, example values, and foreign key relationships.`,
PersistentPreRunE: initFlagsAndConfig,
}
// initFlagsAndConfig initializes database configuration using command flags.
func initFlagsAndConfig(cmd *cobra.Command, args []string) error {
cfg := config.GetConfig()
dbCfg := cfg.Database
if cmd != nil {
dbCfg.Dialect = dialect
dbCfg.Host = host
dbCfg.Port = port
dbCfg.User = username
dbCfg.Password = password
dbCfg.DBName = dbName
dbCfg.CloudSQLInstanceConnectionName = cloudSQLInstanceConnectionName
dbCfg.UsePrivateIP = cloudSQLUsePrivateIP
dbCfg.UpdateExistingMode = strings.ToLower(updateExistingMode)
}
database.SetConfig(&dbCfg)
config.SetConfig(cfg)
if geminiAPIKey == "" {
geminiAPIKey = os.Getenv("GEMINI_API_KEY")
}
cfg.GeminiAPIKey = geminiAPIKey
config.SetConfig(cfg)
return nil
}
func validateDialect(dialect string) error {
supportedDialects := []string{"postgres", "cloudsqlpostgres", "mysql", "cloudsqlmysql", "sqlserver", "cloudsqlsqlserver"}
isValidDialect := false
for _, supportedDialect := range supportedDialects {
if dialect == supportedDialect {
isValidDialect = true
break
}
}
if !isValidDialect {
return fmt.Errorf("unsupported dialect: %s (only %s are supported)", dialect, strings.Join(supportedDialects, ", "))
}
return nil
}
func setupDatabase() (*database.DB, error) {
dbConfig := database.GetConfig()
if dbConfig == nil {
return nil, fmt.Errorf("database config is not initialized")
}
db, err := database.New(*dbConfig)
if err != nil {
log.Println("ERROR: Failed to connect to database:", err)
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
return db, nil
}
// Execute adds all child commands to the root command and sets flags appropriately.
func Execute() error {
return rootCmd.Execute()
}
func init() {
// Global persistent flags
rootCmd.PersistentFlags().BoolVar(&dryRun, "dry-run", true, "Enable dry-run mode (no database modifications)")
// Database connection flags
rootCmd.PersistentFlags().StringVar(&dialect, "dialect", "", fmt.Sprintf("Database dialect (%s) - MANDATORY", strings.Join([]string{"postgres", "mysql", "sqlserver", "cloudsqlpostgres", "cloudsqlmysql", "cloudsqlsqlserver"}, ", ")))
rootCmd.PersistentFlags().StringVar(&host, "host", "", "Database host - MANDATORY")
rootCmd.PersistentFlags().IntVar(&port, "port", 0, "Database port - MANDATORY")
rootCmd.PersistentFlags().StringVar(&username, "username", "", "Database username - MANDATORY")
rootCmd.PersistentFlags().StringVar(&password, "password", "", "Database password - MANDATORY")
rootCmd.PersistentFlags().StringVar(&dbName, "database", "", "Database name - MANDATORY")
rootCmd.PersistentFlags().StringVar(&cloudSQLInstanceConnectionName, "cloudsql-instance-connection-name", "", "Cloud SQL instance connection name (for Cloud SQL dialects) - MANDATORY for CloudSQL")
rootCmd.PersistentFlags().BoolVar(&cloudSQLUsePrivateIP, "cloudsql-use-private-ip", false, "Use private IP for Cloud SQL connection (Cloud SQL)")
rootCmd.PersistentFlags().StringVar(&updateExistingMode, "update_existing", "overwrite", "Mode to update existing comments ('overwrite' or 'append')")
// Gemini API Key flag
rootCmd.PersistentFlags().StringVar(&geminiAPIKey, "gemini-api-key", "", "Gemini API key (can also be set via GEMINI_API_KEY environment variable)")
// Add subcommands
rootCmd.AddCommand(addCommentsCmd)
rootCmd.AddCommand(getCommentsCmd)
rootCmd.AddCommand(deleteCommentsCmd)
rootCmd.AddCommand(applyCommentsCmd)
}