runtime/server_http_request.go (669 lines of code) (raw):

// Copyright (c) 2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package zanzibar import ( "context" "fmt" "io" "net/http" "net/url" "strconv" "strings" "time" "github.com/buger/jsonparser" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" "github.com/uber-go/tally" "github.com/uber/jaeger-client-go" "github.com/uber/zanzibar/runtime/jsonwrapper" "go.uber.org/zap" ) // ServerHTTPRequest struct manages server http request type ServerHTTPRequest struct { httpRequest *http.Request res *ServerHTTPResponse startTime time.Time started bool tracer opentracing.Tracer span opentracing.Span queryValues url.Values parseFailed bool rawBody []byte EndpointName string HandlerName string URL *url.URL Method string Params url.Values Header Header // logger logs entries with default fields that contains request meta info contextLogger ContextLogger // scope emit metrics with default tags that contains request meta info scope tally.Scope jsonWrapper jsonwrapper.JSONWrapper } // NewServerHTTPRequest is helper function to alloc ServerHTTPRequest func NewServerHTTPRequest( w http.ResponseWriter, r *http.Request, params url.Values, endpoint *RouterEndpoint, ) *ServerHTTPRequest { ctx := r.Context() // put request log fields on context logFields := []zap.Field{ zap.String(logFieldEndpointID, endpoint.EndpointName), zap.String(logFieldEndpointHandler, endpoint.HandlerName), zap.String(logFieldRequestHTTPMethod, r.Method), zap.String(logFieldRequestRemoteAddr, r.RemoteAddr), zap.String(logFieldRequestPathname, r.URL.RequestURI()), zap.String(logFieldRequestHost, r.Host), } // put request scope tags on context scopeTags := map[string]string{ scopeTagEndpoint: endpoint.EndpointName, scopeTagHandler: endpoint.HandlerName, scopeTagProtocol: scopeTagHTTP, } if endpoint.contextExtractor != nil { headers := map[string]string{} for k, v := range r.Header { // TODO: this 0th element logic is probably not correct headers[k] = v[0] } ctx = WithEndpointRequestHeadersField(ctx, headers) for k, v := range endpoint.contextExtractor.ExtractScopeTags(ctx) { scopeTags[k] = v } logFields = append(logFields, endpoint.contextExtractor.ExtractLogFields(ctx)...) } // Overriding the api-environment and default to production apiEnvironment := GetAPIEnvironment(endpoint, r) scopeTags[scopeTagsAPIEnvironment] = apiEnvironment logFields = append(logFields, zap.String(apienvironmentKey, apiEnvironment)) // Overriding the environment for shadow requests if endpoint.config != nil { if endpoint.config.ContainsKey("service.shadow.env.override.enable") && endpoint.config.MustGetBoolean("service.shadow.env.override.enable") && endpoint.config.ContainsKey("shadowRequestHeader") && r.Header.Get(endpoint.config.MustGetString("shadowRequestHeader")) != "" { scopeTags[environmentKey] = shadowEnvironment logFields = append(logFields, zap.String(environmentKey, shadowEnvironment)) } } ctx = WithScopeTagsDefault(ctx, scopeTags, endpoint.scope) ctx = WithLogFields(ctx, logFields...) httpRequest := r.WithContext(ctx) scope := getScope(ctx, endpoint.scope) // use the calculated scope instead of making a new one logger := endpoint.contextLogger req := &ServerHTTPRequest{ httpRequest: httpRequest, queryValues: nil, tracer: endpoint.tracer, EndpointName: endpoint.EndpointName, HandlerName: endpoint.HandlerName, URL: httpRequest.URL, Method: httpRequest.Method, Params: params, Header: NewServerHTTPHeader(r.Header), contextLogger: logger, scope: scope, jsonWrapper: endpoint.JSONWrapper, } req.res = NewServerHTTPResponse(w, req) req.start() return req } // GetAPIEnvironment returns the api environment for a given request. // By default, the api environment is set to production. However, there may be // use cases where a different environment may be required for monitoring purposes. // This may be overridden by a non-empty environment value in the request header. func GetAPIEnvironment(endpoint *RouterEndpoint, r *http.Request) string { apiEnvironment := apiEnvironmentDefault if endpoint.config != nil && endpoint.config.ContainsKey("apiEnvironmentHeader") && r.Header.Get(endpoint.config.MustGetString("apiEnvironmentHeader")) != "" { apiEnvironment = r.Header.Get(endpoint.config.MustGetString("apiEnvironmentHeader")) } return apiEnvironment } // Context returns the request's context. func (req *ServerHTTPRequest) Context() context.Context { return req.httpRequest.Context() } // StartTime returns the request's start time. func (req *ServerHTTPRequest) StartTime() time.Time { return req.startTime } // start the request, emit metrics etc func (req *ServerHTTPRequest) start() { if req.started { /* coverage ignore next line */ req.contextLogger.Error(req.Context(), "Cannot start ServerHTTPRequest twice", zap.String("path", req.URL.Path), ) /* coverage ignore next line */ return } req.started = true req.startTime = time.Now() // emit request count req.scope.Counter(endpointRequest).Inc(1) if req.tracer != nil { opName := fmt.Sprintf("%s.%s", req.EndpointName, req.HandlerName) urlTag := opentracing.Tag{Key: "URL", Value: req.URL} MethodTag := opentracing.Tag{Key: "Method", Value: req.Method} carrier := opentracing.HTTPHeadersCarrier(req.httpRequest.Header) spanContext, err := req.tracer.Extract(opentracing.HTTPHeaders, carrier) var span opentracing.Span if err != nil { if err != opentracing.ErrSpanContextNotFound { /* coverage ignore next line */ req.contextLogger.WarnZ(req.Context(), "Error Extracting Trace Headers", zap.Error(err)) } span = req.tracer.StartSpan(opName, urlTag, MethodTag) } else { span = req.tracer.StartSpan(opName, urlTag, MethodTag, ext.RPCServerOption(spanContext)) } req.span = span } req.setupLogFields() } func (req *ServerHTTPRequest) setupLogFields() { fields := GetLogFieldsFromCtx(req.Context()) if span := req.GetSpan(); span != nil { jc, ok := span.Context().(jaeger.SpanContext) if ok { fields = append(fields, zap.String(TraceSpanKey, jc.SpanID().String()), zap.String(TraceIDKey, jc.TraceID().String()), zap.Bool(TraceSampledKey, jc.IsSampled()), ) } } ctx := WithLogFields(req.Context(), fields...) req.httpRequest = req.httpRequest.WithContext(ctx) } // CheckHeaders verifies that request contains required headers. func (req *ServerHTTPRequest) CheckHeaders(headers []string) bool { for _, headerName := range headers { _, ok := req.Header.Get(headerName) if !ok { req.contextLogger.WarnZ(req.Context(), "Got request without mandatory header", zap.String("headerName", headerName), ) if !req.parseFailed { req.res.SendErrorString( 400, "Missing mandatory header: "+headerName, ) req.parseFailed = true } return false } } return true } // PeekBody allows for inspecting a key path inside the body // that is not flushed yet. This is useful for response middlewares // that want to inspect the response body. func (req *ServerHTTPRequest) PeekBody( keys ...string, ) ([]byte, jsonparser.ValueType, error) { value, valueType, _, err := jsonparser.Get( req.rawBody, keys..., ) if err != nil { return nil, -1, err } return value, valueType, nil } func (req *ServerHTTPRequest) parseQueryValues() bool { if req.parseFailed { return false } if req.queryValues != nil { return true } values, err := url.ParseQuery(req.httpRequest.URL.RawQuery) if err != nil { req.contextLogger.WarnZ(req.Context(), "Got request with invalid query string", zap.Error(err)) if !req.parseFailed { req.res.SendErrorString( 400, "Could not parse query string", ) req.parseFailed = true } return false } req.queryValues = values return true } // GetQueryValue will return the first query parameter for key or empty string func (req *ServerHTTPRequest) GetQueryValue(key string) (string, bool) { success := req.parseQueryValues() if !success { return "", false } return req.queryValues.Get(key), true } // SetQueryValue will set the value of the query parameter, replacing any existing value // We only work with one value, and not a list (in keeping with other url.Values methods) func (req *ServerHTTPRequest) SetQueryValue(key string, value string) { if req.queryValues == nil { req.queryValues = make(url.Values) } req.queryValues.Set(key, value) } // GetQueryBool will return a query param as a boolean func (req *ServerHTTPRequest) GetQueryBool(key string) (bool, bool) { success := req.parseQueryValues() if !success { return false, false } value := req.queryValues.Get(key) if value == "true" { return true, true } else if value == "false" { return false, true } err := &strconv.NumError{ Func: "ParseBool", Num: value, Err: strconv.ErrSyntax, } req.LogAndSendQueryError(err, "bool", key, value) return false, false } // GetQueryInt8 will return a query params as int8 func (req *ServerHTTPRequest) GetQueryInt8(key string) (int8, bool) { success := req.parseQueryValues() if !success { return 0, false } value := req.queryValues.Get(key) number, err := strconv.ParseInt(value, 10, 8) if err != nil { req.LogAndSendQueryError(err, "int8", key, value) return 0, false } return int8(number), true } // GetQueryInt16 will return a query params as int16 func (req *ServerHTTPRequest) GetQueryInt16(key string) (int16, bool) { success := req.parseQueryValues() if !success { return 0, false } value := req.queryValues.Get(key) number, err := strconv.ParseInt(value, 10, 16) if err != nil { req.LogAndSendQueryError(err, "int16", key, value) return 0, false } return int16(number), true } // GetQueryInt32 will return a query params as int32 func (req *ServerHTTPRequest) GetQueryInt32(key string) (int32, bool) { success := req.parseQueryValues() if !success { return 0, false } value := req.queryValues.Get(key) number, err := strconv.ParseInt(value, 10, 32) if err != nil { req.LogAndSendQueryError(err, "int32", key, value) return 0, false } return int32(number), true } // GetQueryInt64 will return a query param as int64 func (req *ServerHTTPRequest) GetQueryInt64(key string) (int64, bool) { success := req.parseQueryValues() if !success { return 0, false } value := req.queryValues.Get(key) number, err := strconv.ParseInt(value, 10, 64) if err != nil { req.LogAndSendQueryError(err, "int64", key, value) return 0, false } return number, true } // GetQueryFloat64 will return query param key as float64 func (req *ServerHTTPRequest) GetQueryFloat64(key string) (float64, bool) { success := req.parseQueryValues() if !success { return 0, false } value := req.queryValues.Get(key) number, err := strconv.ParseFloat(value, 64) if err != nil { req.LogAndSendQueryError(err, "float64", key, value) return 0, false } return number, true } // -- Query params as lists -- // GetQueryBoolList will return a query param as a list of boolean func (req *ServerHTTPRequest) GetQueryBoolList(key string) ([]bool, bool) { success := req.parseQueryValues() if !success { return nil, false } values := req.queryValues[key] ret := make([]bool, len(values)) for i, value := range values { if value == "true" { ret[i] = true } else if value == "false" { ret[i] = false } else { err := &strconv.NumError{ Func: "ParseBool", Num: value, Err: strconv.ErrSyntax, } req.LogAndSendQueryError(err, "bool", key, value) return nil, false } } return ret, true } // GetQueryInt8List will return a query params as list of int8 func (req *ServerHTTPRequest) GetQueryInt8List(key string) ([]int8, bool) { success := req.parseQueryValues() if !success { return nil, false } values := req.queryValues[key] ret := make([]int8, len(values)) for i, value := range values { number, err := strconv.ParseInt(value, 10, 8) if err != nil { req.LogAndSendQueryError(err, "int8", key, value) return nil, false } ret[i] = int8(number) } return ret, true } // GetQueryInt16List will return a query params as list of int16 func (req *ServerHTTPRequest) GetQueryInt16List(key string) ([]int16, bool) { success := req.parseQueryValues() if !success { return nil, false } values := req.queryValues[key] ret := make([]int16, len(values)) for i, value := range values { number, err := strconv.ParseInt(value, 10, 16) if err != nil { req.LogAndSendQueryError(err, "int16", key, value) return nil, false } ret[i] = int16(number) } return ret, true } // GetQueryInt32List will return a query params as list of int32 func (req *ServerHTTPRequest) GetQueryInt32List(key string) ([]int32, bool) { success := req.parseQueryValues() if !success { return nil, false } values := req.queryValues[key] ret := make([]int32, len(values)) for i, value := range values { number, err := strconv.ParseInt(value, 10, 32) if err != nil { req.LogAndSendQueryError(err, "int32", key, value) return nil, false } ret[i] = int32(number) } return ret, true } // GetQueryInt64List will return a query params as list of int64 func (req *ServerHTTPRequest) GetQueryInt64List(key string) ([]int64, bool) { success := req.parseQueryValues() if !success { return nil, false } values := req.queryValues[key] ret := make([]int64, len(values)) for i, value := range values { number, err := strconv.ParseInt(value, 10, 64) if err != nil { req.LogAndSendQueryError(err, "int64", key, value) return nil, false } ret[i] = number } return ret, true } // GetQueryFloat64List will return a query params as list of float64 func (req *ServerHTTPRequest) GetQueryFloat64List(key string) ([]float64, bool) { success := req.parseQueryValues() if !success { return nil, false } values := req.queryValues[key] ret := make([]float64, len(values)) for i, value := range values { number, err := strconv.ParseFloat(value, 64) if err != nil { req.LogAndSendQueryError(err, "float64", key, value) return nil, false } ret[i] = number } return ret, true } // GetQueryValues will return all query parameters for key. func (req *ServerHTTPRequest) GetQueryValues(key string) ([]string, bool) { success := req.parseQueryValues() if !success { return nil, false } return req.queryValues[key], true } // GetQueryValueList will return all query parameters for key. func (req *ServerHTTPRequest) GetQueryValueList(key string) ([]string, bool) { return req.GetQueryValues(key) } // -- Query param as set -- /** * A set of bools does not make sense and is unimplemented * Also, in every use-case for a gateway, a set implemented as a map is not very useful, instead one where * it is a list with no duplicates is. Therefore the implementation picks that approach. */ // The "value" in the map representation of a set datastructure var _nullVal = struct{}{} // GetQueryInt8Set will return a query params as set of int8 (implemented as a deduped slice) func (req *ServerHTTPRequest) GetQueryInt8Set(key string) ([]int8, bool) { success := req.parseQueryValues() if !success { return nil, false } values := req.queryValues[key] set := make(map[int8]struct{}, len(values)) for _, value := range values { number, err := strconv.ParseInt(value, 0, 8) if err != nil { req.LogAndSendQueryError(err, "int8", key, value) return nil, false } set[int8(number)] = _nullVal } ret := make([]int8, len(set)) i := 0 for item := range set { ret[i] = item i++ } return ret, true } // GetQueryInt16Set will return a query params as set of int16 func (req *ServerHTTPRequest) GetQueryInt16Set(key string) ([]int16, bool) { success := req.parseQueryValues() if !success { return nil, false } values := req.queryValues[key] set := make(map[int16]struct{}, len(values)) for _, value := range values { number, err := strconv.ParseInt(value, 0, 16) if err != nil { req.LogAndSendQueryError(err, "int16", key, value) return nil, false } set[int16(number)] = _nullVal } ret := make([]int16, len(set)) i := 0 for item := range set { ret[i] = item i++ } return ret, true } // GetQueryInt32Set will return a query params as set of int32 func (req *ServerHTTPRequest) GetQueryInt32Set(key string) ([]int32, bool) { success := req.parseQueryValues() if !success { return nil, false } values := req.queryValues[key] set := make(map[int32]struct{}, len(values)) for _, value := range values { number, err := strconv.ParseInt(value, 10, 32) if err != nil { req.LogAndSendQueryError(err, "int32", key, value) return nil, false } set[int32(number)] = _nullVal } ret := make([]int32, len(set)) i := 0 for item := range set { ret[i] = item i++ } return ret, true } // GetQueryInt64Set will return a query params as set of int64 func (req *ServerHTTPRequest) GetQueryInt64Set(key string) ([]int64, bool) { success := req.parseQueryValues() if !success { return nil, false } values := req.queryValues[key] set := make(map[int64]struct{}, len(values)) for _, value := range values { number, err := strconv.ParseInt(value, 10, 64) if err != nil { req.LogAndSendQueryError(err, "int64", key, value) return nil, false } set[number] = _nullVal } ret := make([]int64, len(set)) i := 0 for item := range set { ret[i] = item i++ } return ret, true } // GetQueryFloat64Set will return a query params as set of float64 func (req *ServerHTTPRequest) GetQueryFloat64Set(key string) ([]float64, bool) { success := req.parseQueryValues() if !success { return nil, false } values := req.queryValues[key] set := make(map[float64]struct{}, len(values)) for _, value := range values { number, err := strconv.ParseFloat(value, 64) if err != nil { req.LogAndSendQueryError(err, "float64", key, value) return nil, false } set[number] = _nullVal } ret := make([]float64, len(set)) i := 0 for item := range set { ret[i] = item i++ } return ret, true } // GetQueryValueSet will return all query parameters for key as a set func (req *ServerHTTPRequest) GetQueryValueSet(key string) ([]string, bool) { success := req.parseQueryValues() if !success { return nil, false } set := make(map[string]struct{}, len(req.queryValues[key])) for _, v := range req.queryValues[key] { set[v] = _nullVal } ret := make([]string, len(set)) i := 0 for item := range set { ret[i] = item i++ } return ret, true } // HasQueryPrefix will check if any query param starts with key. func (req *ServerHTTPRequest) HasQueryPrefix(prefix string) bool { success := req.parseQueryValues() if !success { return false } for key := range req.queryValues { if strings.HasPrefix(key, prefix) { return true } } return false } // CheckQueryValue will check for a required query param. func (req *ServerHTTPRequest) CheckQueryValue(key string) bool { success := req.parseQueryValues() if !success { return false } values := req.queryValues[key] if len(values) == 0 { req.contextLogger.WarnZ(req.Context(), "Got request with missing query string value", zap.String("expectedKey", key), ) if !req.parseFailed { req.res.SendErrorString( 400, "Could not parse query string", ) req.parseFailed = true } return false } return true } // HasQueryValue will return bool if the query param exists. func (req *ServerHTTPRequest) HasQueryValue(key string) bool { success := req.parseQueryValues() if !success { return false } values := req.queryValues[key] if len(values) == 0 { return false } return true } // ReadAndUnmarshalBody will try to unmarshal into struct or fail func (req *ServerHTTPRequest) ReadAndUnmarshalBody( body interface{}, ) bool { rawBody, success := req.ReadAll() if !success { return false } return req.UnmarshalBody(body, rawBody) } // GetRawBody returns raw body of request func (req *ServerHTTPRequest) GetRawBody() []byte { return req.rawBody } // ReadAll helper to read entire body func (req *ServerHTTPRequest) ReadAll() ([]byte, bool) { if req.rawBody != nil { return req.rawBody, true } rawBody, err := io.ReadAll(req.httpRequest.Body) if err != nil { req.contextLogger.ErrorZ(req.Context(), "Could not read request body", zap.Error(err)) if !req.parseFailed { req.res.SendError(500, "Could not read request body", err) req.parseFailed = true } return nil, false } req.rawBody = rawBody return rawBody, true } // UnmarshalBody helper to unmarshal body into struct func (req *ServerHTTPRequest) UnmarshalBody( body interface{}, rawBody []byte, ) bool { err := req.jsonWrapper.Unmarshal(rawBody, body) if err != nil { req.contextLogger.WarnZ(req.Context(), "Could not parse json", zap.Error(err)) if !req.parseFailed { req.res.SendError(400, "Could not parse json: "+err.Error(), err) req.parseFailed = true } return false } return true } // ReplaceBody replaces the raw request body with given body and updates the request content-length header accordingly. // This method is only supposed to be used in middlewares where request body needs to be modified. // The encoding of the body should stay the same. func (req *ServerHTTPRequest) ReplaceBody(body []byte) { // Replace the cached body bytes and fix dependent header req.rawBody = body if _, ok := req.Header.Get("Content-Length"); ok { req.Header.Set("Content-Length", strconv.Itoa(len(body))) } } // GetSpan returns the http request span func (req *ServerHTTPRequest) GetSpan() opentracing.Span { return req.span } // LogAndSendQueryError handles parse failure of query params by logging the issue and returning a 400 to the requestor func (req *ServerHTTPRequest) LogAndSendQueryError(err error, expected, key, value string) { req.contextLogger.WarnZ(req.Context(), "Got request with invalid query string types", zap.String("expected", expected), zap.String("actual", value), zap.String("key", key), zap.Error(err), ) if !req.parseFailed { req.res.SendError(400, "Could not parse query string", err) req.parseFailed = true } }