runtime/gateway.go (790 lines of code) (raw):

// Copyright (c) 2023 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package zanzibar import ( "context" "fmt" "io" "io/ioutil" "net" "net/http" "os" "path/filepath" "strconv" "strings" "sync" "time" metricCollector "github.com/afex/hystrix-go/hystrix/metric_collector" "github.com/opentracing/opentracing-go" "github.com/pkg/errors" "github.com/uber-go/tally" "github.com/uber-go/tally/m3" jaegerConfig "github.com/uber/jaeger-client-go/config" jaegerLibTally "github.com/uber/jaeger-lib/metrics/tally" "github.com/uber/tchannel-go" "github.com/uber/zanzibar/runtime/jsonwrapper" "github.com/uber/zanzibar/runtime/plugins" "go.uber.org/yarpc" "go.uber.org/yarpc/api/transport" "go.uber.org/yarpc/transport/grpc" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) var levelMap = map[string]zapcore.Level{ "debug": zapcore.DebugLevel, "info": zapcore.InfoLevel, "warn": zapcore.WarnLevel, "error": zapcore.ErrorLevel, "dpanic": zapcore.DPanicLevel, "panic": zapcore.PanicLevel, "fatal": zapcore.FatalLevel, } var defaultShutdownPollInterval = 500 * time.Millisecond var defaultCloseTimeout = 10000 * time.Millisecond const ( localhost = "127.0.0.1" testenv = "test" metricsServiceFromEnvKey = "metrics.serviceNameEnv" serviceFromEnvKey = "serviceNameEnv" skipZanzibarLogsKey = "contextlogger.zanzibar.skiplogs" ) // Options configures the gateway type Options struct { MetricsBackend tally.CachedStatsReporter MetricsDefaultBuckets tally.Buckets // To support non-default histogram metric bucketing specific to a gateway LogWriter zapcore.WriteSyncer GetContextScopeExtractors func() []ContextScopeTagsExtractor GetContextFieldExtractors func() []ContextLogFieldsExtractor JSONWrapper jsonwrapper.JSONWrapper NotFoundHandler func(*Gateway) http.HandlerFunc Tracer opentracing.Tracer TracerCloser io.Closer // If present, request uuid is retrieved from the incoming request // headers using the key, and put on the context. Otherwise, a new // uuid is created for the incoming request. RequestUUIDHeaderKey string } // Gateway type type Gateway struct { HTTPPort int32 TChannelPort int32 RealHTTPPort int32 RealHTTPAddr string RealTChannelPort int32 RealTChannelAddr string WaitGroup *sync.WaitGroup ServerTChannel *tchannel.Channel ClientTChannels map[string]*tchannel.Channel ContextLogger ContextLogger ContextMetrics ContextMetrics ContextExtractor ContextExtractor RootScope tally.Scope Logger *zap.Logger ServiceName string Config *StaticConfig HTTPRouter HTTPRouter ServerTChannelRouter *TChannelRouter TChannelSubLoggerLevel zapcore.Level Tracer opentracing.Tracer JSONWrapper jsonwrapper.JSONWrapper // gRPC client dispatcher for gRPC client lifecycle management GRPCClientDispatcher *yarpc.Dispatcher atomLevel *zap.AtomicLevel loggerFile *os.File scopeCloser io.Closer metricsBackend tally.CachedStatsReporter runtimeMetrics RuntimeMetricsCollector metricsDefaultBuckets tally.Buckets logEncoder zapcore.Encoder logWriter zapcore.WriteSyncer logWriteSyncer zapcore.WriteSyncer httpServer *HTTPServer localHTTPServer *HTTPServer tchannelServer *tchannel.Channel tracerCloser io.Closer notFoundHandler http.HandlerFunc requestUUIDHeaderKey string isUnhealthy bool } // DefaultDependencies are the common dependencies for all modules type DefaultDependencies struct { // ContextExtractor extracts context for scope and logs field ContextExtractor ContextExtractor // ContextLogger is a logger with request-scoped log fields ContextLogger ContextLogger // ContextMetrics emit metrics from context ContextMetrics ContextMetrics Logger *zap.Logger Scope tally.Scope Tracer opentracing.Tracer Config *StaticConfig ServerTChannel *tchannel.Channel Gateway *Gateway // dispatcher for managing gRPC clients GRPCClientDispatcher *yarpc.Dispatcher // JSONWrapper provides json marshaling and unmarshaling functionality JSONWrapper jsonwrapper.JSONWrapper } // CreateGateway func func CreateGateway( config *StaticConfig, opts *Options, ) (*Gateway, error) { var metricsBackend tally.CachedStatsReporter var metricsDefaultBuckets tally.Buckets var logWriter zapcore.WriteSyncer var scopeTagsExtractors []ContextScopeTagsExtractor var logFieldsExtractors []ContextLogFieldsExtractor var jsonWrapper jsonwrapper.JSONWrapper if opts == nil { opts = &Options{} } if opts.MetricsBackend != nil { metricsBackend = opts.MetricsBackend } if opts.MetricsDefaultBuckets != nil { metricsDefaultBuckets = opts.MetricsDefaultBuckets } else { metricsDefaultBuckets = tally.DefaultBuckets } if opts.LogWriter != nil { logWriter = opts.LogWriter } if opts.JSONWrapper != nil { jsonWrapper = opts.JSONWrapper } else { jsonWrapper = jsonwrapper.NewDefaultJSONWrapper() } if opts.GetContextScopeExtractors != nil { scopeTagsExtractors = opts.GetContextScopeExtractors() } else { scopeTagsExtractors = []ContextScopeTagsExtractor{GetEndpointRequestHeadersFromCtx} } if opts.GetContextFieldExtractors != nil { logFieldsExtractors = opts.GetContextFieldExtractors() } else { logFieldsExtractors = []ContextLogFieldsExtractor{ func(ctx context.Context) []zap.Field { reqHeaders := GetEndpointRequestHeadersFromCtx(ctx) fields := make([]zap.Field, 0, len(reqHeaders)) for k, v := range reqHeaders { fields = append(fields, zap.String(k, v)) } return fields }, } } extractors := &ContextExtractors{ ScopeTagsExtractors: scopeTagsExtractors, LogFieldsExtractors: logFieldsExtractors, } var service string if config.ContainsKey(serviceFromEnvKey) { service = os.Getenv(config.MustGetString(serviceFromEnvKey)) } if service == "" { service = config.MustGetString("serviceName") } gateway := &Gateway{ HTTPPort: int32(config.MustGetInt("http.port")), TChannelPort: int32(config.MustGetInt("tchannel.port")), ServiceName: service, WaitGroup: &sync.WaitGroup{}, Config: config, ContextExtractor: extractors, JSONWrapper: jsonWrapper, logWriter: logWriter, metricsBackend: metricsBackend, metricsDefaultBuckets: metricsDefaultBuckets, requestUUIDHeaderKey: opts.RequestUUIDHeaderKey, } gateway.setupConfig(config) config.Freeze() // order matters for following setup method calls if err := gateway.setupMetrics(config); err != nil { return nil, err } if err := gateway.setupLogger(config); err != nil { return nil, err } if opts.Tracer != nil && opts.TracerCloser != nil && config.ContainsKey("jaeger.tracer.custom") && config.MustGetBoolean("jaeger.tracer.custom") { gateway.Tracer = opts.Tracer gateway.tracerCloser = opts.TracerCloser } else { if err := gateway.setupTracer(config); err != nil { return nil, err } } if opts.NotFoundHandler != nil && config.ContainsKey("http.notFoundHandler.custom") && config.MustGetBoolean("http.notFoundHandler.custom") { gateway.notFoundHandler = opts.NotFoundHandler(gateway) } // setup router after metrics and logs gateway.HTTPRouter = NewHTTPRouter(gateway) if err := gateway.setupHTTPServer(); err != nil { return nil, err } if err := gateway.setupServerTChannel(config); err != nil { return nil, err } // setup gRPC client dispatcher after metrics, logger and tracer if err := gateway.setupGRPCClientDispatcher(config); err != nil { return nil, err } gateway.registerPredefined() return gateway, nil } // Bootstrap func func (gateway *Gateway) Bootstrap() error { env := gateway.Config.MustGetString("env") // start HTTP server gateway.RootScope.Counter("server.bootstrap").Inc(1) _, err := gateway.localHTTPServer.JustListen() if err != nil { gateway.Logger.Error("Error listening on port", zap.Error(err)) return errors.Wrap(err, "error listening on port") } if gateway.localHTTPServer.RealIP != gateway.httpServer.RealIP && env != testenv { _, err := gateway.httpServer.JustListen() if err != nil { gateway.Logger.Error("Error listening on port", zap.Error(err)) return errors.Wrap(err, "error listening on port") } } else { // Do not start at the same IP gateway.httpServer = gateway.localHTTPServer } gateway.RealHTTPPort = gateway.httpServer.RealPort gateway.RealHTTPAddr = gateway.httpServer.RealAddr gateway.WaitGroup.Add(1) go gateway.httpServer.JustServe(gateway.WaitGroup) if gateway.httpServer != gateway.localHTTPServer { gateway.WaitGroup.Add(1) go gateway.localHTTPServer.JustServe(gateway.WaitGroup) } // start TChannel server ip := localhost if gateway.Config.ContainsKey("tchannel.server.ip") { ip = gateway.Config.MustGetString("tchannel.server.ip") } else if env != testenv { tchannelIP, err := tchannel.ListenIP() if err != nil { return errors.Wrap(err, "error finding the best IP for tchannel") } ip = tchannelIP.String() } tchannelAddr := ip + ":" + strconv.Itoa(int(gateway.TChannelPort)) ln, err := net.Listen("tcp", tchannelAddr) if err != nil { gateway.Logger.Error("Error listening tchannel port", zap.Error(err)) return err } gateway.RealTChannelAddr = ln.Addr().String() gateway.RealTChannelPort = int32(ln.Addr().(*net.TCPAddr).Port) // tchannel serve does not block, connection handling is done in different goroutine err = gateway.tchannelServer.Serve(ln) if err != nil { gateway.Logger.Error("Error starting tchannel server", zap.Error(err)) return err } gateway.RootScope.Counter("startup.success").Inc(1) if gateway.GRPCClientDispatcher != nil { err = gateway.GRPCClientDispatcher.Start() if err != nil { gateway.Logger.Error("error starting gRPC client dispatcher", zap.Error(err)) return err } } return nil } func (gateway *Gateway) registerPredefined() { deps := &DefaultDependencies{ Scope: gateway.RootScope, ContextLogger: gateway.ContextLogger, Logger: gateway.Logger, Tracer: gateway.Tracer, } tracer := NewRouterEndpoint( gateway.ContextExtractor, deps, "health", "health", gateway.handleHealthRequest, ) _ = gateway.HTTPRouter.Handle("GET", "/health", http.HandlerFunc(tracer.HandleRequest)) } func (gateway *Gateway) handleHealthRequest( ctx context.Context, req *ServerHTTPRequest, res *ServerHTTPResponse, ) context.Context { if gateway.isUnhealthy { message := "Unhealthy, from " + gateway.ServiceName bytes := []byte( "{\"ok\":false,\"message\":\"" + message + "\"}\n", ) res.WriteJSONBytes(503, nil, bytes) return ctx } message := "Healthy, from " + gateway.ServiceName bytes := []byte( "{\"ok\":true,\"message\":\"" + message + "\"}\n", ) res.WriteJSONBytes(200, nil, bytes) return ctx } // Shutdown starts the graceful shutdown, blocks until it is complete func (gateway *Gateway) Shutdown() { // stop accepting incoming requests as soon as shutdown signal is received. gateway.isUnhealthy = true var swg sync.WaitGroup ctx, cancel := context.WithTimeout(context.Background(), gateway.ShutdownTimeout()) defer cancel() ec := make(chan error, 4) if gateway.localHTTPServer != gateway.httpServer { swg.Add(1) go func() { defer swg.Done() if err := gateway.localHTTPServer.Shutdown(ctx); err != nil { ec <- errors.Wrap(err, "error shutting down local http server") } }() } // shutdown http server swg.Add(1) go func() { defer swg.Done() if err := gateway.httpServer.Shutdown(ctx); err != nil { ec <- errors.Wrap(err, "error shutting down http server") } }() // shutdown tchannel server swg.Add(1) go func() { defer swg.Done() if err := gateway.shutdownTChannelServerAndClients(ctx); err != nil { ec <- errors.Wrap(err, "error shutting down tchannel server or clients") } }() // wait for servers to shutdown before stopping GRPCClientDispatcher swg.Wait() // stop all grpc clients if gateway.GRPCClientDispatcher != nil { swg.Add(1) go func() { defer swg.Done() if err := gateway.GRPCClientDispatcher.Stop(); err != nil { ec <- errors.Wrap(err, "error stopping gRPC client dispatcher") } }() } swg.Wait() select { case err := <-ec: // close ec so that the range ec will not block forever close(ec) errs := make([]string, 0, cap(ec)) errs = append(errs, err.Error()) for e := range ec { errs = append(errs, e.Error()) } gateway.Logger.Error(fmt.Sprintf( "%d errors when shutting down the servers: %s", len(errs), strings.Join(errs, ";")), ) gateway.RootScope.Counter("shutdown.failure").Inc(1) default: gateway.Logger.Info("servers are shut down gracefully") gateway.RootScope.Counter("shutdown.success").Inc(1) } _ = gateway.tracerCloser.Close() gateway.metricsBackend.Flush() _ = gateway.scopeCloser.Close() // close log files as the last step if gateway.loggerFile != nil { _ = gateway.loggerFile.Sync() _ = gateway.loggerFile.Close() } // stop collecting runtime metrics if gateway.runtimeMetrics != nil { gateway.runtimeMetrics.Stop() } } // Close shuts down the servers and returns immediately func (gateway *Gateway) Close() { if gateway.localHTTPServer != gateway.httpServer { gateway.localHTTPServer.Close() } gateway.httpServer.Close() gateway.tchannelServer.Close() _ = gateway.tracerCloser.Close() gateway.metricsBackend.Flush() _ = gateway.scopeCloser.Close() // close log files as the last step if gateway.loggerFile != nil { _ = gateway.loggerFile.Sync() _ = gateway.loggerFile.Close() } // stop collecting runtime metrics if gateway.runtimeMetrics != nil { gateway.runtimeMetrics.Stop() } } // ShutdownTimeout returns the shutdown configured timeout, which default to 10s. func (gateway *Gateway) ShutdownTimeout() time.Duration { if gateway.Config.ContainsKey("shutdown.timeout") { return time.Duration(gateway.Config.MustGetInt("shutdown.timeout")) * time.Millisecond } return defaultCloseTimeout } // InspectOrDie inspects the config for this gateway func (gateway *Gateway) InspectOrDie() map[string]interface{} { return gateway.Config.InspectOrDie() } // Wait for gateway to close the server func (gateway *Gateway) Wait() { gateway.WaitGroup.Wait() } func (gateway *Gateway) setupConfig(config *StaticConfig) { useDC := config.MustGetBoolean("useDatacenter") if useDC { dcFile := config.MustGetString("datacenterFile") bytes, err := ioutil.ReadFile(dcFile) if err != nil { panic("expected datacenterFile: " + dcFile + " to exist") } config.SetSeedOrDie("datacenter", string(bytes)) } else { config.SetConfigValueOrDie("datacenter", []byte("unknown"), "string") } } func (gateway *Gateway) setupMetrics(config *StaticConfig) (err error) { metricsType := config.MustGetString("metrics.type") var service string if config.ContainsKey(metricsServiceFromEnvKey) { service = os.Getenv(config.MustGetString(metricsServiceFromEnvKey)) } if service == "" { service = config.MustGetString("metrics.serviceName") } env := config.MustGetString("env") if metricsType == "m3" { if gateway.metricsBackend != nil { panic("expected no metrics backend in gateway.") } opts := m3.Options{ HostPorts: []string{config.MustGetString("metrics.m3.hostPort")}, Service: service, Env: env, MaxQueueSize: int(config.MustGetInt("metrics.m3.maxQueueSize")), MaxPacketSizeBytes: int32(config.MustGetInt("metrics.m3.maxPacketSizeBytes")), } if gateway.metricsBackend, err = m3.NewReporter(opts); err != nil { return err } } else if gateway.metricsBackend == nil { panic("expected gateway to have MetricsBackend in opts") } defaultTags := map[string]string{ "env": env, "service": service, "dc": gateway.Config.MustGetString("datacenter"), } if config.MustGetBoolean("metrics.m3.includeHost") { defaultTags["host"] = GetHostname() } // Adds in any env variable variables specified in config envVarsToTagInRootScope := []string{} config.MustGetStruct("envVarsToTagInRootScope", &envVarsToTagInRootScope) for _, envVarName := range envVarsToTagInRootScope { envVarValue := os.Getenv(envVarName) defaultTags[envVarName] = envVarValue } gateway.RootScope, gateway.scopeCloser = tally.NewRootScope( tally.ScopeOptions{ Tags: defaultTags, CachedReporter: gateway.metricsBackend, DefaultBuckets: gateway.metricsDefaultBuckets, Separator: tally.DefaultSeparator, SanitizeOptions: &m3.DefaultSanitizerOpts, }, time.Duration(config.MustGetInt("metrics.flushInterval"))*time.Millisecond, ) gateway.ContextMetrics = NewContextMetrics(gateway.RootScope) // start collecting runtime metrics collectInterval := time.Duration(config.MustGetInt("metrics.runtime.collectInterval")) * time.Millisecond runtimeMetricsOpts := RuntimeMetricsOptions{ EnableCPUMetrics: config.MustGetBoolean("metrics.runtime.enableCPUMetrics"), EnableMemMetrics: config.MustGetBoolean("metrics.runtime.enableMemMetrics"), EnableGCMetrics: config.MustGetBoolean("metrics.runtime.enableGCMetrics"), CollectInterval: collectInterval, } // runtime metrics should always have host tag runtimeScope := gateway.RootScope if _, ok := defaultTags["host"]; !ok { runtimeScope = gateway.RootScope.Tagged(map[string]string{ "host": GetHostname(), }) } gateway.runtimeMetrics = StartRuntimeMetricsCollector( runtimeMetricsOpts, runtimeScope, ) //Initialize M3Collector for hystrix metrics c := plugins.InitializeM3Collector(gateway.RootScope) metricCollector.Registry.Register(c.NewM3Collector) return nil } func (gateway *Gateway) setupLogger(config *StaticConfig) error { var output zapcore.WriteSyncer logEncoder := zapcore.NewJSONEncoder(zap.NewProductionEncoderConfig()) logLevel := zap.InfoLevel if config.ContainsKey("logger.level") { levelString := config.MustGetString("logger.level") var ok bool logLevel, ok = levelMap[levelString] if !ok { return errors.Errorf("unknown log level for gateway logger: %s", levelString) } } tempLogger := zap.New( zapcore.NewCore( logEncoder, os.Stderr, logLevel, ), ) loggerFileName := config.MustGetString("logger.fileName") loggerOutput := config.MustGetString("logger.output") if loggerFileName == "" || loggerOutput == "stdout" { var writer zapcore.WriteSyncer if gateway.logWriter != nil { writer = zap.CombineWriteSyncers(os.Stdout, gateway.logWriter) } else { writer = os.Stdout } output = writer } else { err := os.MkdirAll(filepath.Dir(loggerFileName), 0777) if err != nil { tempLogger.Error("Error creating log directory", zap.Error(err)) return errors.Wrap(err, "Error creating log directory") } loggerFile, err := os.OpenFile( loggerFileName, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644, ) if err != nil { tempLogger.Error("Error opening log file", zap.Error(err)) return errors.Wrap(err, "Error opening log file") } gateway.loggerFile = loggerFile if gateway.logWriter != nil { writer := zap.CombineWriteSyncers(loggerFile, gateway.logWriter) output = writer } else { output = loggerFile } } atomLevel := zap.NewAtomicLevelAt(logLevel) zapLogger := zap.New(zapcore.NewCore( logEncoder, output, atomLevel, )) gateway.atomLevel = &atomLevel gateway.logEncoder = logEncoder gateway.logWriteSyncer = output // Default to a STDOUT logger gateway.Logger = zapLogger.With( zap.String("zone", gateway.Config.MustGetString("datacenter")), zap.String("env", gateway.Config.MustGetString("env")), zap.String("hostname", GetHostname()), zap.String("service", gateway.Config.MustGetString("serviceName")), zap.Int("pid", os.Getpid()), ) gateway.ContextLogger = NewContextLogger(gateway.Logger) if config.ContainsKey(skipZanzibarLogsKey) { skipZanzibarLogs := config.MustGetBoolean(skipZanzibarLogsKey) gateway.ContextLogger.SetSkipZanzibarLogs(skipZanzibarLogs) } return nil } // SubLogger returns a sub logger clone with given name and log level. func (gateway *Gateway) SubLogger(name string, level zapcore.Level) *zap.Logger { newCore := zapcore.NewCore( gateway.logEncoder.Clone(), gateway.logWriteSyncer, level, ) return gateway.Logger.With( zap.String("subLogger", name), ).WithOptions( zap.WrapCore(func(core zapcore.Core) zapcore.Core { return newCore }), ) } func (gateway *Gateway) initJaegerConfig(config *StaticConfig) *jaegerConfig.Configuration { return &jaegerConfig.Configuration{ ServiceName: config.MustGetString("serviceName"), Disabled: config.MustGetBoolean("jaeger.disabled"), Reporter: &jaegerConfig.ReporterConfig{ LocalAgentHostPort: config.MustGetString("jaeger.reporter.hostport"), BufferFlushInterval: time.Duration(config.MustGetInt("jaeger.reporter.flush.milliseconds")) * time.Millisecond, }, Sampler: &jaegerConfig.SamplerConfig{ Type: config.MustGetString("jaeger.sampler.type"), Param: config.MustGetFloat("jaeger.sampler.param"), }, } } func (gateway *Gateway) setupTracer(config *StaticConfig) error { levelString := gateway.Config.MustGetString("subLoggerLevel.jaeger") level, ok := levelMap[levelString] if !ok { return errors.Errorf("unknown sub logger level for jaeger tracer: %s", levelString) } opts := []jaegerConfig.Option{ // TChannel logger implements jaeger logger interface jaegerConfig.Logger(NewTChannelLogger(gateway.SubLogger("jaeger", level))), jaegerConfig.Metrics(jaegerLibTally.Wrap(gateway.RootScope)), } jc := gateway.initJaegerConfig(config) tracer, closer, err := jc.NewTracer(opts...) if err != nil { return errors.Wrapf(err, "error initializing Jaeger tracer client") } opentracing.SetGlobalTracer(tracer) gateway.Tracer = tracer gateway.tracerCloser = closer return nil } func (gateway *Gateway) setupHTTPServer() error { levelString := gateway.Config.MustGetString("subLoggerLevel.http") level, ok := levelMap[levelString] if !ok { return errors.Errorf("unknown sub logger level for http server: %s", levelString) } httpLogger := gateway.SubLogger("http", level) listenIP, err := tchannel.ListenIP() if err != nil { return errors.Wrap(err, "error finding the best IP") } gateway.httpServer = &HTTPServer{ Server: &http.Server{ Addr: listenIP.String() + ":" + strconv.FormatInt(int64(gateway.HTTPPort), 10), Handler: gateway.HTTPRouter, }, Logger: httpLogger, } gateway.localHTTPServer = &HTTPServer{ Server: &http.Server{ Addr: "127.0.0.1:" + strconv.FormatInt(int64(gateway.HTTPPort), 10), Handler: gateway.HTTPRouter, }, Logger: httpLogger, } return nil } func (gateway *Gateway) setupServerTChannel(config *StaticConfig) error { serviceName := config.MustGetString("tchannel.serviceName") processName := config.MustGetString("tchannel.processName") levelString := gateway.Config.MustGetString("subLoggerLevel.tchannel") level, ok := levelMap[levelString] if !ok { return errors.Errorf("unknown sub logger level for tchannel server: %s", levelString) } gateway.TChannelSubLoggerLevel = level channel, err := tchannel.NewChannel( serviceName, &tchannel.ChannelOptions{ ProcessName: processName, Tracer: gateway.Tracer, Logger: NewTChannelLogger(gateway.SubLogger("tchannel", level)), StatsReporter: NewTChannelStatsReporter( gateway.RootScope, ), }) if err != nil { return errors.Errorf("Error creating top channel:\n%s", err) } gateway.ServerTChannel = channel gateway.tchannelServer = gateway.ServerTChannel gateway.ServerTChannelRouter = NewTChannelRouter(channel, gateway) // client tchannels are created explicitly for each client if "dedicated.tchannel.client: true" gateway.ClientTChannels = make(map[string]*tchannel.Channel) return nil } // SetupClientTChannel sets up a dedicated tchannel for each client with a given service name // If multiple backends with the same service name exist (for e.g. presentation service), then // all of them would receive the same channel. The method is exported because it is called from // the generated clients if "dedicated.tchannel.client: true" else server tchannel is reused func (gateway *Gateway) SetupClientTChannel(config *StaticConfig, serviceName string) *tchannel.Channel { if ch, ok := gateway.ClientTChannels[serviceName]; ok { gateway.Logger.Info(fmt.Sprintf("returning already initialised TChannel client for [%v]", serviceName)) return ch } processName := config.MustGetString("tchannel.processName") level := gateway.TChannelSubLoggerLevel channel, err := tchannel.NewChannel( // when specifying the service name for the channel, we reuse the server service // name else calls from other unauthorised sources may be blocked config.MustGetString("tchannel.serviceName"), &tchannel.ChannelOptions{ ProcessName: processName, Tracer: gateway.Tracer, Logger: NewTChannelLogger(gateway.SubLogger("tchannel", level)), StatsReporter: NewTChannelStatsReporter(gateway.RootScope), }) scope := gateway.RootScope.Tagged(map[string]string{ "client": serviceName, }) if err != nil { scope.Gauge("tchannel.client.running").Update(0) gateway.Logger.Info(fmt.Sprintf("Failed to initiate dedicated TChannel client for [%v]", serviceName)) } else { gateway.Logger.Info(fmt.Sprintf("Dedicated TChannel client initiated for client [%v]", serviceName)) scope.Gauge("tchannel.client.running").Update(1) } gateway.ClientTChannels[serviceName] = channel return channel } func (gateway *Gateway) setupGRPCClientDispatcher(config *StaticConfig) error { ip := config.MustGetString("sidecarRouter.default.grpc.ip") port := config.MustGetInt("sidecarRouter.default.grpc.port") address := fmt.Sprintf("%s:%d", ip, port) clientServiceNameMapping := make(map[string]string) config.MustGetStruct("grpc.clientServiceNameMapping", &clientServiceNameMapping) if len(clientServiceNameMapping) == 0 { return nil } unaryOutbound := grpc.NewTransport( grpc.Logger(gateway.Logger), grpc.Tracer(gateway.Tracer), ).NewSingleOutbound(address) outbounds := make(yarpc.Outbounds, len(clientServiceNameMapping)) for key, value := range clientServiceNameMapping { outbounds[key] = transport.Outbounds{ ServiceName: value, Unary: unaryOutbound, } } dispatcher := yarpc.NewDispatcher(yarpc.Config{ Name: config.MustGetString("serviceName"), Outbounds: outbounds, Logging: yarpc.LoggingConfig{ Zap: gateway.Logger, // TODO: set proper extractors }, Metrics: yarpc.MetricsConfig{ // TODO: contextual scope Tally: gateway.RootScope, }, }) gateway.GRPCClientDispatcher = dispatcher return nil } // GetDirnameFromRuntimeCaller will compute the current dirname // if passed a filename from runtime.Caller(0). This is useful // for doing __dirname/__FILE__ for golang. func GetDirnameFromRuntimeCaller(file string) string { dirname := filepath.Dir(file) // Strip _obj dirs generated by test -cover ... if filepath.Base(dirname) == "_obj" { dirname = filepath.Dir(dirname) } // Strip _obj_test in go test -cover if filepath.Base(dirname) == "_obj_test" { dirname = filepath.Dir(dirname) } // go test -cover does weird folder stuff if filepath.Base(dirname) == "_test" { dirname = filepath.Dir(dirname) } // if filepath then we are done, otherwise its go package name if filepath.IsAbs(dirname) { return dirname } // If dirname is not absolute then its a package name... return filepath.Join(os.Getenv("GOPATH"), "src", dirname) } // GetHostname returns hostname func GetHostname() string { host, err := os.Hostname() if err != nil { host = "unknown" } return host } // shutdownTChannelServerAndClients gracefully shuts down the tchannel server, blocks until the shutdown is // complete or the timeout has reached if there is one associated with the given context // It also shuts down all the dedicated client tchannel connections on a best effort basis func (gateway *Gateway) shutdownTChannelServerAndClients(ctx context.Context) error { shutdownPollInterval := defaultShutdownPollInterval if gateway.Config.ContainsKey("shutdown.pollInterval") { shutdownPollInterval = time.Duration(gateway.Config.MustGetInt("shutdown.pollInterval")) * time.Millisecond } ticker := time.NewTicker(shutdownPollInterval) defer ticker.Stop() gateway.Logger.Info("Closing the TChannel server") gateway.tchannelServer.Close() for { select { case <-ctx.Done(): return ctx.Err() case <-ticker.C: if gateway.tchannelServer.Closed() { gateway.Logger.Info("TChannel server closed successfully") for serviceName, clientTchannel := range gateway.ClientTChannels { go func(service string, ch *tchannel.Channel) { gateway.Logger.Info(fmt.Sprintf("Closing TChannel client for [%v]", service)) ch.Close() gateway.RootScope.Tagged(map[string]string{ "client": service, }).Gauge("tchannel.client.running").Update(0) }(serviceName, clientTchannel) } } else { gateway.Logger.Info(fmt.Sprintf("Failed to close TChannel server within %v ms", shutdownPollInterval)) } return nil } } }