pkg/storage/querybuilder/update.go (179 lines of code) (raw):
// Copyright (c) 2019 Uber Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querybuilder
import (
"bytes"
"fmt"
"sort"
"strings"
"errors"
"github.com/lann/builder"
)
type updateData struct {
PlaceholderFormat PlaceholderFormat
Table string
SetClauses []setClause
SetClausesAdd []setClause // collections, +
SetClausesRemove []setClause // collections, -
WhereParts []Sqlizer
IfOnlyParts []Sqlizer
Usings exprs
STApplyMetadata []byte // This is ignored by the ToThrift() API.
STApplyMetadataApplied bool
}
type setClause struct {
column string
value interface{}
}
var (
// ErrMalformedSetClause indicates that the update is missing a set clause
ErrMalformedSetClause = errors.New("update statements must have at least one Set clause")
// ErrMissingTable indicates that the update is missing a target table
ErrMissingTable = errors.New("update statements must specify a table")
)
func (d *updateData) ToSQL() (sqlStr string, args []interface{}, err error) {
if len(d.Table) == 0 {
err = ErrMissingTable
return
}
sql := &bytes.Buffer{}
sql.WriteString("UPDATE ")
sql.WriteString(d.Table)
if len(d.Usings) > 0 {
sql.WriteString(" USING ")
args, _ = d.Usings.AppendToSQL(sql, " ", args)
}
var setSqls []string
if d.STApplyMetadataApplied || len(d.SetClauses) > 0 || len(d.SetClausesAdd) > 0 ||
len(d.SetClausesRemove) > 0 {
sql.WriteString(" SET ")
cnt := len(d.SetClauses) + len(d.SetClausesAdd) + len(d.SetClausesRemove)
if d.STApplyMetadataApplied {
cnt++
}
setSqls = make([]string, cnt)
} else {
err = ErrMalformedSetClause
return
}
setIdx := 0
for _, setClause := range d.SetClauses {
var valSQL string
e, isExpr := setClause.value.(expr)
if isExpr {
valSQL = e.sql
args = append(args, e.args...)
} else {
valSQL = "?"
args = append(args, setClause.value)
}
setSqls[setIdx] = fmt.Sprintf("%s = %s", setClause.column, valSQL)
setIdx++
}
if d.STApplyMetadataApplied {
setSqls[setIdx] = "st_apply_metadata = ?"
args = append(args, d.STApplyMetadata)
setIdx++
}
for _, setClause := range d.SetClausesAdd { // SET emails = emails + ?
args = append(args, setClause.value)
setSqls[setIdx] = fmt.Sprintf("%s = %s + ?", setClause.column, setClause.column)
setIdx++
}
for _, setClause := range d.SetClausesRemove { // SET emails = emails - ?
args = append(args, setClause.value)
setSqls[setIdx] = fmt.Sprintf("%s = %s - ?", setClause.column, setClause.column)
setIdx++
}
sql.WriteString(strings.Join(setSqls, ", "))
if len(d.WhereParts) > 0 {
sql.WriteString(" WHERE ")
args, err = appendToSQL(d.WhereParts, sql, " AND ", args)
if err != nil {
return
}
}
if len(d.IfOnlyParts) > 0 {
sql.WriteString(" IF ")
args, err = appendToSQL(d.IfOnlyParts, sql, " AND ", args)
if err != nil {
return
}
}
sqlStr, err = d.PlaceholderFormat.ReplacePlaceholders(sql.String())
return
}
func (d updateData) GetResource() string {
return d.Table
}
func (d updateData) GetWhereParts() []Sqlizer {
return d.WhereParts
}
func (d updateData) GetColumns() []Sqlizer {
return nil
}
// Builder
// UpdateBuilder builds SQL UPDATE statements.
type UpdateBuilder builder.Builder
func init() {
builder.Register(UpdateBuilder{}, updateData{})
}
// Format methods
// PlaceholderFormat sets PlaceholderFormat (e.g. Question or Dollar) for the
// update.
func (b UpdateBuilder) PlaceholderFormat(f PlaceholderFormat) UpdateBuilder {
return builder.Set(b, "PlaceholderFormat", f).(UpdateBuilder)
}
// SQL methods
// ToSQL builds the update into a SQL string and bound args.
func (b UpdateBuilder) ToSQL() (string, []interface{}, error) {
data := builder.GetStruct(b).(updateData)
return data.ToSQL()
}
// ToUql builds the query into a UQL string and bound args.
// As an runtime optimization, it also returns query options
func (b UpdateBuilder) ToUql() (query string, args []interface{},
options map[string]interface{}, err error) {
data := builder.GetStruct(b).(updateData)
query, args, err = data.ToSQL()
options = map[string]interface{}{
"IsCAS": len(data.IfOnlyParts) > 0,
}
return
}
// StmtType returns type of the statement
func (b UpdateBuilder) StmtType() StmtType {
return UpdateStmtType
}
// GetData returns the underlying struct as an interface
func (b UpdateBuilder) GetData() StatementAccessor {
return builder.GetStruct(b).(updateData)
}
// Table sets the table to be updated.
func (b UpdateBuilder) Table(table string) UpdateBuilder {
return builder.Set(b, "Table", table).(UpdateBuilder)
}
// Set adds SET clauses to the update.
func (b UpdateBuilder) Set(column string, value interface{}) UpdateBuilder {
return builder.Append(b, "SetClauses", setClause{column: column, value: value}).(UpdateBuilder)
}
// Add appends a value to a list column.
func (b UpdateBuilder) Add(column string, value interface{}) UpdateBuilder {
return builder.Append(b, "SetClausesAdd", setClause{column: column, value: value}).(UpdateBuilder)
}
// Remove discards a value from a list column
func (b UpdateBuilder) Remove(column string, value interface{}) UpdateBuilder {
return builder.Append(b, "SetClausesRemove", setClause{column: column, value: value}).(UpdateBuilder)
}
// SetMap is a convenience method which calls .Set for each key/value pair in clauses.
func (b UpdateBuilder) SetMap(clauses map[string]interface{}) UpdateBuilder {
keys := make([]string, len(clauses))
i := 0
for key := range clauses {
keys[i] = key
i++
}
sort.Strings(keys)
for _, key := range keys {
val, _ := clauses[key]
b = b.Set(key, val)
}
return b
}
// Where adds WHERE expressions to the update.
//
// See SelectBuilder.Where for more information.
func (b UpdateBuilder) Where(pred interface{}, args ...interface{}) UpdateBuilder {
return builder.Append(b, "WhereParts", newWherePart(pred, args...)).(UpdateBuilder)
}
// IfOnly represents a LWT
func (b UpdateBuilder) IfOnly(pred interface{}, rest ...interface{}) UpdateBuilder {
return builder.Append(b, "IfOnlyParts", newWherePart(pred, rest...)).(UpdateBuilder)
}
// IsCAS returns true is the update statement has a compare-and-set part
func (b UpdateBuilder) IsCAS() bool {
data := builder.GetStruct(b).(updateData)
return len(data.IfOnlyParts) > 0
}
// AddSTApplyMetadata adds a value for the special st_apply_metadata column.
func (b UpdateBuilder) AddSTApplyMetadata(value []byte) UpdateBuilder {
tmp := builder.Set(b, "STApplyMetadataApplied", true).(UpdateBuilder)
return builder.Set(tmp, "STApplyMetadata", value).(UpdateBuilder)
}
// Using adds an expression to the end of the update.
func (b UpdateBuilder) Using(sql string, args ...interface{}) UpdateBuilder {
return builder.Append(b, "Usings", expression(sql, args...)).(UpdateBuilder)
}