common/pinot/pinotQueryValidator.go (345 lines of code) (raw):
// The MIT License (MIT)
// Copyright (c) 2017-2020 Uber Technologies Inc.
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
package pinot
import (
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/xwb1989/sqlparser"
"github.com/uber/cadence/common"
"github.com/uber/cadence/common/definition"
"github.com/uber/cadence/common/log"
"github.com/uber/cadence/common/types"
)
// VisibilityQueryValidator for sql query validation
type VisibilityQueryValidator struct {
validSearchAttributes map[string]interface{}
}
var timeSystemKeys = map[string]bool{
"StartTime": true,
"CloseTime": true,
"ExecutionTime": true,
"UpdateTime": true,
}
// NewPinotQueryValidator create VisibilityQueryValidator
func NewPinotQueryValidator(validSearchAttributes map[string]interface{}) *VisibilityQueryValidator {
return &VisibilityQueryValidator{
validSearchAttributes: validSearchAttributes,
}
}
// ValidateQuery validates that search attributes in the query and returns modified query.
func (qv *VisibilityQueryValidator) ValidateQuery(whereClause string) (string, error) {
if len(whereClause) != 0 {
// Build a placeholder query that allows us to easily parse the contents of the where clause.
// IMPORTANT: This query is never executed, it is just used to parse and validate whereClause
var placeholderQuery string
whereClause := strings.TrimSpace(whereClause)
if common.IsJustOrderByClause(whereClause) { // just order by
placeholderQuery = fmt.Sprintf("SELECT * FROM dummy %s", whereClause)
} else {
placeholderQuery = fmt.Sprintf("SELECT * FROM dummy WHERE %s", whereClause)
}
stmt, err := sqlparser.Parse(placeholderQuery)
if err != nil {
return "", &types.BadRequestError{Message: "Invalid query."}
}
sel, ok := stmt.(*sqlparser.Select)
if !ok {
return "", &types.BadRequestError{Message: "Invalid select query."}
}
buf := sqlparser.NewTrackedBuffer(nil)
res := ""
// validate where expr
if sel.Where != nil {
res, err = qv.validateWhereExpr(sel.Where.Expr)
if err != nil {
return "", &types.BadRequestError{Message: err.Error()}
}
}
sel.OrderBy.Format(buf)
res += buf.String()
return res, nil
}
return whereClause, nil
}
func (qv *VisibilityQueryValidator) validateWhereExpr(expr sqlparser.Expr) (string, error) {
if expr == nil {
return "", nil
}
switch expr := expr.(type) {
case *sqlparser.AndExpr, *sqlparser.OrExpr:
return qv.validateAndOrExpr(expr)
case *sqlparser.ComparisonExpr:
return qv.validateComparisonExpr(expr)
case *sqlparser.RangeCond:
return qv.validateRangeExpr(expr)
case *sqlparser.ParenExpr:
return qv.validateWhereExpr(expr.Expr)
default:
return "", errors.New("invalid where clause")
}
}
// for "between...and..." only
// <, >, >=, <= are included in validateComparisonExpr()
func (qv *VisibilityQueryValidator) validateRangeExpr(expr sqlparser.Expr) (string, error) {
buf := sqlparser.NewTrackedBuffer(nil)
rangeCond := expr.(*sqlparser.RangeCond)
colName, ok := rangeCond.Left.(*sqlparser.ColName)
if !ok {
return "", errors.New("invalid range expression: fail to get colname")
}
colNameStr := colName.Name.String()
if !qv.isValidSearchAttributes(colNameStr) {
return "", fmt.Errorf("invalid search attribute %q", colNameStr)
}
if definition.IsSystemIndexedKey(colNameStr) {
if _, ok = timeSystemKeys[colNameStr]; ok {
if lowerBound, ok := rangeCond.From.(*sqlparser.SQLVal); ok {
trimmed, err := trimTimeFieldValueFromNanoToMilliSeconds(lowerBound)
if err != nil {
return "", fmt.Errorf("trim time field %s got error: %w", colNameStr, err)
}
rangeCond.From = trimmed
}
if upperBound, ok := rangeCond.To.(*sqlparser.SQLVal); ok {
trimmed, err := trimTimeFieldValueFromNanoToMilliSeconds(upperBound)
if err != nil {
return "", fmt.Errorf("trim time field %s got error: %w", colNameStr, err)
}
rangeCond.To = trimmed
}
}
expr.Format(buf)
return buf.String(), nil
}
// lowerBound, ok := rangeCond.From.(*sqlparser.ColName)
lowerBound, ok := rangeCond.From.(*sqlparser.SQLVal)
if !ok {
return "", errors.New("invalid range expression: fail to get lowerbound")
}
lowerBoundString := string(lowerBound.Val)
upperBound, ok := rangeCond.To.(*sqlparser.SQLVal)
if !ok {
return "", errors.New("invalid range expression: fail to get upperbound")
}
upperBoundString := string(upperBound.Val)
return fmt.Sprintf("(JSON_MATCH(Attr, '\"$.%s\" is not null') "+
"AND CAST(JSON_EXTRACT_SCALAR(Attr, '$.%s') AS INT) >= %s "+
"AND CAST(JSON_EXTRACT_SCALAR(Attr, '$.%s') AS INT) <= %s)", colNameStr, colNameStr, lowerBoundString, colNameStr, upperBoundString), nil
}
func (qv *VisibilityQueryValidator) validateAndOrExpr(expr sqlparser.Expr) (string, error) {
var leftExpr sqlparser.Expr
var rightExpr sqlparser.Expr
isAnd := false
switch expr := expr.(type) {
case *sqlparser.AndExpr:
leftExpr = expr.Left
rightExpr = expr.Right
isAnd = true
case *sqlparser.OrExpr:
leftExpr = expr.Left
rightExpr = expr.Right
}
leftRes, err := qv.validateWhereExpr(leftExpr)
if err != nil {
return "", err
}
rightRes, err := qv.validateWhereExpr(rightExpr)
if err != nil {
return "", err
}
if isAnd {
return fmt.Sprintf("%s and %s", leftRes, rightRes), nil
}
return fmt.Sprintf("(%s or %s)", leftRes, rightRes), nil
}
func (qv *VisibilityQueryValidator) validateComparisonExpr(expr sqlparser.Expr) (string, error) {
comparisonExpr := expr.(*sqlparser.ComparisonExpr)
colName, ok := comparisonExpr.Left.(*sqlparser.ColName)
if !ok {
return "", errors.New("invalid comparison expression, left")
}
colNameStr := colName.Name.String()
if !qv.isValidSearchAttributes(colNameStr) {
return "", fmt.Errorf("invalid search attribute %q", colNameStr)
}
// Case1: it is system key
// this means that we don't need to change the structure of the query,
// just need to check if a value == "missing"
if definition.IsSystemIndexedKey(colNameStr) {
return qv.processSystemKey(expr)
}
// Case2: when a value is not system key
// This means, the value is from Attr so that we need to change the query to be a Json index format
return qv.processCustomKey(expr)
}
// isValidSearchAttributes return true if key is registered
func (qv *VisibilityQueryValidator) isValidSearchAttributes(key string) bool {
validAttr := qv.validSearchAttributes
_, isValidKey := validAttr[key]
return isValidKey
}
func (qv *VisibilityQueryValidator) processSystemKey(expr sqlparser.Expr) (string, error) {
comparisonExpr := expr.(*sqlparser.ComparisonExpr)
buf := sqlparser.NewTrackedBuffer(nil)
colName, ok := comparisonExpr.Left.(*sqlparser.ColName)
if !ok {
return "", fmt.Errorf("left comparison is invalid: %v", comparisonExpr.Left)
}
colNameStr := colName.Name.String()
if comparisonExpr.Operator != sqlparser.EqualStr && comparisonExpr.Operator != sqlparser.NotEqualStr {
if _, ok := timeSystemKeys[colNameStr]; ok {
sqlVal, ok := comparisonExpr.Right.(*sqlparser.SQLVal)
if !ok {
return "", fmt.Errorf("right comparison is invalid: %v", comparisonExpr.Right)
}
trimmed, err := trimTimeFieldValueFromNanoToMilliSeconds(sqlVal)
if err != nil {
return "", fmt.Errorf("trim time field %s got error: %w", colNameStr, err)
}
comparisonExpr.Right = trimmed
}
expr.Format(buf)
return buf.String(), nil
}
// need to deal with missing value e.g. CloseTime = missing
// Question: why is the right side is sometimes a type of "colName", and sometimes a type of "SQLVal"?
// Answer: for any value, sqlParser will treat any string that doesn't surrounded by single quote as ColName;
// any string that surrounded by single quote as SQLVal
_, ok = comparisonExpr.Right.(*sqlparser.SQLVal)
if !ok { // this means, the value is a string, and not surrounded by single qoute, which means, val = missing
colVal, ok := comparisonExpr.Right.(*sqlparser.ColName)
if !ok {
return "", fmt.Errorf("right comparison is invalid: %v", comparisonExpr.Right)
}
colValStr := colVal.Name.String()
// double check if val is not missing
if colValStr != "missing" {
return "", fmt.Errorf("right comparison is invalid string value: %s", colValStr)
}
var newColVal string
if strings.ToLower(colNameStr) == "historylength" {
newColVal = "0"
} else {
newColVal = "-1" // -1 is the default value for all Closed workflows related fields
}
comparisonExpr.Right = &sqlparser.SQLVal{
Type: sqlparser.IntVal, // or sqlparser.StrVal if you need to assign a string
Val: []byte(newColVal),
}
} else {
if _, ok := timeSystemKeys[colNameStr]; ok {
sqlVal, ok := comparisonExpr.Right.(*sqlparser.SQLVal)
if !ok {
return "", fmt.Errorf("right comparison is invalid/missing. key %s, right expr %v", colNameStr, comparisonExpr.Right)
}
trimmed, err := trimTimeFieldValueFromNanoToMilliSeconds(sqlVal)
if err != nil {
return "", fmt.Errorf("trim time field %s got error: %w", colNameStr, err)
}
comparisonExpr.Right = trimmed
} else if colNameStr == "CloseStatus" {
sqlVal, ok := comparisonExpr.Right.(*sqlparser.SQLVal)
if !ok {
return "", fmt.Errorf("right comparison is invalid: %v", comparisonExpr.Right)
}
closeStatus, err := parseCloseStatus(sqlVal)
if err != nil {
return "", fmt.Errorf("parse CloseStatus field got error: %w", err)
}
comparisonExpr.Right = closeStatus
}
}
// For this branch, we still have a sqlExpr type. So need to use a buf to return the string
comparisonExpr.Format(buf)
return buf.String(), nil
}
func (qv *VisibilityQueryValidator) processCustomKey(expr sqlparser.Expr) (string, error) {
comparisonExpr := expr.(*sqlparser.ComparisonExpr)
colName, ok := comparisonExpr.Left.(*sqlparser.ColName)
if !ok {
return "", errors.New("invalid comparison expression, left")
}
colNameStr := colName.Name.String()
// check type: if is IndexedValueTypeString, change to like statement for partial match
valType, ok := qv.validSearchAttributes[colNameStr]
if !ok {
return "", fmt.Errorf("invalid search attribute")
}
// get the column value
colVal, ok := comparisonExpr.Right.(*sqlparser.SQLVal)
if !ok {
return "", errors.New("invalid comparison expression, right")
}
// get the value type
indexValType := common.ConvertIndexedValueTypeToInternalType(valType, log.NewNoop())
operator := comparisonExpr.Operator
colValStr := string(colVal.Val)
switch indexValType {
case types.IndexedValueTypeString:
return processCustomString(comparisonExpr, colNameStr, colValStr), nil
case types.IndexedValueTypeKeyword:
return processCustomKeyword(operator, colNameStr, colValStr), nil
case types.IndexedValueTypeDatetime:
var err error
colVal, err = trimTimeFieldValueFromNanoToMilliSeconds(colVal)
if err != nil {
return "", fmt.Errorf("trim time field %s got error: %w", colNameStr, err)
}
colValStr := string(colVal.Val)
return processCustomNum(operator, colNameStr, colValStr, "BIGINT"), nil
case types.IndexedValueTypeDouble:
return processCustomNum(operator, colNameStr, colValStr, "DOUBLE"), nil
case types.IndexedValueTypeInt:
return processCustomNum(operator, colNameStr, colValStr, "INT"), nil
default:
return processEqual(colNameStr, colValStr), nil
}
}
func processCustomNum(operator string, colNameStr string, colValStr string, valType string) string {
if operator == sqlparser.EqualStr {
return processEqual(colNameStr, colValStr)
}
return fmt.Sprintf("(JSON_MATCH(Attr, '\"$.%s\" is not null') "+
"AND CAST(JSON_EXTRACT_SCALAR(Attr, '$.%s') AS %s) %s %s)", colNameStr, colNameStr, valType, operator, colValStr)
}
func processEqual(colNameStr string, colValStr string) string {
return fmt.Sprintf("JSON_MATCH(Attr, '\"$.%s\"=''%s''')", colNameStr, colValStr)
}
func processCustomKeyword(operator string, colNameStr string, colValStr string) string {
return fmt.Sprintf("(JSON_MATCH(Attr, '\"$.%s\"%s''%s''') or JSON_MATCH(Attr, '\"$.%s[*]\"%s''%s'''))",
colNameStr, operator, colValStr, colNameStr, operator, colValStr)
}
func processCustomString(comparisonExpr *sqlparser.ComparisonExpr, colNameStr string, colValStr string) string {
// change to like statement for partial match
comparisonExpr.Operator = sqlparser.LikeStr
comparisonExpr.Right = &sqlparser.SQLVal{
Type: sqlparser.StrVal,
Val: []byte("%" + colValStr + "%"),
}
return fmt.Sprintf("(JSON_MATCH(Attr, '\"$.%s\" is not null') "+
"AND REGEXP_LIKE(JSON_EXTRACT_SCALAR(Attr, '$.%s', 'string'), '%s*'))", colNameStr, colNameStr, colValStr)
}
func trimTimeFieldValueFromNanoToMilliSeconds(original *sqlparser.SQLVal) (*sqlparser.SQLVal, error) {
// Convert the SQLVal to a string
valStr := string(original.Val)
newVal, err := parseTime(valStr)
if err != nil {
return original, fmt.Errorf("error: failed to parse int from SQLVal %s", valStr)
}
// Convert the new value back to SQLVal
return &sqlparser.SQLVal{
Type: sqlparser.IntVal,
Val: []byte(strconv.FormatInt(newVal, 10)),
}, nil
}
func parseTime(timeStr string) (int64, error) {
if len(timeStr) == 0 {
return 0, errors.New("invalid time string")
}
// try to parse
parsedTime, err := time.Parse(time.RFC3339, timeStr)
if err == nil {
return parsedTime.UnixMilli(), nil
}
// treat as raw time
valInt, err := strconv.ParseInt(timeStr, 10, 64)
if err == nil {
var newVal int64
if valInt < 0 { // exclude open workflow which time field will be -1
newVal = valInt
} else if len(timeStr) > 13 { // Assuming nanoseconds if more than 13 digits
newVal = valInt / 1000000 // Convert time to milliseconds
} else {
newVal = valInt
}
return newVal, nil
}
return 0, errors.New("invalid time string")
}
func parseCloseStatus(original *sqlparser.SQLVal) (*sqlparser.SQLVal, error) {
statusStr := string(original.Val)
// first check if already in int64 format
if _, err := strconv.ParseInt(statusStr, 10, 64); err == nil {
return original, nil
}
// try to parse close status string
var parsedStatus types.WorkflowExecutionCloseStatus
err := parsedStatus.UnmarshalText([]byte(statusStr))
if err != nil {
return nil, err
}
return &sqlparser.SQLVal{
Type: sqlparser.IntVal,
Val: []byte(strconv.FormatInt(int64(parsedStatus), 10)),
}, nil
}