json/handler.go (123 lines of code) (raw):

// Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package json import ( "fmt" "reflect" "github.com/uber/tchannel-go" "github.com/opentracing/opentracing-go" "golang.org/x/net/context" ) var ( typeOfError = reflect.TypeOf((*error)(nil)).Elem() typeOfContext = reflect.TypeOf((*Context)(nil)).Elem() ) // Handlers is the map from method names to handlers. type Handlers map[string]interface{} // verifyHandler ensures that the given t is a function with the following signature: // func(json.Context, *ArgType)(*ResType, error) func verifyHandler(t reflect.Type) error { if t.NumIn() != 2 || t.NumOut() != 2 { return fmt.Errorf("handler should be of format func(json.Context, *ArgType) (*ResType, error)") } isStructPtr := func(t reflect.Type) bool { return t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct } isMap := func(t reflect.Type) bool { return t.Kind() == reflect.Map && t.Key().Kind() == reflect.String } validateArgRes := func(t reflect.Type, name string) error { if !isStructPtr(t) && !isMap(t) { return fmt.Errorf("%v should be a pointer to a struct, or a map[string]interface{}", name) } return nil } if t.In(0) != typeOfContext { return fmt.Errorf("arg0 should be of type json.Context") } if err := validateArgRes(t.In(1), "second argument"); err != nil { return err } if err := validateArgRes(t.Out(0), "first return value"); err != nil { return err } if !t.Out(1).AssignableTo(typeOfError) { return fmt.Errorf("second return value should be an error") } return nil } type handler struct { handler reflect.Value argType reflect.Type isArgMap bool tracer func() opentracing.Tracer } func toHandler(f interface{}) (*handler, error) { hV := reflect.ValueOf(f) if err := verifyHandler(hV.Type()); err != nil { return nil, err } argType := hV.Type().In(1) return &handler{handler: hV, argType: argType, isArgMap: argType.Kind() == reflect.Map}, nil } // Register registers the specified methods specified as a map from method name to the // JSON handler function. The handler functions should have the following signature: // func(context.Context, *ArgType)(*ResType, error) func Register(registrar tchannel.Registrar, funcs Handlers, onError func(context.Context, error)) error { handlers := make(map[string]*handler) handler := tchannel.HandlerFunc(func(ctx context.Context, call *tchannel.InboundCall) { h, ok := handlers[string(call.Method())] if !ok { onError(ctx, fmt.Errorf("call for unregistered method: %s", call.Method())) return } if err := h.Handle(ctx, call); err != nil { onError(ctx, err) } }) for m, f := range funcs { h, err := toHandler(f) if err != nil { return fmt.Errorf("%v cannot be used as a handler: %v", m, err) } h.tracer = func() opentracing.Tracer { return tchannel.TracerFromRegistrar(registrar) } handlers[m] = h registrar.Register(handler, m) } return nil } // Handle deserializes the JSON arguments and calls the underlying handler. func (h *handler) Handle(tctx context.Context, call *tchannel.InboundCall) error { var headers map[string]string if err := tchannel.NewArgReader(call.Arg2Reader()).ReadJSON(&headers); err != nil { return fmt.Errorf("arg2 read failed: %v", err) } tctx = tchannel.ExtractInboundSpan(tctx, call, headers, h.tracer()) ctx := WithHeaders(tctx, headers) var arg3 reflect.Value var callArg reflect.Value if h.isArgMap { arg3 = reflect.New(h.argType) // New returns a pointer, but the method accepts the map directly. callArg = arg3.Elem() } else { arg3 = reflect.New(h.argType.Elem()) callArg = arg3 } if err := tchannel.NewArgReader(call.Arg3Reader()).ReadJSON(arg3.Interface()); err != nil { return fmt.Errorf("arg3 read failed: %v", err) } args := []reflect.Value{reflect.ValueOf(ctx), callArg} results := h.handler.Call(args) res := results[0].Interface() err := results[1].Interface() // If an error was returned, we create an error arg3 to respond with. if err != nil { // TODO(prashantv): More consistent error handling between json/raw/thrift.. if serr, ok := err.(tchannel.SystemError); ok { return call.Response().SendSystemError(serr) } call.Response().SetApplicationError() // TODO(prashant): Allow client to customize the error in more ways. res = struct { Type string `json:"type"` Message string `json:"message"` }{ Type: "error", Message: err.(error).Error(), } } if err := tchannel.NewArgWriter(call.Response().Arg2Writer()).WriteJSON(ctx.ResponseHeaders()); err != nil { return err } return tchannel.NewArgWriter(call.Response().Arg3Writer()).WriteJSON(res) }