runtime/tchannel_server.go (194 lines of code) (raw):
// Copyright (c) 2023 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package zanzibar
import (
"context"
"fmt"
"strings"
"sync"
"github.com/opentracing/opentracing-go"
"github.com/pborman/uuid"
"github.com/uber-go/tally"
"github.com/uber/jaeger-client-go"
"github.com/uber/tchannel-go"
"go.uber.org/zap"
netContext "golang.org/x/net/context"
)
// PostResponseCB registers a callback that is run after a response has been
// completely processed (e.g. written to the channel).
// This gives the server a chance to clean up resources from the response object
type PostResponseCB func(ctx context.Context, method string, response RWTStruct)
// TChannelEndpoint wraps over a TChannelHandler and can be registered to a TChannelRouter
// to handle tchannel inbound call. It only has one Handle method which is delegated to the
// embedded TChannelHandler.
type TChannelEndpoint struct {
TChannelHandler
EndpointID string
HandlerID string
Method string
callback PostResponseCB
}
// TChannelRouter handles incoming TChannel calls and routes them to the matching TChannelHandler.
type TChannelRouter struct {
sync.RWMutex
registrar tchannel.Registrar
endpoints map[string]*TChannelEndpoint
contextLogger ContextLogger
scope tally.Scope
extractor ContextExtractor
requestUUIDHeaderKey string
}
// netContextRouter implements the Handle interface that consumes netContext instead of stdlib context
type netContextRouter struct {
router *TChannelRouter
}
func (ncr netContextRouter) Handle(ctx netContext.Context, call *tchannel.InboundCall) {
ncr.router.Handle(ctx, call)
}
// NewTChannelEndpoint creates a new tchannel endpoint to handle an incoming
// call for its method.
func NewTChannelEndpoint(
endpointID, handlerID, method string,
handler TChannelHandler,
) *TChannelEndpoint {
return NewTChannelEndpointWithPostResponseCB(
endpointID, handlerID, method,
handler, nil,
)
}
// NewTChannelEndpointWithPostResponseCB creates a new tchannel endpoint,
// with or without a post response callback function.
func NewTChannelEndpointWithPostResponseCB(
endpointID, handlerID, method string,
handler TChannelHandler,
callback PostResponseCB,
) *TChannelEndpoint {
return &TChannelEndpoint{
TChannelHandler: handler,
EndpointID: endpointID,
HandlerID: handlerID,
Method: method,
callback: callback,
}
}
// NewTChannelRouter returns a TChannel router that can serve thrift services over TChannel.
func NewTChannelRouter(registrar tchannel.Registrar, g *Gateway) *TChannelRouter {
return &TChannelRouter{
registrar: registrar,
endpoints: map[string]*TChannelEndpoint{},
contextLogger: g.ContextLogger,
scope: g.RootScope,
extractor: g.ContextExtractor,
requestUUIDHeaderKey: g.requestUUIDHeaderKey,
}
}
// Register registers the given TChannelEndpoint.
func (s *TChannelRouter) Register(e *TChannelEndpoint) error {
s.RLock()
if _, ok := s.endpoints[e.Method]; ok {
s.RUnlock()
return fmt.Errorf("handler for '%s' is already registered", e.Method)
}
s.RUnlock()
s.Lock()
s.endpoints[e.Method] = e
s.Unlock()
ncr := netContextRouter{router: s}
s.registrar.Register(ncr, e.Method)
return nil
}
// Handle handles an incoming TChannel call and forwards it to the correct handler.
func (s *TChannelRouter) Handle(ctx context.Context, call *tchannel.InboundCall) {
method := call.MethodString()
if sep := strings.Index(method, "::"); sep == -1 {
s.contextLogger.Error(ctx, "Handle got call for which does not match the expected call format", zap.String(logFieldRequestMethod, method))
return
}
s.RLock()
e, ok := s.endpoints[method]
s.RUnlock()
if !ok {
s.contextLogger.Error(ctx, "Handle got call for method which is not registered",
zap.String(logFieldRequestMethod, method),
)
return
}
// put log fields on the context
logFields := []zap.Field{
zap.String(logFieldEndpointID, e.EndpointID),
zap.String(logFieldEndpointHandler, e.HandlerID),
zap.String(logFieldRequestMethod, e.Method),
}
ctx = WithLogFields(ctx, logFields...)
// put scope tags on the context
scopeTags := map[string]string{
scopeTagEndpoint: e.EndpointID,
scopeTagHandler: e.HandlerID,
scopeTagEndpointMethod: e.Method,
scopeTagProtocol: scopeTagTChannel,
}
ctx = WithScopeTagsDefault(ctx, scopeTags, s.scope)
var err error
c := &tchannelInboundCall{
call: call,
endpoint: e,
contextLogger: s.contextLogger,
scope: s.scope.Tagged(scopeTags),
}
c.start()
ctx, err = s.handleHeader(ctx, c)
defer func() { c.finish(ctx, err) }()
if err != nil {
return
}
errc := make(chan error, 1)
go func() { errc <- s.handleBody(ctx, c) }()
select {
case <-ctx.Done():
err = ctx.Err()
if err == context.Canceled {
// check if context was Canceled due to handle response
if c.responded {
err = <-errc
}
}
case err = <-errc:
}
}
func (s *TChannelRouter) handleHeader(
ctx context.Context,
c *tchannelInboundCall,
) (context.Context, error) {
if err := c.readReqHeaders(ctx); err != nil {
return ctx, err
}
reqUUID, ok := c.reqHeaders[s.requestUUIDHeaderKey]
if !ok {
reqUUID = uuid.New()
}
ctx = withRequestUUID(ctx, reqUUID)
// put request headers on context so that user-provided extractor
// functions can choose to have certain headers as metric tags or
// log fields
ctx = WithEndpointRequestHeadersField(ctx, c.reqHeaders)
// use user-provided extractor function to decide metric tags
scopeTags := make(map[string]string)
for k, v := range s.extractor.ExtractScopeTags(ctx) {
scopeTags[k] = v
}
ctx = WithScopeTagsDefault(ctx, scopeTags, c.scope)
if len(scopeTags) != 0 {
c.scope = c.scope.Tagged(scopeTags)
}
// use user-provided extractor function to decide log fields
logFields := s.extractor.ExtractLogFields(ctx)
logFields = append(logFields, zap.String(logFieldRequestUUID, reqUUID))
ctx = WithLogFields(ctx, logFields...)
return ctx, nil
}
func (s *TChannelRouter) handleBody(
ctx context.Context,
c *tchannelInboundCall,
) (err error) {
wireValue, err := c.readReqBody(ctx)
if err != nil {
return err
}
// trace request
tracer := tchannel.TracerFromRegistrar(s.registrar)
ctx = tchannel.ExtractInboundSpan(ctx, c.call, c.reqHeaders, tracer)
if s := opentracing.SpanFromContext(ctx); s != nil {
if jaegerCtx, ok := s.Context().(jaeger.SpanContext); ok {
logFields := make([]zap.Field, 3)
logFields[0] = zap.String(TraceIDKey, jaegerCtx.TraceID().String())
logFields[1] = zap.String(TraceSpanKey, jaegerCtx.SpanID().String())
logFields[2] = zap.Bool(TraceSampledKey, jaegerCtx.IsSampled())
ctx = WithLogFields(ctx, logFields...)
}
}
// handle request
resp, err := c.handle(ctx, &wireValue)
if err != nil {
return err
}
// TODO: put response headers on ctx for final metrics and logs
//ctx = WithEndpointResponseHeadersField(ctx, c.resHeaders)
// write response
if err = c.writeResHeaders(ctx); err != nil {
return err
}
if err = c.writeResBody(ctx, resp); err != nil {
return err
}
return err
}