in pkg/rules/databasesql/databasesql_extractor.go [109:168]
func extractSQLParams(query string) (map[string]string, error) {
stmt, err := sqlparser.Parse(query)
if err != nil {
log.Printf("failed to fetch sql params: %v", err)
return nil, err
}
values := make(map[string]string)
// Only support DML currently
switch stmt := stmt.(type) {
case *sqlparser.Select:
if stmt.Where != nil {
extractConditions(stmt.Where.Expr, values)
}
case *sqlparser.Update:
for _, expr := range stmt.Exprs {
if sqlVal, ok := expr.Expr.(*sqlparser.SQLVal); ok && sqlVal.Type == sqlparser.StrVal {
values[expr.Name.Name.String()] = string(sqlVal.Val)
}
}
if stmt.Where != nil {
extractConditions(stmt.Where.Expr, values)
}
case *sqlparser.Delete:
if stmt.Where != nil {
extractConditions(stmt.Where.Expr, values)
}
case *sqlparser.Insert:
columns := make([]string, 0, len(stmt.Columns))
for _, col := range stmt.Columns {
columns = append(columns, col.String())
}
rows, ok := stmt.Rows.(sqlparser.Values)
if ok {
for i, row := range rows {
rowSuffix := ""
if len(rows) > 1 {
rowSuffix = fmt.Sprintf("_row%d", i+1)
}
for j, val := range row {
if j >= len(columns) {
continue
}
if sqlVal, ok := val.(*sqlparser.SQLVal); ok && sqlVal.Type == sqlparser.StrVal {
colName := columns[j]
if rowSuffix != "" {
colName += rowSuffix
}
values[colName] = string(sqlVal.Val)
}
}
}
}
}
return values, nil
}