testutils/relay.go (100 lines of code) (raw):
// Copyright (c) 2015 Uber Technologies, Inc.
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package testutils
import (
"io"
"net"
"sync"
"testing"
"github.com/uber/tchannel-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
)
type frameRelay struct {
sync.Mutex // protects conns
t testing.TB
destination string
relayFunc func(outgoing bool, f *tchannel.Frame) *tchannel.Frame
closed atomic.Uint32
conns []net.Conn
wg sync.WaitGroup
}
func (r *frameRelay) listen() (listenHostPort string, cancel func()) {
conn, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(r.t, err, "net.Listen failed")
go func() {
for {
c, err := conn.Accept()
if err != nil {
if r.closed.Load() == 0 {
r.t.Errorf("Accept failed: %v", err)
}
return
}
r.Lock()
r.conns = append(r.conns, c)
r.Unlock()
r.relayConn(c)
}
}()
return conn.Addr().String(), func() {
r.closed.Inc()
conn.Close()
r.Lock()
for _, c := range r.conns {
c.Close()
}
r.Unlock()
// Wait for all the outbound connections we created to close.
r.wg.Wait()
}
}
func (r *frameRelay) relayConn(c net.Conn) {
outC, err := net.Dial("tcp", r.destination)
if !assert.NoError(r.t, err, "relay connection failed") {
return
}
r.Lock()
defer r.Unlock()
if r.closed.Load() > 0 {
outC.Close()
return
}
r.conns = append(r.conns, outC)
r.wg.Add(2)
go r.relayBetween(true /* outgoing */, c, outC)
go r.relayBetween(false /* outgoing */, outC, c)
}
func (r *frameRelay) relayBetween(outgoing bool, c net.Conn, outC net.Conn) {
defer r.wg.Done()
frame := tchannel.NewFrame(tchannel.MaxFramePayloadSize)
for {
err := frame.ReadIn(c)
if err == io.EOF {
// Connection gracefully closed.
return
}
if err != nil && r.closed.Load() > 0 {
// Once the relay is shutdown, we expect connection errors.
return
}
if !assert.NoError(r.t, err, "read frame failed") {
return
}
outFrame := r.relayFunc(outgoing, frame)
if outFrame == nil {
continue
}
err = outFrame.WriteOut(outC)
if err != nil && r.closed.Load() > 0 {
// Once the relay is shutdown, we expect connection errors.
return
}
if !assert.NoError(r.t, err, "write frame failed") {
return
}
}
}
// FrameRelay sets up a relay that can modify frames using relayFunc.
func FrameRelay(t testing.TB, destination string, relayFunc func(outgoing bool, f *tchannel.Frame) *tchannel.Frame) (listenHostPort string, cancel func()) {
relay := &frameRelay{
t: t,
destination: destination,
relayFunc: relayFunc,
}
return relay.listen()
}