document/cbor/decode.go (286 lines of code) (raw):

package cbor import ( "fmt" "math/big" "reflect" "github.com/aws/smithy-go/document" "github.com/aws/smithy-go/document/internal/serde" "github.com/aws/smithy-go/encoding/cbor" ) // decoderOptions is the set of options that can be configured for a Decoder. // // FUTURE(rpc2cbor): document support is currently disabled. This API is // unexported until that changes. type decoderOptions struct{} // decoder is a Smithy document decoder for CBOR-based protocols. // // FUTURE(rpc2cbor): document support is currently disabled. This API is // unexported until that changes. type decoder struct { options decoderOptions } // newDecoder returns a Decoder for deserializing Smithy documents. // // FUTURE(rpc2cbor): document support is currently disabled. This API is // unexported until that changes. func newDecoder(optFns ...func(options *decoderOptions)) *decoder { o := decoderOptions{} for _, fn := range optFns { fn(&o) } return &decoder{ options: o, } } // Decode unmarshals a CBOR Value into the target. func (d *decoder) Decode(v cbor.Value, to interface{}) error { if document.IsNoSerde(to) { return fmt.Errorf("unsupported type: %T", to) } rv := reflect.ValueOf(to) if rv.Kind() != reflect.Ptr || rv.IsNil() || !rv.IsValid() { return &document.InvalidUnmarshalError{Type: reflect.TypeOf(to)} } return d.decode(v, rv, serde.Tag{}) } func (d *decoder) decode(cv cbor.Value, rv reflect.Value, tag serde.Tag) error { if _, ok := cv.(*cbor.Nil); ok { return d.decodeNil(serde.Indirect(rv, true)) } rv = serde.Indirect(rv, false) if err := d.unsupportedType(rv); err != nil { return err } switch v := cv.(type) { case cbor.Uint, cbor.NegInt: return d.decodeInt(v, rv) case cbor.Float64: return d.decodeFloat(float64(v), rv) case cbor.String: return d.decodeString(string(v), rv) case cbor.Bool: return d.decodeBool(bool(v), rv) case cbor.List: return d.decodeList(v, rv) case cbor.Map: return d.decodeMap(v, rv) case *cbor.Tag: return d.decodeTag(v, rv) default: return fmt.Errorf("unsupported cbor document type %T", v) } } func (d *decoder) decodeInt(v cbor.Value, rv reflect.Value) error { switch rv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: i, err := cbor.AsInt64(v) if err != nil { return err } if rv.OverflowInt(i) { return &document.UnmarshalTypeError{ Value: fmt.Sprintf("number overflow, %d", i), Type: rv.Type(), } } rv.SetInt(i) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: u, ok := v.(cbor.Uint) if !ok { return &document.UnmarshalTypeError{Value: "number", Type: rv.Type()} } if rv.OverflowUint(uint64(u)) { return &document.UnmarshalTypeError{ Value: fmt.Sprintf("number overflow, %d", u), Type: rv.Type(), } } rv.SetUint(uint64(u)) default: return &document.UnmarshalTypeError{Value: "number", Type: rv.Type()} } return nil } func (d *decoder) decodeNil(rv reflect.Value) error { if rv.IsValid() && rv.CanSet() { rv.Set(reflect.Zero(rv.Type())) } return nil } func (d *decoder) decodeBool(v bool, rv reflect.Value) error { switch rv.Kind() { case reflect.Bool, reflect.Interface: rv.Set(reflect.ValueOf(v).Convert(rv.Type())) default: return &document.UnmarshalTypeError{Value: "bool", Type: rv.Type()} } return nil } func (d *decoder) decodeFloat(v float64, rv reflect.Value) error { switch rv.Kind() { case reflect.Interface: rv.Set(reflect.ValueOf(v)) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: i, accuracy := big.NewFloat(v).Int64() if accuracy != big.Exact || rv.OverflowInt(i) { return &document.UnmarshalTypeError{ Value: fmt.Sprintf("int overflow, %e", v), Type: rv.Type(), } } rv.SetInt(i) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: u, accuracy := big.NewFloat(v).Uint64() if accuracy != big.Exact || rv.OverflowUint(u) { return &document.UnmarshalTypeError{ Value: fmt.Sprintf("uint overflow, %e", v), Type: rv.Type(), } } rv.SetUint(u) case reflect.Float32, reflect.Float64: if rv.OverflowFloat(v) { return &document.UnmarshalTypeError{ Value: fmt.Sprintf("float overflow, %e", v), Type: rv.Type(), } } rv.SetFloat(v) default: return &document.UnmarshalTypeError{Value: "number", Type: rv.Type()} } return nil } func (d *decoder) decodeList(v cbor.List, rv reflect.Value) error { var isArray bool switch rv.Kind() { case reflect.Slice: // Make room for the slice elements if needed if rv.IsNil() || rv.Cap() < len(v) { rv.Set(reflect.MakeSlice(rv.Type(), 0, len(v))) } case reflect.Array: // Limited to capacity of existing array. isArray = true case reflect.Interface: s := make([]interface{}, len(v)) for i, av := range v { if err := d.decode(av, reflect.ValueOf(&s[i]).Elem(), serde.Tag{}); err != nil { return err } } rv.Set(reflect.ValueOf(s)) return nil default: return &document.UnmarshalTypeError{Value: "list", Type: rv.Type()} } // If rv is not a slice, array for i := 0; i < rv.Cap() && i < len(v); i++ { if !isArray { rv.SetLen(i + 1) } if err := d.decode(v[i], rv.Index(i), serde.Tag{}); err != nil { return err } } return nil } func (d *decoder) decodeString(v string, rv reflect.Value) error { switch rv.Kind() { case reflect.String: rv.SetString(v) case reflect.Interface: rv.Set(reflect.ValueOf(v).Convert(rv.Type())) default: return &document.UnmarshalTypeError{Value: "string", Type: rv.Type()} } return nil } func (d *decoder) decodeMap(tv cbor.Map, rv reflect.Value) error { switch rv.Kind() { case reflect.Map: t := rv.Type() if t.Key().Kind() != reflect.String { return &document.UnmarshalTypeError{Value: "map string key", Type: t.Key()} } if rv.IsNil() { rv.Set(reflect.MakeMap(t)) } case reflect.Struct: if rv.CanInterface() && document.IsNoSerde(rv.Interface()) { return &document.UnmarshalTypeError{ Value: fmt.Sprintf("unsupported type"), Type: rv.Type(), } } case reflect.Interface: rv.Set(reflect.MakeMap(serde.ReflectTypeOf.MapStringToInterface)) rv = rv.Elem() default: return &document.UnmarshalTypeError{Value: "map", Type: rv.Type()} } if rv.Kind() == reflect.Map { for k, kv := range tv { key := reflect.New(rv.Type().Key()).Elem() key.SetString(k) elem := reflect.New(rv.Type().Elem()).Elem() if err := d.decode(kv, elem, serde.Tag{}); err != nil { return err } rv.SetMapIndex(key, elem) } } else if rv.Kind() == reflect.Struct { fields := serde.GetStructFields(rv.Type()) for k, kv := range tv { if f, ok := fields.FieldByName(k); ok { fv := serde.DecoderFieldByIndex(rv, f.Index) if err := d.decode(kv, fv, f.Tag); err != nil { return err } } } } return nil } func (d *decoder) decodeTag(tv *cbor.Tag, rv reflect.Value) error { rvt := rv.Type() switch { case rvt.ConvertibleTo(serde.ReflectTypeOf.BigInt): i, err := cbor.AsBigInt(tv) if err != nil { return &document.UnmarshalTypeError{Value: "tag", Type: rv.Type()} } rv.Set(reflect.ValueOf(*i).Convert(rvt)) return nil case rvt.ConvertibleTo(serde.ReflectTypeOf.BigFloat): i, err := asBigFloat(tv) if err != nil { return &document.UnmarshalTypeError{Value: "tag", Type: rv.Type()} } rv.Set(reflect.ValueOf(*i).Convert(rvt)) return nil default: return &document.UnmarshalTypeError{Value: "tag", Type: rv.Type()} } } func (d *decoder) unsupportedType(rv reflect.Value) error { if rv.Kind() == reflect.Interface && rv.NumMethod() != 0 { return &document.UnmarshalTypeError{Value: "non-empty interface", Type: rv.Type()} } if rv.Type().ConvertibleTo(serde.ReflectTypeOf.Time) { return &document.UnmarshalTypeError{ Type: rv.Type(), Value: fmt.Sprintf("time value"), } } return nil } func asBigFloat(tv *cbor.Tag) (*big.Float, error) { const tagbase10 = 4 if tv.ID != tagbase10 { return nil, fmt.Errorf("invalid tag: %d", tv.ID) } pcs, ok := tv.Value.(cbor.List) if !ok { return nil, fmt.Errorf("invalid tagged type: %T", tv.Value) } if len(pcs) != 2 { return nil, fmt.Errorf("invalid tagged list len: %d", len(pcs)) } eval, mval := pcs[0], pcs[1] exp, err := cbor.AsBigInt(eval) if err != nil { return nil, fmt.Errorf("invalid exp: %w", err) } mant, err := cbor.AsBigInt(mval) if !ok { return nil, fmt.Errorf("invalid mant: %w", err) } // We literally re-express this as <mant>e<exp> and send it through // bigfloat parse. Not mathematically amazing, but ensures that // string-borne bignums and this are computed identically. str := fmt.Sprintf("%se%s", mant.String(), exp.String()) x, _, err := new(big.Float).Parse(str, 0) return x, err }