internal/pkg/logger/http.go (207 lines of code) (raw):
// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
// or more contributor license agreements. Licensed under the Elastic License;
// you may not use this file except in compliance with the Elastic License.
package logger
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/gofrs/uuid"
"github.com/rs/zerolog"
"go.elastic.co/apm/module/apmzerolog/v2"
"github.com/elastic/fleet-server/v7/internal/pkg/apikey"
)
const (
HeaderRequestID = "X-Request-Id"
httpSlashPrefix = "HTTP/"
)
type ReaderCounter struct {
io.ReadCloser
count uint64
}
func NewReaderCounter(r io.ReadCloser) *ReaderCounter {
return &ReaderCounter{
ReadCloser: r,
}
}
func (rd *ReaderCounter) Read(buf []byte) (int, error) {
n, err := rd.ReadCloser.Read(buf)
atomic.AddUint64(&rd.count, uint64(n)) //nolint:gosec // disable G115
return n, err
}
func (rd *ReaderCounter) Count() uint64 {
return atomic.LoadUint64(&rd.count)
}
type ResponseCounter struct {
http.ResponseWriter
count uint64
statusCode int
}
func NewResponseCounter(w http.ResponseWriter) *ResponseCounter {
return &ResponseCounter{
ResponseWriter: w,
}
}
func (rc *ResponseCounter) Write(buf []byte) (int, error) {
if rc.statusCode == 0 {
rc.WriteHeader(http.StatusOK)
}
n, err := rc.ResponseWriter.Write(buf)
atomic.AddUint64(&rc.count, uint64(n)) //nolint:gosec // disable G115
return n, err
}
func (rc *ResponseCounter) WriteHeader(statusCode int) {
rc.ResponseWriter.WriteHeader(statusCode)
// Defend unsupported multiple calls to WriteHeader
if rc.statusCode == 0 {
rc.statusCode = statusCode
}
}
// Unwrap unwraps the underlying ResponseWriter
func (rc *ResponseCounter) Unwrap() http.ResponseWriter {
return rc.ResponseWriter
}
func (rc *ResponseCounter) Count() uint64 {
return atomic.LoadUint64(&rc.count)
}
type ctxTSKey struct{}
// CtxStartTime returns the start time associated with a context
func CtxStartTime(ctx context.Context) (time.Time, bool) {
ts, ok := ctx.Value(ctxTSKey{}).(time.Time)
return ts, ok
}
func splitAddr(addr string) (host string, port int) {
host, portS, err := net.SplitHostPort(addr)
if err == nil {
if v, err := strconv.Atoi(portS); err == nil {
port = v
}
}
return //nolint:nakedret // short function
}
// Expects HTTP version in form of HTTP/x.y
func stripHTTP(h string) string {
switch h {
case "HTTP/2.0":
return "2.0"
case "HTTP/1.1":
return "1.1" //nolint:goconst // 1.1 is used by http and tls
default:
if strings.HasPrefix(h, httpSlashPrefix) {
return h[len(httpSlashPrefix):]
}
}
return h
}
func httpMeta(r *http.Request, e *zerolog.Event) {
oldForce := r.URL.ForceQuery
r.URL.ForceQuery = false
e.Str(ECSURLFull, r.URL.String())
r.URL.ForceQuery = oldForce
if domain := r.URL.Hostname(); domain != "" {
e.Str(ECSURLDomain, domain)
}
port := r.URL.Port()
if port != "" {
if v, err := strconv.Atoi(port); err == nil {
e.Int(ECSURLPort, v)
}
}
// HTTP info
e.Str(ECSHTTPVersion, stripHTTP(r.Proto))
e.Str(ECSHTTPRequestMethod, r.Method)
// ApiKey
if apiKey, err := apikey.ExtractAPIKey(r); err == nil {
e.Str(APIKeyID, apiKey.ID)
}
// Client info
if r.RemoteAddr != "" {
e.Str(ECSClientAddress, r.RemoteAddr)
}
// TLS info
e.Bool(ECSTLSEstablished, r.TLS != nil)
}
func httpDebug(r *http.Request, e *zerolog.Event) {
// Client info
if r.RemoteAddr != "" {
remoteIP, remotePort := splitAddr(r.RemoteAddr)
e.Str(ECSClientIP, remoteIP)
e.Int(ECSClientPort, remotePort)
}
if r.TLS != nil {
e.Str(ECSTLSVersion, TLSVersionToString(r.TLS.Version))
e.Str(ECSTLSCipher, tls.CipherSuiteName(r.TLS.CipherSuite))
e.Bool(ECSTLSsResumed, r.TLS.DidResume)
if r.TLS.ServerName != "" {
e.Str(ECSTLSClientServerName, r.TLS.ServerName)
}
if len(r.TLS.PeerCertificates) > 0 && r.TLS.PeerCertificates[0] != nil {
leaf := r.TLS.PeerCertificates[0]
if leaf.SerialNumber != nil {
e.Str(ECSTLSClientSerialNumber, leaf.SerialNumber.String())
}
e.Str(ECSTLSClientIssuer, leaf.Issuer.String())
e.Str(ECSTLSClientSubject, leaf.Subject.String())
e.Str(ECSTLSClientNotBefore, leaf.NotBefore.UTC().Format(ECSTLSClientTimeFormat))
e.Str(ECSTLSClientNotAfter, leaf.NotAfter.UTC().Format(ECSTLSClientTimeFormat))
}
}
}
// Middleware wraps an HTTP Middleware in a request logger.
//
// It will also attach a (zerolog) logger, and request start time to each requests' context.
// The default settings will result in an ECS compliant entry if the response code is not 2XX.
// The middleware will generate a new UUID if there's no X-Request-ID header
// Responses will also have the X-Request-ID header set.
// If debug is enabled a request will result in 2 log entries; one at the start of the request and one when the response is sent (regardless of status code)
func Middleware(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
start := time.Now().UTC()
// Look for request id
reqID := r.Header.Get(HeaderRequestID)
if reqID == "" { // generate a request ID if there's none
// zerolog has strong support for github.com/rs/xid a 12byte ID - perhaps we should use it if UUID is too big?
uid, err := uuid.NewV4()
if err == nil {
reqID = uid.String()
r.Header.Set(HeaderRequestID, reqID) // incase other handlers need it
}
}
// Insert X-Request-ID header into response
w.Header().Set(HeaderRequestID, reqID)
// get the server bound addr from the req ctx
var addr string
netAddr, ok := r.Context().Value(http.LocalAddrContextKey).(net.Addr)
if ok {
addr = netAddr.String()
}
// Add trace correlation fields
ctx := r.Context()
zlog := zerolog.Ctx(ctx).Hook(apmzerolog.TraceContextHook(ctx))
// Update request context
// NOTE this injects the request id and addr into all logs that use the request logger
zlog = zlog.With().Str(ECSHTTPRequestID, reqID).Str(ECSServerAddress, addr).Logger()
ctx = zlog.WithContext(ctx)
ctx = context.WithValue(ctx, ctxTSKey{}, start)
r = r.WithContext(ctx)
e := zlog.Info()
if !e.Enabled() {
next.ServeHTTP(w, r)
return
}
rdCounter := NewReaderCounter(r.Body)
r.Body = rdCounter
wrCounter := NewResponseCounter(w)
if zlog.Debug().Enabled() {
d := zlog.Debug()
httpMeta(r, d)
httpDebug(r, d)
d.Msg("HTTP start")
}
next.ServeHTTP(wrCounter, r)
httpMeta(r, e)
// Write an info level log line for each HTTP request if debug is enabled, or a non-2XX status is returned.
if zlog.Debug().Enabled() || (wrCounter.statusCode < 200 || wrCounter.statusCode >= 300) {
e.Uint64(ECSHTTPRequestBodyBytes, rdCounter.Count())
e.Uint64(ECSHTTPResponseBodyBytes, wrCounter.Count())
e.Int(ECSHTTPResponseCode, wrCounter.statusCode)
e.Int64(ECSEventDuration, time.Since(start).Nanoseconds())
e.Msg("HTTP Request")
}
}
return http.HandlerFunc(fn)
}
func TLSVersionToString(vers uint16) string {
switch vers {
case tls.VersionTLS10:
return "1.0"
case tls.VersionTLS11:
return "1.1"
case tls.VersionTLS12:
return "1.2"
case tls.VersionTLS13:
return "1.3"
default:
}
return fmt.Sprintf("unknown_0x%x", vers)
}