dax/internal/parser/expression.go (590 lines of code) (raw):
/*
Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
A copy of the License is located at
http://www.apache.org/licenses/LICENSE-2.0
or in the "license" file accompanying this file. This file is distributed
on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
express or implied. See the License for the specific language governing
permissions and limitations under the License.
*/
package parser
import (
"bytes"
"fmt"
"io"
"strconv"
"strings"
antlr "github.com/antlr4-go/antlr/v4"
"github.com/aws/aws-dax-go-v2/dax/internal/cbor"
"github.com/aws/aws-dax-go-v2/dax/internal/parser/generated"
"github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
)
const (
encodingVersion = 1
variablePrefix = ":"
substitutionPrefix = "#"
)
const tagDocumentPathOrdinal = 3324 // TODO remove once merged
type ExpressionEncoder struct {
generated.BaseDynamoDbGrammarListener
// input
expressions map[int]string
substitutes map[string]string
variables map[string]types.AttributeValue
// output
encoded map[int][]byte
variableValues []types.AttributeValue
// book keeping
stack []sexpr
unusedSubstitutes stringSet
unusedVariables stringSet
variableIdByName map[string]int
exprType int
nestingLevel int
err error
// temporary buffer/writer
cborWriter *cbor.Writer
buf *bytes.Buffer
}
func NewExpressionEncoder(expr map[int]string, subs map[string]string, vars map[string]types.AttributeValue) *ExpressionEncoder {
var b bytes.Buffer
us := make(stringSet, len(subs))
us.addKeysStrVal(subs)
uv := make(stringSet, len(vars))
uv.addKeysAttrVal(vars)
return &ExpressionEncoder{
expressions: expr,
substitutes: subs,
variables: vars,
encoded: make(map[int][]byte),
variableIdByName: make(map[string]int),
variableValues: make([]types.AttributeValue, 0, len(vars)),
unusedSubstitutes: us,
unusedVariables: uv,
cborWriter: cbor.NewWriter(&b),
buf: &b,
}
}
func (e *ExpressionEncoder) Parse() (map[int][]byte, error) {
if len(e.expressions) == 0 || len(e.encoded) == len(e.expressions) {
return e.encoded, nil
}
var err error
for k, v := range e.expressions {
e.reset(k)
if err = walkDynamoDbExpr(k, v, e); err != nil {
return nil, err
}
if err = e.validate(false); err != nil {
return nil, err
}
exprRaw := e.pop()
expr := e.genSExpr(exprRaw)
if e.encoded[k], err = e.fullExpr(k, expr); err != nil {
return nil, err
}
}
if err = e.validate(true); err != nil {
return nil, err
}
return e.encoded, nil
}
func (e *ExpressionEncoder) Write(typ int, writer io.Writer) error {
if _, err := e.Parse(); err != nil {
return err
}
_, err := writer.Write(e.encoded[typ])
return err
}
func (e *ExpressionEncoder) reset(typ int) {
e.exprType = typ
e.nestingLevel = 0
e.variableIdByName = make(map[string]int)
e.variableValues = make([]types.AttributeValue, 0, len(e.variables))
e.err = nil
}
func (e *ExpressionEncoder) validate(final bool) error {
if e.err != nil {
return e.err
}
if final {
if len(e.unusedSubstitutes) != 0 {
return newInvalidParameterError(fmt.Sprintf("Value provided in ExpressionAttributeNames unused in expressions: keys: {%v}", e.unusedSubstitutes))
}
if len(e.unusedVariables) != 0 {
return newInvalidParameterError(fmt.Sprintf("Value provided in ExpressionAttributeValues unused in expressions: keys: {%v}", e.unusedVariables))
}
} else {
if sz := len(e.stack); sz != 1 {
return newInvalidParameterError(fmt.Sprintf("Invalid %sExpression, Stack size = %d", exprTypeString(e.exprType), sz))
}
if e.nestingLevel != 0 {
return newInvalidParameterError(fmt.Sprintf("Invalid %sExpression, Nesting level = %d", exprTypeString(e.exprType), e.nestingLevel))
}
}
return nil
}
func (e *ExpressionEncoder) genSExpr(expr sexpr) []byte {
e.writeSExpr(expr)
return e.bytes()
}
func (e *ExpressionEncoder) writeSExpr(expr sexpr) {
switch {
case expr.atom != nil:
e.cborWriter.Write(expr.atom)
case expr.expr != nil:
e.cborWriter.WriteArrayHeader(len(expr.expr))
for _, a := range expr.expr {
e.writeSExpr(a)
}
default:
}
}
func (e *ExpressionEncoder) fullExpr(typ int, expr []byte) ([]byte, error) {
if typ == ProjectionExpr {
e.cborWriter.WriteArrayHeader(2)
} else {
e.cborWriter.WriteArrayHeader(3)
}
e.cborWriter.WriteInt(encodingVersion)
e.cborWriter.Write(expr)
if typ != ProjectionExpr {
e.cborWriter.WriteArrayHeader(len(e.variableValues))
for _, v := range e.variableValues {
if err := cbor.EncodeAttributeValue(v, e.cborWriter); err != nil {
return nil, err
}
}
}
return e.bytes(), nil
}
func (e *ExpressionEncoder) ExitId(ctx *generated.IdContext) {
if e.err != nil { // bail out
return
}
id := ctx.GetText()
if id[0:1] == substitutionPrefix {
s, ok := e.substitutes[id]
if !ok {
e.err = newInvalidParameterError(fmt.Sprintf("Invalid %sExpression. Substitution value not provided for %s", exprTypeString(e.exprType), id))
return
}
e.unusedSubstitutes.remove(id)
e.push(e.encodeDocumentPathElement(s))
} else {
e.push(e.encodeDocumentPathElement(id))
}
}
func (e *ExpressionEncoder) ExitListAccess(ctx *generated.ListAccessContext) {
if e.err != nil { // bail out
return
}
s := ctx.GetText()
o, _ := strconv.ParseInt(s[1:len(s)-1], 10, 64) // parser will detect syntax error
e.push(e.encodeListAccess(o))
}
func (e *ExpressionEncoder) ExitPath(ctx *generated.PathContext) {
if e.err != nil { // bail out
return
}
c := ctx.GetChildCount()
comp := make([]sexpr, c)
for c > 0 {
c--
comp[c] = e.pop()
}
b := e.encodeFunction(opDocumentPath, comp)
e.push(b)
}
func (e *ExpressionEncoder) ExitProjection(ctx *generated.ProjectionContext) {
if e.err != nil { // bail out
return
}
c := int((ctx.GetChildCount() + 1) / 2) // path, path, ... path
fields := make([]sexpr, c)
for c > 0 {
c--
fields[c] = e.pop()
}
b := e.encodeArray(fields)
e.push(b)
}
func (e *ExpressionEncoder) ExitLiteralSub(ctx *generated.LiteralSubContext) {
if e.err != nil { // bail out
return
}
l := ctx.GetText()
v := e.encodeVariable(l[1:])
e.push(v)
}
func (e *ExpressionEncoder) ExitComparator_symbol(ctx *generated.Comparator_symbolContext) {
if e.err != nil { // bail out
return
}
var f sexpr
switch ctx.GetStart().GetTokenType() {
case generated.DynamoDbGrammarParserEQ:
f = e.encodeFunctionCode(opEqual)
case generated.DynamoDbGrammarParserNE:
f = e.encodeFunctionCode(opNotEqual)
case generated.DynamoDbGrammarParserLT:
f = e.encodeFunctionCode(opLessThan)
case generated.DynamoDbGrammarParserLE:
f = e.encodeFunctionCode(opLessEqual)
case generated.DynamoDbGrammarParserGT:
f = e.encodeFunctionCode(opGreaterThan)
case generated.DynamoDbGrammarParserGE:
f = e.encodeFunctionCode(opGreaterEqual)
default:
// parser will detect syntax error
}
e.push(f)
}
func (e *ExpressionEncoder) EnterComparator(ctx *generated.ComparatorContext) {
if e.err != nil { // bail out
return
}
e.nestingLevel++
}
func (e *ExpressionEncoder) ExitComparator(ctx *generated.ComparatorContext) {
if e.err != nil { // bail out
return
}
e.nestingLevel--
a2 := e.pop()
f := e.pop()
a1 := e.pop()
e.push(e.encodeArray([]sexpr{f, a1, a2}))
}
func (e *ExpressionEncoder) ExitAnd(ctx *generated.AndContext) {
if e.err != nil { // bail out
return
}
a2 := e.pop()
a1 := e.pop()
e.push(e.encodeFunction(opAnd, []sexpr{a1, a2}))
}
func (e *ExpressionEncoder) ExitOr(ctx *generated.OrContext) {
if e.err != nil { // bail out
return
}
a2 := e.pop()
a1 := e.pop()
e.push(e.encodeFunction(opOr, []sexpr{a1, a2}))
}
func (e *ExpressionEncoder) ExitNegation(ctx *generated.NegationContext) {
if e.err != nil { // bail out
return
}
a := e.pop()
e.push(e.encodeFunction(opNot, []sexpr{a}))
}
func (e *ExpressionEncoder) EnterIn(ctx *generated.InContext) {
if e.err != nil { // bail out
return
}
e.nestingLevel++
}
func (e *ExpressionEncoder) ExitIn(ctx *generated.InContext) {
if e.err != nil { // bail out
return
}
e.nestingLevel--
n := (ctx.GetChildCount() - 3) / 2 // arg + IN + '(' + args*2-1 + ')'
args := make([]sexpr, n)
for n > 0 {
n--
args[n] = e.pop()
}
a := e.pop()
e.push(e.encodeFunction(opIn, []sexpr{a, sexpr{expr: args}}))
}
func (e *ExpressionEncoder) EnterBetween(ctx *generated.BetweenContext) {
if e.err != nil { // bail out
return
}
e.nestingLevel++
}
func (e *ExpressionEncoder) ExitBetween(ctx *generated.BetweenContext) {
if e.err != nil { // bail out
return
}
e.nestingLevel--
a3 := e.pop()
a2 := e.pop()
a1 := e.pop()
e.push(e.encodeFunction(opBetween, []sexpr{a1, a2, a3}))
}
func (e *ExpressionEncoder) EnterFunctionCall(ctx *generated.FunctionCallContext) {
if e.err != nil { // bail out
return
}
fname := strings.ToLower(ctx.ID().GetText())
eName := exprTypeString(e.exprType)
switch e.exprType {
case UpdateExpr:
if !e.funcAllowedInUpdate(fname) {
e.err = newInvalidParameterError(fmt.Sprintf("Invalid %sExpression: The function is not allowed in a %s expression", eName, strings.ToLower(eName)))
return
}
if e.nestingLevel > 0 && fname != "if_not_exists" {
e.err = newInvalidParameterError(fmt.Sprintf("Only if_not_exists() function can be nested"))
return
}
case FilterExpr, ConditionExpr, KeyConditionExpr:
if !e.funcAllowedInFilter(fname) {
e.err = newInvalidParameterError(fmt.Sprintf("Invalid %sExpression: The function is not allowed in a %s expression", eName, strings.ToLower(eName)))
return
}
if e.nestingLevel == 0 && fname == "size" {
e.err = newInvalidParameterError(fmt.Sprintf("Invalid %sExpression: The function is not allowed in a %s expression", eName, strings.ToLower(eName)))
return
}
if e.nestingLevel > 0 && fname != "size" {
e.err = newInvalidParameterError(fmt.Sprintf("Only size() function is allowed to be nested"))
}
default:
}
e.nestingLevel++
}
func (e *ExpressionEncoder) ExitFunctionCall(ctx *generated.FunctionCallContext) {
if e.err != nil { // bail out
return
}
fname := strings.ToLower(ctx.ID().GetText())
var fn int
switch fname {
case "attribute_exists":
fn = opAttributeExists
break
case "attribute_not_exists":
fn = opAttributeNotExists
break
case "attribute_type":
fn = opAttributeType
break
case "begins_with":
fn = opBeginsWith
break
case "contains":
fn = opContains
break
case "size":
fn = opSize
break
case "if_not_exists":
fn = opIfNotExists
break
case "list_append":
fn = opListAppend
break
default:
// parser will detect syntax error
}
n := (ctx.GetChildCount() - 2) / 2 // children = fname + '(' + numOperands*2-1 + ')'
args := make([]sexpr, n)
for n > 0 {
n--
args[n] = e.pop()
}
e.push(e.encodeFunction(fn, args))
e.nestingLevel--
}
func (e *ExpressionEncoder) ExitSet_action(ctx *generated.Set_actionContext) {
if e.err != nil { // bail out
return
}
a := e.pop()
p := e.pop()
e.push(e.encodeFunction(opSetAction, []sexpr{p, a}))
}
func (e *ExpressionEncoder) ExitRemove_action(ctx *generated.Remove_actionContext) {
if e.err != nil { // bail out
return
}
p := e.pop()
e.push(e.encodeFunction(opRemoveAction, []sexpr{p}))
}
func (e *ExpressionEncoder) ExitAdd_action(ctx *generated.Add_actionContext) {
if e.err != nil { // bail out
return
}
v := e.pop()
p := e.pop()
e.push(e.encodeFunction(opAddAction, []sexpr{p, v}))
}
func (e *ExpressionEncoder) ExitDelete_action(ctx *generated.Delete_actionContext) {
if e.err != nil { // bail out
return
}
v := e.pop()
p := e.pop()
e.push(e.encodeFunction(opDeleteAction, []sexpr{p, v}))
}
func (e *ExpressionEncoder) EnterPlusMinus(ctx *generated.PlusMinusContext) {
if e.err != nil { // bail out
return
}
e.nestingLevel++
}
func (e *ExpressionEncoder) ExitPlusMinus(ctx *generated.PlusMinusContext) {
if e.err != nil { // bail out
return
}
e.nestingLevel--
a2 := e.pop()
a1 := e.pop()
var fn int
pt := ctx.GetChild(1).(antlr.ParseTree) // leak antlr construct
switch pt.GetText() {
case "+":
fn = opPlus
case "-":
fn = opMinus
default:
e.err = newInvalidParameterError(fmt.Sprintf("Must be +/-"))
return
}
e.push(e.encodeFunction(fn, []sexpr{a1, a2}))
}
func (e *ExpressionEncoder) ExitUpdate(ctx *generated.UpdateContext) {
if e.err != nil { // bail out
return
}
sz := len(e.stack)
us := make([]sexpr, sz)
for sz > 0 {
u := e.pop()
sz--
us[sz] = u
}
e.push(e.encodeArray(us))
}
func (e *ExpressionEncoder) encodeFunction(code int, args []sexpr) sexpr {
b := make([]sexpr, len(args)+1)
b[0] = e.encodeFunctionCode(code)
for i, a := range args {
b[i+1] = a
}
return sexpr{expr: b}
}
func (e *ExpressionEncoder) encodeDocumentPathElement(s string) sexpr {
e.cborWriter.WriteString(s)
return sexpr{atom: e.bytes()}
}
func (e *ExpressionEncoder) encodeListAccess(o int64) sexpr {
e.cborWriter.WriteTag(tagDocumentPathOrdinal)
e.cborWriter.WriteInt64(o)
return sexpr{atom: e.bytes()}
}
func (e *ExpressionEncoder) encodeVariable(l string) sexpr {
n := variablePrefix + l
v, ok := e.variables[n]
if !ok {
e.err = newInvalidParameterError(fmt.Sprintf("Invalid %sExpression: An expression attribute value used in expression is not defined: attribute value %s", exprTypeString(e.exprType), n))
return sexpr{}
}
e.unusedVariables.remove(n)
id, ok := e.variableIdByName[n]
if !ok {
id = len(e.variableValues)
e.variableIdByName[n] = id
e.variableValues = append(e.variableValues, v)
}
return e.encodeFunction(opVariable, []sexpr{e.encodeId(id)})
}
func (e *ExpressionEncoder) encodeId(id int) sexpr {
e.cborWriter.WriteInt(id)
return sexpr{atom: e.bytes()}
}
func (e *ExpressionEncoder) encodeFunctionCode(c int) sexpr {
e.cborWriter.WriteInt(c)
return sexpr{atom: e.bytes()}
}
func (e *ExpressionEncoder) encodeArray(args []sexpr) sexpr {
b := make([]sexpr, len(args))
for i, a := range args {
b[i] = a
}
return sexpr{expr: b}
}
func (e *ExpressionEncoder) bytes() []byte {
e.cborWriter.Flush()
l := e.buf.Len()
b := make([]byte, l)
e.buf.Read(b)
e.buf.Reset()
return b
}
func (e *ExpressionEncoder) funcAllowedInUpdate(fname string) bool {
switch fname {
case "attribute_exists", "attribute_not_exists", "attribute_type", "begins_with", "contains", "size":
return false
default:
return true
}
}
func (e *ExpressionEncoder) funcAllowedInFilter(fname string) bool {
switch fname {
case "if_not_exists", "list_append":
return false
default:
return true
}
}
func (e *ExpressionEncoder) push(v sexpr) {
e.stack = append(e.stack, v)
}
func (e *ExpressionEncoder) pop() sexpr {
// stack cannot be empty as it will be identified as syntax error by parser
v := e.stack[len(e.stack)-1]
e.stack = e.stack[:len(e.stack)-1]
return v
}
func newInvalidParameterError(msg string) error {
return fmt.Errorf("invalid parameter: %s", msg)
}
type sexpr struct {
atom []byte
expr []sexpr
}
type stringSet map[string]struct{}
func (s stringSet) addKeysStrVal(m map[string]string) {
for k, _ := range m {
s[k] = struct{}{}
}
}
func (s stringSet) addKeysAttrVal(m map[string]types.AttributeValue) {
for k, _ := range m {
s[k] = struct{}{}
}
}
func (s stringSet) remove(v string) {
delete(s, v)
}
func (s stringSet) String() string {
out := ""
first := true
for k, _ := range s {
if first {
first = false
} else {
out = out + ", "
}
out = out + k
}
return out
}
const (
// Comparison operators
opEqual = iota
opNotEqual
opLessThan
opGreaterEqual
opGreaterThan
opLessEqual
// Logical operators
opAnd
opOr
opNot
// Range operators
opBetween
// Enumeration operators
opIn
// Functions
opAttributeExists
opAttributeNotExists
opAttributeType
opBeginsWith
opContains
opSize
// Document path elements
opVariable // takes 1 argument which is a placeholder for a value
opDocumentPath // takes an array of string/number which form a document path
// Update Actions
opSetAction
opAddAction
opDeleteAction
opRemoveAction
// Update operations
opIfNotExists
opListAppend
opPlus
opMinus
)