pkg/storage/querybuilder/expr.go (202 lines of code) (raw):
// Copyright (c) 2019 Uber Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querybuilder
import (
"fmt"
"io"
"reflect"
"strings"
"database/sql/driver"
"github.com/gocql/gocql"
)
type expr struct {
sql string
args []interface{}
}
// expr builds value expressions for InsertBuilder and UpdateBuilder.
//
// Ex:
// .Values(Expr("FROM_UNIXTIME(?)", t))
func expression(sql string, args ...interface{}) expr {
return expr{sql: sql, args: args}
}
func (e expr) ToSQL() (sql string, args []interface{}, err error) {
return e.sql, e.args, nil
}
type exprs []expr
func (es exprs) AppendToSQL(w io.Writer, sep string, args []interface{}) ([]interface{}, error) {
for i, e := range es {
if i > 0 {
_, err := io.WriteString(w, sep)
if err != nil {
return nil, err
}
}
_, err := io.WriteString(w, e.sql)
if err != nil {
return nil, err
}
args = append(args, e.args...)
}
return args, nil
}
// aliasExpr helps to alias part of SQL query generated with underlying "expr"
type aliasExpr struct {
expr Sqlizer
alias string
}
// Alias allows to define alias for column in SelectBuilder. Useful when column is
// defined as complex expression like IF or CASE
// Ex:
// .Column(Alias(caseStmt, "case_column"))
func alias(expr Sqlizer, alias string) aliasExpr {
return aliasExpr{expr, alias}
}
// ToSQL converts to SQL string and args
func (e aliasExpr) ToSQL() (sql string, args []interface{}, err error) {
sql, args, err = e.expr.ToSQL()
if err == nil {
sql = fmt.Sprintf("(%s) AS %s", sql, e.alias)
}
return
}
// Eq is syntactic sugar for use with Where/Having/Set methods.
// Ex:
// .Where(Eq{"id": 1})
type Eq map[string]interface{}
// UUID represents the cassandra uuid data type
type UUID struct {
gocql.UUID
}
// ParseUUID creates an UUID object from a string
func ParseUUID(input string) (UUID, error) {
uuid, err := gocql.ParseUUID(input)
return UUID{UUID: uuid}, err
}
// IsUUID asserts if a value is of a UUID type
func IsUUID(value interface{}) bool {
switch value.(type) {
case UUID:
return true
case gocql.UUID:
return true
case *UUID:
return true
case *gocql.UUID:
return true
}
return false
}
func (eq Eq) toSQL(useNotOpr bool) (sql string, args []interface{}, err error) {
var (
exprs []string
equalOpr = "="
inOpr = "IN"
nullOpr = "IS"
)
if useNotOpr {
equalOpr = "<>"
inOpr = "NOT IN"
nullOpr = "IS NOT"
}
for key, val := range eq {
expr := ""
switch v := val.(type) {
case driver.Valuer:
if val, err = v.Value(); err != nil {
return
}
}
if val == nil {
expr = fmt.Sprintf("%s %s NULL", key, nullOpr)
} else {
valVal := reflect.ValueOf(val)
if !IsUUID(val) && (valVal.Kind() == reflect.Array ||
valVal.Kind() == reflect.Slice) {
if valVal.Len() == 0 {
expr = fmt.Sprintf("%s %s (NULL)", key, inOpr)
if args == nil {
args = []interface{}{}
}
} else {
for i := 0; i < valVal.Len(); i++ {
args = append(args, valVal.Index(i).Interface())
}
expr = fmt.Sprintf("%s %s (%s)", key, inOpr, Placeholders(valVal.Len()))
}
} else {
expr = fmt.Sprintf("%s %s ?", key, equalOpr)
args = append(args, val)
}
}
exprs = append(exprs, expr)
}
sql = strings.Join(exprs, " AND ")
return
}
// ToSQL converts to SQL string and args
func (eq Eq) ToSQL() (sql string, args []interface{}, err error) {
return eq.toSQL(false)
}
// NotEq is syntactic sugar for use with Where/Having/Set methods.
// Ex:
// .Where(NotEq{"id": 1}) == "id <> 1"
type NotEq Eq
// ToSQL converts to SQL string and args
func (neq NotEq) ToSQL() (sql string, args []interface{}, err error) {
return Eq(neq).toSQL(true)
}
// Lt is syntactic sugar for use with Where/Having/Set methods.
// Ex:
// .Where(Lt{"id": 1})
type Lt map[string]interface{}
func (lt Lt) toSQL(opposite, orEq bool) (sql string, args []interface{}, err error) {
var (
exprs []string
opr = "<"
)
if opposite {
opr = ">"
}
if orEq {
opr = fmt.Sprintf("%s%s", opr, "=")
}
for key, val := range lt {
expr := ""
switch v := val.(type) {
case driver.Valuer:
if val, err = v.Value(); err != nil {
return
}
}
if val == nil {
err = fmt.Errorf("cannot use null with less than or greater than operators")
return
}
valVal := reflect.ValueOf(val)
if valVal.Kind() == reflect.Array || valVal.Kind() == reflect.Slice {
err = fmt.Errorf("cannot use UUID, array or slice with less than or greater than operators")
return
}
expr = fmt.Sprintf("%s %s ?", key, opr)
args = append(args, val)
exprs = append(exprs, expr)
}
sql = strings.Join(exprs, " AND ")
return
}
// ToSQL converts to SQL string and args
func (lt Lt) ToSQL() (sql string, args []interface{}, err error) {
return lt.toSQL(false, false)
}
// LtOrEq is syntactic sugar for use with Where/Having/Set methods.
// Ex:
// .Where(LtOrEq{"id": 1}) == "id <= 1"
type LtOrEq Lt
// ToSQL converts to SQL string and args
func (ltOrEq LtOrEq) ToSQL() (sql string, args []interface{}, err error) {
return Lt(ltOrEq).toSQL(false, true)
}
// Gt is syntactic sugar for use with Where/Having/Set methods.
// Ex:
// .Where(Gt{"id": 1}) == "id > 1"
type Gt Lt
// ToSQL converts to SQL string and args
func (gt Gt) ToSQL() (sql string, args []interface{}, err error) {
return Lt(gt).toSQL(true, false)
}
// GtOrEq is syntactic sugar for use with Where/Having/Set methods.
// Ex:
// .Where(GtOrEq{"id": 1}) == "id >= 1"
type GtOrEq Lt
// ToSQL converts to SQL string and args
func (gtOrEq GtOrEq) ToSQL() (sql string, args []interface{}, err error) {
return Lt(gtOrEq).toSQL(true, true)
}
type conj []Sqlizer
func (c conj) join(sep string) (sql string, args []interface{}, err error) {
var sqlParts []string
for _, sqlizer := range c {
partSQL, partArgs, err := sqlizer.ToSQL()
if err != nil {
return "", nil, err
}
if partSQL != "" {
sqlParts = append(sqlParts, partSQL)
args = append(args, partArgs...)
}
}
if len(sqlParts) > 0 {
sql = fmt.Sprintf("(%s)", strings.Join(sqlParts, sep))
}
return
}
// And is of type conj
type And conj
// ToSQL converts to SQL string and args
func (a And) ToSQL() (string, []interface{}, error) {
return conj(a).join(" AND ")
}
// Or is of type conj
type Or conj
// ToSQL converts to SQL string and args
func (o Or) ToSQL() (string, []interface{}, error) {
return conj(o).join(" OR ")
}