internal/database/postgres/postgres.go (398 lines of code) (raw):

/* * Copyright 2025 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package postgres import ( "context" "database/sql" "fmt" "net" "os" "strings" "cloud.google.com/go/cloudsqlconn" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/stdlib" "github.com/lib/pq" "github.com/GoogleCloudPlatform/db-context-enrichment/internal/config" "github.com/GoogleCloudPlatform/db-context-enrichment/internal/database" ) // postgresHandler struct implements database.DialectHandler for PostgreSQL. type postgresHandler struct{} var _ database.DialectHandler = (*postgresHandler)(nil) // CreateCloudSQLPool for PostgreSQL func (h postgresHandler) CreateCloudSQLPool(cfg config.DatabaseConfig) (*sql.DB, error) { mustGetenv := func(k string, cfg config.DatabaseConfig) string { // Keep mustGetenv here as it's specific to connection v := "" switch k { case "user_name": v = cfg.User case "password": v = cfg.Password case "database_name": v = cfg.DBName case "instance_name": v = cfg.CloudSQLInstanceConnectionName case "PRIVATE_IP": if cfg.UsePrivateIP { v = "true" } } if v == "" { return os.Getenv(k) // Fallback to environment variable if not in Config } return v } dbUser := mustGetenv("user_name", cfg) dbPwd := mustGetenv("password", cfg) dbName := mustGetenv("database_name", cfg) instanceConnectionName := mustGetenv("instance_name", cfg) usePrivate := mustGetenv("PRIVATE_IP", cfg) dsn := fmt.Sprintf("user=%s password=%s database=%s", dbUser, dbPwd, dbName) config, err := pgx.ParseConfig(dsn) if err != nil { return nil, err } var opts []cloudsqlconn.Option if usePrivate != "" && strings.ToLower(usePrivate) != "false" && usePrivate != "0" { // Handle boolean-like env vars opts = append(opts, cloudsqlconn.WithDefaultDialOptions(cloudsqlconn.WithPrivateIP())) } d, err := cloudsqlconn.NewDialer(context.Background(), opts...) if err != nil { return nil, err } config.DialFunc = func(ctx context.Context, network, instance string) (net.Conn, error) { return d.Dial(ctx, instanceConnectionName) } dbURI := stdlib.RegisterConnConfig(config) dbPool, err := sql.Open("pgx", dbURI) if err != nil { return nil, fmt.Errorf("sql.Open: %w", err) } return dbPool, nil } // CreateStandardPool creates a standard PostgreSQL connection pool func (h postgresHandler) CreateStandardPool(cfg config.DatabaseConfig) (*sql.DB, error) { connStr := fmt.Sprintf( "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.DBName, cfg.SSLMode, ) dbPool, err := sql.Open("postgres", connStr) if err != nil { return nil, fmt.Errorf("error opening database: %w", err) } return dbPool, err } // QuoteIdentifier for PostgreSQL func (h postgresHandler) QuoteIdentifier(name string) string { // Replace any existing quotes with double quotes to escape them name = strings.Replace(name, `"`, `""`, -1) // Wrap the entire name in quotes return fmt.Sprintf(`"%s"`, name) } // ListTables for PostgreSQL func (h postgresHandler) ListTables(db *database.DB) ([]string, error) { query := ` SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE' ORDER BY table_name;` rows, err := db.Query(query) if err != nil { return nil, fmt.Errorf("error querying tables: %w", err) } defer rows.Close() var tables []string for rows.Next() { var tableName string if err := rows.Scan(&tableName); err != nil { return nil, fmt.Errorf("error scanning table name: %w", err) } tables = append(tables, tableName) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating table rows: %w", err) } return tables, nil } // ListColumns for PostgreSQL func (h postgresHandler) ListColumns(db *database.DB, tableName string) ([]database.ColumnInfo, error) { query := ` SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = 'public' AND table_name = $1 ORDER BY ordinal_position;` rows, err := db.Query(query, tableName) if err != nil { return nil, fmt.Errorf("error querying columns for table %s: %w", tableName, err) } defer rows.Close() var columns []database.ColumnInfo for rows.Next() { var colInfo database.ColumnInfo if err := rows.Scan(&colInfo.Name, &colInfo.DataType); err != nil { return nil, fmt.Errorf("error scanning column name and data type: %w", err) } columns = append(columns, colInfo) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating column rows: %w", err) } return columns, nil } // GetColumnMetadata for PostgreSQL func (h postgresHandler) GetColumnMetadata(db *database.DB, tableName string, columnName string) (map[string]interface{}, error) { // Quote table and column names to handle special characters and spaces quotedTable := h.QuoteIdentifier(tableName) // Use handler's QuoteIdentifier quotedColumn := h.QuoteIdentifier(columnName) // Get distinct count distinctQuery := fmt.Sprintf("SELECT COUNT(DISTINCT %s::text) FROM %s", quotedColumn, quotedTable) var distinctCount int err := db.QueryRow(distinctQuery).Scan(&distinctCount) if err != nil { return nil, fmt.Errorf("failed to get distinct count: %w", err) } // Get null count nullQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s IS NULL", quotedTable, quotedColumn) var nullCount int err = db.QueryRow(nullQuery).Scan(&nullCount) if err != nil { return nil, fmt.Errorf("failed to get null count: %w", err) } // Get example values (top 3) exampleQuery := fmt.Sprintf("SELECT DISTINCT %s::text FROM %s WHERE %s IS NOT NULL LIMIT 3", quotedColumn, quotedTable, quotedColumn) rows, err := db.Query(exampleQuery) if err != nil { return nil, fmt.Errorf("failed to get example values: %w", err) } defer rows.Close() var examples []string for rows.Next() { var value string if err := rows.Scan(&value); err != nil { return nil, fmt.Errorf("error scanning example value: %w", err) } examples = append(examples, value) } return map[string]interface{}{ "DistinctCount": distinctCount, "NullCount": nullCount, "ExampleValues": examples, }, nil } // formatExampleValues formats a slice of example values for SQL comment in PostgreSQL func (h postgresHandler) formatExampleValues(values []string) string { if len(values) == 0 { return "[]" } // Quote each value and join with comma quoted := make([]string, len(values)) for i, v := range values { quoted[i] = "\"" + v + "\"" } return fmt.Sprintf("[%s]", strings.Join(quoted, ", ")) } func (h postgresHandler) generateMetadataComment(data *database.CommentData, enrichments map[string]bool) string { if data == nil { return "" } if data.TableName == "" || data.ColumnName == "" { return "" } var commentParts []string // Helper function to check if enrichment is requested isEnrichmentRequested := func(enrichment string) bool { if len(enrichments) == 0 { return true // If no enrichments specified, include all } return enrichments[enrichment] } if isEnrichmentRequested("description") && data.Description != "" { commentParts = append(commentParts, fmt.Sprintf("*Important Note*: %s", data.Description)) } if isEnrichmentRequested("examples") && len(data.ExampleValues) > 0 { commentParts = append(commentParts, fmt.Sprintf("Example Values: %s", h.formatExampleValues(data.ExampleValues))) } if isEnrichmentRequested("distinct_values") { commentParts = append(commentParts, fmt.Sprintf("Count Distinct Values: %d", data.DistinctCount)) } if isEnrichmentRequested("null_count") { commentParts = append(commentParts, fmt.Sprintf("Count Null: %d", data.NullCount)) } return strings.Join(commentParts, " | ") } func (h postgresHandler) generateTableMetadataComment(data *database.TableCommentData, enrichments map[string]bool) string { if data == nil || data.TableName == "" { return "" } var commentParts []string isEnrichmentRequested := func(enrichment string) bool { if len(enrichments) == 0 { return true } return enrichments[enrichment] } if isEnrichmentRequested("description") && data.Description != "" { commentParts = append(commentParts, fmt.Sprintf("*Important Note*: %s", data.Description)) } return strings.Join(commentParts, " | ") } func (h postgresHandler) mergeComments(existingComment string, newMetadataComment string, updateExistingMode string) string { startTag := "<gemini>" endTag := "</gemini>" startIndex := strings.Index(existingComment, startTag) endIndex := strings.LastIndex(existingComment, endTag) comment := "" if startIndex == -1 || endIndex == -1 || endIndex <= startIndex { // No Gemini tag found, append new comment with tags if existingComment != "" { comment = existingComment + " " + startTag + newMetadataComment + endTag } else { comment = startTag + newMetadataComment + endTag // Just add new comment with tags } } else if updateExistingMode == "append" { currentGeminiComment := existingComment[startIndex+len(startTag) : endIndex] if currentGeminiComment != "" { comment = existingComment[:endIndex] + " " + newMetadataComment + endTag + existingComment[endIndex+len(endTag):] // Append to existing gemini comment } } else { // Gemini tag found, replace content inside tags prefix := existingComment[:startIndex] suffix := existingComment[endIndex+len(endTag):] comment = prefix + startTag + newMetadataComment + endTag + suffix } if comment == "" { comment = existingComment } if comment == "<gemini></gemini>" { comment = "" } return comment } // GenerateCommentSQL creates SQL statements for column comments in PostgreSQL func (h postgresHandler) GenerateCommentSQL(db *database.DB, data *database.CommentData, enrichments map[string]bool) (string, error) { if data == nil { return "", fmt.Errorf("comment data cannot be nil") } if data.TableName == "" || data.ColumnName == "" { return "", fmt.Errorf("table and column names cannot be empty") } config := database.GetConfig() // Retrieve global config // Pass the enrichments map to generateMetadataComment newMetadataComment := h.generateMetadataComment(data, enrichments) existingComment, err := h.GetColumnComment(context.Background(), db, data.TableName, data.ColumnName) if err != nil { return "", err } finalComment := h.mergeComments(existingComment, newMetadataComment, config.UpdateExistingMode) quotedComment := pq.QuoteLiteral(finalComment) if finalComment == "" { return "", nil } return fmt.Sprintf( "COMMENT ON COLUMN %s.%s IS %s;", h.QuoteIdentifier(data.TableName), h.QuoteIdentifier(data.ColumnName), quotedComment, ), nil } // GenerateDeleteCommentSQL for PostgreSQL func (h postgresHandler) GenerateDeleteCommentSQL(ctx context.Context, db *database.DB, tableName string, columnName string) (string, error) { if tableName == "" || columnName == "" { return "", fmt.Errorf("table and column names cannot be empty") } existingComment, err := h.GetColumnComment(ctx, db, tableName, columnName) if err != nil { return "", err } startTag := "<gemini>" endTag := "</gemini>" startIndex := strings.Index(existingComment, startTag) endIndex := strings.LastIndex(existingComment, endTag) var finalComment string if startIndex != -1 && endIndex != -1 && endIndex > startIndex { // Gemini tags found, remove content within tags prefix := existingComment[:startIndex] suffix := existingComment[endIndex+len(endTag):] finalComment = strings.TrimSpace(prefix + suffix) // Trim leading/trailing spaces after removing gemini part } else { // Gemini tags not found, or invalid tags, keep original comment finalComment = existingComment } quotedComment := pq.QuoteLiteral(finalComment) return fmt.Sprintf( "COMMENT ON COLUMN %s.%s IS %s;", h.QuoteIdentifier(tableName), h.QuoteIdentifier(columnName), quotedComment, ), nil } // GetColumnComment for PostgreSQL retrieves the comment for a specific column. func (h postgresHandler) GetColumnComment(ctx context.Context, db *database.DB, tableName string, columnName string) (string, error) { query := ` SELECT description FROM pg_catalog.pg_description JOIN pg_catalog.pg_class c ON pg_description.objoid = c.oid JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid JOIN pg_catalog.pg_attribute a ON pg_description.objoid = a.attrelid AND pg_description.objsubid = a.attnum WHERE n.nspname = 'public' -- Assuming public schema AND c.relname = $1 AND a.attname = $2; ` var comment sql.NullString // Use sql.NullString to handle NULL values err := db.QueryRowContext(ctx, query, tableName, columnName).Scan(&comment) if err != nil { if err == sql.ErrNoRows { return "", nil // No comment found, return empty string } return "", fmt.Errorf("failed to retrieve column comment: %w", err) } if comment.Valid { return comment.String, nil } return "", nil // Comment is NULL in DB, return empty string } // GenerateTableCommentSQL generates the SQL to comment on a table. func (h postgresHandler) GenerateTableCommentSQL(db *database.DB, data *database.TableCommentData, enrichments map[string]bool) (string, error) { if data == nil || data.TableName == "" { return "", fmt.Errorf("table comment data cannot be nil or empty") } config := database.GetConfig() newMetadataComment := h.generateTableMetadataComment(data, enrichments) existingComment, err := h.GetTableComment(context.Background(), db, data.TableName) if err != nil { return "", err } finalComment := h.mergeComments(existingComment, newMetadataComment, config.UpdateExistingMode) quotedComment := pq.QuoteLiteral(finalComment) if finalComment == "" { return "", nil } return fmt.Sprintf( "COMMENT ON TABLE %s IS %s;", h.QuoteIdentifier(data.TableName), quotedComment, ), nil } // GetTableComment retrieves the existing comment for a table. func (h postgresHandler) GetTableComment(ctx context.Context, db *database.DB, tableName string) (string, error) { query := ` SELECT pg_catalog.obj_description(c.oid, 'pg_class') FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = 'public' AND c.relname = $1; ` var comment sql.NullString err := db.QueryRowContext(ctx, query, tableName).Scan(&comment) if err != nil { if err == sql.ErrNoRows { return "", nil // No comment, return empty string. } return "", fmt.Errorf("failed to retrieve table comment: %w", err) } if comment.Valid { return comment.String, nil } return "", nil // Comment is NULL. } // GenerateDeleteTableCommentSQL generates SQL to remove the Gemini-generated part of a table comment. func (h postgresHandler) GenerateDeleteTableCommentSQL(ctx context.Context, db *database.DB, tableName string) (string, error) { if tableName == "" { return "", fmt.Errorf("table name cannot be empty") } existingComment, err := h.GetTableComment(ctx, db, tableName) if err != nil { return "", err } startTag := "<gemini>" endTag := "</gemini>" startIndex := strings.Index(existingComment, startTag) endIndex := strings.LastIndex(existingComment, endTag) var finalComment string if startIndex != -1 && endIndex != -1 && endIndex > startIndex { prefix := existingComment[:startIndex] suffix := existingComment[endIndex+len(endTag):] finalComment = strings.TrimSpace(prefix + suffix) } else { finalComment = existingComment // No gemini tags, keep original. } quotedComment := pq.QuoteLiteral(finalComment) return fmt.Sprintf("COMMENT ON TABLE %s IS %s;", h.QuoteIdentifier(tableName), quotedComment), nil } func init() { database.RegisterDialectHandler("postgres", postgresHandler{}) database.RegisterDialectHandler("cloudsqlpostgres", postgresHandler{}) }