internal/enricher/enricher.go (734 lines of code) (raw):

/* * Copyright 2025 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package enricher import ( "bytes" "context" "fmt" "log" "os" "sort" "strings" "sync" "time" "github.com/GoogleCloudPlatform/db-context-enrichment/internal/database" "github.com/google/generative-ai-go/genai" "google.golang.org/api/option" "google.golang.org/grpc/status" ) // MetadataCollector collects metadata and applies comments. type MetadataCollector struct { db *database.DB // Database connection retryOpts *RetryOptions // Retry configuration DryRun bool // Dry run mode flag Metadata []*ColumnMetadata // Collected metadata TableMetadata []*TableMetadata // Collected table metadata mu sync.Mutex // Mutex to protect Metadata slice TableFilters map[string][]string // Table and column filters Enrichments map[string]bool // Enrichment types to include GeminiAPIKey string // Gemini API Key Model string // Gemini model name AdditionalContext string // Additional context from files schemaContext string // Database schema context (generated once) } // NewMetadataCollector creates a new MetadataCollector instance. func NewMetadataCollector(db *database.DB, retryOpts *RetryOptions, dryRun bool, geminiAPIKey string, additionalContext string, model string) *MetadataCollector { return &MetadataCollector{ db: db, retryOpts: retryOpts, DryRun: dryRun, Metadata: []*ColumnMetadata{}, TableMetadata: []*TableMetadata{}, TableFilters: make(map[string][]string), Enrichments: make(map[string]bool), GeminiAPIKey: geminiAPIKey, AdditionalContext: additionalContext, schemaContext: "", Model: model, } } // IsGeminiAPIKeyValid checks if the Gemini API key is valid by attempting to list models. func (mc *MetadataCollector) IsGeminiAPIKeyValid(ctx context.Context) error { if mc.GeminiAPIKey == "" { return fmt.Errorf("gemini api key is not configured") } client, err := genai.NewClient(ctx, option.WithAPIKey(mc.GeminiAPIKey)) if err != nil { return fmt.Errorf("failed to create Gemini client: %w", err) } defer client.Close() modelIterator := client.ListModels(ctx) // Iterate to trigger the API call and check for errors during iteration. _, err = modelIterator.Next() if err != nil { // Check if the error is related to authentication (invalid API key) if st, ok := status.FromError(err); ok { if st.Code() == 16 || // Unauthenticated st.Code() == 7 { // PermissionDenied (sometimes used for invalid keys) return fmt.Errorf("invalid Gemini API key: %w", err) } } return fmt.Errorf("failed to list Gemini models, potentially due to API key issue or other error: %w", err) } return nil } // Helper function to check if a specific enrichment is requested func (mc *MetadataCollector) isEnrichmentRequested(enrichment string) bool { if len(mc.Enrichments) == 0 { return true // If no enrichments are specified, include all } return mc.Enrichments[enrichment] } // CollectColumnMetadata gathers comprehensive metadata for a specific database column. func (mc *MetadataCollector) CollectColumnMetadata(ctx context.Context, tableName string, colInfo database.ColumnInfo, schemaContext string) (*ColumnMetadata, error) { if tableName == "" || colInfo.Name == "" { return nil, &ErrInvalidInput{ Msg: "table name and column name cannot be empty", } } var ( exampleValues []string distinctCount int64 nullCount int64 err error ) dbMetadata := make(map[string]interface{}) if mc.isEnrichmentRequested("examples") || mc.isEnrichmentRequested("distinct_values") || mc.isEnrichmentRequested("null_count") { dbMetadata, err = mc.db.GetColumnMetadata(tableName, colInfo.Name) if err != nil { return nil, &ErrQueryExecution{ Msg: "failed to get column metadata", Err: err, } } } if mc.isEnrichmentRequested("examples") { retrivedExampleValues, ok := dbMetadata["ExampleValues"].([]string) if !ok { log.Println("WARN: unexpected type for ExampleValues") } exampleValues = retrivedExampleValues if mc.GeminiAPIKey != "" { processedExampleValues, err := mc.generateExampleValuesWithGemini(ctx, colInfo, tableName, exampleValues) if err != nil { log.Printf("WARN: Failed to generate/process example values with Gemini for column %s.%s: %v, using original examples.", tableName, colInfo.Name, err) } else { exampleValues = processedExampleValues } } } if mc.isEnrichmentRequested("distinct_values") { distinctCountFloat, ok := dbMetadata["DistinctCount"].(int) if !ok { log.Println("WARN: unexpected type for DistinctCount") } distinctCount = int64(distinctCountFloat) } if mc.isEnrichmentRequested("null_count") { nullCountFloat, ok := dbMetadata["NullCount"].(int) if !ok { log.Println("WARN: unexpected type for NullCount") } nullCount = int64(nullCountFloat) } metadata := &ColumnMetadata{ Table: tableName, Column: colInfo.Name, DataType: colInfo.DataType, ExampleValues: exampleValues, DistinctCount: distinctCount, NullCount: nullCount, } // Generate description using Gemini if API key is available and "description" enrichment is requested if mc.GeminiAPIKey != "" && mc.isEnrichmentRequested("description") { description, err := mc.generateDescriptionWithGemini(ctx, metadata, schemaContext) if err != nil { log.Printf("WARN: Failed to generate description with Gemini for %s.%s: %v", tableName, colInfo.Name, err) } else { metadata.Description = description } } return metadata, nil } // CollectTableMetadata collects metadata for a table func (mc *MetadataCollector) CollectTableMetadata(ctx context.Context, tableName string, schemaContext string) (*TableMetadata, error) { if tableName == "" { return nil, &ErrInvalidInput{Msg: "table name cannot be empty"} } metadata := &TableMetadata{ Table: tableName, } // Generate description using Gemini if API key is available.s if mc.GeminiAPIKey != "" && mc.isEnrichmentRequested("description") { description, err := mc.generateTableDescriptionWithGemini(ctx, metadata, schemaContext) if err != nil { log.Printf("WARN: Failed to generate table description with Gemini for %s: %v", tableName, err) } else { metadata.Description = description } } return metadata, nil } // GenerateCommentSQLs collects metadata and generates SQL statements, no application. func (mc *MetadataCollector) GenerateCommentSQLs(ctx context.Context) ([]string, error) { startTime := time.Now() log.Println("INFO: Starting metadata collection and SQL comment generation...") tables, err := mc.db.ListTables() if err != nil { return nil, fmt.Errorf("failed to list tables: %w", err) } // Apply table filtering filteredTables := []string{} if len(mc.TableFilters) > 0 { for table := range mc.TableFilters { filteredTables = append(filteredTables, table) } } else { filteredTables = tables } tables = filteredTables // Generate schema context once schemaContext, err := mc.generateSchemaContext() if err != nil { log.Printf("WARN: Failed to generate schema context: %v", err) schemaContext = "" } mc.schemaContext = schemaContext var orderedSQLs []OrderedSQL var wg sync.WaitGroup errorChannel := make(chan error, len(tables)) var mu sync.Mutex for _, table := range tables { wg.Add(1) go func(table string) { defer wg.Done() // Collect TABLE metadata and generate SQL tableMetadata, err := withRetry[*TableMetadata](ctx, DefaultRetryOptions, func(ctx context.Context) (*TableMetadata, error) { return mc.CollectTableMetadata(ctx, table, mc.schemaContext) }) if err != nil { log.Println("ERROR: Failed to collect metadata for table:", table, "error:", err) errorChannel <- err return } tableCommentData := &database.TableCommentData{ TableName: tableMetadata.Table, Description: tableMetadata.Description, } tableSQL, genTableErr := mc.db.GenerateTableCommentSQL(tableCommentData, mc.Enrichments) if genTableErr != nil { log.Printf("WARN: Failed to generate table comment SQL for %s: %v", table, genTableErr) errorChannel <- genTableErr return } if tableSQL != "" { mu.Lock() orderedSQLs = append(orderedSQLs, OrderedSQL{SQL: tableSQL, Table: table, IsTableComment: true}) mu.Unlock() } columnInfos, err := mc.db.ListColumns(table) if err != nil { log.Println("ERROR: Failed to list columns for table:", table, "error:", err) errorChannel <- err return } // Apply column filtering filteredColumnInfos := []database.ColumnInfo{} if columnFilters, ok := mc.TableFilters[table]; ok && columnFilters != nil { for _, colInfo := range columnInfos { for _, filteredCol := range columnFilters { if colInfo.Name == filteredCol { filteredColumnInfos = append(filteredColumnInfos, colInfo) break } } } } else { filteredColumnInfos = columnInfos } columnInfos = filteredColumnInfos for _, colInfo := range columnInfos { wg.Add(1) go func(colInfo database.ColumnInfo) { defer wg.Done() metadata, err := withRetry[*ColumnMetadata](ctx, DefaultRetryOptions, func(ctx context.Context) (*ColumnMetadata, error) { return mc.CollectColumnMetadata(ctx, table, colInfo, mc.schemaContext) }) if err != nil { log.Println("ERROR: Failed to collect metadata for column:", colInfo.Name, "in table:", table, "error:", err) errorChannel <- err return } commentData := &database.CommentData{ TableName: metadata.Table, ColumnName: metadata.Column, ColumnDataType: metadata.DataType, ExampleValues: metadata.ExampleValues, DistinctCount: metadata.DistinctCount, NullCount: metadata.NullCount, Description: metadata.Description, } // Pass mc.Enrichments to GenerateCommentSQL sql, genErr := mc.db.GenerateCommentSQL(commentData, mc.Enrichments) if genErr != nil { log.Printf("WARN: Failed to generate comment SQL for %s.%s: %v", metadata.Table, metadata.Column, genErr) errorChannel <- genErr return } if sql != "" { mu.Lock() sql = "\t" + sql // Lock before modifying orderedSQLs orderedSQLs = append(orderedSQLs, OrderedSQL{SQL: sql, Table: table, Column: colInfo.Name, IsTableComment: false}) // Mark as column comment mu.Unlock() // Unlock after modifying orderedSQLs } }(colInfo) } }(table) } wg.Wait() close(errorChannel) // Sort orderedSQLs to maintain order sort.Slice(orderedSQLs, func(i, j int) bool { if orderedSQLs[i].Table != orderedSQLs[j].Table { return orderedSQLs[i].Table < orderedSQLs[j].Table } if orderedSQLs[i].IsTableComment != orderedSQLs[j].IsTableComment { return orderedSQLs[i].IsTableComment } return orderedSQLs[i].Column < orderedSQLs[j].Column }) allSQLs := make([]string, 0, len(orderedSQLs)) for _, osql := range orderedSQLs { allSQLs = append(allSQLs, osql.SQL) } log.Println("INFO: Metadata collection and SQL comment generation completed in:", time.Since(startTime)) return allSQLs, nil } // GenerateDeleteCommentSQLs collects metadata and generates SQL for deletion. func (mc *MetadataCollector) GenerateDeleteCommentSQLs(ctx context.Context) ([]string, error) { startTime := time.Now() log.Println("INFO: Starting metadata collection and SQL comment deletion generation...") tables, err := mc.db.ListTables() if err != nil { return nil, fmt.Errorf("failed to list tables: %w", err) } // Apply table filtering filteredTables := []string{} if len(mc.TableFilters) > 0 { for table := range mc.TableFilters { filteredTables = append(filteredTables, table) } } else { filteredTables = tables } tables = filteredTables var allSQLs []string var wg sync.WaitGroup errorChannel := make(chan error, len(tables)) for _, table := range tables { wg.Add(1) go func(table string) { defer wg.Done() // Generate SQL for deleting TABLE comments tableSQL, genTableErr := withRetry[string](ctx, DefaultRetryOptions, func(ctx context.Context) (string, error) { return mc.db.GenerateDeleteTableCommentSQL(ctx, table) }) if genTableErr != nil { log.Printf("WARN: Failed to generate delete table comment SQL for %s: %v", table, genTableErr) errorChannel <- genTableErr return } if tableSQL != "" { mc.mu.Lock() allSQLs = append(allSQLs, tableSQL) mc.mu.Unlock() } else { log.Printf("INFO: No SQL generated for deleting table comment in %s, possibly no gemini tag.", table) } columnInfos, err := mc.db.ListColumns(table) if err != nil { log.Println("ERROR: Failed to list columns for table:", table, "error:", err) errorChannel <- err return } // Apply column filtering filteredColumnInfos := []database.ColumnInfo{} if columnFilters, ok := mc.TableFilters[table]; ok && columnFilters != nil { for _, colInfo := range columnInfos { for _, filteredCol := range columnFilters { if colInfo.Name == filteredCol { filteredColumnInfos = append(filteredColumnInfos, colInfo) break } } } } else { filteredColumnInfos = columnInfos } columnInfos = filteredColumnInfos for _, colInfo := range columnInfos { wg.Add(1) go func(colInfo database.ColumnInfo) { defer wg.Done() sql, genErr := withRetry[string](ctx, DefaultRetryOptions, func(ctx context.Context) (string, error) { return mc.db.GenerateDeleteCommentSQL(ctx, table, colInfo.Name) }) if genErr != nil { log.Printf("WARN: Failed to generate delete comment SQL for %s.%s: %v", table, colInfo.Name, genErr) errorChannel <- genErr return } if sql != "" { mc.mu.Lock() allSQLs = append(allSQLs, sql) mc.mu.Unlock() } else { log.Printf("INFO: No SQL generated for deleting comment in %s.%s, possibly no gemini tag.", table, colInfo.Name) } }(colInfo) } }(table) } wg.Wait() close(errorChannel) // Check for and combine errors var combinedErr error for err := range errorChannel { if err != nil { if combinedErr == nil { combinedErr = err } else { combinedErr = fmt.Errorf("%w; %v", combinedErr, err) } } } if combinedErr != nil { return nil, combinedErr } if len(allSQLs) == 0 { log.Println("INFO: No SQL statements generated for deleting comments. Possibly no gemini tags found.") } log.Println("INFO: SQL comment deletion generation completed in:", time.Since(startTime)) return allSQLs, nil } // ColumnComment represents a comment for a specific column. type ColumnComment struct { Table string `json:"table"` Column string `json:"column"` Comment string `json:"comment"` } // GetComments retrieves all column and table comments. func (mc *MetadataCollector) GetComments(ctx context.Context) ([]*ColumnComment, error) { tables, err := mc.db.ListTables() if err != nil { return nil, fmt.Errorf("failed to list tables: %w", err) } var allComments []*ColumnComment var wg sync.WaitGroup errorChannel := make(chan error, len(tables)) for _, table := range tables { wg.Add(1) go func(table string) { defer wg.Done() tableComment, err := withRetry[string](ctx, DefaultRetryOptions, func(ctx context.Context) (string, error) { return mc.db.GetTableComment(ctx, table) }) if err != nil { log.Printf("WARN: Failed to get table comment for table: %s, error: %v", table, err) errorChannel <- err return } if tableComment != "" { mc.mu.Lock() allComments = append(allComments, &ColumnComment{ Table: table, Column: "", // Leave Column empty for table comments Comment: tableComment, }) mc.mu.Unlock() } columnInfos, err := mc.db.ListColumns(table) if err != nil { log.Printf("ERROR: Failed to list columns for table: %s, error: %v", table, err) errorChannel <- err return } for _, colInfo := range columnInfos { wg.Add(1) go func(colInfo database.ColumnInfo) { defer wg.Done() comment, err := withRetry[string](ctx, DefaultRetryOptions, func(ctx context.Context) (string, error) { return mc.db.GetColumnComment(ctx, table, colInfo.Name) }) if err != nil { log.Printf("WARN: Failed to get comment for column: %s in table: %s, error: %v", colInfo.Name, table, err) errorChannel <- err return } if comment != "" { mc.mu.Lock() allComments = append(allComments, &ColumnComment{ Table: table, Column: colInfo.Name, Comment: comment, }) mc.mu.Unlock() } }(colInfo) } }(table) } wg.Wait() close(errorChannel) var combinedErr error for err := range errorChannel { if err != nil { if combinedErr == nil { combinedErr = err } else { combinedErr = fmt.Errorf("%w; %v", combinedErr, err) } } } if combinedErr != nil { return nil, combinedErr } return allComments, nil } // FormatCommentsAsText formats the comments as plain text. func FormatCommentsAsText(comments []*ColumnComment) string { var buffer bytes.Buffer // Sort comments by table and column sort.Slice(comments, func(i, j int) bool { if comments[i].Table != comments[j].Table { return comments[i].Table < comments[j].Table } return comments[i].Column < comments[j].Column }) // Print comments in order for _, comment := range comments { if comment.Column == "" { buffer.WriteString(fmt.Sprintf("Table: %s\n", comment.Table)) } else { buffer.WriteString(fmt.Sprintf("Table: %s, Column: %s\n", comment.Table, comment.Column)) } buffer.WriteString(fmt.Sprintf("Comment: %s\n", comment.Comment)) buffer.WriteString("\n") // Add an empty line } return buffer.String() } // WriteCommentsToFile writes the comments to a file. func WriteCommentsToFile(comments string, filePath string) error { log.Printf("INFO: Writing comments to file: %s", filePath) file, err := os.Create(filePath) if err != nil { return fmt.Errorf("failed to create file: %w", err) } defer file.Close() _, err = file.WriteString(comments) return err } // generateDescriptionWithGemini calls the Gemini API to generate a column description. func (mc *MetadataCollector) generateDescriptionWithGemini(ctx context.Context, metadata *ColumnMetadata, schemaContext string) (string, error) { if mc.GeminiAPIKey == "" { return "", nil } if mc.AdditionalContext == "" { return "", nil } prompt := fmt.Sprintf(` Your task is to generate a brief and concise description for a database column based on the provided context. The context might be irrelevant, so you need to firstly read through the context and decide if there is any relevant information for the target table.column. ********** Knowledge Context ********** %s ********** End Knowledge Context ********** **Instructions:** - Response starting with your analysis. Then output the description in between <result></result> tags. - Focus on the column's purpose and meaning within the database. - Be concise and informative, no more than 50 words. - Important: Only provide a description for the column if there is information related to this column in additional context. Otherwise in all other cases, return empty <result></result> tags. The target table and column is: Column Name: %s in Table: %s Now start your response. Remember, only give description when the knowledge context provides useful information about the column. `, mc.AdditionalContext, metadata.Column, metadata.Table) client, err := genai.NewClient(ctx, option.WithAPIKey(mc.GeminiAPIKey)) if err != nil { return "", fmt.Errorf("failed to create Gemini client: %w", err) } defer client.Close() model_name := mc.Model if model_name == "" { model_name = "gemini-1.5-pro-002" } model := client.GenerativeModel(model_name) model.SetTemperature(0.4) model.SetMaxOutputTokens(500) model.SetTopP(0.8) model.SetTopK(40) resp, err := model.GenerateContent(ctx, genai.Text(prompt)) if err != nil { return "", fmt.Errorf("Gemini API call failed: %w", err) } description, err := extractTextFromResponse(resp) if err != nil { return "", err } return description, nil } // generateTableDescriptionWithGemini generates a description for a table. func (mc *MetadataCollector) generateTableDescriptionWithGemini(ctx context.Context, metadata *TableMetadata, schemaContext string) (string, error) { if mc.GeminiAPIKey == "" { return "", nil } if mc.AdditionalContext == "" { return "", nil } prompt := fmt.Sprintf(` Your task is to generate a brief and concise description for a database table based on the provided context. The context might be irrelevant, so you need to firstly read through the context and decide if there is any relevant information for the target table. ********** Knowledge Context ********** %s ********** End Knowledge Context ********** **Instructions:** - Response starting with your analysis. Then output the description in between <result></result> tags. - Be concise and informative, no more than 50 words. - Important: Only provide a description for the table if there is information related to this table in additional context. Otherwise in all other cases, return empty <result></result> tags. The target table is: Table: %s Now start your response. Remember, only give description when the knowledge context provides useful information about the table. `, mc.AdditionalContext, metadata.Table) client, err := genai.NewClient(ctx, option.WithAPIKey(mc.GeminiAPIKey)) if err != nil { return "", fmt.Errorf("failed to create Gemini client: %w", err) } defer client.Close() model_name := mc.Model if model_name == "" { model_name = "gemini-1.5-pro-002" } model := client.GenerativeModel(model_name) model.SetTemperature(0.4) model.SetMaxOutputTokens(500) model.SetTopP(0.8) model.SetTopK(40) resp, err := model.GenerateContent(ctx, genai.Text(prompt)) if err != nil { return "", fmt.Errorf("Gemini API call failed: %w", err) } description, err := extractTextFromResponse(resp) if err != nil { return "", err } return description, nil } func (mc *MetadataCollector) generateExampleValuesWithGemini(ctx context.Context, colInfo database.ColumnInfo, tableName string, originalExampleValues []string) ([]string, error) { if mc.GeminiAPIKey == "" { return nil, nil } if len(originalExampleValues) == 0 { return []string{}, nil } dataTypeDescription := colInfo.DataType exampleValuesStr := strings.Join(originalExampleValues, ", ") prompt := fmt.Sprintf(` You are an expert in data privacy and database metadata. Your task is to analyze a database column and determine if it likely contains Personally Identifiable Information (PII). Based on your analysis, you will either return synthetic, representative example values or the original example values. **Column Information:** - Column Name: %s - Table Name: %s - Data Type: %s - Original Example Values: [%s] **Instructions:** 1. **Analyze for PII:** Based on the column name, data type, and example values, determine if this column is likely to contain PII. Consider common patterns and keywords that indicate personal information (names, emails, phone numbers, addresses, IDs, etc.). 2. **Decision:** - **If likely PII:** Generate equal number of synthetic example values that are representative of the data in the "%s" column but are completely fabricated and do not resemble real personal data. The synthetic values should be consistent with the "%s" data type. - **If NOT likely PII:** Return the original example values provided. 3. **Output Format:** - If you generated synthetic values, output them as a comma-separated list enclosed in <synthetic_examples>...</synthetic_examples> tags. - If you are returning the original values, output them as a comma-separated list enclosed in <original_examples>...</original_examples> tags. Example Output for Synthetic Values: <synthetic_examples>Fake Name 1, Fake Name 2, Fake Name 3</synthetic_examples> Example Output for Original Values: <original_examples>Value 1, Value 2, Value 3</original_examples> Now, analyze the column and provide the appropriate output. `, colInfo.Name, tableName, dataTypeDescription, exampleValuesStr, colInfo.Name, dataTypeDescription) client, err := genai.NewClient(ctx, option.WithAPIKey(mc.GeminiAPIKey)) if err != nil { return nil, fmt.Errorf("failed to create Gemini client: %w", err) } defer client.Close() model_name := mc.Model if model_name == "" { model_name = "gemini-1.5-pro-002" } model := client.GenerativeModel(model_name) model.SetTemperature(0.4) model.SetMaxOutputTokens(500) model.SetTopP(0.8) model.SetTopK(40) resp, err := model.GenerateContent(ctx, genai.Text(prompt)) if err != nil { return nil, fmt.Errorf("Gemini API call failed for example value generation: %w", err) } responseString, err := extractTextFromResponseForExampleValues(resp) if err != nil { return nil, err } var exampleValues []string if strings.Contains(responseString, "<synthetic_examples>") { startTag := "<synthetic_examples>" endTag := "</synthetic_examples>" startIndex := strings.Index(responseString, startTag) endIndex := strings.Index(responseString, endTag) if startIndex != -1 && endIndex != -1 && startIndex < endIndex { syntheticValueString := responseString[startIndex+len(startTag) : endIndex] exampleValues = strings.Split(syntheticValueString, ",") for i := range exampleValues { exampleValues[i] = strings.TrimSpace(exampleValues[i]) } log.Printf("INFO: Gemini determined column '%s' table '%s' is PII and generated synthetic examples.", colInfo.Name, tableName) } else { return nil, fmt.Errorf("invalid response format for synthetic examples from Gemini: tags not found") } } else if strings.Contains(responseString, "<original_examples>") { startTag := "<original_examples>" endTag := "</original_examples>" startIndex := strings.Index(responseString, startTag) endIndex := strings.Index(responseString, endTag) if startIndex != -1 && endIndex != -1 && startIndex < endIndex { originalValueString := responseString[startIndex+len(startTag) : endIndex] exampleValues = strings.Split(originalValueString, ",") for i := range exampleValues { exampleValues[i] = strings.TrimSpace(exampleValues[i]) } } else { return nil, fmt.Errorf("invalid response format for original examples from Gemini: tags not found") } } else { return nil, fmt.Errorf("unexpected response format from Gemini for example values: %s", responseString) } return exampleValues, nil } // Helper function to extract text func extractTextFromResponse(resp *genai.GenerateContentResponse) (string, error) { if len(resp.Candidates) == 0 || len(resp.Candidates[0].Content.Parts) == 0 { return "", fmt.Errorf("empty response from Gemini API") } // Safely access and return the text if textPart, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok { // Extract the text between <result> tags resp := strings.TrimSpace(string(textPart)) if idx1 := strings.LastIndex(resp, "<result>"); idx1 != -1 { if idx2 := strings.Index(resp[idx1+len("<result>"):], "</result>"); idx2 != -1 { resp = resp[idx1+len("<result>") : idx1+len("<result>")+idx2] } else { return "", nil } } else { return "", nil } return resp, nil } return "", fmt.Errorf("unexpected response format from Gemini API: %+v", resp) } func extractTextFromResponseForExampleValues(resp *genai.GenerateContentResponse) (string, error) { if len(resp.Candidates) == 0 { return "", fmt.Errorf("no candidates in response") } part := resp.Candidates[0].Content.Parts[0] text, ok := part.(genai.Text) if !ok { return "", fmt.Errorf("unexpected response type: %T", part) } return string(text), nil } // generateSchemaContext generates a string representation of the database schema. func (mc *MetadataCollector) generateSchemaContext() (string, error) { tables, err := mc.db.ListTables() if err != nil { return "", fmt.Errorf("failed to list tables for schema context: %w", err) } var schemaContext strings.Builder schemaContext.WriteString("Tables:\n") for _, table := range tables { columns, err := mc.db.ListColumns(table) if err != nil { log.Printf("WARN: Failed to list columns for table %s in schema context: %v", table, err) continue } schemaContext.WriteString(fmt.Sprintf("- %s: [", table)) columnNames := []string{} for _, col := range columns { columnNames = append(columnNames, col.Name) } schemaContext.WriteString(strings.Join(columnNames, ", ")) schemaContext.WriteString("]\n") } return schemaContext.String(), nil }