internal/beater/request/context.go (257 lines of code) (raw):

// Licensed to Elasticsearch B.V. under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Elasticsearch B.V. licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package request import ( "compress/gzip" "compress/zlib" "context" "encoding/json" "errors" "io" "net/http" "net/netip" "strings" "time" "github.com/elastic/elastic-agent-libs/logp" "github.com/elastic/apm-server/internal/beater/auth" "github.com/elastic/apm-server/internal/beater/headers" "github.com/elastic/apm-server/internal/logs" "github.com/elastic/apm-server/internal/netutil" ) const ( mimeTypeAny = "*/*" mimeTypeApplicationJSON = "application/json" ) var ( mimeTypesJSON = []string{mimeTypeAny, mimeTypeApplicationJSON} errTimeout = errors.New("request timed out") ) type zlibReadCloseResetter interface { io.ReadCloser zlib.Resetter } // Context abstracts request and response information for http requests type Context struct { // compressedRequestReadCloser will be initialised for requests // with a non-zero ContentLength and empty Content-Encoding. compressedRequestReadCloser compressedRequestReadCloser countingReadCloser countingReadCloser gzipReader *gzip.Reader zlibReader zlibReadCloseResetter Request *http.Request Logger *logp.Logger Authentication auth.AuthenticationDetails Result Result // Timestamp holds the time at which the request was received by // the server. Timestamp time.Time // SourceIP holds the IP address of the originating client, if known, // as recorded in Forwarded, X-Forwarded-For, etc. SourceIP netip.Addr // SourcePort holds the port of the originating client, as recorded in // the Forwarded header. This will be zero unless the port is recorded // in the Forwarded header. SourcePort int // ClientIP holds the IP address of the originating client, if known, // as recorded in Forwarded, X-Forwarded-For, etc. // // For TCP-based requests this will have the same value as SourceIP. ClientIP netip.Addr // ClientPort holds the port of the originating client, as recorded in // the Forwarded header. This will be zero unless the port is recorded // in the Forwarded header. ClientPort int // SourceNATIP holds the IP address of the (source) network peer, or // zero if unknown. SourceNATIP netip.Addr // UserAgent holds the User-Agent request header value. UserAgent string // ResponseWriter is exported to enable passing Context to OTLP handlers // An alternate solution would be to implement context.WriteHeaders() ResponseWriter http.ResponseWriter writeAttempts int } // NewContext creates an empty Context struct func NewContext() *Context { return &Context{} } // Reset allows to reuse a context by removing all request specific information. // // It is valid to call Reset(nil, nil), which will just clear all information. // If w and r are non-nil, the context will be associated with them for handling // the request, and information such as the user agent and source IP will be // extracted for handlers. func (c *Context) Reset(w http.ResponseWriter, r *http.Request) { if c.Request != nil { if c.Request.MultipartForm != nil { err := c.Request.MultipartForm.RemoveAll() if err != nil && c.Logger != nil { c.Logger.Errorw("failed to remove temporary form files", "error", err) } } // Close the request body, which may have been replaced // by decodeRequestBody. net/http holds onto the original // body, so it won't close the decompressor. if c.Request.Body != nil { c.Request.Body.Close() } } *c = Context{ Logger: nil, Authentication: auth.AuthenticationDetails{}, ResponseWriter: w, // Reuse gzip and zlib reader buffers. gzipReader: c.gzipReader, zlibReader: c.zlibReader, } c.Result.Reset() if r != nil { c.setRequest(r) } } func (c *Context) setRequest(r *http.Request) { c.Timestamp = time.Now() c.Request = r c.UserAgent = strings.Join(r.Header["User-Agent"], ", ") ip, port := netutil.SplitAddrPort(r.RemoteAddr) c.SourceIP, c.ClientIP = ip, ip c.SourcePort, c.ClientPort = int(port), int(port) if ip, port := netutil.ClientAddrFromHeaders(r.Header); ip.IsValid() { c.SourceNATIP = c.ClientIP c.SourceIP, c.ClientIP = ip, ip c.SourcePort, c.ClientPort = int(port), int(port) } if c.Request.ContentLength != -1 { c.countingReadCloser.n = c.Request.ContentLength } else { c.countingReadCloser.ReadCloser = c.Request.Body c.Request.Body = &c.countingReadCloser } if err := c.decodeRequestBody(); err != nil { if c.Logger != nil { c.Logger.Errorw("failed to decode request body", "error", err) } c.Result.SetWithError(IDResponseErrorsDecode, err) } } // MultipleWriteAttempts returns a boolean set to true if WriteResult() was called multiple times. func (c *Context) MultipleWriteAttempts() bool { return c.writeAttempts > 1 } // WriteResult sets response headers, and writes the body to the response writer. // In case body is nil only the headers will be set. // In case statusCode indicates an error response, the body is also set as error in the context. // Only first call with write to http response. // This function wraps c.ResponseWriter.Write() - only one or the other should be used. func (c *Context) WriteResult() { if c.MultipleWriteAttempts() { return } c.writeAttempts++ // Before writing the result check for client timeout. // In case it happened override the result with timeout error. if err := c.Request.Context().Err(); errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { c.Result.SetDefault(IDResponseErrorsTimeout) c.Result.Err = errTimeout c.Result.Body = errTimeout.Error() } c.ResponseWriter.Header().Set(headers.XContentTypeOptions, "nosniff") body := c.Result.Body if body == nil { c.ResponseWriter.WriteHeader(c.Result.StatusCode) return } // wrap body in map: necessary to keep current logic if c.Result.Failure() { if b, ok := body.(string); ok { body = map[string]string{"error": b} } } var err error if c.acceptJSON() { c.ResponseWriter.Header().Set(headers.ContentType, "application/json") c.ResponseWriter.WriteHeader(c.Result.StatusCode) err = c.writeJSON(body, true) } else { c.ResponseWriter.Header().Set(headers.ContentType, "text/plain; charset=utf-8") c.ResponseWriter.WriteHeader(c.Result.StatusCode) err = c.writePlain(body) } if err != nil { c.errOnWrite(err) } } func (c *Context) acceptJSON() bool { acceptHeader := c.Request.Header.Get(headers.Accept) for _, s := range mimeTypesJSON { if strings.Contains(acceptHeader, s) { return true } } return false } func (c *Context) writeJSON(body interface{}, pretty bool) error { enc := json.NewEncoder(c.ResponseWriter) if pretty { enc.SetIndent("", " ") } return enc.Encode(body) } func (c *Context) writePlain(body interface{}) error { if b, ok := body.(string); ok { _, err := c.ResponseWriter.Write([]byte(b + "\n")) return err } // unexpected behavior to return json but changing this would be breaking return c.writeJSON(body, false) } func (c *Context) errOnWrite(err error) { if c.Logger == nil { c.Logger = logp.NewLogger(logs.Response) } c.Logger.Errorw("write response", "error", err) } func (c *Context) decodeRequestBody() error { if c.Request.ContentLength == 0 { return nil } var reader io.ReadCloser var err error switch c.Request.Header.Get("Content-Encoding") { case "deflate": reader, err = c.resetZlib(c.Request.Body) case "gzip": reader, err = c.resetGzip(c.Request.Body) default: // Sniff encoding from payload by looking at the first two bytes. // This produces much less garbage than opportunistically calling // gzip.NewReader, zlib.NewReader, etc. // // Portions of code based on compress/zlib and compress/gzip. const ( zlibDeflate = 8 gzipID1 = 0x1f gzipID2 = 0x8b ) rc := &c.compressedRequestReadCloser rc.ReadCloser = c.Request.Body if _, err := c.Request.Body.Read(rc.magic[:]); err != nil { if err == io.EOF { // Leave the original request body in place, // which should continue returning io.EOF. return nil } return err } if rc.magic[0] == gzipID1 && rc.magic[1] == gzipID2 { reader, err = c.resetGzip(rc) } else if rc.magic[0]&0x0f == zlibDeflate { reader, err = c.resetZlib(rc) } else { reader = rc } } if err != nil { return err } c.Request.ContentLength = -1 c.Request.Body = reader return nil } func (c *Context) resetZlib(r io.Reader) (io.ReadCloser, error) { if c.zlibReader == nil { zr, err := zlib.NewReader(r) if err != nil { return nil, err } c.zlibReader = zr.(zlibReadCloseResetter) } else if err := c.zlibReader.Reset(r, nil); err != nil { return nil, err } return c.zlibReader, nil } func (c *Context) resetGzip(r io.Reader) (io.ReadCloser, error) { var err error if c.gzipReader == nil { c.gzipReader, err = gzip.NewReader(r) } else { err = c.gzipReader.Reset(r) } return c.gzipReader, err } // RequestBodyBytes returns the original c.Request.ContentLength if it // was not -1, otherwise it returns the number of bytes read from the // request body. // // RequestBodyBytes must not be called concurrently with // c.Request.Body.Read(). func (c *Context) RequestBodyBytes() int64 { return c.countingReadCloser.n } type countingReadCloser struct { io.ReadCloser n int64 } func (r *countingReadCloser) Read(p []byte) (int, error) { n, err := r.ReadCloser.Read(p) if n > 0 { r.n += int64(n) } return n, err } type compressedRequestReadCloser struct { io.ReadCloser magic [2]byte magicRead int } func (r *compressedRequestReadCloser) Read(p []byte) (int, error) { var nmagic int if r.magicRead < 2 { nmagic = copy(p[:], r.magic[r.magicRead:]) r.magicRead += nmagic } n, err := r.ReadCloser.Read(p[nmagic:]) return n + nmagic, err }