utils/http.go (255 lines of code) (raw):

// Copyright (c) 2017-2018 Uber Technologies, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package utils import ( "compress/gzip" "encoding/json" "fmt" "github.com/uber-go/tally" "github.com/uber/aresdb/common" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" "golang.org/x/net/netutil" "net" "net/http" "strconv" "time" ) const ( HTTPContentTypeHeaderKey = "Content-Type" HTTPAcceptTypeHeaderKey = "Accept" HTTPAcceptEncodingHeaderKey = "Accept-Encoding" HTTPContentEncodingHeaderKey = "Content-Encoding" HTTPContentTypeApplicationJson = "application/json" HTTPContentTypeApplicationGRPC = "application/grpc" // HTTPContentTypeUpsertBatch defines the upsert data content type. HTTPContentTypeUpsertBatch = "application/upsert-data" // HTTPContentTypeHyperLogLog defines the hyperloglog query result content type. HTTPContentTypeHyperLogLog = "application/hll" HTTPContentEncodingGzip = "gzip" // CompressionThreshold is the min number of bytes beyond which we will compress json payload CompressionThreshold = 1 << 10 ) var epoch = time.Unix(0, 0).Format(time.RFC1123) var noCacheHeaders = map[string]string{ "Expires": epoch, "Cache-Control": "no-cache, private, max-age=0", "Pragma": "no-cache", "X-Accel-Expires": "0", } var etagHeaders = []string{ "ETag", "If-Modified-Since", "If-Match", "If-None-Match", "If-Range", "If-Unmodified-Since", } // NoCache sets no cache headers and removes any ETag headers that may have been set. func NoCache(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { for _, v := range etagHeaders { if r.Header.Get(v) != "" { r.Header.Del(v) } } for k, v := range noCacheHeaders { w.Header().Set(k, v) } h.ServeHTTP(w, r) } return http.HandlerFunc(fn) } // GetOrigin returns the caller of the request. func GetOrigin(r *http.Request) string { origin := r.Header.Get("RPC-Caller") if origin == "" { origin = r.Header.Get("X-Uber-Origin") } if origin == "" { origin = "UNKNOWN" } return origin } // LimitServe will start a http server on the port with the handler and at most maxConnection concurrent connections. func LimitServe(port int, handler http.Handler, httpCfg common.HTTPConfig) { listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { GetLogger().Fatal(err) } defer listener.Close() listener = netutil.LimitListener(listener, httpCfg.MaxConnections) server := &http.Server{ ReadTimeout: time.Duration(httpCfg.ReadTimeOutInSeconds) * time.Second, WriteTimeout: time.Duration(httpCfg.WriteTimeOutInSeconds) * time.Second, Handler: h2c.NewHandler(handler, &http2.Server{}), } GetLogger().Fatal(server.Serve(listener)) } // LimitServeAsync will start a http server on the port with the handler and at most maxConnection concurrent connections. func LimitServeAsync(port int, handler http.Handler, httpCfg common.HTTPConfig) (chan error, *http.Server) { listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { GetLogger().Fatal(err) } listener = netutil.LimitListener(listener, httpCfg.MaxConnections) server := &http.Server{ ReadTimeout: time.Duration(httpCfg.ReadTimeOutInSeconds) * time.Second, WriteTimeout: time.Duration(httpCfg.WriteTimeOutInSeconds) * time.Second, Handler: h2c.NewHandler(handler, &http2.Server{}), } errChan := make(chan error) go func() { defer listener.Close() errChan <- server.Serve(listener) }() return errChan, server } // HandlerFunc defines http handler function type HandlerFunc func(rw *ResponseWriter, r *http.Request) // HTTPHandlerWrapper wraps http handler function type HTTPHandlerWrapper func(handler HandlerFunc) HandlerFunc // ApplyHTTPWrappers apply wrappers according to the order func ApplyHTTPWrappers(handler HandlerFunc, wrappers ...HTTPHandlerWrapper) http.HandlerFunc { h := handler for _, wrapper := range wrappers { h = wrapper(h) } return func(writer http.ResponseWriter, request *http.Request) { rw := NewResponseWriter(writer) h(rw, request) } } // MetricsLoggingMiddleWareProvider provides middleware for metrics and logger for http requests type MetricsLoggingMiddleWareProvider struct { scope tally.Scope logger common.Logger } // NewMetricsLoggingMiddleWareProvider creates metrics and logging middleware provider func NewMetricsLoggingMiddleWareProvider(scope tally.Scope, logger common.Logger) MetricsLoggingMiddleWareProvider { return MetricsLoggingMiddleWareProvider{ scope: scope, logger: logger, } } // WithMetrics plug in metrics middleware func (p *MetricsLoggingMiddleWareProvider) WithMetrics(next HandlerFunc) HandlerFunc { funcName := GetFuncName(next) return func(rw *ResponseWriter, r *http.Request) { origin := GetOrigin(r) stopWatch := p.scope.Tagged(map[string]string{ metricsTagHandler: funcName, metricsTagOrigin: origin, }).Timer(scopeNameHTTPHandlerLatency).Start() next(rw, r) stopWatch.Stop() p.scope.Tagged(map[string]string{ metricsTagHandler: funcName, metricsTagOrigin: origin, metricsTagStatusCode: strconv.Itoa(rw.statusCode), }).Counter(scopeNameHTTPHandlerCall).Inc(1) } } // WithLogging plug in logging middleware func (p *MetricsLoggingMiddleWareProvider) WithLogging(next HandlerFunc) HandlerFunc { return func(rw *ResponseWriter, r *http.Request) { next(rw, r) if rw.err != nil { p.logger.With( "request", rw.req, "status", rw.statusCode, "error", rw.err, "method", r.Method, "name", r.URL.Path, ).Errorf("request failed") } else { p.logger.With( "request", rw.req, "name", r.URL.Path, ).Debug("request succeeded") } } } func setCommonHeaders(w http.ResponseWriter) { w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate") w.Header().Set("Pragma", "no-cache") w.Header().Set("Expires", "0") } // ErrorResponse represents error response. // swagger:response errorResponse type ErrorResponse struct { //in: body Body APIError } // ResponseWriter decorates http.ResponseWriter type ResponseWriter struct { http.ResponseWriter statusCode int req interface{} err error } // NewResponseWriter returns response writer with status code 200 func NewResponseWriter(rw http.ResponseWriter) *ResponseWriter { return &ResponseWriter{ statusCode: http.StatusOK, ResponseWriter: rw, } } // SetRequest set unmarshalled request body to response writer for logging purpose func (s *ResponseWriter) SetRequest(req interface{}) { s.req = req } // WriteHeader implements http.ResponseWriter WriteHeader for write status code func (s *ResponseWriter) WriteHeader(code int) { if code > 0 { s.statusCode = code s.ResponseWriter.WriteHeader(code) } } // WriteBytes implements http.ResponseWriter Write for write bytes func (s *ResponseWriter) WriteBytes(bts []byte) { setCommonHeaders(s) s.Write(bts) } // WriteBytesWithCode writes bytes with code func (s *ResponseWriter) WriteBytesWithCode(code int, bts []byte) { setCommonHeaders(s) s.WriteHeader(code) if bts != nil { s.Write(bts) } } // WriteJSONBytes write json bytes with default status ok func (s *ResponseWriter) WriteJSONBytes(jsonBytes []byte, marshalErr error) { s.WriteJSONBytesWithCode(http.StatusOK, jsonBytes, marshalErr) } // WriteJSONBytesWithCode write json bytes and marshal error to response func (s *ResponseWriter) WriteJSONBytesWithCode(code int, jsonBytes []byte, marshalErr error) { s.Header().Set(HTTPContentTypeHeaderKey, HTTPContentTypeApplicationJson) if marshalErr != nil { jsonMarshalErrorResponse := ErrorResponse{} code = http.StatusInternalServerError jsonMarshalErrorResponse.Body.Code = code jsonMarshalErrorResponse.Body.Message = "failed to marshal object" jsonMarshalErrorResponse.Body.Cause = marshalErr // ignore this error since this should not happen jsonBytes, _ = json.Marshal(jsonMarshalErrorResponse.Body) } if jsonBytes == nil { return } // try best effort write with gzip compression willCompress := len(jsonBytes) > CompressionThreshold if willCompress { gw, err := gzip.NewWriterLevel(s, gzip.BestSpeed) if err == nil { defer gw.Close() s.Header().Set(HTTPContentEncodingHeaderKey, HTTPContentEncodingGzip) setCommonHeaders(s) s.WriteHeader(code) _, _ = gw.Write(jsonBytes) return } } // default to normal json response s.WriteBytesWithCode(code, jsonBytes) } // WriteObject write json object to response func (s *ResponseWriter) WriteObject(obj interface{}) { s.WriteObjectWithCode(http.StatusOK, obj) } // WriteObjectWithCode serialize object and write code func (s *ResponseWriter) WriteObjectWithCode(code int, obj interface{}) { if obj != nil { jsonBytes, err := json.Marshal(obj) s.WriteJSONBytesWithCode(code, jsonBytes, err) } else { s.WriteBytesWithCode(code, nil) } } // WriteErrorWithCode writes error with specific code func (s *ResponseWriter) WriteErrorWithCode(code int, err error) { s.err = err var errorResponse ErrorResponse if e, ok := err.(APIError); ok { errorResponse.Body = e } else { errorResponse.Body.Message = err.Error() } errorResponse.Body.Code = code s.WriteObjectWithCode(errorResponse.Body.Code, errorResponse.Body) } // WriteError write error to response func (s *ResponseWriter) WriteError(err error) { if e, ok := err.(APIError); ok { s.WriteErrorWithCode(e.Code, err) } else { s.WriteErrorWithCode(http.StatusInternalServerError, err) } }