testing/proxytest/proxytest.go (273 lines of code) (raw):

// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one // or more contributor license agreements. Licensed under the Elastic License 2.0; // you may not use this file except in compliance with the Elastic License 2.0. package proxytest import ( "bufio" "context" "crypto" "crypto/tls" "crypto/x509" "fmt" "io" "log" "log/slog" "net" "net/http" "net/http/httptest" "net/url" "strings" "sync" "testing" "github.com/gofrs/uuid/v5" ) type Proxy struct { *httptest.Server // Port is the port Server is listening on. Port string // LocalhostURL is the server URL as "http(s)://localhost:PORT". // Deprecated. Use Proxy.URL instead. LocalhostURL string // proxiedRequests is a "request log" for every request the proxy receives. proxiedRequests []string proxiedRequestsMu sync.Mutex requestsWG *sync.WaitGroup opts options log *slog.Logger ca ca client *http.Client } type Option func(o *options) type options struct { addr string rewriteHost func(string) string rewriteURL func(u *url.URL) // logFn if set will be used to log every request. logFn func(format string, a ...any) verbose bool serverTLSConfig *tls.Config capriv crypto.PrivateKey cacert *x509.Certificate client *http.Client } type ca struct { capriv crypto.PrivateKey cacert *x509.Certificate } // WithAddress will set the address the server will listen on. The format is as // defined by net.Listen for a tcp connection. func WithAddress(addr string) Option { return func(o *options) { o.addr = addr } } // WithHTTPClient sets http.Client used to proxy requests to the target host. func WithHTTPClient(c *http.Client) Option { return func(o *options) { o.client = c } } // WithMITMCA sets the CA used for MITM (men in the middle) when proxying HTTPS // requests. It's used to generate TLS certificates matching the target host. // Ideally the CA is the same as the one issuing the TLS certificate for the // proxy set by WithServerTLSConfig. func WithMITMCA(priv crypto.PrivateKey, cert *x509.Certificate) func(o *options) { return func(o *options) { o.capriv = priv o.cacert = cert } } // WithRequestLog sets the proxy to log every request using logFn. It uses name // as a prefix to the log. func WithRequestLog(name string, logFn func(format string, a ...any)) Option { return func(o *options) { o.logFn = func(format string, a ...any) { logFn("[proxy-"+name+"] "+format, a...) } } } // WithRewrite will replace old by new on the request URL host when forwarding it. func WithRewrite(old, new string) Option { return func(o *options) { o.rewriteHost = func(s string) string { return strings.Replace(s, old, new, 1) } } } // WithRewriteFn calls f on the request *url.URL before forwarding it. // It takes precedence over WithRewrite. Use if more control over the rewrite // is needed. func WithRewriteFn(f func(u *url.URL)) Option { return func(o *options) { o.rewriteURL = f } } // WithServerTLSConfig sets the TLS config for the server. func WithServerTLSConfig(tc *tls.Config) Option { return func(o *options) { o.serverTLSConfig = tc } } // WithVerboseLog sets the proxy to log every request verbosely and enables // debug level logging. WithRequestLog must be used as well, otherwise // WithVerboseLog will not take effect. func WithVerboseLog() Option { return func(o *options) { o.verbose = true } } // New returns a new Proxy ready for use. Use: // - WithAddress to set the proxy's address, // - WithRewrite or WithRewriteFn to rewrite the URL before forwarding the request. // // Check the other With* functions for more options. func New(t *testing.T, optns ...Option) *Proxy { t.Helper() opts := options{addr: "127.0.0.1:0", client: &http.Client{}} for _, o := range optns { o(&opts) } if opts.logFn == nil { opts.logFn = func(format string, a ...any) {} } l, err := net.Listen("tcp", opts.addr) //nolint:gosec,nolintlint // it's a test if err != nil { t.Fatalf("NewServer failed to create a net.Listener: %v", err) } // Create a text handler that writes to standard output lv := slog.LevelInfo if opts.verbose { lv = slog.LevelDebug } p := Proxy{ requestsWG: &sync.WaitGroup{}, opts: opts, client: opts.client, log: slog.New(slog.NewTextHandler(logfWriter(opts.logFn), &slog.HandlerOptions{ Level: lv, })), } if opts.capriv != nil && opts.cacert != nil { p.ca = ca{capriv: opts.capriv, cacert: opts.cacert} } p.Server = httptest.NewUnstartedServer( http.HandlerFunc(func(ww http.ResponseWriter, r *http.Request) { // Sometimes, on CI obviously, the last log happens after the test // finishes. See https://github.com/elastic/elastic-agent/issues/5869. // Therefore, let's add an extra layer to try to avoid that. p.requestsWG.Add(1) defer p.requestsWG.Done() w := &proxyResponseWriter{w: ww} requestID := uuid.Must(uuid.NewV4()).String() p.log.Info(fmt.Sprintf("STARTING - %s '%s' %s %s", r.Method, r.URL, r.Proto, r.RemoteAddr)) rr := addIDToReqCtx(r, requestID) rrr := addLoggerReqCtx(rr, p.log.With("req_id", requestID)) p.ServeHTTP(w, rrr) p.log.Info(fmt.Sprintf("[%s] DONE %d - %s %s %s %s\n", requestID, w.statusCode, r.Method, r.URL, r.Proto, r.RemoteAddr)) }), ) p.Server.Listener = l if opts.serverTLSConfig != nil { p.Server.TLS = opts.serverTLSConfig } u, err := url.Parse(p.URL) if err != nil { panic(fmt.Sprintf("could parse fleet-server URL: %v", err)) } p.Port = u.Port() p.LocalhostURL = "http://localhost:" + p.Port return &p } func (p *Proxy) Start() error { p.Server.Start() u, err := url.Parse(p.URL) if err != nil { return fmt.Errorf("could not parse fleet-server URL: %w", err) } p.Port = u.Port() p.LocalhostURL = "http://localhost:" + p.Port p.log.Info(fmt.Sprintf("running on %s -> %s", p.URL, p.LocalhostURL)) return nil } func (p *Proxy) StartTLS() error { p.Server.StartTLS() u, err := url.Parse(p.URL) if err != nil { return fmt.Errorf("could not parse fleet-server URL: %w", err) } p.Port = u.Port() p.LocalhostURL = "https://localhost:" + p.Port p.log.Info(fmt.Sprintf("running on %s -> %s", p.URL, p.LocalhostURL)) return nil } func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodConnect { p.serveHTTPS(w, r) return } p.serveHTTP(w, r) } func (p *Proxy) Close() { // Sometimes, on CI obviously, the last log happens after the test // finishes. See https://github.com/elastic/elastic-agent/issues/5869. // So, manually wait all ongoing requests to finish. p.requestsWG.Wait() p.Server.Close() } func (p *Proxy) serveHTTP(w http.ResponseWriter, r *http.Request) { resp, err := p.processRequest(r) if err != nil { w.WriteHeader(http.StatusInternalServerError) msg := fmt.Sprintf("could not make request: %#v", err.Error()) log.Print(msg) _, _ = fmt.Fprint(w, msg) return } defer resp.Body.Close() w.WriteHeader(resp.StatusCode) for k, v := range resp.Header { w.Header()[k] = v } if _, err = io.Copy(w, resp.Body); err != nil { p.opts.logFn("[ERROR] could not write response body: %v", err) } } // processRequest executes the configured request manipulation and perform the // request. func (p *Proxy) processRequest(r *http.Request) (*http.Response, error) { origURL := r.URL.String() switch { case p.opts.rewriteURL != nil: p.opts.rewriteURL(r.URL) case p.opts.rewriteHost != nil: r.URL.Host = p.opts.rewriteHost(r.URL.Host) } // It should not be required, however if not set, enroll will fail with // "Unknown resource" r.Host = r.URL.Host p.log.Debug(fmt.Sprintf("original URL: %s, new URL: %s", origURL, r.URL.String())) p.proxiedRequestsMu.Lock() p.proxiedRequests = append(p.proxiedRequests, fmt.Sprintf("%s - %s %s %s", r.Method, r.URL.Scheme, r.URL.Host, r.URL.String())) p.proxiedRequestsMu.Unlock() // when modifying the request, RequestURI isn't updated, and it isn't // needed anyway, so remove it. r.RequestURI = "" return p.client.Do(r) } // ProxiedRequests returns a slice with the "request log" with every request the // proxy received. func (p *Proxy) ProxiedRequests() []string { p.proxiedRequestsMu.Lock() defer p.proxiedRequestsMu.Unlock() var rs []string rs = append(rs, p.proxiedRequests...) return rs } var _ http.Hijacker = &proxyResponseWriter{} // proxyResponseWriter wraps a http.ResponseWriter to expose the status code // through proxyResponseWriter.statusCode type proxyResponseWriter struct { w http.ResponseWriter statusCode int } func (s *proxyResponseWriter) Header() http.Header { return s.w.Header() } func (s *proxyResponseWriter) Write(bs []byte) (int, error) { return s.w.Write(bs) } func (s *proxyResponseWriter) WriteHeader(statusCode int) { s.statusCode = statusCode s.w.WriteHeader(statusCode) } func (s *proxyResponseWriter) StatusCode() int { return s.statusCode } func (s *proxyResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { hijacker, ok := s.w.(http.Hijacker) if !ok { return nil, nil, fmt.Errorf("%T does not support hijacking", s.w) } return hijacker.Hijack() } type ctxKeyRecID struct{} type ctxKeyLogger struct{} func addIDToReqCtx(r *http.Request, id string) *http.Request { return r.WithContext(context.WithValue(r.Context(), ctxKeyRecID{}, id)) } func idFromReqCtx(r *http.Request) string { //nolint:unused // kept for completeness return r.Context().Value(ctxKeyRecID{}).(string) } func addLoggerReqCtx(r *http.Request, log *slog.Logger) *http.Request { return r.WithContext(context.WithValue(r.Context(), ctxKeyLogger{}, log)) } func loggerFromReqCtx(r *http.Request) *slog.Logger { l, ok := r.Context().Value(ctxKeyLogger{}).(*slog.Logger) if !ok { return slog.Default() } return l } type logfWriter func(format string, a ...any) func (w logfWriter) Write(p []byte) (n int, err error) { w(string(p)) return len(p), nil }