banyand/queue/sub/server.go (252 lines of code) (raw):

// Licensed to Apache Software Foundation (ASF) under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Apache Software Foundation (ASF) licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package sub import ( "context" "net" "net/http" "runtime/debug" "strconv" "sync" "time" "github.com/go-chi/chi/v5" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/recovery" grpc_validator "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/pkg/errors" grpclib "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/health" "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/status" clusterv1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/cluster/v1" databasev1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1" measurev1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/measure/v1" streamv1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/stream/v1" "github.com/apache/skywalking-banyandb/banyand/observability" "github.com/apache/skywalking-banyandb/banyand/queue" "github.com/apache/skywalking-banyandb/pkg/bus" "github.com/apache/skywalking-banyandb/pkg/healthcheck" "github.com/apache/skywalking-banyandb/pkg/logger" "github.com/apache/skywalking-banyandb/pkg/meter" "github.com/apache/skywalking-banyandb/pkg/run" ) const defaultRecvSize = 10 << 20 var ( errServerCert = errors.New("invalid server cert file") errServerKey = errors.New("invalid server key file") errNoAddr = errors.New("no address") _ run.PreRunner = (*server)(nil) _ run.Service = (*server)(nil) queueSubScope = observability.RootScope.SubScope("queue_sub") ) type server struct { databasev1.UnimplementedSnapshotServiceServer streamv1.UnimplementedStreamServiceServer creds credentials.TransportCredentials omr observability.MetricsRegistry httpSrv *http.Server log *logger.Logger ser *grpclib.Server listeners map[bus.Topic][]bus.MessageListener topicMap map[string]bus.Topic clusterv1.UnimplementedServiceServer metrics *metrics clientCloser context.CancelFunc host string addr string httpAddr string certFile string keyFile string maxRecvMsgSize run.Bytes listenersLock sync.RWMutex port uint32 httpPort uint32 tls bool } // NewServer returns a new gRPC server. func NewServer(omr observability.MetricsRegistry) queue.Server { return &server{ listeners: make(map[bus.Topic][]bus.MessageListener), topicMap: make(map[string]bus.Topic), omr: omr, } } func (s *server) PreRun(_ context.Context) error { s.log = logger.GetLogger("server-queue-sub") s.metrics = newMetrics(s.omr.With(queueSubScope)) return nil } func (s *server) Name() string { return "server-queue" } func (s *server) Role() databasev1.Role { return databasev1.Role_ROLE_DATA } func (s *server) GetPort() *uint32 { return &s.port } func (s *server) FlagSet() *run.FlagSet { fs := run.NewFlagSet("grpc") s.maxRecvMsgSize = defaultRecvSize fs.VarP(&s.maxRecvMsgSize, "max-recv-msg-size", "", "the size of max receiving message") fs.BoolVar(&s.tls, "tls", false, "connection uses TLS if true, else plain TCP") fs.StringVar(&s.certFile, "cert-file", "", "the TLS cert file") fs.StringVar(&s.keyFile, "key-file", "", "the TLS key file") fs.StringVar(&s.host, "grpc-host", "", "the host of banyand listens") fs.Uint32Var(&s.port, "grpc-port", 17912, "the port of banyand listens") fs.Uint32Var(&s.httpPort, "http-port", 17913, "the port of banyand http api listens") return fs } func (s *server) Validate() error { s.addr = net.JoinHostPort(s.host, strconv.FormatUint(uint64(s.port), 10)) if s.addr == ":" { return errNoAddr } s.httpAddr = net.JoinHostPort(s.host, strconv.FormatUint(uint64(s.httpPort), 10)) if s.httpAddr == ":" { return errNoAddr } if !s.tls { return nil } if s.certFile == "" { return errServerCert } if s.keyFile == "" { return errServerKey } creds, errTLS := credentials.NewServerTLSFromFile(s.certFile, s.keyFile) if errTLS != nil { return errors.Wrap(errTLS, "failed to load cert and key") } s.creds = creds return nil } func (s *server) Serve() run.StopNotify { var opts []grpclib.ServerOption if s.tls { opts = []grpclib.ServerOption{grpclib.Creds(s.creds)} } grpcPanicRecoveryHandler := func(p any) (err error) { s.log.Error().Interface("panic", p).Str("stack", string(debug.Stack())).Msg("recovered from panic") return status.Errorf(codes.Internal, "%s", p) } streamChain := []grpclib.StreamServerInterceptor{ recovery.StreamServerInterceptor(recovery.WithRecoveryHandler(grpcPanicRecoveryHandler)), } unaryChain := []grpclib.UnaryServerInterceptor{ grpc_validator.UnaryServerInterceptor(), recovery.UnaryServerInterceptor(recovery.WithRecoveryHandler(grpcPanicRecoveryHandler)), } opts = append(opts, grpclib.MaxRecvMsgSize(int(s.maxRecvMsgSize)), grpclib.ChainUnaryInterceptor(unaryChain...), grpclib.ChainStreamInterceptor(streamChain...), ) s.ser = grpclib.NewServer(opts...) clusterv1.RegisterServiceServer(s.ser, s) grpc_health_v1.RegisterHealthServer(s.ser, health.NewServer()) databasev1.RegisterSnapshotServiceServer(s.ser, s) streamv1.RegisterStreamServiceServer(s.ser, &streamService{ser: s}) measurev1.RegisterMeasureServiceServer(s.ser, &measureService{ser: s}) var ctx context.Context ctx, s.clientCloser = context.WithCancel(context.Background()) clientOpts := make([]grpclib.DialOption, 0, 1) if s.creds == nil { clientOpts = append(clientOpts, grpclib.WithTransportCredentials(insecure.NewCredentials())) } else { clientOpts = append(clientOpts, grpclib.WithTransportCredentials(s.creds)) } stopCh := make(chan struct{}) client, err := healthcheck.NewClient(ctx, s.log, s.addr, clientOpts) if err != nil { s.log.Error().Err(err).Msg("Failed to health check client") close(stopCh) return stopCh } gwMux := runtime.NewServeMux(runtime.WithHealthzEndpoint(client)) if err := databasev1.RegisterSnapshotServiceHandlerFromEndpoint(ctx, gwMux, s.addr, clientOpts); err != nil { s.log.Error().Err(err).Msg("Failed to register snapshot service") close(stopCh) return stopCh } mux := chi.NewRouter() mux.Mount("/api", http.StripPrefix("/api", gwMux)) s.httpSrv = &http.Server{ Addr: s.httpAddr, Handler: mux, ReadHeaderTimeout: 3 * time.Second, } var wg sync.WaitGroup wg.Add(2) go func() { lis, err := net.Listen("tcp", s.addr) if err != nil { s.log.Error().Err(err).Msg("Failed to listen") close(stopCh) return } s.log.Info().Str("addr", s.addr).Msg("Listening to") err = s.ser.Serve(lis) if err != nil { s.log.Error().Err(err).Msg("server is interrupted") } wg.Done() }() go func() { s.log.Info().Str("listenAddr", s.httpAddr).Msg("Start healthz http server") err := s.httpSrv.ListenAndServe() if err != http.ErrServerClosed { s.log.Error().Err(err) } wg.Done() }() go func() { wg.Wait() s.log.Info().Msg("All servers are stopped") close(stopCh) }() return stopCh } func (s *server) GracefulStop() { s.log.Info().Msg("stopping") stopped := make(chan struct{}) s.clientCloser() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _ = s.httpSrv.Shutdown(ctx) go func() { s.ser.GracefulStop() close(stopped) }() t := time.NewTimer(10 * time.Second) select { case <-t.C: s.ser.Stop() s.log.Info().Msg("force stopped") case <-stopped: t.Stop() s.log.Info().Msg("stopped gracefully") } } type metrics struct { totalStarted meter.Counter totalFinished meter.Counter totalErr meter.Counter totalLatency meter.Counter totalMsgReceived meter.Counter totalMsgReceivedErr meter.Counter totalMsgSent meter.Counter totalMsgSentErr meter.Counter } func newMetrics(factory *observability.Factory) *metrics { return &metrics{ totalStarted: factory.NewCounter("total_started", "topic"), totalFinished: factory.NewCounter("total_finished", "topic"), totalErr: factory.NewCounter("total_err", "topic"), totalLatency: factory.NewCounter("total_latency", "topic"), totalMsgReceived: factory.NewCounter("total_msg_received", "topic"), totalMsgReceivedErr: factory.NewCounter("total_msg_received_err", "topic"), totalMsgSent: factory.NewCounter("total_msg_sent", "topic"), totalMsgSentErr: factory.NewCounter("total_msg_sent_err", "topic"), } }