thrift/server.go (170 lines of code) (raw):

// Copyright (c) 2015 Uber Technologies, Inc. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package thrift import ( "log" "strings" "sync" tchannel "github.com/uber/tchannel-go" "github.com/uber/tchannel-go/internal/argreader" "github.com/uber/tchannel-go/thirdparty/github.com/apache/thrift/lib/go/thrift" "golang.org/x/net/context" ) type handler struct { server TChanServer postResponseCB PostResponseCB } // Server handles incoming TChannel calls and forwards them to the matching TChanServer. type Server struct { sync.RWMutex ch tchannel.Registrar log tchannel.Logger handlers map[string]handler metaHandler *metaHandler ctxFn func(ctx context.Context, method string, headers map[string]string) Context } // NewServer returns a server that can serve thrift services over TChannel. func NewServer(registrar tchannel.Registrar) *Server { metaHandler := newMetaHandler() server := &Server{ ch: registrar, log: registrar.Logger(), handlers: make(map[string]handler), metaHandler: metaHandler, ctxFn: defaultContextFn, } server.Register(newTChanMetaServer(metaHandler)) if ch, ok := registrar.(*tchannel.Channel); ok { // Register the meta endpoints on the "tchannel" service name. NewServer(ch.GetSubChannel("tchannel")) } return server } // Register registers the given TChanServer to be called on any incoming call for its' services. // TODO(prashant): Replace Register call with this call. func (s *Server) Register(svr TChanServer, opts ...RegisterOption) { service := svr.Service() handler := &handler{server: svr} for _, opt := range opts { opt.Apply(handler) } s.Lock() s.handlers[service] = *handler s.Unlock() for _, m := range svr.Methods() { s.ch.Register(s, service+"::"+m) } } // RegisterHealthHandler uses the user-specified function f for the Health endpoint. func (s *Server) RegisterHealthHandler(f HealthFunc) { wrapped := func(ctx Context, r HealthRequest) (bool, string) { return f(ctx) } s.metaHandler.setHandler(wrapped) } // RegisterHealthRequestHandler uses the user-specified function for the // Health endpoint. The function receives the health request which includes // information about the type of the request being performed. func (s *Server) RegisterHealthRequestHandler(f HealthRequestFunc) { s.metaHandler.setHandler(f) } // SetContextFn sets the function used to convert a context.Context to a thrift.Context. // Note: This API may change and is only intended to bridge different contexts. func (s *Server) SetContextFn(f func(ctx context.Context, method string, headers map[string]string) Context) { s.ctxFn = f } func (s *Server) onError(call *tchannel.InboundCall, err error) { // TODO(prashant): Expose incoming call errors through options for NewServer. remotePeer := call.RemotePeer() logger := s.log.WithFields( tchannel.ErrField(err), tchannel.LogField{Key: "method", Value: call.MethodString()}, tchannel.LogField{Key: "callerName", Value: call.CallerName()}, // TODO: These are very similar to the connection fields, but we don't // have access to the connection's logger. Consider exposing the // connection through CurrentCall. tchannel.LogField{Key: "localAddr", Value: call.LocalPeer().HostPort}, tchannel.LogField{Key: "remoteHostPort", Value: remotePeer.HostPort}, tchannel.LogField{Key: "remoteIsEphemeral", Value: remotePeer.IsEphemeral}, tchannel.LogField{Key: "remoteProcess", Value: remotePeer.ProcessName}, ) if tchannel.GetSystemErrorCode(err) == tchannel.ErrCodeTimeout { logger.Debug("Thrift server timeout.") } else { logger.Error("Thrift server error.") } } func defaultContextFn(ctx context.Context, method string, headers map[string]string) Context { return WithHeaders(ctx, headers) } func (s *Server) handle(origCtx context.Context, handler handler, method string, call *tchannel.InboundCall) error { reader, err := call.Arg2Reader() if err != nil { return err } headers, err := ReadHeaders(reader) if err != nil { return err } if err := argreader.EnsureEmpty(reader, "reading request headers"); err != nil { return err } if err := reader.Close(); err != nil { return err } reader, err = call.Arg3Reader() if err != nil { return err } tracer := tchannel.TracerFromRegistrar(s.ch) origCtx = tchannel.ExtractInboundSpan(origCtx, call, headers, tracer) ctx := s.ctxFn(origCtx, method, headers) wp := getProtocolReader(reader) success, resp, err := handler.server.Handle(ctx, method, wp.protocol) thriftProtocolPool.Put(wp) if handler.postResponseCB != nil { defer handler.postResponseCB(ctx, method, resp) } if err != nil { if _, ok := err.(thrift.TProtocolException); ok { // We failed to parse the Thrift generated code, so convert the error to bad request. err = tchannel.NewSystemError(tchannel.ErrCodeBadRequest, err.Error()) } reader.Close() call.Response().SendSystemError(err) return nil } if err := argreader.EnsureEmpty(reader, "reading request body"); err != nil { return err } if err := reader.Close(); err != nil { return err } if !success { call.Response().SetApplicationError() } writer, err := call.Response().Arg2Writer() if err != nil { return err } if err := WriteHeaders(writer, ctx.ResponseHeaders()); err != nil { return err } if err := writer.Close(); err != nil { return err } writer, err = call.Response().Arg3Writer() wp = getProtocolWriter(writer) defer thriftProtocolPool.Put(wp) if err := resp.Write(wp.protocol); err != nil { call.Response().SendSystemError(err) return err } return writer.Close() } func getServiceMethod(method string) (string, string, bool) { s := string(method) sep := strings.Index(s, "::") if sep == -1 { return "", "", false } return s[:sep], s[sep+2:], true } // Handle handles an incoming TChannel call and forwards it to the correct handler. func (s *Server) Handle(ctx context.Context, call *tchannel.InboundCall) { op := call.MethodString() service, method, ok := getServiceMethod(op) if !ok { log.Fatalf("Handle got call for %s which does not match the expected call format", op) } s.RLock() handler, ok := s.handlers[service] s.RUnlock() if !ok { log.Fatalf("Handle got call for service %v which is not registered", service) } if err := s.handle(ctx, handler, method, call); err != nil { s.onError(call, err) } }