transport/httpcommon/diag.go (160 lines of code) (raw):
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package httpcommon
import (
"bytes"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"net/http/httptrace"
"net/textproto"
"net/url"
"strings"
"github.com/elastic/elastic-agent-libs/transport/tlscommon"
)
// DiagRequest returns a diagnostics hook callback that will send the passed requests using a roundtripper generated from the settings and log httptrace events in the returned bytes.
func (settings *HTTPTransportSettings) DiagRequests(reqs []*http.Request, opts ...TransportOption) func() []byte {
if settings == nil {
return func() []byte {
return []byte(`error: nil httpcommon.HTTPTransportSettings`)
}
}
if len(reqs) == 0 {
return func() []byte {
return []byte(`error: 0 requests`)
}
}
return func() []byte {
var b bytes.Buffer
rt, err := settings.RoundTripper(opts...)
if err != nil {
b.WriteString("unable to get roundtripper: " + err.Error())
return b.Bytes()
}
logger := log.New(&b, "", log.LstdFlags|log.Lmicroseconds|log.LUTC)
if settings.TLS == nil {
logger.Print("No TLS settings")
} else {
logger.Print("TLS settings detected")
}
logger.Printf("Proxy disable=%v url=%s", settings.Proxy.Disable, settings.Proxy.URL)
ct := &httptrace.ClientTrace{
GetConn: func(hostPort string) {
logger.Printf("GetConn called for %q", hostPort)
},
GotConn: func(connInfo httptrace.GotConnInfo) {
logger.Printf("GotConn for %q", connInfo.Conn.RemoteAddr())
},
GotFirstResponseByte: func() {
logger.Print("Response started")
},
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
logger.Printf("Got info response status=%d, headers=%v", code, header)
return nil
},
DNSStart: func(info httptrace.DNSStartInfo) {
logger.Printf("Starting DNS lookup for %q", info.Host)
},
DNSDone: func(info httptrace.DNSDoneInfo) {
logger.Printf("Done DNS lookup: %+v", info)
},
ConnectStart: func(network, addr string) {
logger.Printf("Connection started to %s:%s", network, addr)
},
ConnectDone: func(network, addr string, err error) {
logger.Printf("Connection to %s:%s done, err: %v", network, addr, err)
},
TLSHandshakeStart: func() {
logger.Print("TLS handshake starting")
},
TLSHandshakeDone: func(state tls.ConnectionState, err error) {
logger.Printf("TLS handshake done. state=%+v err=%v", state, err)
logger.Printf("Peer certificate count %d", len(state.PeerCertificates))
for i, crt := range state.PeerCertificates {
logger.Printf("- Peer Certificate %d\n\t%s", i, tlscommon.CertDiagString(crt))
}
logger.Printf("Verified chains count: %d", len(state.VerifiedChains))
for i, chain := range state.VerifiedChains {
for j, crt := range chain {
logger.Printf("- Chain %d certificate %d\n\t%s", i, j, tlscommon.CertDiagString(crt))
}
}
},
WroteHeaders: func() {
logger.Printf("Wrote request headers")
},
Wait100Continue: func() {
logger.Printf("Waiting for continue")
},
WroteRequest: func(info httptrace.WroteRequestInfo) {
logger.Printf("Wrote request err=%v", info.Err)
},
}
for i, req := range reqs {
logger.Printf("Request %d to %s starting", i, req.URL.String())
req = req.WithContext(httptrace.WithClientTrace(req.Context(), ct))
if resp, err := rt.RoundTrip(req); err != nil {
logger.Printf("request %d error: %s", i, diagError(err))
} else if isGoHTTPResp(resp) {
resp.Body.Close()
logger.Printf("request %d error: HTTP request sent to HTTPS server.", i)
} else {
resp.Body.Close()
logger.Printf("request %d successful. status=%d", i, resp.StatusCode)
}
}
return b.Bytes()
}
}
// isGoHTTPResp detects if the response is one that a go http.Server sends if an HTTP request is made to an HTTPS server.
// non Go servers may return a net.OpError instead.
func isGoHTTPResp(r *http.Response) bool {
if r.StatusCode != http.StatusBadRequest {
return false
}
p, err := io.ReadAll(r.Body)
if err != nil {
return false
}
return string(p) == "Client sent an HTTP request to an HTTPS server.\n"
}
// diagError tries to diagnose the error and return a cause/possible cause in a human readable format.
// If no matching errors are found err.Error is returned.
func diagError(err error) string {
// client does not support server algorithm
if errors.Is(err, x509.ErrUnsupportedAlgorithm) {
return fmt.Sprintf("%v: caused by client does not support server's signature algorithm.", err)
}
// a *net.OpError could indicate an HTTP request made to an HTTPS server
var netErr *net.OpError
if errors.As(err, &netErr) {
if netErr.Err.Error() == "read: connection reset by peer" {
return fmt.Sprintf("%v: possible cause: HTTP schema used for HTTPS server.", netErr)
}
}
// Client does not have CA that matches server cert
var unknownAuthErr x509.UnknownAuthorityError
if errors.As(err, &unknownAuthErr) {
return fmt.Sprintf("%v: caused by no trusted client CA.", err)
}
// CA is ok but the server's cert is not.
var certValidErr x509.CertificateInvalidError
if errors.As(err, &certValidErr) {
return fmt.Sprintf("%v: caused by invalid server certificate.", certValidErr)
}
// cert validation error can indicate that a custom CA needs to be used
var tlsErr *tls.CertificateVerificationError
if errors.As(err, &tlsErr) {
return fmt.Sprintf("%v: possible cause: client TLS settings incorrect.", tlsErr)
}
// keep unwrapping to url.Error as the last error as other failures can be embedded in a url.Error
// Can detect if an HTTPS request is made to an HTTP server
var uErr *url.Error
if errors.As(err, &uErr) {
errString := uErr.Err.Error()
if strings.Contains(errString, "http: server gave HTTP response to HTTPS client") {
return fmt.Sprintf("%v: caused by using HTTPS schema on HTTP server.", uErr)
}
if strings.Contains(errString, "remote error: tls: certificate required") {
return fmt.Sprintf("%v: caused by missing mTLS client cert.", uErr)
}
if strings.Contains(errString, "remote error: tls: expired certificate") {
return fmt.Sprintf("%v: caused by expired mTLS client cert.", uErr)
}
if strings.Contains(errString, "remote error: tls: bad certificate") {
return fmt.Sprintf("%v: caused by invalid mTLS client cert, does the server trust the CA used for the client cert?.", uErr)
}
}
return err.Error()
}