runtime/router.go (206 lines of code) (raw):
// Copyright (c) 2023 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package zanzibar
import (
"context"
"fmt"
"net/http"
"net/url"
"github.com/opentracing/opentracing-go"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"github.com/uber-go/tally"
"github.com/uber/zanzibar/runtime/jsonwrapper"
zrouter "github.com/uber/zanzibar/runtime/router"
"go.uber.org/zap"
)
const (
notFound = "NotFound"
methodNotAllowed = "MethodNotAllowed"
)
// HTTPRouter provides a HTTP router. It will match patterns in URLs and route them to provided HTTP handlers.
//
// This router has support for decoding path "parameters" in the URL into named values. An example:
//
// var r zanzibar.HTTPRouter
//
// r.Handle("GET", "/foo/:bar", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// params := zanzibar.ParamsFromContext(r.Context())
// w.Write("%s", params.Get("bar"))
// }))
type HTTPRouter interface {
// HTTPRouter implements a http.Handle as a convenience to allow HTTPRouter to be invoked by the standard library HTTP server.
http.Handler
// Handle associates a HTTP method and a pattern string to a HTTP handler function. If the method and pattern string
// already exists, an error is returned.
Handle(method, pattern string, handler http.Handler) error
}
// ParamsFromContext extracts the URL parameters that are embedded in the context by the Zanzibar HTTP router implementation.
func ParamsFromContext(ctx context.Context) url.Values {
params := zrouter.ParamsFromContext(ctx)
urlValues := make(url.Values)
for _, paramValue := range params {
urlValues.Add(paramValue.Key, paramValue.Value)
}
return urlValues
}
// HandlerFn is a func that handles ServerHTTPRequest
type HandlerFn func(
context.Context,
*ServerHTTPRequest,
*ServerHTTPResponse,
) context.Context
// RouterEndpoint struct represents an endpoint that can be registered
// into the router itself.
type RouterEndpoint struct {
EndpointName string
HandlerName string
HandlerFn HandlerFn
JSONWrapper jsonwrapper.JSONWrapper
contextExtractor ContextExtractor
contextLogger ContextLogger
scope tally.Scope
tracer opentracing.Tracer
config *StaticConfig
}
// NewRouterEndpoint creates an endpoint that can be registered to HTTPRouter
func NewRouterEndpoint(
extractor ContextExtractor,
deps *DefaultDependencies,
endpointID string,
handlerID string,
handler HandlerFn,
) *RouterEndpoint {
return &RouterEndpoint{
EndpointName: endpointID,
HandlerName: handlerID,
HandlerFn: handler,
contextExtractor: extractor,
contextLogger: deps.ContextLogger,
scope: deps.Scope,
tracer: deps.Tracer,
JSONWrapper: deps.JSONWrapper,
config: deps.Config,
}
}
// HandleRequest is called by the router and starts the request
func (endpoint *RouterEndpoint) HandleRequest(
w http.ResponseWriter,
r *http.Request,
) {
// TODO: (lu) get timeout from endpoint config
//_, ok := ctx.Deadline()
//if !ok {
// var cancel context.CancelFunc
// ctx, cancel = context.WithTimeout(ctx, time.Duration(100)*time.Millisecond)
// defer cancel()
//}
urlValues := ParamsFromContext(r.Context())
req := NewServerHTTPRequest(w, r, urlValues, endpoint)
ctx := req.Context()
endpoint.HandlerFn(ctx, req, req.res)
req.res.flush(ctx)
}
// httpRouter data structure to handle and register endpoints
type httpRouter struct {
gateway *Gateway
httpRouter *zrouter.Router
notFoundEndpoint *RouterEndpoint
methodNotAllowedEndpoint *RouterEndpoint
panicCount tally.Counter
routeMap map[string]*RouterEndpoint
requestUUIDHeaderKey string
}
var _ HTTPRouter = (*httpRouter)(nil)
// NewHTTPRouter allocates a HTTP router
func NewHTTPRouter(gateway *Gateway) HTTPRouter {
deps := &DefaultDependencies{
Logger: gateway.Logger,
ContextLogger: gateway.ContextLogger,
Scope: gateway.RootScope,
Tracer: gateway.Tracer,
Config: gateway.Config,
}
router := &httpRouter{
notFoundEndpoint: NewRouterEndpoint(
gateway.ContextExtractor, deps,
notFound, notFound, nil,
),
methodNotAllowedEndpoint: NewRouterEndpoint(
gateway.ContextExtractor, deps,
methodNotAllowed, methodNotAllowed, nil,
),
gateway: gateway,
panicCount: gateway.RootScope.Counter("runtime.router.panic"),
routeMap: make(map[string]*RouterEndpoint),
requestUUIDHeaderKey: gateway.requestUUIDHeaderKey,
}
notFoundHandler := http.HandlerFunc(router.handleNotFound)
if gateway.notFoundHandler != nil {
notFoundHandler = gateway.notFoundHandler
}
handleMethodNotAllowed := true
if gateway.Config.ContainsKey("http.handleMethodNotAllowed") {
handleMethodNotAllowed = gateway.Config.MustGetBoolean("http.handleMethodNotAllowed")
}
router.httpRouter = &zrouter.Router{
HandleMethodNotAllowed: handleMethodNotAllowed,
NotFound: notFoundHandler,
MethodNotAllowed: http.HandlerFunc(router.handleMethodNotAllowed),
PanicHandler: router.handlePanic,
WhitelistedPaths: router.getWhitelistedPaths(),
}
return router
}
// Register register a handler function.
func (router *httpRouter) Handle(method, prefix string, handler http.Handler) (err error) {
h := func(w http.ResponseWriter, r *http.Request) {
reqUUID := r.Header.Get(router.requestUUIDHeaderKey)
if reqUUID == "" {
reqUUID = uuid.New()
}
ctx := withRequestUUID(r.Context(), reqUUID)
ctx = WithLogFields(ctx, zap.String(logFieldRequestUUID, reqUUID))
r = r.WithContext(ctx)
handler.ServeHTTP(w, r)
}
return router.httpRouter.Handle(method, prefix, http.HandlerFunc(h))
}
// ServeHTTP implements the http.Handle as a convenience to allow HTTPRouter to be invoked by the standard library HTTP server.
func (router *httpRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := WithSafeLogFields(r.Context())
r = r.WithContext(ctx)
router.httpRouter.ServeHTTP(w, r)
}
func (router *httpRouter) handlePanic(
w http.ResponseWriter, r *http.Request, v interface{},
) {
err, ok := v.(error)
if !ok {
err = errors.Errorf("http router panic: %v", v)
}
_, ok = err.(fmt.Formatter)
if !ok {
err = errors.Wrap(err, "wrapped")
}
logger := router.gateway.ContextLogger
logger.Error(r.Context(), "A http request handler paniced", zap.Error(err), zap.Int(logFieldResponseStatusCode, http.StatusInternalServerError))
router.panicCount.Inc(1)
http.Error(w,
http.StatusText(http.StatusInternalServerError),
http.StatusInternalServerError,
)
}
func (router *httpRouter) handleNotFound(
w http.ResponseWriter,
r *http.Request,
) {
scopeTags := map[string]string{
scopeTagEndpoint: router.notFoundEndpoint.EndpointName,
scopeTagHandler: router.notFoundEndpoint.HandlerName,
scopeTagProtocol: scopeTagHTTP,
}
ctx := r.Context()
ctx = WithScopeTagsDefault(ctx, scopeTags, router.gateway.RootScope)
r = r.WithContext(ctx)
req := NewServerHTTPRequest(w, r, nil, router.notFoundEndpoint)
http.NotFound(w, r)
req.res.StatusCode = http.StatusNotFound
req.res.finish(ctx)
}
func (router *httpRouter) handleMethodNotAllowed(
w http.ResponseWriter,
r *http.Request,
) {
scopeTags := map[string]string{
scopeTagEndpoint: router.methodNotAllowedEndpoint.EndpointName,
scopeTagHandler: router.methodNotAllowedEndpoint.HandlerName,
scopeTagProtocol: scopeTagHTTP,
}
ctx := r.Context()
ctx = WithScopeTagsDefault(ctx, scopeTags, router.gateway.RootScope)
r = r.WithContext(ctx)
req := NewServerHTTPRequest(w, r, nil, router.methodNotAllowedEndpoint)
http.Error(w,
http.StatusText(http.StatusMethodNotAllowed),
http.StatusMethodNotAllowed,
)
req.res.StatusCode = http.StatusMethodNotAllowed
req.res.finish(ctx)
}
func (router *httpRouter) getWhitelistedPaths() []string {
var whitelistedPaths []string
if router.gateway.Config != nil &&
router.gateway.Config.ContainsKey("router.whitelistedPaths") {
router.gateway.Config.MustGetStruct("router.whitelistedPaths", &whitelistedPaths)
}
return whitelistedPaths
}