x-pack/filebeat/input/http_endpoint/handler.go (519 lines of code) (raw):

// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one // or more contributor license agreements. Licensed under the Elastic License; // you may not use this file except in compliance with the Elastic License. package http_endpoint import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "net" "net/http" "net/url" "reflect" "sort" "strconv" "strings" "sync/atomic" "time" "github.com/google/cel-go/cel" "github.com/google/cel-go/checker/decls" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "go.uber.org/zap" "go.uber.org/zap/zapcore" "google.golang.org/protobuf/types/known/structpb" "github.com/elastic/beats/v7/libbeat/beat" "github.com/elastic/beats/v7/libbeat/common" "github.com/elastic/beats/v7/libbeat/common/jsontransform" "github.com/elastic/beats/v7/x-pack/filebeat/input/internal/httplog" "github.com/elastic/elastic-agent-libs/logp" "github.com/elastic/elastic-agent-libs/mapstr" "github.com/elastic/mito/lib" ) const headerContentEncoding = "Content-Encoding" var ( errBodyEmpty = errors.New("body cannot be empty") errUnsupportedType = errors.New("only JSON objects are accepted") errNotCRC = errors.New("event not processed as CRC request") ) type handler struct { ctx context.Context metrics *inputMetrics publish func(beat.Event) log *logp.Logger validator apiValidator txBaseID string // Random value to make transaction IDs unique. txIDCounter atomic.Uint64 // Transaction ID counter that is incremented for each request. // inFlight is the sum of message body length // that have been received but not yet ACKed // or timed out or otherwise handled. // // Requests that do not request a timeout do // not contribute to this value. inFlight atomic.Int64 // maxInFlight is the maximum value of inFligh // that will be allowed for any messages received // by the handler. If non-zero, inFlight may // not exceed this value. maxInFlight int64 retryAfter int reqLogger *zap.Logger host, scheme string program *program messageField string responseCode int responseBody string includeHeaders []string preserveOriginalEvent bool crc *crcValidator } func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { txID := h.nextTxID() h.log.Debugw("request", "url", r.URL, "tx_id", txID) status, err := h.validator.validateRequest(r) if err != nil { h.sendAPIErrorResponse(txID, w, r, h.log, status, err) return } wait, err := getTimeoutWait(r.URL, h.log) if err != nil { h.sendAPIErrorResponse(txID, w, r, h.log, http.StatusBadRequest, err) return } var ( acked chan struct{} timeout *time.Timer ) if h.maxInFlight != 0 { // Consider non-ACKing messages as well. These do not add // to the sum of in-flight bytes, but we can still assess // whether a message would take us over the limit. inFlight := h.inFlight.Load() + r.ContentLength if inFlight > h.maxInFlight { w.Header().Set(headerContentEncoding, "application/json") w.Header().Set("Retry-After", strconv.Itoa(h.retryAfter)) w.WriteHeader(http.StatusServiceUnavailable) _, err := fmt.Fprintf(w, `{"warn":"max in flight message memory exceeded","max_in_flight":%d,"in_flight":%d}`, h.maxInFlight, inFlight, ) if err != nil { h.log.Errorw("failed to write 503", "error", err) } return } } if wait != 0 { acked = make(chan struct{}) timeout = time.NewTimer(wait) h.inFlight.Add(r.ContentLength) defer func() { // Any return will be a message handling completion and the // the removal of the allocation from the queue assuming that // the client has requested a timeout. Either we have an early // error condition or timeout and the message is dropped, we // have ACKed all the events in the request, or the input has // been cancelled. h.inFlight.Add(-r.ContentLength) }() } start := time.Now() acker := newBatchACKTracker(func() { h.metrics.batchACKTime.Update(time.Since(start).Nanoseconds()) h.metrics.batchesACKedTotal.Inc() if acked != nil { close(acked) } }) h.metrics.batchesReceived.Add(1) h.metrics.contentLength.Update(r.ContentLength) body, status, err := getBodyReader(r) if err != nil { h.sendAPIErrorResponse(txID, w, r, h.log, status, err) h.metrics.apiErrors.Add(1) return } defer body.Close() if h.reqLogger != nil { // If we are logging, keep a copy of the body for the logger. // This is stashed in the r.Body field. This is only safe // because we are closing the original body in a defer and // r.Body is not otherwise referenced by the non-logging logic // after the call to getBodyReader above. var buf bytes.Buffer body = io.NopCloser(io.TeeReader(body, &buf)) r.Body = io.NopCloser(&buf) } objs, status, err := httpReadJSON(body, h.program) if err != nil { h.sendAPIErrorResponse(txID, w, r, h.log, status, err) h.metrics.apiErrors.Add(1) return } var headers map[string]interface{} if len(h.includeHeaders) != 0 { headers = getIncludedHeaders(r, h.includeHeaders) } var ( respCode int respBody string ) h.metrics.batchSize.Update(int64(len(objs))) for _, obj := range objs { var err error if h.crc != nil { respCode, respBody, err = h.crc.validate(obj) if err == nil { // CRC request processed break } else if !errors.Is(err, errNotCRC) { h.metrics.apiErrors.Add(1) h.sendAPIErrorResponse(txID, w, r, h.log, http.StatusBadRequest, err) return } } acker.Add() if err = h.publishEvent(obj, headers, acker); err != nil { h.metrics.apiErrors.Add(1) h.sendAPIErrorResponse(txID, w, r, h.log, http.StatusInternalServerError, err) return } h.metrics.eventsPublished.Add(1) respCode, respBody = h.responseCode, h.responseBody } acker.Ready() if acked == nil { h.sendResponse(w, respCode, respBody) } else { select { case <-acked: h.log.Debugw("request acked", "tx_id", txID) if !timeout.Stop() { <-timeout.C } h.sendResponse(w, respCode, respBody) case <-timeout.C: h.log.Debugw("request timed out", "tx_id", txID) h.sendAPIErrorResponse(txID, w, r, h.log, http.StatusGatewayTimeout, errTookTooLong) case <-h.ctx.Done(): h.log.Debugw("request context cancelled", "tx_id", txID) h.sendAPIErrorResponse(txID, w, r, h.log, http.StatusGatewayTimeout, h.ctx.Err()) } if h.reqLogger != nil { h.logRequest(txID, r, respCode, nil) } } h.metrics.batchProcessingTime.Update(time.Since(start).Nanoseconds()) h.metrics.batchesPublished.Add(1) } var errTookTooLong = errors.New("could not publish event within timeout") func getTimeoutWait(u *url.URL, log *logp.Logger) (time.Duration, error) { q := u.Query() switch len(q) { case 0: return 0, nil case 1: if _, ok := q["wait_for_completion_timeout"]; !ok { // Get the only key in q. We don't know what it is, so iterate // over the first one of one. var k string for k = range q { break } return 0, fmt.Errorf("unexpected URL query: %s", k) } default: delete(q, "wait_for_completion_timeout") keys := make([]string, 0, len(q)) for k := range q { keys = append(keys, k) } sort.Strings(keys) return 0, fmt.Errorf("unexpected URL query: %s", strings.Join(keys, ", ")) } p := q.Get("wait_for_completion_timeout") if p == "" { // This will never happen; it is already handled in the check switch above. return 0, nil } log.Debugw("wait_for_completion_timeout parameter", "value", p) t, err := time.ParseDuration(p) if err != nil { return 0, fmt.Errorf("could not parse wait_for_completion_timeout parameter: %w", err) } if t < 0 { return 0, fmt.Errorf("negative wait_for_completion_timeout parameter: %w", err) } return t, nil } func (h *handler) sendAPIErrorResponse(txID string, w http.ResponseWriter, r *http.Request, log *logp.Logger, status int, apiError error) { log.Errorw("request error", "tx_id", txID, "status_code", status, "error", apiError) w.Header().Add("Content-Type", "application/json") w.WriteHeader(status) var ( mw io.Writer = w buf bytes.Buffer ) if h.reqLogger != nil { mw = io.MultiWriter(mw, &buf) } enc := json.NewEncoder(mw) enc.SetEscapeHTML(false) err := enc.Encode(map[string]interface{}{"message": apiError.Error()}) if err != nil { log.Debugw("Failed to write HTTP response.", "error", err, "client.address", r.RemoteAddr) } if h.reqLogger != nil { h.logRequest(txID, r, status, buf.Bytes()) } } func (h *handler) logRequest(txID string, r *http.Request, status int, respBody []byte) { // Populate and preserve scheme and host if they are missing; // they are required for httputil.DumpRequestOut. var scheme, host string if r.URL.Scheme == "" { scheme = r.URL.Scheme r.URL.Scheme = h.scheme } if r.URL.Host == "" { host = r.URL.Host r.URL.Host = h.host } extra := make([]zapcore.Field, 1, 4) extra[0] = zap.Int("status", status) addr, port, err := net.SplitHostPort(r.RemoteAddr) if err == nil { extra = append(extra, zap.String("source.ip", addr), zap.String("source.port", port), ) } if len(respBody) != 0 { extra = append(extra, zap.ByteString("http.response.body.content", respBody), ) } h.log.Debugw("new request trace transaction", "id", txID) // Limit request logging body size to 10kiB. const maxBodyLen = 10 * (1 << 10) httplog.LogRequest(h.reqLogger.With(zap.String("transaction.id", txID)), r, maxBodyLen, extra...) if scheme != "" { r.URL.Scheme = scheme } if host != "" { r.URL.Host = host } } func (h *handler) nextTxID() string { count := h.txIDCounter.Add(1) return h.formatTxID(count) } func (h *handler) formatTxID(count uint64) string { return h.txBaseID + "-" + strconv.FormatUint(count, 10) } func (h *handler) sendResponse(w http.ResponseWriter, status int, message string) { w.Header().Add("Content-Type", "application/json") w.WriteHeader(status) if _, err := io.WriteString(w, message); err != nil { h.log.Debugw("Failed writing response to client.", "error", err) } } func (h *handler) publishEvent(obj, headers mapstr.M, acker *batchACKTracker) error { event := beat.Event{ Timestamp: time.Now().UTC(), Private: acker, } if h.messageField == "." { event.Fields = obj } else { if _, err := event.PutValue(h.messageField, obj); err != nil { return fmt.Errorf("failed to put data into event key %q: %w", h.messageField, err) } } if h.preserveOriginalEvent { event.Fields["event"] = mapstr.M{ "original": obj.String(), } } if len(headers) > 0 { event.Fields["headers"] = headers } h.publish(event) return nil } func httpReadJSON(body io.Reader, prg *program) (objs []mapstr.M, status int, err error) { if body == http.NoBody { return nil, http.StatusNotAcceptable, errBodyEmpty } obj, err := decodeJSON(body, prg) if err != nil { return nil, http.StatusBadRequest, err } return obj, http.StatusOK, err } func decodeJSON(body io.Reader, prg *program) (objs []mapstr.M, err error) { decoder := json.NewDecoder(body) for decoder.More() { var raw json.RawMessage if err = decoder.Decode(&raw); err != nil { if err == io.EOF { //nolint:errorlint // This will never be a wrapped error. break } return nil, fmt.Errorf("malformed JSON object at stream position %d: %w", decoder.InputOffset(), err) } var obj interface{} if err = newJSONDecoder(bytes.NewReader(raw)).Decode(&obj); err != nil { return nil, fmt.Errorf("malformed JSON object at stream position %d: %w", decoder.InputOffset(), err) } if prg != nil { obj, err = prg.eval(obj) if err != nil { return nil, err } if _, ok := obj.([]interface{}); ok { // Re-marshal to ensure the raw bytes agree with the constructed object. // This is only necessary when the program constructs an array return. raw, err = json.Marshal(obj) if err != nil { return nil, fmt.Errorf("failed to remarshal object: %w", err) } } } switch v := obj.(type) { case map[string]interface{}: objs = append(objs, v) case []interface{}: nobjs, err := decodeJSONArray(bytes.NewReader(raw)) if err != nil { return nil, fmt.Errorf("recursive error %d: %w", decoder.InputOffset(), err) } objs = append(objs, nobjs...) default: return nil, fmt.Errorf("%w: %T", errUnsupportedType, v) } } for i := range objs { jsontransform.TransformNumbers(objs[i]) } return objs, nil } type program struct { prg cel.Program ast *cel.Ast } func newProgram(src string, log *logp.Logger) (*program, error) { if src == "" { return nil, nil } registry, err := types.NewRegistry() if err != nil { return nil, fmt.Errorf("failed to create env: %w", err) } env, err := cel.NewEnv( cel.Declarations(decls.NewVar("obj", decls.Dyn)), cel.OptionalTypes(cel.OptionalTypesVersion(lib.OptionalTypesVersion)), cel.CustomTypeAdapter(&numberAdapter{registry}), cel.CustomTypeProvider(registry), lib.Debug(debug(log)), ) if err != nil { return nil, fmt.Errorf("failed to create env: %w", err) } ast, iss := env.Compile(src) if iss.Err() != nil { return nil, fmt.Errorf("failed compilation: %w", iss.Err()) } prg, err := env.Program(ast) if err != nil { return nil, fmt.Errorf("failed program instantiation: %w", err) } return &program{prg: prg, ast: ast}, nil } func debug(log *logp.Logger) func(string, any) { log = log.Named("http_endpoint_cel_debug") return func(tag string, value any) { level := "DEBUG" if _, ok := value.(error); ok { level = "ERROR" } log.Debugw(level, "tag", tag, "value", value) } } var _ types.Adapter = (*numberAdapter)(nil) type numberAdapter struct { fallback types.Adapter } func (a *numberAdapter) NativeToValue(value any) ref.Val { switch value := value.(type) { case []any: for i, v := range value { value[i] = a.NativeToValue(v) } case map[string]any: for k, v := range value { value[k] = a.NativeToValue(v) } case json.Number: var errs []error i, err := value.Int64() if err == nil { return types.Int(i) } errs = append(errs, err) f, err := value.Float64() if err == nil { // Literalise floats that could have been an integer greater than // can be stored without loss of precision in a double. // This is any integer wider than the IEEE-754 double mantissa. // As a heuristic, allow anything that includes a decimal point // or uses scientific notation. We could be more careful, but // it is likely not important, and other languages use the same // rule. if f >= 0x1p53 && !strings.ContainsFunc(string(value), func(r rune) bool { return r == '.' || r == 'e' || r == 'E' }) { return types.String(value) } return types.Double(f) } errs = append(errs, err) return types.NewErr("%v", errors.Join(errs...)) } return a.fallback.NativeToValue(value) } func (p *program) eval(obj interface{}) (interface{}, error) { out, _, err := p.prg.Eval(map[string]interface{}{"obj": obj}) if err != nil { err = lib.DecoratedError{AST: p.ast, Err: err} return nil, fmt.Errorf("failed eval: %w", err) } v, err := out.ConvertToNative(reflect.TypeOf((*structpb.Value)(nil))) if err != nil { return nil, fmt.Errorf("failed proto conversion: %w", err) } switch v := v.(type) { case *structpb.Value: return v.AsInterface(), nil default: // This should never happen. return nil, fmt.Errorf("unexpected native conversion type: %T", v) } } func decodeJSONArray(raw *bytes.Reader) (objs []mapstr.M, err error) { dec := newJSONDecoder(raw) token, err := dec.Token() if err != nil { if err == io.EOF { //nolint:errorlint // This will never be a wrapped error. return nil, nil } return nil, fmt.Errorf("failed reading JSON array: %w", err) } if token != json.Delim('[') { return nil, fmt.Errorf("malformed JSON array, not starting with delimiter [ at position: %d", dec.InputOffset()) } for dec.More() { var raw json.RawMessage if err := dec.Decode(&raw); err != nil { if err == io.EOF { //nolint:errorlint // This will never be a wrapped error. break } return nil, fmt.Errorf("malformed JSON object at stream position %d: %w", dec.InputOffset(), err) } var obj interface{} if err := newJSONDecoder(bytes.NewReader(raw)).Decode(&obj); err != nil { return nil, fmt.Errorf("malformed JSON object at stream position %d: %w", dec.InputOffset(), err) } m, ok := obj.(map[string]interface{}) if ok { objs = append(objs, m) } } return objs, nil } func getIncludedHeaders(r *http.Request, headerConf []string) (includedHeaders mapstr.M) { includedHeaders = mapstr.M{} for _, header := range headerConf { if value, found := r.Header[header]; found { includedHeaders[common.DeDot(header)] = value } } return includedHeaders } func newJSONDecoder(r io.Reader) *json.Decoder { dec := json.NewDecoder(r) dec.UseNumber() return dec } // getBodyReader returns a reader that decodes the specified Content-Encoding. func getBodyReader(r *http.Request) (body io.ReadCloser, status int, err error) { switch enc := r.Header.Get(headerContentEncoding); enc { case "gzip", "x-gzip": gzipReader, err := newPooledGzipReader(r.Body) if err != nil { return nil, http.StatusInternalServerError, fmt.Errorf("failed to create gzip reader: %w", err) } return gzipReader, 0, nil case "": // No encoding. return r.Body, 0, nil default: return nil, http.StatusUnsupportedMediaType, fmt.Errorf("unsupported Content-Encoding type %q", enc) } }