vulndb/sqlutil/stmt.go (188 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlutil
import (
"strings"
)
type baseStmt struct {
q []string
}
func (s *baseStmt) add(parts ...string) *baseStmt {
s.q = append(s.q, parts...)
return s
}
func (s *baseStmt) join(parts []string) *baseStmt {
s.q = append(s.q, strings.Join(parts, ", "))
return s
}
func (s *baseStmt) group(g *baseStmt, as string) *baseStmt {
s.add("(").add(g.q...).add(")")
if as != "" {
s.add("AS").add(as)
}
return s
}
// String returns the statement.
func (s *baseStmt) String() string {
if len(s.q) == 0 {
return ""
}
var sb strings.Builder
sb.WriteString(s.q[0])
for i := 1; i < len(s.q); i++ {
if s.q[i-1] != "(" && s.q[i] != ")" && s.q[i] != "," {
sb.WriteString(" ")
}
sb.WriteString(s.q[i])
}
return sb.String()
}
func genbindgroup(n int) string {
if n == 0 {
return ""
}
var s strings.Builder
s.WriteString("(?")
for i := 1; i < n; i++ {
s.WriteString(", ?")
}
s.WriteString(")")
return s.String()
}
// InsertStmt represents the INSERT statement.
type InsertStmt struct {
baseStmt
values []interface{}
}
// Insert creates and initializes a new INSERT statement.
func Insert() *InsertStmt {
s := &InsertStmt{}
s.add("INSERT")
return s
}
// Replace creates and initializes a new REPLACE statement.
// Backed by InsertStmt to be used interchangeably.
func Replace() *InsertStmt {
s := &InsertStmt{}
s.add("REPLACE")
return s
}
// Into adds the INTO part of the statement.
func (s *InsertStmt) Into(table string) *InsertStmt {
s.add("INTO", table)
return s
}
// Fields adds the (f1, fN) part of the insert part of the statement.
func (s *InsertStmt) Fields(fields ...string) *InsertStmt {
s.add("(").join(fields).add(")")
return s
}
// Values adds the (v1, vN) part of the insert part of the statement.
// Each record is added to the statement as a set of bindings (?, ?), and
// their values recorded. Use QueryArgs to get the values recorded.
func (s *InsertStmt) Values(records ...Record) *InsertStmt {
if len(records) == 0 {
return s
}
if len(s.values) == 0 {
s.add("VALUES")
} else {
s.add(",")
}
g := make([]string, len(records))
for i, r := range records {
values := r.Values()
if len(values) == 0 {
continue
}
g[i] = genbindgroup(len(values))
s.values = append(s.values, values...)
}
s.join(g)
return s
}
// Select adds a Select to the statement.
func (s *InsertStmt) Select(stmt *SelectStmt) *InsertStmt {
s.add(stmt.q...)
s.values = append(s.values, stmt.values...)
return s
}
// Literal adds literal string l to the statement.
func (s *InsertStmt) Literal(l string) *InsertStmt {
s.add(l)
return s
}
// QueryArgs returns the values corresponding to bindings (?, ?) from
// all calls to Values. e.g. db.Exec(stmt.String(), stmt.QueryArgs()...)
func (s *InsertStmt) QueryArgs() []interface{} {
return s.values
}
// UpdateStmt represents the UPDATE statement.
type UpdateStmt struct {
baseStmt
values []interface{}
}
// Update creates and initializes a new UPDATE statement.
func Update(tables ...string) *UpdateStmt {
s := &UpdateStmt{}
s.add("UPDATE").join(tables)
return s
}
// Set adds the SET part of the statement.
func (s *UpdateStmt) Set(al *AssignmentList) *UpdateStmt {
s.add("SET").add(al.String())
s.values = append(s.values, al.Values()...)
return s
}
// Where adds the WHERE part of the statement.
func (s *UpdateStmt) Where(cond *QueryConditionSet) *UpdateStmt {
s.add("WHERE").add(cond.String())
s.values = append(s.values, cond.Values()...)
return s
}
// QueryArgs returns the values corresponding to bindings (?, ?) from
// all calls to Set and Where. e.g. db.Exec(stmt.String(), stmt.QueryArgs()...)
func (s *UpdateStmt) QueryArgs() []interface{} {
return s.values
}
// SelectStmt represents a SELECT statement.
type SelectStmt struct {
baseStmt
values []interface{}
}
// Select creates and initializes a new SELECT statement.
func Select(fields ...string) *SelectStmt {
s := &SelectStmt{}
s.add("SELECT").join(fields)
return s
}
// Select adds another Select statement to the statement.
// e.g. Select("*").From().Select(...)
func (s *SelectStmt) Select(a *SelectStmt) *SelectStmt {
s.add(a.baseStmt.q...)
s.values = append(s.values, a.values...)
return s
}
// SelectGroup adds another SelectStmt to the statement, as a group, with
// an optional alias.
func (s *SelectStmt) SelectGroup(as string, g *SelectStmt) *SelectStmt {
s.group(&g.baseStmt, as)
s.values = append(s.values, g.values...)
return s
}
// From adds the FROM part of the statement.
func (s *SelectStmt) From(tables ...string) *SelectStmt {
s.add("FROM")
if len(tables) > 0 {
s.join(tables)
}
return s
}
// Where adds the WHERE statement followed by conditions to the statement.
func (s *SelectStmt) Where(cond *QueryConditionSet) *SelectStmt {
s.add("WHERE").add(cond.String())
s.values = append(s.values, cond.Values()...)
return s
}
// Literal adds literal string l to the statement.
// Useful for e.g. UNION, JOIN, LIMIT.
func (s *SelectStmt) Literal(l string) *SelectStmt {
s.add(l)
return s
}
// QueryArgs returns the values corresponding to bindings (?, ?) from
// all calls to Where. e.g. db.Query(stmt.String(), stmt.QueryArgs()...)
func (s *SelectStmt) QueryArgs() []interface{} {
return s.values
}
// DeleteStmt represents a DELETE statement.
type DeleteStmt struct {
baseStmt
values []interface{}
}
// Delete creates and initializes a new DELETE statement.
func Delete() *DeleteStmt {
s := &DeleteStmt{}
s.add("DELETE")
return s
}
// From adds the FROM part of the statement.
func (s *DeleteStmt) From(tables ...string) *DeleteStmt {
s.add("FROM").join(tables)
return s
}
// Where adds the WHERE statement followed by conditions to the statement.
func (s *DeleteStmt) Where(cond *QueryConditionSet) *DeleteStmt {
s.add("WHERE").add(cond.String())
s.values = append(s.values, cond.Values()...)
return s
}
// QueryArgs returns the values corresponding to bindings (?, ?) from
// all calls to Where. e.g. db.Exec(stmt.String(), stmt.QueryArgs()...)
func (s *DeleteStmt) QueryArgs() []interface{} {
return s.values
}
// Literal adds literal string l to the statement.
func (s *DeleteStmt) Literal(l string) *DeleteStmt {
s.add(l)
return s
}