banyand/liaison/grpc/server.go (238 lines of code) (raw):

// Licensed to Apache Software Foundation (ASF) under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Apache Software Foundation (ASF) licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. // Package grpc implements the gRPC services defined by APIs. package grpc import ( "context" "net" "runtime/debug" "time" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/recovery" grpc_validator "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator" "github.com/pkg/errors" grpclib "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/health" "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/status" databasev1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1" measurev1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/measure/v1" propertyv1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/property/v1" streamv1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/stream/v1" "github.com/apache/skywalking-banyandb/banyand/metadata" "github.com/apache/skywalking-banyandb/banyand/metadata/schema" "github.com/apache/skywalking-banyandb/banyand/observability" "github.com/apache/skywalking-banyandb/banyand/queue" "github.com/apache/skywalking-banyandb/pkg/logger" "github.com/apache/skywalking-banyandb/pkg/run" ) const defaultRecvSize = 10 << 20 var ( errServerCert = errors.New("invalid server cert file") errServerKey = errors.New("invalid server key file") errNoAddr = errors.New("no address") errQueryMsg = errors.New("invalid query message") errAccessLogRootPath = errors.New("access log root path is required") ) type server struct { pipeline queue.Queue creds credentials.TransportCredentials *indexRuleRegistryServer measureSVC *measureService log *logger.Logger ser *grpclib.Server *propertyServer *topNAggregationRegistryServer *groupRegistryServer stopCh chan struct{} streamSVC *streamService *measureRegistryServer *streamRegistryServer *indexRuleBindingRegistryServer addr string keyFile string certFile string accessLogRootPath string accessLogRecorders []accessLogRecorder maxRecvMsgSize run.Bytes tls bool enableIngestionAccessLog bool } // NewServer returns a new gRPC server. func NewServer(_ context.Context, pipeline queue.Queue, schemaRegistry metadata.Repo) run.Unit { streamSVC := &streamService{ discoveryService: newDiscoveryService(pipeline, schema.KindStream, schemaRegistry), } measureSVC := &measureService{ discoveryService: newDiscoveryService(pipeline, schema.KindMeasure, schemaRegistry), } s := &server{ pipeline: pipeline, streamSVC: streamSVC, measureSVC: measureSVC, streamRegistryServer: &streamRegistryServer{ schemaRegistry: schemaRegistry, }, indexRuleBindingRegistryServer: &indexRuleBindingRegistryServer{ schemaRegistry: schemaRegistry, }, indexRuleRegistryServer: &indexRuleRegistryServer{ schemaRegistry: schemaRegistry, }, measureRegistryServer: &measureRegistryServer{ schemaRegistry: schemaRegistry, }, groupRegistryServer: &groupRegistryServer{ schemaRegistry: schemaRegistry, }, topNAggregationRegistryServer: &topNAggregationRegistryServer{ schemaRegistry: schemaRegistry, }, propertyServer: &propertyServer{ schemaRegistry: schemaRegistry, }, } s.accessLogRecorders = []accessLogRecorder{streamSVC, measureSVC} return s } func (s *server) PreRun() error { s.log = logger.GetLogger("liaison-grpc") s.streamSVC.setLogger(s.log) s.measureSVC.setLogger(s.log) components := []*discoveryService{ s.streamSVC.discoveryService, s.measureSVC.discoveryService, } for _, c := range components { c.SetLogger(s.log) if err := c.initialize(); err != nil { return err } } if s.enableIngestionAccessLog { for _, alr := range s.accessLogRecorders { if err := alr.activeIngestionAccessLog(s.accessLogRootPath); err != nil { return err } } } return nil } func (s *server) Name() string { return "grpc" } func (s *server) FlagSet() *run.FlagSet { fs := run.NewFlagSet("grpc") s.maxRecvMsgSize = defaultRecvSize fs.VarP(&s.maxRecvMsgSize, "max-recv-msg-size", "", "the size of max receiving message") fs.BoolVar(&s.tls, "tls", false, "connection uses TLS if true, else plain TCP") fs.StringVar(&s.certFile, "cert-file", "", "the TLS cert file") fs.StringVar(&s.keyFile, "key-file", "", "the TLS key file") fs.StringVar(&s.addr, "addr", ":17912", "the address of banyand listens") fs.BoolVar(&s.enableIngestionAccessLog, "enable-ingestion-access-log", false, "enable ingestion access log") fs.StringVar(&s.accessLogRootPath, "access-log-root-path", "", "access log root path") return fs } func (s *server) Validate() error { if s.addr == "" { return errNoAddr } if s.enableIngestionAccessLog && s.accessLogRootPath == "" { return errAccessLogRootPath } observability.UpdateAddress("grpc", s.addr) if !s.tls { return nil } if s.certFile == "" { return errServerCert } if s.keyFile == "" { return errServerKey } creds, errTLS := credentials.NewServerTLSFromFile(s.certFile, s.keyFile) if errTLS != nil { return errors.Wrap(errTLS, "failed to load cert and key") } s.creds = creds return nil } func (s *server) Serve() run.StopNotify { var opts []grpclib.ServerOption if s.tls { opts = []grpclib.ServerOption{grpclib.Creds(s.creds)} } grpcPanicRecoveryHandler := func(p any) (err error) { s.log.Error().Interface("panic", p).Str("stack", string(debug.Stack())).Msg("recovered from panic") return status.Errorf(codes.Internal, "%s", p) } unaryMetrics, streamMetrics := observability.MetricsServerInterceptor() streamChain := []grpclib.StreamServerInterceptor{ grpc_validator.StreamServerInterceptor(), recovery.StreamServerInterceptor(recovery.WithRecoveryHandler(grpcPanicRecoveryHandler)), } if streamMetrics != nil { streamChain = append(streamChain, streamMetrics) } unaryChain := []grpclib.UnaryServerInterceptor{ grpc_validator.UnaryServerInterceptor(), recovery.UnaryServerInterceptor(recovery.WithRecoveryHandler(grpcPanicRecoveryHandler)), } if unaryMetrics != nil { unaryChain = append(unaryChain, unaryMetrics) } opts = append(opts, grpclib.MaxRecvMsgSize(int(s.maxRecvMsgSize)), grpclib.ChainUnaryInterceptor(unaryChain...), grpclib.ChainStreamInterceptor(streamChain...), ) s.ser = grpclib.NewServer(opts...) streamv1.RegisterStreamServiceServer(s.ser, s.streamSVC) measurev1.RegisterMeasureServiceServer(s.ser, s.measureSVC) // register *Registry databasev1.RegisterGroupRegistryServiceServer(s.ser, s.groupRegistryServer) databasev1.RegisterIndexRuleBindingRegistryServiceServer(s.ser, s.indexRuleBindingRegistryServer) databasev1.RegisterIndexRuleRegistryServiceServer(s.ser, s.indexRuleRegistryServer) databasev1.RegisterStreamRegistryServiceServer(s.ser, s.streamRegistryServer) databasev1.RegisterMeasureRegistryServiceServer(s.ser, s.measureRegistryServer) propertyv1.RegisterPropertyServiceServer(s.ser, s.propertyServer) databasev1.RegisterTopNAggregationRegistryServiceServer(s.ser, s.topNAggregationRegistryServer) grpc_health_v1.RegisterHealthServer(s.ser, health.NewServer()) s.stopCh = make(chan struct{}) go func() { lis, err := net.Listen("tcp", s.addr) if err != nil { s.log.Error().Err(err).Msg("Failed to listen") close(s.stopCh) return } s.log.Info().Str("addr", s.addr).Msg("Listening to") err = s.ser.Serve(lis) if err != nil { s.log.Error().Err(err).Msg("server is interrupted") } close(s.stopCh) }() return s.stopCh } func (s *server) GracefulStop() { s.log.Info().Msg("stopping") stopped := make(chan struct{}) go func() { s.ser.GracefulStop() if s.enableIngestionAccessLog { for _, alr := range s.accessLogRecorders { _ = alr.Close() } } close(stopped) }() t := time.NewTimer(10 * time.Second) select { case <-t.C: s.ser.Stop() s.log.Info().Msg("force stopped") case <-stopped: t.Stop() s.log.Info().Msg("stopped gracefully") } } type accessLogRecorder interface { activeIngestionAccessLog(root string) error Close() error }