testing/cbor.go (137 lines of code) (raw):

package testing import ( "encoding/base64" "fmt" "io" "math" "reflect" "github.com/aws/smithy-go/encoding/cbor" ) // CompareCBOR checks whether two CBOR values are equivalent. // // The function signature is tailored for use in smithy protocol tests, where // the expected encoding is given in base64, and the actual value to check is // passed from the mock HTTP request body. func CompareCBOR(actual io.Reader, expect64 string) error { ap, err := io.ReadAll(actual) if err != nil { return fmt.Errorf("read actual: %w", err) } av, err := cbor.Decode(ap) if err != nil { return fmt.Errorf("decode actual: %w", err) } ep, err := base64.StdEncoding.DecodeString(expect64) if err != nil { return fmt.Errorf("decode expect64: %w", err) } ev, err := cbor.Decode(ep) if err != nil { return fmt.Errorf("decode expect: %w", err) } return cmpCBOR(ev, av, "<root>") } func cmpCBOR(e, a cbor.Value, path string) error { switch v := e.(type) { case cbor.Uint, cbor.NegInt, cbor.Slice, cbor.String, cbor.Bool, *cbor.Nil, *cbor.Undefined: if !reflect.DeepEqual(e, a) { return fmt.Errorf("%s: %v != %v", path, e, a) } return nil case cbor.List: return cmpList(v, a, path) case cbor.Map: return cmpMap(v, a, path) case *cbor.Tag: return cmpTag(v, a, path) case cbor.Float32: return cmpF32(v, a, path) case cbor.Float64: return cmpF64(v, a, path) default: return fmt.Errorf("%s: unrecognized variant %T", path, e) } } func cmpList(e cbor.List, a cbor.Value, path string) error { av, ok := a.(cbor.List) if !ok { return fmt.Errorf("%s: %T != %T", path, e, a) } if len(e) != len(av) { return fmt.Errorf("%s: length %d != %d", path, len(e), len(av)) } for i := 0; i < len(e); i++ { ipath := fmt.Sprintf("%s[%d]", path, i) if err := cmpCBOR(e[i], av[i], ipath); err != nil { return err } } return nil } func cmpMap(e cbor.Map, a cbor.Value, path string) error { av, ok := a.(cbor.Map) if !ok { return fmt.Errorf("%s: %T != %T", path, e, a) } if len(e) != len(av) { return fmt.Errorf("%s: length %d != %d", path, len(e), len(av)) } for k, ev := range e { avv, ok := av[k] if !ok { return fmt.Errorf("%s: missing key %s", path, k) } kpath := fmt.Sprintf("%s[%q]", path, k) if err := cmpCBOR(ev, avv, kpath); err != nil { return err } } return nil } func cmpTag(e *cbor.Tag, a cbor.Value, path string) error { av, ok := a.(*cbor.Tag) if !ok { return fmt.Errorf("%s: %T != %T", path, e, a) } if e.ID != av.ID { return fmt.Errorf("%s: tag ID %d != %d", path, e.ID, av.ID) } return cmpCBOR(e.Value, av.Value, path) } func cmpF32(e cbor.Float32, a cbor.Value, path string) error { av, ok := a.(cbor.Float32) if !ok { return fmt.Errorf("%s: %T != %T", path, e, a) } ebits, abits := math.Float32bits(float32(e)), math.Float32bits(float32(av)) if enan, anan := isNaN32(ebits), isNaN32(abits); enan || anan { if enan != anan { return fmt.Errorf("%s: NaN: float32(%x) != float32(%x)", path, ebits, abits) } return nil } if ebits != abits { return fmt.Errorf("%s: float32(%x) != float32(%x)", path, ebits, abits) } return nil } func cmpF64(e cbor.Float64, a cbor.Value, path string) error { av, ok := a.(cbor.Float64) if !ok { return fmt.Errorf("%s: %T != %T", path, e, a) } ebits, abits := math.Float64bits(float64(e)), math.Float64bits(float64(av)) if enan, anan := isNaN64(ebits), isNaN64(abits); enan || anan { if enan != anan { return fmt.Errorf("%s: NaN: float64(%x) != float64(%x)", path, ebits, abits) } return nil } if math.Float64bits(float64(e)) != math.Float64bits(float64(av)) { return fmt.Errorf("%s: float64(%x) != float64(%x)", path, ebits, abits) } return nil } func isNaN32(f uint32) bool { const infmask = 0x7f800000 return f&infmask == infmask && f != infmask && f != (1<<31)|infmask } func isNaN64(f uint64) bool { const infmask = 0x7ff00000_00000000 return f&infmask == infmask && f != infmask && f != (1<<63)|infmask }