sdk/internal/mock/mock.go (261 lines of code) (raw):
//go:build go1.18
// +build go1.18
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
package mock
import (
"crypto/tls"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"sync"
"sync/atomic"
"time"
)
const (
mockReadError = "mock-read-error"
)
// Server is a wrapper around an httptest.Server.
// The serving of requests is not safe for concurrent use
// which is ok for right now as each test creates is own
// server and doesn't create additional go routines.
type Server struct {
srv *httptest.Server
// static is the static response, if this is not nil it's always returned.
static *mockResponse
// respLock synchronizes access to resp
respLock *sync.RWMutex
// resp is the queue of responses. each response is taken from the front.
resp []mockResponse
// count tracks the number of requests that have been made.
count int32
// determines whether all requests will be routed to the httptest Server by changing the Host of each request
routeAllRequestsToMockServer bool
}
func newServer() *Server {
return &Server{respLock: &sync.RWMutex{}}
}
// NewServer creates a new Server object.
// The returned close func must be called when the server is no longer needed.
func NewServer(cfg ...ServerOption) (*Server, func()) {
s := newServer()
s.srv = httptest.NewUnstartedServer(http.HandlerFunc(s.serveHTTP))
for _, c := range cfg {
c.apply(s)
}
s.srv.Start()
return s, func() { s.srv.Close() }
}
// NewTLSServer creates a new Server object and applies any additional ServerOption
// configurations provided on the Server.
// The returned close func must be called when the server is no longer needed.
// It will return nil for both the server and close func if it encountered an error
// when configuring HTTP2 for the new TLS server.
func NewTLSServer(cfg ...ServerOption) (*Server, func()) {
s := newServer()
s.srv = httptest.NewUnstartedServer(http.HandlerFunc(s.serveHTTP))
for _, c := range cfg {
c.apply(s)
}
s.srv.StartTLS()
return s, func() { s.srv.Close() }
}
// ServerConfig returns the server config. Please note this should not be
// modified since Start or StartTLS has already been called on the server.
func (s *Server) ServerConfig() *http.Server {
return s.srv.Config
}
// returns true if the next response is an error response
func (s *Server) isErrorResp() bool {
// always favor static response
if s.static != nil {
return s.static.err != nil
}
s.respLock.RLock()
defer s.respLock.RUnlock()
if len(s.resp) > 0 {
return s.resp[0].err != nil
}
panic("no more responses")
}
// returns the static response or the next response in the queue
func (s *Server) getResponse() mockResponse {
// always favor static response
if s.static != nil {
return *s.static
}
if len(s.resp) > 0 {
// pop off first response and return it
s.respLock.Lock()
defer s.respLock.Unlock()
resp := s.resp[0]
s.resp = s.resp[1:]
return resp
}
panic("no more responses")
}
// URL returns the endpoint of the test server in URL format.
func (s *Server) URL() string {
return s.srv.URL
}
// Do implements the azcore.Transport interface on Server.
// Calling this when the response queue is empty and no static
// response has been set will cause a panic.
func (s *Server) Do(req *http.Request) (*http.Response, error) {
atomic.AddInt32(&s.count, 1)
// error responses are returned here
if s.isErrorResp() {
resp := s.getResponse()
return nil, resp.err
}
var err error
var resp *http.Response
if s.routeAllRequestsToMockServer {
var srvUrl *url.URL
originalURL := req.URL
mockUrl := *req.URL
srvUrl, err = url.Parse(s.srv.URL)
if err != nil {
return nil, fmt.Errorf("unable to parse the test server URL: %v", err)
}
mockUrl.Host = srvUrl.Host
mockUrl.Scheme = srvUrl.Scheme
req.URL = &mockUrl
resp, err = s.srv.Client().Do(req)
req.URL = originalURL
} else {
resp, err = s.srv.Client().Do(req)
}
if err != nil {
return resp, err
}
// wrap the response body in a readFailer if the mock-read-error header is set
if resp.Header.Get(mockReadError) != "" {
resp.Body = &readFailer{wrapped: resp.Body}
resp.Header.Del(mockReadError)
}
return resp, err
}
func (s *Server) serveHTTP(w http.ResponseWriter, req *http.Request) {
var resp mockResponse
for {
// grab next response from the queue
resp = s.getResponse()
if resp.pred == nil {
// no predicate, we're done
break
} else if resp.pred(req) {
// response applies to this request, so remove the next response
s.getResponse()
break
}
}
if resp.delay > 0 {
time.Sleep(resp.delay)
}
err := resp.write(w)
if err != nil {
panic(err)
}
}
// Requests returns the number of times an HTTP request was made.
// NOTE: this is *not* goroutine safe to be called while requests are being served.
func (s *Server) Requests() int {
return int(s.count)
}
// AppendError appends the error to the end of the response queue.
func (s *Server) AppendError(err error) {
s.resp = append(s.resp, mockResponse{err: err})
}
// RepeatError appends the error n number of times to the end of the response queue.
func (s *Server) RepeatError(n int, err error) {
for i := 0; i < n; i++ {
s.AppendError(err)
}
}
// SetError indicates the same error should always be returned.
// Any responses set via other methods will be ignored.
func (s *Server) SetError(err error) {
s.static = &mockResponse{err: err}
}
// AppendResponse appends the response to the end of the response queue.
// If no options are provided the default response is an http.StatusOK.
func (s *Server) AppendResponse(opts ...ResponseOption) {
mr := mockResponse{code: http.StatusOK, headers: http.Header{}}
for _, o := range opts {
o.apply(&mr)
}
s.resp = append(s.resp, mr)
}
// RepeatResponse appends the response n number of times to the end of the response queue.
// If no options are provided the default response is an http.StatusOK.
func (s *Server) RepeatResponse(n int, opts ...ResponseOption) {
for i := 0; i < n; i++ {
s.AppendResponse(opts...)
}
}
// SetResponse indicates the same response should always be returned.
// Any responses set via other methods will be ignored.
// If no options are provided the default response is an http.StatusOK.
// NOTE: does not support WithPredicate(), will cause a panic.
func (s *Server) SetResponse(opts ...ResponseOption) {
mr := mockResponse{code: http.StatusOK, headers: http.Header{}}
for _, o := range opts {
o.apply(&mr)
}
if mr.pred != nil {
panic("WithPredicate not supported for static responses")
}
s.static = &mr
}
// ServerOption is an abstraction for configuring a mock Server.
type ServerOption interface {
apply(s *Server)
}
type fnSrvOpt func(*Server)
func (fn fnSrvOpt) apply(s *Server) {
fn(s)
}
func WithTransformAllRequestsToTestServerUrl() ServerOption {
return fnSrvOpt(func(s *Server) {
s.routeAllRequestsToMockServer = true
})
}
// WithTLSConfig sets the given TLS config on server.
func WithTLSConfig(cfg *tls.Config) ServerOption {
return fnSrvOpt(func(s *Server) {
s.srv.TLS = cfg
})
}
// WithHTTP2Enabled sets the HTTP2Enabled field on the testserver to the boolean value provided.
func WithHTTP2Enabled(enabled bool) ServerOption {
return fnSrvOpt(func(s *Server) {
s.srv.EnableHTTP2 = enabled
})
}
// ResponseOption is an abstraction for configuring a mock HTTP response.
type ResponseOption interface {
apply(mr *mockResponse)
}
type fnRespOpt func(*mockResponse)
func (fn fnRespOpt) apply(mr *mockResponse) {
fn(mr)
}
// ResponsePredicate is a predicate that's invoked in response to an HTTP request.
type ResponsePredicate func(*http.Request) bool
type mockResponse struct {
code int
body []byte
headers http.Header
err error
rerr bool
delay time.Duration
pred ResponsePredicate
}
func (mr mockResponse) write(w http.ResponseWriter) error {
if len(mr.headers) > 0 {
for k, v := range mr.headers {
for _, vv := range v {
w.Header().Add(k, vv)
}
}
}
if mr.rerr {
w.Header().Add(mockReadError, "true")
}
w.WriteHeader(mr.code)
if mr.body != nil {
_, err := w.Write(mr.body)
if err != nil {
return err
}
}
return nil
}
// WithStatusCode sets the HTTP response's status code to the specified value.
func WithStatusCode(c int) ResponseOption {
return fnRespOpt(func(mr *mockResponse) {
mr.code = c
})
}
// WithBody sets the HTTP response's body to the specified value.
func WithBody(b []byte) ResponseOption {
return fnRespOpt(func(mr *mockResponse) {
mr.body = b
})
}
// WithHeader adds the specified header and value to the HTTP response.
func WithHeader(k, v string) ResponseOption {
return fnRespOpt(func(mr *mockResponse) {
mr.headers.Add(k, v)
})
}
// WithSlowResponse will sleep for the specified duration before returning the HTTP response.
func WithSlowResponse(d time.Duration) ResponseOption {
return fnRespOpt(func(mr *mockResponse) {
mr.delay = d
})
}
// WithBodyReadError returns a response that will fail when reading the body.
func WithBodyReadError() ResponseOption {
return fnRespOpt(func(mr *mockResponse) {
mr.rerr = true
})
}
// WithPredicate invokes the specified predicate func on the HTTP request.
// If the predicate returns true, the response associated with the predicate is
// returned and the next response is removed from the queue. When false, the
// associated response is ignored and the next one is returned.
// NOTE: not supported for static responses, will cause a panic.
func WithPredicate(p ResponsePredicate) ResponseOption {
return fnRespOpt(func(mr *mockResponse) {
mr.pred = p
})
}
type readFailer struct {
wrapped io.ReadCloser
}
func (r *readFailer) Close() error {
return r.wrapped.Close()
}
func (r *readFailer) Read(p []byte) (int, error) {
return 0, errors.New("mock read failure")
}