internal/database/database.go (144 lines of code) (raw):
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package database
import (
"context"
"database/sql"
"fmt"
"log"
"strings"
"sync"
"time"
"github.com/GoogleCloudPlatform/db-context-enrichment/internal/config"
)
// DB wraps the sql.DB instance and provides additional functionality
type DB struct {
*sql.DB
DialectHandler
}
// ForeignKeyInfo holds foreign key relationship details
type ForeignKeyInfo struct {
RefTable string
RefColumn string
}
// ColumnInfo holds column name and datatype
type ColumnInfo struct {
Name string
DataType string
}
// CommentData holds the data required to generate column comments.
// It's defined in the database package to avoid cyclic dependencies.
type CommentData struct {
TableName string
ColumnName string
ColumnDataType string
ExampleValues []string
DistinctCount int64
NullCount int64
Description string
}
// TableCommentData holds data for table comments.
type TableCommentData struct {
TableName string
Description string
}
// DialectHandler interface
type DialectHandler interface {
CreateCloudSQLPool(cfg config.DatabaseConfig) (*sql.DB, error)
CreateStandardPool(cfg config.DatabaseConfig) (*sql.DB, error)
QuoteIdentifier(name string) string
ListTables(db *DB) ([]string, error)
ListColumns(db *DB, tableName string) ([]ColumnInfo, error)
GetColumnMetadata(db *DB, tableName string, columnName string) (map[string]interface{}, error)
GenerateCommentSQL(db *DB, data *CommentData, enrichments map[string]bool) (string, error)
GetColumnComment(ctx context.Context, db *DB, tableName string, columnName string) (string, error)
GenerateDeleteCommentSQL(ctx context.Context, db *DB, tableName string, columnName string) (string, error)
// Added these methods for table-level comments
GenerateTableCommentSQL(db *DB, data *TableCommentData, enrichments map[string]bool) (string, error)
GetTableComment(ctx context.Context, db *DB, tableName string) (string, error)
GenerateDeleteTableCommentSQL(ctx context.Context, db *DB, tableName string) (string, error)
}
var (
globalConfig *config.DatabaseConfig
mu sync.RWMutex
)
// SetConfig sets the global database configuration
func SetConfig(cfg *config.DatabaseConfig) {
mu.Lock()
defer mu.Unlock()
globalConfig = cfg
}
// GetConfig returns the current database configuration
func GetConfig() *config.DatabaseConfig {
mu.RLock()
defer mu.RUnlock()
return globalConfig
}
var dialectHandlers = make(map[string]DialectHandler)
// RegisterDialectHandler registers a DialectHandler for a given dialect.
func RegisterDialectHandler(dialect string, handler DialectHandler) {
dialectHandlers[dialect] = handler
}
// createPool sets up common connection pool parameters and pings the database.
func createPool(db *sql.DB) (*sql.DB, error) {
db.SetMaxOpenConns(5)
db.SetMaxIdleConns(2)
db.SetConnMaxLifetime(time.Hour)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := db.PingContext(ctx); err != nil {
db.Close()
return nil, fmt.Errorf("database ping failed: %w", err)
}
return db, nil
}
// New creates a new database connection based on the specified dialect
func New(cfg config.DatabaseConfig) (*DB, error) {
var db *sql.DB
var err error
handler, ok := dialectHandlers[cfg.Dialect]
if !ok {
return nil, fmt.Errorf("unsupported dialect: %s", cfg.Dialect)
}
if strings.HasPrefix(cfg.Dialect, "cloudsql") && cfg.CloudSQLInstanceConnectionName != "" {
db, err = handler.CreateCloudSQLPool(cfg)
} else {
db, err = handler.CreateStandardPool(cfg)
}
if err != nil {
return nil, err
}
// Test the connection and setup pool
db, err = createPool(db)
if err != nil {
log.Println("ERROR: Failed to create connection pool:", err)
return nil, err
}
return &DB{
DB: db,
DialectHandler: handler,
}, nil
}
// ListTables returns all table names
func (db *DB) ListTables() ([]string, error) {
return db.DialectHandler.ListTables(db)
}
// ListColumns returns all column names for the given table
func (db *DB) ListColumns(tableName string) ([]ColumnInfo, error) { // Modified to return []ColumnInfo
return db.DialectHandler.ListColumns(db, tableName)
}
// GetColumnMetadata collects metadata for a specific column
func (db *DB) GetColumnMetadata(tableName string, columnName string) (map[string]interface{}, error) {
return db.DialectHandler.GetColumnMetadata(db, tableName, columnName)
}
// GenerateCommentSQL generates the SQL query to add comment to a column
func (db *DB) GenerateCommentSQL(data *CommentData, enrichments map[string]bool) (string, error) {
return db.DialectHandler.GenerateCommentSQL(db, data, enrichments)
}
// Close closes the database connection
func (db *DB) Close() error {
return db.DB.Close()
}
// GetColumnComment retrieves the comment for a specific column
func (db *DB) GetColumnComment(ctx context.Context, tableName string, columnName string) (string, error) {
return db.DialectHandler.GetColumnComment(ctx, db, tableName, columnName)
}
// GenerateDeleteCommentSQL generates the SQL query to delete gemini comment from a column
func (db *DB) GenerateDeleteCommentSQL(ctx context.Context, tableName string, columnName string) (string, error) {
return db.DialectHandler.GenerateDeleteCommentSQL(ctx, db, tableName, columnName)
}
// GenerateTableCommentSQL generates the SQL query to add a comment to a table.
func (db *DB) GenerateTableCommentSQL(data *TableCommentData, enrichments map[string]bool) (string, error) {
return db.DialectHandler.GenerateTableCommentSQL(db, data, enrichments)
}
// GetTableComment retrieves the comment for a specific table.
func (db *DB) GetTableComment(ctx context.Context, tableName string) (string, error) {
return db.DialectHandler.GetTableComment(ctx, db, tableName)
}
// GenerateDeleteTableCommentSQL generates the SQL query to delete a gemini comment from a table.
func (db *DB) GenerateDeleteTableCommentSQL(ctx context.Context, tableName string) (string, error) {
return db.DialectHandler.GenerateDeleteTableCommentSQL(ctx, db, tableName)
}
// ExecuteSQLStatements executes a batch of SQL statements from a string slice
func (db *DB) ExecuteSQLStatements(ctx context.Context, sqlStatements []string) error {
for _, sqlStmt := range sqlStatements {
_, err := db.ExecContext(ctx, sqlStmt)
if err != nil {
return fmt.Errorf("failed to execute SQL statement: %w\nStatement: %s", err, sqlStmt)
}
}
return nil
}