func extractSQLParams()

in pkg/rules/databasesql/databasesql_extractor.go [109:168]


func extractSQLParams(query string) (map[string]string, error) {
	stmt, err := sqlparser.Parse(query)
	if err != nil {
		log.Printf("failed to fetch sql params: %v", err)
		return nil, err
	}

	values := make(map[string]string)
	// Only support DML currently
	switch stmt := stmt.(type) {
	case *sqlparser.Select:
		if stmt.Where != nil {
			extractConditions(stmt.Where.Expr, values)
		}
	case *sqlparser.Update:
		for _, expr := range stmt.Exprs {
			if sqlVal, ok := expr.Expr.(*sqlparser.SQLVal); ok && sqlVal.Type == sqlparser.StrVal {
				values[expr.Name.Name.String()] = string(sqlVal.Val)
			}
		}
		if stmt.Where != nil {
			extractConditions(stmt.Where.Expr, values)
		}
	case *sqlparser.Delete:
		if stmt.Where != nil {
			extractConditions(stmt.Where.Expr, values)
		}
	case *sqlparser.Insert:
		columns := make([]string, 0, len(stmt.Columns))
		for _, col := range stmt.Columns {
			columns = append(columns, col.String())
		}

		rows, ok := stmt.Rows.(sqlparser.Values)
		if ok {
			for i, row := range rows {
				rowSuffix := ""
				if len(rows) > 1 {
					rowSuffix = fmt.Sprintf("_row%d", i+1)
				}

				for j, val := range row {
					if j >= len(columns) {
						continue
					}

					if sqlVal, ok := val.(*sqlparser.SQLVal); ok && sqlVal.Type == sqlparser.StrVal {
						colName := columns[j]
						if rowSuffix != "" {
							colName += rowSuffix
						}
						values[colName] = string(sqlVal.Val)
					}
				}
			}
		}
	}

	return values, nil
}