internal/fake/net_conn.go (391 lines of code) (raw):
package fake
import (
"errors"
"math"
"net"
"time"
"github.com/Azure/go-amqp/internal/buffer"
"github.com/Azure/go-amqp/internal/encoding"
"github.com/Azure/go-amqp/internal/frames"
)
// NetConnOptions contains options when creating a NetConn.
// Pass the zero-value to accept the default values.
type NetConnOptions struct {
// ChunkSize is the size of chunks to split responses into.
// A zero or negative value means no chunking.
// The default value is zero.
ChunkSize int
}
// NewNetConn creates a new instance of NetConn.
// Responder is invoked by Write when a frame is received.
// Return a zero-value Response/nil error to swallow the frame.
// Return a non-nil error to simulate a write error.
// NOTE: resp is called on a separate goroutine so it MUST NOT access any *testing.T etc
func NewNetConn(resp func(remoteChannel uint16, fr frames.FrameBody) (Response, error), opts NetConnOptions) *NetConn {
netConn := &NetConn{
ReadErr: make(chan error),
WriteErr: make(chan error, 1),
opts: opts,
resp: resp,
// during shutdown, connReader can close before connWriter as they both
// both return on c.Done being closed, so there is some non-determinism
// here. this means that sometimes writes can still happen but there's
// no reader to consume them. we used a buffered channel to prevent these
// writes from blocking shutdown. the size was arbitrarily picked.
readData: make(chan []byte, 10),
// used to serialize writes so the frames are returned in their specified order.
// buffering is necessary because write() will sleep when a write delay was
// specified and we don't want to stall Write(). the size was arbitrarily picked.
writeResp: make(chan Response, 10),
close: make(chan struct{}),
readDL: newNopTimer(), // default, no deadline
}
go netConn.write()
return netConn
}
// NetConn is a fake network connection that satisfies the net.Conn interface.
type NetConn struct {
// OnClose is called from Close() before it returns.
// The value returned from OnClose is returned from Close().
OnClose func() error
// ReadErr is used to simulate a connReader error.
// The error written to this channel is returned
// from the call to NetConn.Read.
ReadErr chan error
// WriteErr is used to simulate a connWriter error.
// The error sent here is returned from the call to NetConn.Write.
// Has a buffer of one so setting a pending error won't block.
WriteErr chan error
opts NetConnOptions
resp func(uint16, frames.FrameBody) (Response, error)
readDL readTimer
readData chan []byte
writeResp chan Response
close chan struct{}
closed bool
}
// SendFrame sends the encoded frame to the client.
// Use this to send a frame at an arbitrary time.
func (n *NetConn) SendFrame(f []byte) {
n.readData <- f
}
// SendKeepAlive sends a keep-alive frame to the client.
func (n *NetConn) SendKeepAlive() {
// empty frame
n.readData <- []uint8{0, 0, 0, 8, 2, 0, 0, 0}
}
// SendMultiFrameTransfer splits payload into 32-byte chunks, encodes, and sends to the client.
// Payload must be big enough for at least two chunks.
func (n *NetConn) SendMultiFrameTransfer(channel uint16, linkHandle, deliveryID uint32, payload []byte, edit func(int, *frames.PerformTransfer)) error {
bb, err := encodeMultiFrameTransfer(channel, linkHandle, deliveryID, payload, edit)
if err != nil {
return err
}
for _, b := range bb {
n.readData <- b
}
return nil
}
// Response is the response returned from a responder function.
type Response struct {
// Payload is the marshalled frame to send to Conn.connReader
Payload []byte
// WriteDelay is the duration to wait before writing Payload.
// Use this to introduce a delay when waiting for a response.
WriteDelay time.Duration
// ChunkSize is the size of chunks to split Payload into.
// A zero or negative value means no chunking.
// This value supercedes the NetConnOptions.ChunkSize.
ChunkSize int
}
// ErrAlreadyClosed is returned by Close() if [NetConn] is already closed.
var ErrAlreadyClosed = errors.New("fake already closed")
///////////////////////////////////////////////////////
// following methods are for the net.Conn interface
///////////////////////////////////////////////////////
// NOTE: Read, Write, and Close are all called by separate goroutines!
// Read is invoked by conn.connReader to recieve frame data.
// It blocks until Write or Close are called, or the read
// deadline expires which will return an error.
func (n *NetConn) Read(b []byte) (int, error) {
select {
case <-n.close:
return 0, net.ErrClosed
default:
// not closed yet
}
select {
case <-n.close:
return 0, net.ErrClosed
case <-n.readDL.C():
return 0, errors.New("fake connection read deadline exceeded")
case rd := <-n.readData:
return copy(b, rd), nil
case err := <-n.ReadErr:
return 0, err
}
}
// Write is invoked by conn.connWriter when we're being sent frame
// data. Every call to Write will invoke the responder callback that
// must reply with one of three possibilities.
// 1. an encoded frame and nil error
// 2. a non-nil error to similate a write failure
// 3. a nil slice and nil error indicating the frame should be ignored
func (n *NetConn) Write(b []byte) (int, error) {
select {
case <-n.close:
return 0, net.ErrClosed
default:
// not closed yet
}
select {
case err := <-n.WriteErr:
return 0, err
default:
// no fake write error
}
remoteChannel, frame, err := decodeFrame(b)
if err != nil {
return 0, err
}
resp, err := n.resp(remoteChannel, frame)
if err != nil {
return 0, err
}
if resp.Payload != nil {
select {
case n.writeResp <- resp:
// resp was sent to write()
default:
// this means we incorrectly sized writeResp.
// we do this to ensure that we never stall
// waiting to write to writeResp.
panic("writeResp full")
}
}
return len(b), nil
}
func (n *NetConn) write() {
for {
select {
case <-n.close:
return
case resp := <-n.writeResp:
// any write delay MUST happen outside of NetConn.Write
// else all we do is stall Conn.connWriter() which doesn't
// actually simulate a delayed response to a frame.
time.Sleep(resp.WriteDelay)
if resp.ChunkSize < 1 {
// no chunk size for this response, fall back to options
resp.ChunkSize = n.opts.ChunkSize
}
if resp.ChunkSize < 1 {
// send in one chunk
resp.ChunkSize = len(resp.Payload)
}
remaining := resp.Payload
for {
if l := len(remaining); l < resp.ChunkSize {
resp.ChunkSize = l
}
chunk := remaining[:resp.ChunkSize]
n.readData <- chunk
remaining = remaining[resp.ChunkSize:]
if len(remaining) == 0 {
break
}
}
}
}
}
// Close is called by conn.close.
func (n *NetConn) Close() error {
if n.closed {
return ErrAlreadyClosed
}
n.closed = true
close(n.close)
if n.OnClose != nil {
return n.OnClose()
}
return nil
}
func (n *NetConn) LocalAddr() net.Addr {
return &net.IPAddr{
IP: net.IPv4(127, 0, 0, 2),
}
}
func (n *NetConn) RemoteAddr() net.Addr {
return &net.IPAddr{
IP: net.IPv4(127, 0, 0, 2),
}
}
func (n *NetConn) SetDeadline(t time.Time) error {
return errors.New("not used")
}
func (n *NetConn) SetReadDeadline(t time.Time) error {
// called by conn.connReader before calling Read
// stop the last timer if available
if n.readDL != nil && !n.readDL.Stop() {
<-n.readDL.C()
}
n.readDL = timer{t: time.NewTimer(time.Until(t))}
return nil
}
func (n *NetConn) SetWriteDeadline(t time.Time) error {
// called by conn.connWriter before calling Write
return nil
}
///////////////////////////////////////////////////////
///////////////////////////////////////////////////////
// ProtoID indicates the type of protocol (copied from conn.go)
type ProtoID uint8
const (
ProtoAMQP ProtoID = 0x0
ProtoTLS ProtoID = 0x2
ProtoSASL ProtoID = 0x3
)
// ProtoHeader adds the initial handshake frame to the list of responses.
// This frame, and PerformOpen, are needed when calling amqp.New() to create a client.
func ProtoHeader(id ProtoID) ([]byte, error) {
return []byte{'A', 'M', 'Q', 'P', byte(id), 1, 0, 0}, nil
}
// PerformOpen appends a PerformOpen frame with the specified container ID.
// This frame, and ProtoHeader, are needed when calling amqp.New() to create a client.
func PerformOpen(containerID string) ([]byte, error) {
// send the default values for max channels and frame size
return EncodeFrame(frames.TypeAMQP, 0, &frames.PerformOpen{
ChannelMax: 65535,
ContainerID: containerID,
IdleTimeout: time.Minute,
MaxFrameSize: 4294967295,
})
}
// PerformBegin appends a PerformBegin frame with the specified remote channel ID.
// This frame is needed when making a call to Client.NewSession().
func PerformBegin(channel, remoteChannel uint16) ([]byte, error) {
return EncodeFrame(frames.TypeAMQP, channel, &frames.PerformBegin{
RemoteChannel: &remoteChannel,
NextOutgoingID: 1,
IncomingWindow: 5000,
OutgoingWindow: 1000,
HandleMax: math.MaxInt16,
})
}
// SenderAttach encodes a PerformAttach frame with the specified values.
// This frame is needed when making a call to Session.NewSender().
func SenderAttach(channel uint16, linkName string, linkHandle uint32, mode encoding.SenderSettleMode) ([]byte, error) {
return EncodeFrame(frames.TypeAMQP, channel, &frames.PerformAttach{
Name: linkName,
Handle: linkHandle,
Role: encoding.RoleReceiver,
Target: &frames.Target{
Address: "test",
Durable: encoding.DurabilityNone,
ExpiryPolicy: encoding.ExpirySessionEnd,
},
SenderSettleMode: &mode,
MaxMessageSize: math.MaxUint32,
})
}
// ReceiverAttach appends a PerformAttach frame with the specified values.
// This frame is needed when making a call to Session.NewReceiver().
func ReceiverAttach(channel uint16, linkName string, linkHandle uint32, mode encoding.ReceiverSettleMode, filter encoding.Filter) ([]byte, error) {
return EncodeFrame(frames.TypeAMQP, channel, &frames.PerformAttach{
Name: linkName,
Handle: linkHandle,
Role: encoding.RoleSender,
Source: &frames.Source{
Address: "test",
Durable: encoding.DurabilityNone,
ExpiryPolicy: encoding.ExpirySessionEnd,
Filter: filter,
},
ReceiverSettleMode: &mode,
MaxMessageSize: math.MaxUint32,
})
}
// PerformTransfer appends a PerformTransfer frame with the specified values.
// The linkHandle MUST match the linkHandle value specified in ReceiverAttach.
func PerformTransfer(channel uint16, linkHandle, deliveryID uint32, payload []byte) ([]byte, error) {
format := uint32(0)
payloadBuf := &buffer.Buffer{}
encoding.WriteDescriptor(payloadBuf, encoding.TypeCodeApplicationData)
err := encoding.WriteBinary(payloadBuf, payload)
if err != nil {
return nil, err
}
return EncodeFrame(frames.TypeAMQP, channel, &frames.PerformTransfer{
Handle: linkHandle,
DeliveryID: &deliveryID,
DeliveryTag: []byte("tag"),
MessageFormat: &format,
Payload: payloadBuf.Detach(),
})
}
// PerformDisposition appends a PerformDisposition frame with the specified values.
// The firstID MUST match the deliveryID value specified in PerformTransfer.
func PerformDisposition(role encoding.Role, channel uint16, firstID uint32, lastID *uint32, state encoding.DeliveryState) ([]byte, error) {
return EncodeFrame(frames.TypeAMQP, channel, &frames.PerformDisposition{
Role: role,
First: firstID,
Last: lastID,
Settled: true,
State: state,
})
}
// PerformDetach encodes a PerformDetach frame with an optional error.
func PerformDetach(channel uint16, linkHandle uint32, e *encoding.Error) ([]byte, error) {
return EncodeFrame(frames.TypeAMQP, channel, &frames.PerformDetach{Handle: linkHandle, Closed: true, Error: e})
}
// PerformEnd encodes a PerformEnd frame with an optional error.
func PerformEnd(channel uint16, e *encoding.Error) ([]byte, error) {
return EncodeFrame(frames.TypeAMQP, channel, &frames.PerformEnd{Error: e})
}
// PerformClose encodes a PerformClose frame with an optional error.
func PerformClose(e *encoding.Error) ([]byte, error) {
return EncodeFrame(frames.TypeAMQP, 0, &frames.PerformClose{Error: e})
}
// AMQPProto is the frame type passed to FrameCallback() for the initial protocal handshake.
type AMQPProto struct {
frames.FrameBody
}
// KeepAlive is the frame type passed to FrameCallback() for keep-alive frames.
type KeepAlive struct {
frames.FrameBody
}
type frameHeader frames.Header
func (f frameHeader) Marshal(wr *buffer.Buffer) error {
wr.AppendUint32(f.Size)
wr.AppendByte(f.DataOffset)
wr.AppendByte(byte(f.FrameType))
wr.AppendUint16(f.Channel)
return nil
}
// EncodeFrame encodes the specified frame to be sent over the wire.
func EncodeFrame(t frames.Type, channel uint16, f frames.FrameBody) ([]byte, error) {
bodyBuf := buffer.New([]byte{})
if err := encoding.Marshal(bodyBuf, f); err != nil {
return nil, err
}
// create the frame header, needs size of the body plus itself
header := frameHeader{
Size: uint32(bodyBuf.Len()) + 8,
DataOffset: 2,
FrameType: uint8(t),
Channel: channel,
}
headerBuf := buffer.New([]byte{})
if err := encoding.Marshal(headerBuf, header); err != nil {
return nil, err
}
// concatenate header + body
raw := headerBuf.Detach()
raw = append(raw, bodyBuf.Detach()...)
return raw, nil
}
func decodeFrame(b []byte) (uint16, frames.FrameBody, error) {
if len(b) > 3 && b[0] == 'A' && b[1] == 'M' && b[2] == 'Q' && b[3] == 'P' {
return 0, &AMQPProto{}, nil
}
buf := buffer.New(b)
header, err := frames.ParseHeader(buf)
if err != nil {
return 0, nil, err
}
bodySize := int64(header.Size - frames.HeaderSize)
if bodySize == 0 {
// keep alive frame
return 0, &KeepAlive{}, nil
}
// parse the frame
b, ok := buf.Next(bodySize)
if !ok {
return 0, nil, err
}
fr, err := frames.ParseBody(buffer.New(b))
if err != nil {
return 0, nil, err
}
return header.Channel, fr, nil
}
func encodeMultiFrameTransfer(channel uint16, linkHandle, deliveryID uint32, payload []byte, edit func(int, *frames.PerformTransfer)) ([][]byte, error) {
frameData := [][]byte{}
format := uint32(0)
payloadBuf := &buffer.Buffer{}
// determine the number of frames to create
chunks := len(payload) / 32
if r := len(payload) % 32; r > 0 {
chunks++
}
if chunks < 2 {
return nil, errors.New("payload is too small for multi-frame transfer")
}
more := true
for chunk := 0; chunk < chunks; chunk++ {
encoding.WriteDescriptor(payloadBuf, encoding.TypeCodeApplicationData)
var err error
if chunk+1 < chunks {
err = encoding.WriteBinary(payloadBuf, payload[chunk*32:chunk*32+32])
} else {
// final frame
err = encoding.WriteBinary(payloadBuf, payload[chunk*32:])
more = false
}
if err != nil {
return nil, err
}
var fr *frames.PerformTransfer
if chunk == 0 {
// first frame requires extra data
fr = &frames.PerformTransfer{
Handle: linkHandle,
DeliveryID: &deliveryID,
DeliveryTag: []byte("tag"),
MessageFormat: &format,
More: true,
Payload: payloadBuf.Detach(),
}
} else {
fr = &frames.PerformTransfer{
Handle: linkHandle,
More: more,
Payload: payloadBuf.Detach(),
}
}
if edit != nil {
edit(chunk, fr)
}
b, err := EncodeFrame(frames.TypeAMQP, channel, fr)
if err != nil {
return nil, err
}
frameData = append(frameData, b)
}
return frameData, nil
}
type readTimer interface {
C() <-chan time.Time
Stop() bool
}
func newNopTimer() nopTimer {
return nopTimer{t: make(chan time.Time)}
}
type nopTimer struct {
t chan time.Time
}
func (n nopTimer) C() <-chan time.Time {
return n.t
}
func (n nopTimer) Stop() bool {
close(n.t)
return true
}
type timer struct {
t *time.Timer
}
func (t timer) C() <-chan time.Time {
return t.t.C
}
func (t timer) Stop() bool {
return t.t.Stop()
}