internal/database/sqlserver/sqlserver.go (452 lines of code) (raw):

/* * Copyright 2025 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package sqlserver import ( "context" "database/sql" "fmt" "net" "os" "strings" "cloud.google.com/go/cloudsqlconn" "github.com/GoogleCloudPlatform/db-context-enrichment/internal/config" "github.com/GoogleCloudPlatform/db-context-enrichment/internal/database" mssql "github.com/denisenkom/go-mssqldb" ) // sqlServerHandler struct implements database.DialectHandler for SQL Server. type sqlServerHandler struct{} var _ database.DialectHandler = (*sqlServerHandler)(nil) type csqlDialer struct { dialer *cloudsqlconn.Dialer connName string usePrivate bool } // DialContext adheres to the mssql.Dialer interface. func (c *csqlDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { var opts []cloudsqlconn.DialOption if c.usePrivate { opts = append(opts, cloudsqlconn.WithPrivateIP()) } return c.dialer.Dial(ctx, c.connName, opts...) } // CreateCloudSQLPool for SQL Server func (h sqlServerHandler) CreateCloudSQLPool(cfg config.DatabaseConfig) (*sql.DB, error) { mustGetenv := func(k string, cfg config.DatabaseConfig) string { v := "" switch k { case "user_name": v = cfg.User case "password": v = cfg.Password case "database_name": v = cfg.DBName case "instance_name": v = cfg.CloudSQLInstanceConnectionName case "PRIVATE_IP": if cfg.UsePrivateIP { v = "true" } } if v == "" { return os.Getenv(k) } return v } dbUser := mustGetenv("user_name", cfg) dbPwd := mustGetenv("password", cfg) dbName := mustGetenv("database_name", cfg) instanceConnectionName := mustGetenv("instance_name", cfg) usePrivate := mustGetenv("PRIVATE_IP", cfg) // WithLazyRefresh() Option is used to perform refresh // when needed, rather than on a scheduled interval. // This is recommended for serverless environments to // avoid background refreshes from throttling CPU. dialer, err := cloudsqlconn.NewDialer(context.Background(), cloudsqlconn.WithLazyRefresh()) if err != nil { return nil, fmt.Errorf("cloudsqlconn.NewDailer: %w", err) } connector, err := mssql.NewConnector(fmt.Sprintf("sqlserver://%s:%s@localhost:1433?database=%s&dial=cloudsqlconn&instance=%s", dbUser, dbPwd, dbName, instanceConnectionName)) if err != nil { return nil, fmt.Errorf("mssql.NewConnector: %w", err) } connector.Dialer = &csqlDialer{ dialer: dialer, connName: instanceConnectionName, usePrivate: usePrivate != "", } dbPool := sql.OpenDB(connector) return dbPool, nil } // CreateStandardPool creates a standard SQL Server connection pool func (h sqlServerHandler) CreateStandardPool(cfg config.DatabaseConfig) (*sql.DB, error) { port := cfg.Port if port == 0 { port = 1433 // Default SQL Server port } connStr := fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s", cfg.User, cfg.Password, cfg.Host, port, cfg.DBName) dbPool, err := sql.Open("sqlserver", connStr) if err != nil { return nil, fmt.Errorf("sql.Open (standard sqlserver): %w", err) } return dbPool, nil } // QuoteIdentifier for SQL Server // SQL Server uses square brackets [] for identifiers. // Double quotes "" are also accepted in some contexts but square brackets are standard and safer. func (h sqlServerHandler) QuoteIdentifier(name string) string { return fmt.Sprintf("[%s]", name) } // ListTables for SQL Server func (h sqlServerHandler) ListTables(db *database.DB) ([]string, error) { query := "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE' AND TABLE_CATALOG = DB_NAME()" rows, err := db.Query(query) if err != nil { return nil, fmt.Errorf("error querying tables: %w", err) } defer rows.Close() var tables []string for rows.Next() { var tableName string if err := rows.Scan(&tableName); err != nil { return nil, fmt.Errorf("error scanning table name: %w", err) } tables = append(tables, tableName) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating table rows: %w", err) } return tables, nil } // ListColumns for SQL Server func (h sqlServerHandler) ListColumns(db *database.DB, tableName string) ([]database.ColumnInfo, error) { query := fmt.Sprintf("SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '%s' AND TABLE_CATALOG = DB_NAME()", tableName) rows, err := db.Query(query) if err != nil { return nil, fmt.Errorf("error querying columns for table %s: %w", tableName, err) } defer rows.Close() var columns []database.ColumnInfo for rows.Next() { var colInfo database.ColumnInfo if err := rows.Scan(&colInfo.Name, &colInfo.DataType); err != nil { return nil, fmt.Errorf("error scanning column details: %w", err) } columns = append(columns, colInfo) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating column rows: %w", err) } return columns, nil } // GetColumnMetadata for SQL Server func (h sqlServerHandler) GetColumnMetadata(db *database.DB, tableName string, columnName string) (map[string]interface{}, error) { quotedTable := h.QuoteIdentifier(tableName) quotedColumn := h.QuoteIdentifier(columnName) distinctQuery := fmt.Sprintf("SELECT COUNT(DISTINCT %s) FROM %s", quotedColumn, quotedTable) var distinctCount int err := db.QueryRow(distinctQuery).Scan(&distinctCount) if err != nil { return nil, fmt.Errorf("failed to get distinct count: %w", err) } nullQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE %s IS NULL", quotedTable, quotedColumn) var nullCount int err = db.QueryRow(nullQuery).Scan(&nullCount) if err != nil { return nil, fmt.Errorf("failed to get null count: %w", err) } exampleQuery := fmt.Sprintf("SELECT TOP 3 %s FROM %s WHERE %s IS NOT NULL", quotedColumn, quotedTable, quotedColumn) rows, err := db.Query(exampleQuery) if err != nil { return nil, fmt.Errorf("failed to get example values: %w", err) } defer rows.Close() var examples []string for rows.Next() { var value string if err := rows.Scan(&value); err != nil { return nil, fmt.Errorf("error scanning example value: %w", err) } examples = append(examples, value) } return map[string]interface{}{ "DistinctCount": distinctCount, "NullCount": nullCount, "ExampleValues": examples, }, nil } // formatExampleValues formats a slice of example values for SQL comment in SQL Server func (h sqlServerHandler) formatExampleValues(values []string) string { if len(values) == 0 { return "[]" } quoted := make([]string, len(values)) for i, v := range values { // Use %q to add double quotes and escape internal double quotes and backslashes. quoted[i] = fmt.Sprintf("%q", v) } return fmt.Sprintf("[%s]", strings.Join(quoted, ", ")) // Format as array } func (h sqlServerHandler) generateMetadataComment(data *database.CommentData, enrichments map[string]bool) string { if data == nil { return "" } if data.TableName == "" || data.ColumnName == "" { return "" } var commentParts []string // Helper function to check if enrichment is requested isEnrichmentRequested := func(enrichment string) bool { if len(enrichments) == 0 { return true // If no enrichments specified, include all } return enrichments[enrichment] } if isEnrichmentRequested("description") && data.Description != "" { commentParts = append(commentParts, fmt.Sprintf("Description: %s", data.Description)) } if isEnrichmentRequested("examples") && len(data.ExampleValues) > 0 { commentParts = append(commentParts, fmt.Sprintf("Examples: %s", h.formatExampleValues(data.ExampleValues))) } if isEnrichmentRequested("distinct_values") { commentParts = append(commentParts, fmt.Sprintf("Distinct Values: %d", data.DistinctCount)) } if isEnrichmentRequested("null_count") { commentParts = append(commentParts, fmt.Sprintf("Null Count: %d", data.NullCount)) } return strings.Join(commentParts, " | ") } func (h sqlServerHandler) generateTableMetadataComment(data *database.TableCommentData, enrichments map[string]bool) string { if data == nil || data.TableName == "" { return "" } var commentParts []string isEnrichmentRequested := func(enrichment string) bool { if len(enrichments) == 0 { return true // If no enrichments specified, include all } return enrichments[enrichment] } if isEnrichmentRequested("description") && data.Description != "" { commentParts = append(commentParts, fmt.Sprintf("Description: %s", data.Description)) } return strings.Join(commentParts, " | ") } func (h sqlServerHandler) mergeComments(existingComment string, newMetadataComment string, updateExistingMode string) string { startTag := "<gemini>" endTag := "</gemini>" startIndex := strings.Index(existingComment, startTag) endIndex := strings.LastIndex(existingComment, endTag) comment := "" if startIndex == -1 || endIndex == -1 || endIndex <= startIndex { // No Gemini tag found, append new comment with tags if existingComment != "" { comment = existingComment + " " + startTag + newMetadataComment + endTag } else { comment = startTag + newMetadataComment + endTag // Just add new comment with tags } } else if updateExistingMode == "append" { currentGeminiComment := existingComment[startIndex+len(startTag) : endIndex] if currentGeminiComment != "" { comment = existingComment[:endIndex] + " " + newMetadataComment + endTag + existingComment[endIndex+len(endTag):] // Append to existing gemini comment } } else { // Gemini tag found, replace content inside tags prefix := existingComment[:startIndex] suffix := existingComment[endIndex+len(endTag):] comment = prefix + startTag + newMetadataComment + endTag + suffix } if comment == "" { comment = existingComment } if comment == "<gemini></gemini>" { comment = "" } return comment } // GenerateCommentSQL creates SQL statements for column comments in SQL Server func (h sqlServerHandler) GenerateCommentSQL(db *database.DB, data *database.CommentData, enrichments map[string]bool) (string, error) { if data == nil { return "", fmt.Errorf("metadata cannot be nil") } if data.TableName == "" || data.ColumnName == "" { return "", fmt.Errorf("table and column names cannot be empty") } config := database.GetConfig() // Retrieve global config newMetadataComment := h.generateMetadataComment(data, enrichments) existingComment, err := h.GetColumnComment(context.Background(), db, data.TableName, data.ColumnName) if err != nil { return "", err } finalComment := h.mergeComments(existingComment, newMetadataComment, config.UpdateExistingMode) // Pass updateExistingMode // Check if comment already exists to decide between sp_addextendedproperty and sp_updateextendedproperty query := ` SELECT CAST(value as NVARCHAR(MAX)) FROM fn_listextendedproperty (N'MS_Description', N'SCHEMA', N'dbo', N'TABLE', @tableName, N'COLUMN', @columnName) ` var existingCommentDB string err = db.QueryRow(query, sql.Named("tableName", data.TableName), sql.Named("columnName", data.ColumnName)).Scan(&existingCommentDB) if err != nil && err != sql.ErrNoRows { return "", fmt.Errorf("failed to check for existing comment: %w", err) } var sqlStmt string if err == sql.ErrNoRows { // No existing comment, use sp_addextendedproperty sqlStmt = fmt.Sprintf( "EXEC sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'dbo', N'TABLE', %s, N'COLUMN', %s;", finalComment, h.QuoteIdentifier(data.TableName), h.QuoteIdentifier(data.ColumnName), ) } else { // Existing comment found, use sp_updateextendedproperty sqlStmt = fmt.Sprintf( "EXEC sp_updateextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'dbo', N'TABLE', %s, N'COLUMN', %s;", finalComment, // Use the merged comment h.QuoteIdentifier(data.TableName), h.QuoteIdentifier(data.ColumnName), ) } return sqlStmt, nil } // GenerateDeleteCommentSQL for SQL Server func (h sqlServerHandler) GenerateDeleteCommentSQL(ctx context.Context, db *database.DB, tableName string, columnName string) (string, error) { if tableName == "" || columnName == "" { return "", fmt.Errorf("table and column names cannot be empty") } existingComment, err := h.GetColumnComment(ctx, db, tableName, columnName) if err != nil { return "", err } startTag := "<gemini>" endTag := "</gemini>" startIndex := strings.Index(existingComment, startTag) endIndex := strings.LastIndex(existingComment, endTag) var finalComment string if startIndex != -1 && endIndex != -1 && endIndex > startIndex { // Gemini tags found, remove content within tags prefix := existingComment[:startIndex] suffix := existingComment[endIndex+len(endTag):] finalComment = strings.TrimSpace(prefix + suffix) // Trim leading/trailing spaces after removing gemini part } else { // Gemini tags not found, or invalid tags, keep original comment finalComment = existingComment } escapedComment := strings.ReplaceAll(finalComment, "'", "''") // Generate SQL to update or add extended property (same as in GenerateCommentSQL, but with the modified comment) sqlStmt := fmt.Sprintf( "EXEC sp_updateextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'dbo', N'TABLE', %s, N'COLUMN', %s;", escapedComment, // Use the modified comment (Gemini part removed) h.QuoteIdentifier(tableName), h.QuoteIdentifier(columnName), ) return sqlStmt, nil } // GetColumnComment for SQL Server retrieves the comment for a specific column. func (h sqlServerHandler) GetColumnComment(ctx context.Context, db *database.DB, tableName string, columnName string) (string, error) { query := ` SELECT CAST(value as NVARCHAR(MAX)) FROM fn_listextendedproperty (N'MS_Description', N'SCHEMA', N'dbo', N'TABLE', @tableName, N'COLUMN', @columnName) ` var comment sql.NullString // Use sql.NullString to handle NULL values err := db.QueryRowContext(ctx, query, sql.Named("tableName", tableName), sql.Named("columnName", columnName)).Scan(&comment) if err != nil { if err == sql.ErrNoRows { return "", nil // No comment found, return empty string } return "", fmt.Errorf("failed to retrieve column comment: %w", err) } if comment.Valid { return comment.String, nil } else { return "", nil // Comment is NULL in DB, return nil, nil } } // GenerateTableCommentSQL generates the SQL to comment on a table. func (h sqlServerHandler) GenerateTableCommentSQL(db *database.DB, data *database.TableCommentData, enrichments map[string]bool) (string, error) { if data == nil || data.TableName == "" { return "", fmt.Errorf("table comment data cannot be nil or empty") } config := database.GetConfig() newMetadataComment := h.generateTableMetadataComment(data, enrichments) existingComment, err := h.GetTableComment(context.Background(), db, data.TableName) if err != nil { return "", err } finalComment := h.mergeComments(existingComment, newMetadataComment, config.UpdateExistingMode) if finalComment == "" { return "", nil } // Check if the extended property already exists for the table checkQuery := ` SELECT 1 FROM sys.extended_properties WHERE class = 1 -- Object or column AND class_desc = 'OBJECT_OR_COLUMN' AND major_id = OBJECT_ID(@tableName) AND minor_id = 0 -- Table level (minor_id is 0 for table) AND name = N'MS_Description'; ` var exists int err = db.QueryRow(checkQuery, sql.Named("tableName", data.TableName)).Scan(&exists) var sqlStmt string if err != nil && err != sql.ErrNoRows { // An actual error occurred during the check return "", fmt.Errorf("failed to check for existing table comment: %w", err) } else if err == sql.ErrNoRows { // No existing comment, use sp_addextendedproperty sqlStmt = fmt.Sprintf(` EXEC sp_addextendedproperty @name = N'MS_Description', @value = N'%s', @level0type = N'SCHEMA', @level0name = N'dbo', @level1type = N'TABLE', @level1name = %s;`, finalComment, h.QuoteIdentifier(data.TableName), ) } else { // Existing comment, use sp_updateextendedproperty sqlStmt = fmt.Sprintf(` EXEC sp_updateextendedproperty @name = N'MS_Description', @value = N'%s', @level0type = N'SCHEMA', @level0name = N'dbo', @level1type = N'TABLE', @level1name = %s;`, finalComment, h.QuoteIdentifier(data.TableName), ) } return sqlStmt, nil } // GetTableComment retrieves the existing comment for a table. func (h sqlServerHandler) GetTableComment(ctx context.Context, db *database.DB, tableName string) (string, error) { query := ` SELECT CAST(ep.value AS NVARCHAR(MAX)) FROM sys.extended_properties AS ep INNER JOIN sys.tables AS t ON ep.major_id = t.object_id INNER JOIN sys.schemas AS s ON t.schema_id = s.schema_id WHERE ep.minor_id = 0 AND ep.name = 'MS_Description' AND t.name = @tableName AND s.name = 'dbo'; ` var comment sql.NullString err := db.QueryRowContext(ctx, query, sql.Named("tableName", tableName)).Scan(&comment) if err != nil { if err == sql.ErrNoRows { return "", nil // No comment, return empty string. } return "", fmt.Errorf("failed to retrieve table comment: %w", err) } if comment.Valid { return comment.String, nil } return "", nil // Comment is NULL. } // GenerateDeleteTableCommentSQL generates SQL to remove the Gemini-generated part of a table comment. func (h sqlServerHandler) GenerateDeleteTableCommentSQL(ctx context.Context, db *database.DB, tableName string) (string, error) { if tableName == "" { return "", fmt.Errorf("table name cannot be empty") } existingComment, err := h.GetTableComment(ctx, db, tableName) if err != nil { return "", err } startTag := "<gemini>" endTag := "</gemini>" startIndex := strings.Index(existingComment, startTag) endIndex := strings.LastIndex(existingComment, endTag) var finalComment string if startIndex != -1 && endIndex != -1 && endIndex > startIndex { prefix := existingComment[:startIndex] suffix := existingComment[endIndex+len(endTag):] finalComment = strings.TrimSpace(prefix + suffix) } else { finalComment = existingComment // No gemini tags, keep original. } // Use sp_updateextendedproperty to update the comment (removing the Gemini part) sqlStmt := fmt.Sprintf(` EXEC sp_updateextendedproperty @name = N'MS_Description', @value = N'%s', @level0type = N'SCHEMA', @level0name = N'dbo', @level1type = N'TABLE', @level1name = %s;`, finalComment, h.QuoteIdentifier(tableName), ) return sqlStmt, nil } func init() { database.RegisterDialectHandler("sqlserver", sqlServerHandler{}) database.RegisterDialectHandler("cloudsqlsqlserver", sqlServerHandler{}) }