runtime/tchannel_inbound_call.go (218 lines of code) (raw):

// Copyright (c) 2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package zanzibar import ( "bytes" "context" "fmt" "time" "github.com/pkg/errors" "github.com/uber-go/tally" "github.com/uber/tchannel-go" "go.uber.org/thriftrw/protocol/binary" "go.uber.org/thriftrw/wire" "go.uber.org/zap" ) type tchannelInboundCall struct { endpoint *TChannelEndpoint call *tchannel.InboundCall success bool responded bool startTime time.Time finishTime time.Time reqHeaders map[string]string resHeaders map[string]string // Logger logs entries with default fields that contains request meta info contextLogger ContextLogger // Scope emit metrics with default tags that contains request meta info scope tally.Scope } func (c *tchannelInboundCall) start() { c.startTime = time.Now() } func (c *tchannelInboundCall) finish(ctx context.Context, err error) { c.finishTime = time.Now() if err != nil { errCause := tchannel.GetSystemErrorCode(errors.Cause(err)) errTag := map[string]string{scopeTagError: errCause.MetricsKey()} c.scope.Tagged(errTag).Counter(endpointSystemErrors).Inc(1) } else if !c.success { // The endpoint already has emitted an app-error stat in the template } else { c.scope.Counter(endpointSuccess).Inc(1) } delta := c.finishTime.Sub(c.startTime) c.scope.Timer(endpointLatency).Record(delta) c.scope.Histogram(endpointLatencyHist, tally.DefaultBuckets).RecordDuration(delta) c.scope.Counter(endpointRequest).Inc(1) fields := c.logFields(ctx) if err == nil { c.contextLogger.Debug(ctx, "Finished an incoming server TChannel request", fields...) } else { fields = append(fields, zap.Error(err)) c.contextLogger.Warn(ctx, "Failed to serve incoming TChannel request", fields...) } } func (c *tchannelInboundCall) logFields(ctx context.Context) []zap.Field { fields := []zap.Field{ zap.String(logFieldRequestRemoteAddr, c.call.RemotePeer().HostPort), zap.String("calling-service", c.call.CallerName()), } for k, v := range c.resHeaders { fields = append(fields, zap.String( fmt.Sprintf("%s-%s", logFieldEndpointResponseHeaderPrefix, k), v, )) } fields = append(fields, GetLogFieldsFromCtx(ctx)...) return fields } // readReqHeaders reads request headers from arg2 func (c *tchannelInboundCall) readReqHeaders(ctx context.Context) error { // fail fast if timed out if deadline, ok := ctx.Deadline(); ok && time.Now().After(deadline) { return context.DeadlineExceeded } treader, err := c.call.Arg2Reader() if err != nil { return errors.Wrapf(err, "Could not create arg2reader for inbound %s.%s (%s) request", c.endpoint.EndpointID, c.endpoint.HandlerID, c.endpoint.Method, ) } c.reqHeaders, err = ReadHeaders(treader) if err != nil { _ = treader.Close() return errors.Wrapf(err, "Could not read headers for inbound %s.%s (%s) request", c.endpoint.EndpointID, c.endpoint.HandlerID, c.endpoint.Method, ) } if err := EnsureEmpty(treader, "reading request headers"); err != nil { _ = treader.Close() return errors.Wrapf(err, "Could not ensure arg2reader is empty for inbound %s.%s (%s) request", c.endpoint.EndpointID, c.endpoint.HandlerID, c.endpoint.Method, ) } if err := treader.Close(); err != nil { return errors.Wrapf(err, "Could not close arg2reader for inbound %s.%s (%s) request", c.endpoint.EndpointID, c.endpoint.HandlerID, c.endpoint.Method, ) } return nil } // readReqBody reads request body from arg3 func (c *tchannelInboundCall) readReqBody(ctx context.Context) (wireValue wire.Value, err error) { // fail fast if timed out if deadline, ok := ctx.Deadline(); ok && time.Now().After(deadline) { err = context.DeadlineExceeded return } treader, err := c.call.Arg3Reader() if err != nil { err = errors.Wrapf(err, "Could not create arg3reader for inbound %s.%s (%s) request", c.endpoint.EndpointID, c.endpoint.HandlerID, c.endpoint.Method, ) return } buf := GetBuffer() defer PutBuffer(buf) if _, err = buf.ReadFrom(treader); err != nil { _ = treader.Close() err = errors.Wrapf(err, "Could not read from arg3reader for inbound %s.%s (%s) request", c.endpoint.EndpointID, c.endpoint.HandlerID, c.endpoint.Method, ) return } wireValue, err = binary.Default.Decode(bytes.NewReader(buf.Bytes()), wire.TStruct) if err != nil { c.contextLogger.WarnZ(ctx, "Could not decode arg3 for inbound request", zap.Error(err)) err = errors.Wrapf(err, "Could not decode arg3 for inbound %s.%s (%s) request", c.endpoint.EndpointID, c.endpoint.HandlerID, c.endpoint.Method, ) return } if err = EnsureEmpty(treader, "reading request body"); err != nil { _ = treader.Close() err = errors.Wrapf(err, "Could not ensure arg3reader is empty for inbound %s.%s (%s) request", c.endpoint.EndpointID, c.endpoint.HandlerID, c.endpoint.Method, ) return } if err = treader.Close(); err != nil { err = errors.Wrapf(err, "Could not close arg3reader for inbound %s.%s (%s) request", c.endpoint.EndpointID, c.endpoint.HandlerID, c.endpoint.Method, ) return } return } // handle tchannel server endpoint call func (c *tchannelInboundCall) handle(ctx context.Context, wireValue *wire.Value) (resp RWTStruct, err error) { // fail fast if timed out if deadline, ok := ctx.Deadline(); ok && time.Now().After(deadline) { err = context.DeadlineExceeded return } ctx, c.success, resp, c.resHeaders, err = c.endpoint.Handle(ctx, c.reqHeaders, wireValue) if c.endpoint.callback != nil { defer c.endpoint.callback(ctx, c.endpoint.Method, resp) } if err != nil { c.contextLogger.Warn(ctx, "Unexpected tchannel system error", zap.Error(err)) if er := c.call.Response().SendSystemError(errors.New("Server Error")); er != nil { c.contextLogger.Warn(ctx, "Error sending server error response", zap.Error(er)) } return } if !c.success { if err = c.call.Response().SetApplicationError(); err != nil { c.contextLogger.Warn(ctx, "Could not set application error for inbound response", zap.Error(err)) return } } return } // writeResHeaders writes response headers to arg2 func (c *tchannelInboundCall) writeResHeaders(ctx context.Context) error { // fail fast if timed out if deadline, ok := ctx.Deadline(); ok && time.Now().After(deadline) { return context.DeadlineExceeded } twriter, err := c.call.Response().Arg2Writer() if err != nil { return errors.Wrapf(err, "Could not create arg2writer for inbound %s.%s (%s) response", c.endpoint.EndpointID, c.endpoint.HandlerID, c.endpoint.Method, ) } if err = WriteHeaders(twriter, c.resHeaders); err != nil { _ = twriter.Close() return errors.Wrapf(err, "Could not write headers for inbound %s.%s (%s) response", c.endpoint.EndpointID, c.endpoint.HandlerID, c.endpoint.Method, ) } if err = twriter.Close(); err != nil { return errors.Wrapf(err, "Could not close arg2writer for inbound %s.%s (%s) response", c.endpoint.EndpointID, c.endpoint.HandlerID, c.endpoint.Method, ) } return nil } // writeResBody writes response body to arg3 func (c *tchannelInboundCall) writeResBody(ctx context.Context, resp RWTStruct) error { // fail fast if timed out if deadline, ok := ctx.Deadline(); ok && time.Now().After(deadline) { return context.DeadlineExceeded } structWireValue, err := resp.ToWire() if err != nil { if er := c.call.Response().SendSystemError(errors.New("Server Error")); er != nil { c.contextLogger.WarnZ(ctx, "Error sending server error response", zap.Error(er)) } return errors.Wrapf(err, "Could not serialize arg3 for inbound %s.%s (%s) response", c.endpoint.EndpointID, c.endpoint.HandlerID, c.endpoint.Method, ) } twriter, err := c.call.Response().Arg3Writer() if err != nil { return errors.Wrapf(err, "Could not create arg3writer for inbound %s.%s (%s) response", c.endpoint.EndpointID, c.endpoint.HandlerID, c.endpoint.Method, ) } err = binary.Default.Encode(structWireValue, twriter) if err != nil { _ = twriter.Close() return errors.Wrapf(err, "Could not write arg3 for inbound %s.%s (%s) response", c.endpoint.EndpointID, c.endpoint.HandlerID, c.endpoint.Method, ) } c.responded = true if err = twriter.Close(); err != nil { return errors.Wrapf(err, "Could not close arg3writer for inbound %s.%s (%s) response", c.endpoint.EndpointID, c.endpoint.HandlerID, c.endpoint.Method, ) } return nil }