pkg/rules/databasesql/databasesql_extractor.go (156 lines of code) (raw):

// Copyright (c) 2025 Alibaba Group Holding Ltd. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package databasesql import ( "fmt" "log" "github.com/xwb1989/sqlparser" ) func extractSQLMetadata(request databaseSqlRequest) { sql := request.sql if sqlCache.Contains(sql) { return } collection := extractCollection(sql) sqlMeta := SQLMeta{ stmt: request.sql, operation: request.opType, collection: collection, } sqlCache.Add(sql, sqlMeta) } func getCollection(sql string) string { if meta, found := sqlCache.Get(sql); found { return meta.collection } // Attempt to retrieve the collection again. return extractCollection(sql) } func getParams(sql string) []any { meta, found := sqlCache.Get(sql) if found && len(meta.params) > 0 { return meta.params } // db params did not been retrieved in `extractSQLMetadata()`, so parse it and update the db meta here. paramsMap, _ := extractSQLParams(sql) params := []any{} for _, v := range paramsMap { params = append(params, v) } if found { // `meta` was found already, only need to update it's params. updatedMeta := meta updatedMeta.params = params sqlCache.Add(sql, updatedMeta) } return params } func extractCollection(query string) string { stmt, err := sqlparser.Parse(query) if err != nil { return "" } // Only support DML currently switch stmt := stmt.(type) { case *sqlparser.Select: return getTableName(stmt.From) case *sqlparser.Update: return getTableName(stmt.TableExprs) case *sqlparser.Insert: return stmt.Table.Name.String() case *sqlparser.Delete: return getTableName(stmt.TableExprs) default: return "" } } func getTableName(node sqlparser.SQLNode) string { switch n := node.(type) { case sqlparser.TableName: return n.Name.String() case sqlparser.TableExprs: for _, expr := range n { aliasedExpr, ok := expr.(*sqlparser.AliasedTableExpr) if !ok { continue } tableName, ok := aliasedExpr.Expr.(sqlparser.TableName) if ok { return tableName.Name.String() } } } return "" } // Extract SQL parameters func extractSQLParams(query string) (map[string]string, error) { stmt, err := sqlparser.Parse(query) if err != nil { log.Printf("failed to fetch sql params: %v", err) return nil, err } values := make(map[string]string) // Only support DML currently switch stmt := stmt.(type) { case *sqlparser.Select: if stmt.Where != nil { extractConditions(stmt.Where.Expr, values) } case *sqlparser.Update: for _, expr := range stmt.Exprs { if sqlVal, ok := expr.Expr.(*sqlparser.SQLVal); ok && sqlVal.Type == sqlparser.StrVal { values[expr.Name.Name.String()] = string(sqlVal.Val) } } if stmt.Where != nil { extractConditions(stmt.Where.Expr, values) } case *sqlparser.Delete: if stmt.Where != nil { extractConditions(stmt.Where.Expr, values) } case *sqlparser.Insert: columns := make([]string, 0, len(stmt.Columns)) for _, col := range stmt.Columns { columns = append(columns, col.String()) } rows, ok := stmt.Rows.(sqlparser.Values) if ok { for i, row := range rows { rowSuffix := "" if len(rows) > 1 { rowSuffix = fmt.Sprintf("_row%d", i+1) } for j, val := range row { if j >= len(columns) { continue } if sqlVal, ok := val.(*sqlparser.SQLVal); ok && sqlVal.Type == sqlparser.StrVal { colName := columns[j] if rowSuffix != "" { colName += rowSuffix } values[colName] = string(sqlVal.Val) } } } } } return values, nil } func extractConditions(expr sqlparser.Expr, values map[string]string) { switch expr := expr.(type) { case *sqlparser.ComparisonExpr: if expr.Operator == "=" { colName, ok := expr.Left.(*sqlparser.ColName) if !ok { return } valExpr, ok := expr.Right.(*sqlparser.SQLVal) if !ok { return } if valExpr.Type == sqlparser.StrVal { values[colName.Name.String()] = string(valExpr.Val) } } case *sqlparser.AndExpr: extractConditions(expr.Left, values) extractConditions(expr.Right, values) case *sqlparser.OrExpr: extractConditions(expr.Left, values) extractConditions(expr.Right, values) } }