lib/middleware/middleware.go (60 lines of code) (raw):

// Copyright (c) 2016-2019 Uber Technologies, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package middleware import ( "net/http" "strconv" "strings" "time" "github.com/go-chi/chi" "github.com/uber-go/tally" ) // tagEndpoint tags stats by endpoint path and method, ignoring any path variables. // For example, "/foo/{foo}/bar/{bar}" is tagged with endpoint "foo.bar" // // Note: tagEndpoint should always be called AFTER the "next" handler serves, // such that chi can populate proper route context with the path. // // Wrong: // // tagEndpoint(stats, r).Counter("n").Inc(1) // next.ServeHTTP(w, r) // // Right: // // next.ServeHTTP(w, r) // tagEndpoint(stats, r).Counter("n").Inc(1) // func tagEndpoint(stats tally.Scope, r *http.Request) tally.Scope { ctx := chi.RouteContext(r.Context()) var staticParts []string for _, part := range strings.Split(ctx.RoutePattern(), "/") { if len(part) == 0 || isPathVariable(part) { continue } staticParts = append(staticParts, part) } return stats.Tagged(map[string]string{ "endpoint": strings.Join(staticParts, "."), "method": strings.ToUpper(r.Method), }) } // isPathVariable returns true if s is a path variable, e.g. "{foo}". func isPathVariable(s string) bool { return len(s) >= 2 && s[0] == '{' && s[len(s)-1] == '}' } // LatencyTimer measures endpoint latencies. func LatencyTimer(stats tally.Scope) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() next.ServeHTTP(w, r) tagEndpoint(stats, r).Timer("latency").Record(time.Since(start)) }) } } type recordStatusWriter struct { http.ResponseWriter wroteHeader bool code int } func (w *recordStatusWriter) WriteHeader(code int) { if !w.wroteHeader { w.code = code w.wroteHeader = true w.ResponseWriter.WriteHeader(code) } } func (w *recordStatusWriter) Write(b []byte) (int, error) { w.WriteHeader(http.StatusOK) return w.ResponseWriter.Write(b) } // StatusCounter measures endpoint status count. func StatusCounter(stats tally.Scope) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { recordw := &recordStatusWriter{w, false, http.StatusOK} next.ServeHTTP(recordw, r) tagEndpoint(stats, r).Counter(strconv.Itoa(recordw.code)).Inc(1) }) } }