dax/internal/cbor/attrval.go (272 lines of code) (raw):

/* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at http://www.apache.org/licenses/LICENSE-2.0 or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package cbor import ( "errors" "fmt" "math/big" "strconv" "strings" "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/aws/smithy-go" ) const ( tagStringSet = 3321 + iota tagNumberSet tagBinarySet tagDocumentPathOrdinal ) func EncodeAttributeValue(value types.AttributeValue, writer *Writer) error { if value == nil { return &smithy.SerializationError{Err: errors.New("invalid attribute value: nil")} } var err error switch v := value.(type) { case *types.AttributeValueMemberS: err = writer.WriteString(v.Value) case *types.AttributeValueMemberN: err = writeStringNumber(v.Value, writer) case *types.AttributeValueMemberB: err = writer.WriteBytes(v.Value) case *types.AttributeValueMemberSS: if len(v.Value) == 0 { return &smithy.SerializationError{Err: errors.New("invalid string set: nil or empty")} } if err = writer.writeType(Tag, tagStringSet); err != nil { return err } if err = writer.WriteArrayHeader(len(v.Value)); err != nil { return err } for _, sp := range v.Value { if err := writer.WriteString(sp); err != nil { return err } } case *types.AttributeValueMemberNS: if len(v.Value) == 0 { return &smithy.SerializationError{Err: errors.New("invalid number set: nil or empty")} } if err = writer.writeType(Tag, tagNumberSet); err != nil { return err } if err = writer.WriteArrayHeader(len(v.Value)); err != nil { return err } for _, sp := range v.Value { if err := writeStringNumber(sp, writer); err != nil { return err } } case *types.AttributeValueMemberBS: if len(v.Value) == 0 { return &smithy.SerializationError{Err: errors.New("invalid binary set: nil or empty")} } if err = writer.writeType(Tag, tagBinarySet); err != nil { return err } if err = writer.WriteArrayHeader(len(v.Value)); err != nil { return err } for _, bp := range v.Value { if err := writer.WriteBytes(bp); err != nil { return err } } case *types.AttributeValueMemberL: if err = writer.WriteArrayHeader(len(v.Value)); err != nil { return err } for _, v := range v.Value { if err := EncodeAttributeValue(v, writer); err != nil { return err } } case *types.AttributeValueMemberM: if err = writer.WriteMapHeader(len(v.Value)); err != nil { return err } for k, v := range v.Value { if err := writer.WriteString(k); err != nil { return err } if err = EncodeAttributeValue(v, writer); err != nil { return err } } case *types.AttributeValueMemberBOOL: err = writer.WriteBoolean(v.Value) case *types.AttributeValueMemberNULL: if !v.Value { return &smithy.SerializationError{Err: errors.New("invalid null attribute value")} } err = writer.WriteNull() } return err } func writeStringNumber(val string, writer *Writer) error { if strings.IndexAny(val, ".eE") >= 0 { dec := new(Decimal) if _, ok := dec.SetString(val); !ok { return &smithy.SerializationError{Err: fmt.Errorf("invalid number %v", val)} } err := writer.WriteDecimal(dec) return err } if len(val) > 18 { bint := new(big.Int) bint.SetString(val, 10) err := writer.WriteBigInt(bint) return err } i, err := strconv.ParseInt(val, 10, 64) if err != nil { return &smithy.SerializationError{Err: fmt.Errorf("invalid number %v", val)} } err = writer.WriteInt64(i) return err } func DecodeAttributeValue(reader *Reader) (types.AttributeValue, error) { hdr, err := reader.PeekHeader() if err != nil { return nil, err } major := hdr & MajorTypeMask minor := hdr & MinorTypeMask switch major { case Utf: s, err := reader.ReadString() if err != nil { return nil, err } return &types.AttributeValueMemberS{Value: s}, nil case Bytes: b, err := reader.ReadBytes() if err != nil { return nil, err } return &types.AttributeValueMemberB{Value: b}, nil case Array: len, err := reader.ReadArrayLength() if err != nil { return nil, err } as := make([]types.AttributeValue, len) for i := 0; i < len; i++ { a, err := DecodeAttributeValue(reader) if err != nil { return nil, err } as[i] = a } return &types.AttributeValueMemberL{Value: as}, nil case Map: len, err := reader.ReadMapLength() if err != nil { return nil, err } m := make(map[string]types.AttributeValue, len) for i := 0; i < len; i++ { k, err := reader.ReadString() if err != nil { return nil, err } v, err := DecodeAttributeValue(reader) if err != nil { return nil, err } m[k] = v } return &types.AttributeValueMemberM{Value: m}, nil case PosInt, NegInt: s, err := reader.ReadCborIntegerToString() if err != nil { return nil, err } return &types.AttributeValueMemberN{Value: s}, nil case Simple: if _, _, err := reader.readTypeHeader(); err != nil { return nil, err } switch hdr { case False: return &types.AttributeValueMemberBOOL{Value: false}, nil case True: return &types.AttributeValueMemberBOOL{Value: true}, nil case Nil: return &types.AttributeValueMemberNULL{Value: true}, nil default: return nil, &smithy.DeserializationError{Err: fmt.Errorf("unknown minor type %d for simple major type", minor)} } case Tag: switch minor { case TagPosBigInt, TagNegBigInt: i, err := reader.ReadBigInt() if err != nil { return nil, err } return &types.AttributeValueMemberN{Value: i.String()}, nil case TagDecimal: d, err := reader.ReadDecimal() if err != nil { return nil, err } return &types.AttributeValueMemberN{Value: d.String()}, nil default: _, tag, err := reader.readTypeHeader() if err != nil { return nil, err } switch tag { case tagStringSet: len, err := reader.ReadArrayLength() if err != nil { return nil, err } ss := make([]string, len) for i := 0; i < len; i++ { s, err := reader.ReadString() if err != nil { return nil, err } ss[i] = s } return &types.AttributeValueMemberSS{Value: ss}, nil case tagNumberSet: len, err := reader.ReadArrayLength() if err != nil { return nil, err } ss := make([]string, len) for i := 0; i < len; i++ { av, err := DecodeAttributeValue(reader) if err != nil { return nil, err } n, ok := av.(*types.AttributeValueMemberN) if !ok { return nil, &smithy.DeserializationError{Err: fmt.Errorf("attribute type is not number. type: %T", av)} } ss[i] = n.Value } return &types.AttributeValueMemberNS{Value: ss}, nil case tagBinarySet: len, err := reader.ReadArrayLength() if err != nil { return nil, err } bs := make([][]byte, len) for i := 0; i < len; i++ { b, err := reader.ReadBytes() if err != nil { return nil, err } bs[i] = b } return &types.AttributeValueMemberBS{Value: bs}, nil default: return nil, &smithy.DeserializationError{Err: fmt.Errorf("unknown minor type %d or tag %d", minor, tag)} } } default: return nil, &smithy.DeserializationError{Err: fmt.Errorf("unknown major type %d", major)} } }