internal/encoding/encode.go (495 lines of code) (raw):
package encoding
import (
"encoding/binary"
"errors"
"fmt"
"math"
"time"
"unicode/utf8"
"github.com/Azure/go-amqp/internal/buffer"
)
type marshaler interface {
Marshal(*buffer.Buffer) error
}
func Marshal(wr *buffer.Buffer, i any) error {
switch t := i.(type) {
case nil:
wr.AppendByte(byte(TypeCodeNull))
case bool:
if t {
wr.AppendByte(byte(TypeCodeBoolTrue))
} else {
wr.AppendByte(byte(TypeCodeBoolFalse))
}
case *bool:
if *t {
wr.AppendByte(byte(TypeCodeBoolTrue))
} else {
wr.AppendByte(byte(TypeCodeBoolFalse))
}
case uint:
writeUint64(wr, uint64(t))
case *uint:
writeUint64(wr, uint64(*t))
case uint64:
writeUint64(wr, t)
case *uint64:
writeUint64(wr, *t)
case uint32:
writeUint32(wr, t)
case *uint32:
writeUint32(wr, *t)
case uint16:
wr.AppendByte(byte(TypeCodeUshort))
wr.AppendUint16(t)
case *uint16:
wr.AppendByte(byte(TypeCodeUshort))
wr.AppendUint16(*t)
case uint8:
wr.Append([]byte{
byte(TypeCodeUbyte),
t,
})
case *uint8:
wr.Append([]byte{
byte(TypeCodeUbyte),
*t,
})
case int:
writeInt64(wr, int64(t))
case *int:
writeInt64(wr, int64(*t))
case int8:
wr.Append([]byte{
byte(TypeCodeByte),
uint8(t),
})
case *int8:
wr.Append([]byte{
byte(TypeCodeByte),
uint8(*t),
})
case int16:
wr.AppendByte(byte(TypeCodeShort))
wr.AppendUint16(uint16(t))
case *int16:
wr.AppendByte(byte(TypeCodeShort))
wr.AppendUint16(uint16(*t))
case int32:
writeInt32(wr, t)
case *int32:
writeInt32(wr, *t)
case int64:
writeInt64(wr, t)
case *int64:
writeInt64(wr, *t)
case float32:
writeFloat(wr, t)
case *float32:
writeFloat(wr, *t)
case float64:
writeDouble(wr, t)
case *float64:
writeDouble(wr, *t)
case string:
return writeString(wr, t)
case *string:
return writeString(wr, *t)
case []byte:
return WriteBinary(wr, t)
case *[]byte:
return WriteBinary(wr, *t)
case map[any]any:
return writeMap(wr, t)
case *map[any]any:
return writeMap(wr, *t)
case map[string]any:
return writeMap(wr, t)
case *map[string]any:
return writeMap(wr, *t)
case map[Symbol]any:
return writeMap(wr, t)
case *map[Symbol]any:
return writeMap(wr, *t)
case Unsettled:
return writeMap(wr, t)
case *Unsettled:
return writeMap(wr, *t)
case time.Time:
writeTimestamp(wr, t)
case *time.Time:
writeTimestamp(wr, *t)
case []int8:
return arrayInt8(t).Marshal(wr)
case *[]int8:
return arrayInt8(*t).Marshal(wr)
case []uint16:
return arrayUint16(t).Marshal(wr)
case *[]uint16:
return arrayUint16(*t).Marshal(wr)
case []int16:
return arrayInt16(t).Marshal(wr)
case *[]int16:
return arrayInt16(*t).Marshal(wr)
case []uint32:
return arrayUint32(t).Marshal(wr)
case *[]uint32:
return arrayUint32(*t).Marshal(wr)
case []int32:
return arrayInt32(t).Marshal(wr)
case *[]int32:
return arrayInt32(*t).Marshal(wr)
case []uint64:
return arrayUint64(t).Marshal(wr)
case *[]uint64:
return arrayUint64(*t).Marshal(wr)
case []int64:
return arrayInt64(t).Marshal(wr)
case *[]int64:
return arrayInt64(*t).Marshal(wr)
case []float32:
return arrayFloat(t).Marshal(wr)
case *[]float32:
return arrayFloat(*t).Marshal(wr)
case []float64:
return arrayDouble(t).Marshal(wr)
case *[]float64:
return arrayDouble(*t).Marshal(wr)
case []bool:
return arrayBool(t).Marshal(wr)
case *[]bool:
return arrayBool(*t).Marshal(wr)
case []string:
return arrayString(t).Marshal(wr)
case *[]string:
return arrayString(*t).Marshal(wr)
case []Symbol:
return arraySymbol(t).Marshal(wr)
case *[]Symbol:
return arraySymbol(*t).Marshal(wr)
case [][]byte:
return arrayBinary(t).Marshal(wr)
case *[][]byte:
return arrayBinary(*t).Marshal(wr)
case []time.Time:
return arrayTimestamp(t).Marshal(wr)
case *[]time.Time:
return arrayTimestamp(*t).Marshal(wr)
case []UUID:
return arrayUUID(t).Marshal(wr)
case *[]UUID:
return arrayUUID(*t).Marshal(wr)
case []any:
return list(t).Marshal(wr)
case *[]any:
return list(*t).Marshal(wr)
case marshaler:
return t.Marshal(wr)
default:
return fmt.Errorf("marshal not implemented for %T", i)
}
return nil
}
func writeInt32(wr *buffer.Buffer, n int32) {
if n < 128 && n >= -128 {
wr.Append([]byte{
byte(TypeCodeSmallint),
byte(n),
})
return
}
wr.AppendByte(byte(TypeCodeInt))
wr.AppendUint32(uint32(n))
}
func writeInt64(wr *buffer.Buffer, n int64) {
if n < 128 && n >= -128 {
wr.Append([]byte{
byte(TypeCodeSmalllong),
byte(n),
})
return
}
wr.AppendByte(byte(TypeCodeLong))
wr.AppendUint64(uint64(n))
}
func writeUint32(wr *buffer.Buffer, n uint32) {
if n == 0 {
wr.AppendByte(byte(TypeCodeUint0))
return
}
if n < 256 {
wr.Append([]byte{
byte(TypeCodeSmallUint),
byte(n),
})
return
}
wr.AppendByte(byte(TypeCodeUint))
wr.AppendUint32(n)
}
func writeUint64(wr *buffer.Buffer, n uint64) {
if n == 0 {
wr.AppendByte(byte(TypeCodeUlong0))
return
}
if n < 256 {
wr.Append([]byte{
byte(TypeCodeSmallUlong),
byte(n),
})
return
}
wr.AppendByte(byte(TypeCodeUlong))
wr.AppendUint64(n)
}
func writeFloat(wr *buffer.Buffer, f float32) {
wr.AppendByte(byte(TypeCodeFloat))
wr.AppendUint32(math.Float32bits(f))
}
func writeDouble(wr *buffer.Buffer, f float64) {
wr.AppendByte(byte(TypeCodeDouble))
wr.AppendUint64(math.Float64bits(f))
}
func writeTimestamp(wr *buffer.Buffer, t time.Time) {
wr.AppendByte(byte(TypeCodeTimestamp))
ms := t.UnixMilli()
wr.AppendUint64(uint64(ms))
}
// marshalField is a field to be marshaled
type MarshalField struct {
Value any // value to be marshaled, use pointers to avoid interface conversion overhead
Omit bool // indicates that this field should be omitted (set to null)
}
// marshalComposite is a helper for us in a composite's marshal() function.
//
// The returned bytes include the composite header and fields. Fields with
// omit set to true will be encoded as null or omitted altogether if there are
// no non-null fields after them.
func MarshalComposite(wr *buffer.Buffer, code AMQPType, fields []MarshalField) error {
// lastSetIdx is the last index to have a non-omitted field.
// start at -1 as it's possible to have no fields in a composite
lastSetIdx := -1
// marshal each field into it's index in rawFields,
// null fields are skipped, leaving the index nil.
for i, f := range fields {
if f.Omit {
continue
}
lastSetIdx = i
}
// write header only
if lastSetIdx == -1 {
wr.Append([]byte{
0x0,
byte(TypeCodeSmallUlong),
byte(code),
byte(TypeCodeList0),
})
return nil
}
// write header
WriteDescriptor(wr, code)
// write fields
wr.AppendByte(byte(TypeCodeList32))
// write temp size, replace later
sizeIdx := wr.Len()
wr.Append([]byte{0, 0, 0, 0})
preFieldLen := wr.Len()
// field count
wr.AppendUint32(uint32(lastSetIdx + 1))
// write null to each index up to lastSetIdx
for _, f := range fields[:lastSetIdx+1] {
if f.Omit {
wr.AppendByte(byte(TypeCodeNull))
continue
}
err := Marshal(wr, f.Value)
if err != nil {
return err
}
}
// fix size
size := uint32(wr.Len() - preFieldLen)
buf := wr.Bytes()
binary.BigEndian.PutUint32(buf[sizeIdx:], size)
return nil
}
func WriteDescriptor(wr *buffer.Buffer, code AMQPType) {
wr.Append([]byte{
0x0,
byte(TypeCodeSmallUlong),
byte(code),
})
}
func writeString(wr *buffer.Buffer, str string) error {
if !utf8.ValidString(str) {
return errors.New("not a valid UTF-8 string")
}
l := len(str)
switch {
// Str8
case l < 256:
wr.Append([]byte{
byte(TypeCodeStr8),
byte(l),
})
wr.AppendString(str)
return nil
// Str32
case uint(l) < math.MaxUint32:
wr.AppendByte(byte(TypeCodeStr32))
wr.AppendUint32(uint32(l))
wr.AppendString(str)
return nil
default:
return errors.New("too long")
}
}
func WriteBinary(wr *buffer.Buffer, bin []byte) error {
l := len(bin)
switch {
// List8
case l < 256:
wr.Append([]byte{
byte(TypeCodeVbin8),
byte(l),
})
wr.Append(bin)
return nil
// List32
case uint(l) < math.MaxUint32:
wr.AppendByte(byte(TypeCodeVbin32))
wr.AppendUint32(uint32(l))
wr.Append(bin)
return nil
default:
return errors.New("too long")
}
}
func writeMap(wr *buffer.Buffer, m any) error {
startIdx := wr.Len()
wr.Append([]byte{
byte(TypeCodeMap32), // type
0, 0, 0, 0, // size placeholder
0, 0, 0, 0, // length placeholder
})
var pairs int
switch m := m.(type) {
case map[any]any:
pairs = len(m) * 2
for key, val := range m {
err := Marshal(wr, key)
if err != nil {
return err
}
err = Marshal(wr, val)
if err != nil {
return err
}
}
case map[string]any:
pairs = len(m) * 2
for key, val := range m {
err := writeString(wr, key)
if err != nil {
return err
}
err = Marshal(wr, val)
if err != nil {
return err
}
}
case map[Symbol]any:
pairs = len(m) * 2
for key, val := range m {
err := key.Marshal(wr)
if err != nil {
return err
}
err = Marshal(wr, val)
if err != nil {
return err
}
}
case Unsettled:
pairs = len(m) * 2
for key, val := range m {
err := writeString(wr, key)
if err != nil {
return err
}
err = Marshal(wr, val)
if err != nil {
return err
}
}
case Filter:
pairs = len(m) * 2
for key, val := range m {
err := key.Marshal(wr)
if err != nil {
return err
}
err = val.Marshal(wr)
if err != nil {
return err
}
}
case Annotations:
pairs = len(m) * 2
for key, val := range m {
switch key := key.(type) {
case string:
err := Symbol(key).Marshal(wr)
if err != nil {
return err
}
case Symbol:
err := key.Marshal(wr)
if err != nil {
return err
}
case int64:
writeInt64(wr, key)
case int:
writeInt64(wr, int64(key))
default:
return fmt.Errorf("unsupported Annotations key type %T", key)
}
err := Marshal(wr, val)
if err != nil {
return err
}
}
default:
return fmt.Errorf("unsupported map type %T", m)
}
if uint(pairs) > math.MaxUint32-4 {
return errors.New("map contains too many elements")
}
// overwrite placeholder size and length
bytes := wr.Bytes()[startIdx+1 : startIdx+9]
_ = bytes[7] // bounds check hint
length := wr.Len() - startIdx - 1 - 4 // -1 for type, -4 for length
binary.BigEndian.PutUint32(bytes[:4], uint32(length))
binary.BigEndian.PutUint32(bytes[4:8], uint32(pairs))
return nil
}
// type length sizes
const (
array8TLSize = 2
array32TLSize = 5
)
func writeArrayHeader(wr *buffer.Buffer, length, typeSize int, type_ AMQPType) {
size := length * typeSize
// array type
if size+array8TLSize <= math.MaxUint8 {
wr.Append([]byte{
byte(TypeCodeArray8), // type
byte(size + array8TLSize), // size
byte(length), // length
byte(type_), // element type
})
} else {
wr.AppendByte(byte(TypeCodeArray32)) //type
wr.AppendUint32(uint32(size + array32TLSize)) // size
wr.AppendUint32(uint32(length)) // length
wr.AppendByte(byte(type_)) // element type
}
}
func writeVariableArrayHeader(wr *buffer.Buffer, length, elementsSizeTotal int, type_ AMQPType) {
// 0xA_ == 1, 0xB_ == 4
// http://docs.oasis-open.org/amqp/core/v1.0/os/amqp-core-types-v1.0-os.html#doc-idp82960
elementTypeSize := 1
if type_&0xf0 == 0xb0 {
elementTypeSize = 4
}
size := elementsSizeTotal + (length * elementTypeSize) // size excluding array length
if size+array8TLSize <= math.MaxUint8 {
wr.Append([]byte{
byte(TypeCodeArray8), // type
byte(size + array8TLSize), // size
byte(length), // length
byte(type_), // element type
})
} else {
wr.AppendByte(byte(TypeCodeArray32)) // type
wr.AppendUint32(uint32(size + array32TLSize)) // size
wr.AppendUint32(uint32(length)) // length
wr.AppendByte(byte(type_)) // element type
}
}