internal/utils/files_utils.go (113 lines of code) (raw):
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package utils
import (
"bufio"
"fmt"
"os"
"strings"
)
func ReadSQLStatementsFromFile(filePath string) ([]string, error) {
content, err := os.ReadFile(filePath)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
}
sqlStatements := strings.Split(string(content), ";\n")
var trimmedStatements []string
for _, stmt := range sqlStatements {
trimmedStmt := strings.TrimSpace(stmt)
if trimmedStmt != "" {
trimmedStatements = append(trimmedStatements, trimmedStmt)
}
}
return trimmedStatements, nil
}
// ReadContextFiles reads the content of the specified context files and combines them into a single string.
func ReadContextFiles(filePaths string) (string, error) {
if filePaths == "" {
return "", nil // No context files provided
}
paths := strings.Split(filePaths, ",")
var combinedContext strings.Builder
for _, path := range paths {
path = strings.TrimSpace(path)
content, err := os.ReadFile(path)
if err != nil {
return "", fmt.Errorf("failed to read context file '%s': %w", path, err)
}
combinedContext.WriteString("\n-- Context from file: " + path + " --\n")
combinedContext.WriteString(string(content))
}
return combinedContext.String(), nil
}
func GetDefaultOutputFilePath(dbName, commandName string) string {
switch commandName {
case "get-comments":
return fmt.Sprintf("%s_comments.txt", dbName)
default: // add-comments, delete-comments, etc.
return fmt.Sprintf("%s_comments.sql", dbName)
}
}
func ConfirmAction(actionDescription string) bool {
reader := bufio.NewReader(os.Stdin)
fmt.Printf("\n-------------------------------------------------------------\n")
fmt.Printf("Generated %s:\n", actionDescription)
fmt.Print("Do you want to apply these changes to the database? (yes/no): ")
text, _ := reader.ReadString('\n')
action := strings.TrimSpace(strings.ToLower(text))
return action == "yes" || action == "y"
}
func ParseTablesFlag(tablesFlag string) (map[string][]string, error) {
tableColumns := make(map[string][]string)
if tablesFlag == "" {
return tableColumns, nil
}
// strip any whitespace
tablesFlag = strings.ReplaceAll(tablesFlag, " ", "")
// Split by comma, but only if the comma is not within square brackets
parts := SplitOutsideBrackets(tablesFlag)
for _, part := range parts {
part = strings.TrimSpace(part)
// Check if there are columns specified
bracketStart := strings.Index(part, "[")
if bracketStart != -1 {
bracketEnd := strings.Index(part, "]")
if bracketEnd == -1 {
return nil, fmt.Errorf("missing closing bracket in: %s", part)
}
tableName := strings.TrimSpace(part[:bracketStart])
columnsStr := strings.TrimSpace(part[bracketStart+1 : bracketEnd])
// Split columns by comma and trim spaces
columns := strings.Split(columnsStr, ",")
var trimmedColumns []string
for _, col := range columns {
trimmedColumns = append(trimmedColumns, strings.TrimSpace(col))
}
tableColumns[tableName] = trimmedColumns
} else {
// No columns specified, just table name
tableColumns[part] = nil
}
}
return tableColumns, nil
}
// SplitOutsideBrackets Helper function to split string by commas that are not within brackets
func SplitOutsideBrackets(s string) []string {
var result []string
var current strings.Builder
inBrackets := false
for _, char := range s {
switch char {
case '[':
inBrackets = true
current.WriteRune(char)
case ']':
inBrackets = false
current.WriteRune(char)
case ',':
if inBrackets {
current.WriteRune(char)
} else {
result = append(result, current.String())
current.Reset()
}
default:
current.WriteRune(char)
}
}
// Add the last part
if current.Len() > 0 {
result = append(result, current.String())
}
return result
}