banyand/liaison/http/server.go (341 lines of code) (raw):
// Licensed to Apache Software Foundation (ASF) under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Apache Software Foundation (ASF) licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// Package http implements the gRPC gateway.
package http
import (
"context"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/go-chi/chi/v5"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/pkg/errors"
"go.uber.org/multierr"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
commonv1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/common/v1"
databasev1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/database/v1"
measurev1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/measure/v1"
propertyv1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/property/v1"
streamv1 "github.com/apache/skywalking-banyandb/api/proto/banyandb/stream/v1"
"github.com/apache/skywalking-banyandb/pkg/healthcheck"
"github.com/apache/skywalking-banyandb/pkg/logger"
"github.com/apache/skywalking-banyandb/pkg/run"
pkgtls "github.com/apache/skywalking-banyandb/pkg/tls"
)
var (
_ run.Config = (*server)(nil)
_ run.Service = (*server)(nil)
errServerCert = errors.New("http: invalid server cert file")
errServerKey = errors.New("http: invalid server key file")
errNoAddr = errors.New("http: no address")
)
// NewServer return a http service.
func NewServer() Server {
return &server{
stopCh: make(chan struct{}),
}
}
// Server is the http service.
type Server interface {
run.Unit
GetPort() *uint32
}
type server struct {
creds credentials.TransportCredentials
tlsReloader *pkgtls.Reloader
grpcTLSReloader *pkgtls.Reloader
l *logger.Logger
handlerWrapper *atomicHandler
srv *http.Server
stopCh chan struct{}
gwMux *runtime.ServeMux
grpcClient *healthcheck.Client
grpcCtx context.Context
grpcCancel context.CancelFunc
host string
listenAddr string
grpcAddr string
keyFile string
certFile string
grpcCert string
port uint32
tls bool
}
func (p *server) FlagSet() *run.FlagSet {
flagSet := run.NewFlagSet("http")
flagSet.StringVar(&p.host, "http-host", "", "listen host for http")
flagSet.Uint32Var(&p.port, "http-port", 17913, "listen port for http")
flagSet.StringVar(&p.grpcAddr, "http-grpc-addr", "localhost:17912", "http server redirect grpc requests to this address")
flagSet.StringVar(&p.certFile, "http-cert-file", "", "the TLS cert file of http server")
flagSet.StringVar(&p.keyFile, "http-key-file", "", "the TLS key file of http server")
flagSet.StringVar(&p.grpcCert, "http-grpc-cert-file", "", "the grpc TLS cert file if grpc server enables tls")
flagSet.BoolVar(&p.tls, "http-tls", false, "connection uses TLS if true, else plain HTTP")
return flagSet
}
func (p *server) Validate() error {
p.listenAddr = net.JoinHostPort(p.host, strconv.FormatUint(uint64(p.port), 10))
if p.listenAddr == ":" {
return errNoAddr
}
if !p.tls {
return nil
}
if p.certFile == "" {
return errServerCert
}
if p.keyFile == "" {
return errServerKey
}
return nil
}
func (p *server) Name() string {
return "liaison-http"
}
func (p *server) Role() databasev1.Role {
return databasev1.Role_ROLE_LIAISON
}
func (p *server) GetPort() *uint32 {
return &p.port
}
func (p *server) PreRun(_ context.Context) error {
p.l = logger.GetLogger(p.Name())
p.l.Info().Str("level", p.l.GetLevel().String()).Msg("Logger initialized")
// Log flag values after parsing
p.l.Debug().Bool("tls", p.tls).Str("certFile", p.certFile).Str("keyFile", p.keyFile).Msg("Flag values after parsing")
// Initialize TLSReloader if TLS is enabled
p.l.Debug().Bool("tls", p.tls).Msg("HTTP TLS flag is set")
if p.tls {
p.l.Debug().Str("certFile", p.certFile).Str("keyFile", p.keyFile).Msg("Initializing TLSReloader for HTTP")
var err error
p.tlsReloader, err = pkgtls.NewReloader(p.certFile, p.keyFile, p.l)
if err != nil {
p.l.Error().Err(err).Msg("Failed to initialize TLSReloader for HTTP")
return err
}
} else {
p.l.Warn().Msg("HTTP TLS is disabled, skipping TLSReloader initialization")
}
// Initialize gRPC client with cert file
if p.grpcCert != "" {
p.l.Debug().Str("grpcCert", p.grpcCert).Msg("Initializing TLS credentials for gRPC connection")
// Create a client cert reloader that only watches the cert file
var err error
p.grpcTLSReloader, err = pkgtls.NewClientCertReloader(p.grpcCert, p.l)
if err != nil {
p.l.Error().Err(err).Msg("Failed to initialize gRPC TLS reloader")
return err
}
// Start the reloader
if err = p.grpcTLSReloader.Start(); err != nil {
p.l.Error().Err(err).Msg("Failed to start gRPC TLS reloader")
return err
}
// Get the update channel from the reloader
certUpdateCh := p.grpcTLSReloader.GetUpdateChannel()
p.l.Info().Msg("Starting certificate update notification listener")
// Start a goroutine to watch for certificate update events
go func() {
p.l.Info().Msg("Certificate update notification goroutine started")
for {
select {
case <-certUpdateCh:
// Certificate was updated, let's debounce to handle potential multiple notifications
p.l.Info().Msg("Received certificate update notification")
// Debounce multiple notifications that might come in rapid succession
func() {
p.l.Info().Msg("Processing certificate update after debounce")
// Cancel existing gRPC connections
prevGRPCCancel := p.grpcCancel
if prevGRPCCancel != nil {
defer func() {
p.l.Info().Msg("Canceling existing gRPC connections")
prevGRPCCancel()
}()
}
// Create a new context for the new connections
p.grpcCtx, p.grpcCancel = context.WithCancel(context.Background())
// Force a short delay to ensure all resources are properly cleaned up
time.Sleep(200 * time.Millisecond)
// Re-create the gateway with updated credentials
if p.gwMux != nil {
p.l.Info().Msg("Re-creating gateway with updated credentials")
}
// Reinitialize the gRPC client (which will get fresh credentials from the reloader)
if err := p.initGRPCClient(); err != nil {
p.l.Error().Err(err).Msg("Failed to reinitialize gRPC client after credential update")
} else {
p.l.Info().Msg("Successfully reinitialized gRPC client with new credentials")
}
}()
case <-p.stopCh:
p.l.Info().Msg("Stopping certificate update notification listener")
return
}
}
}()
}
p.handlerWrapper = &atomicHandler{}
// Configure the HTTP server with dynamic TLS if enabled
p.srv = &http.Server{
Addr: p.listenAddr,
Handler: p.handlerWrapper,
ReadHeaderTimeout: 3 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 60 * time.Second,
}
if p.tls {
p.srv.TLSConfig = p.tlsReloader.GetTLSConfig()
}
return nil
}
func (p *server) Serve() run.StopNotify {
p.grpcCtx, p.grpcCancel = context.WithCancel(context.Background())
// Start TLS reloader for HTTP server
if p.tls && p.tlsReloader != nil {
if err := p.tlsReloader.Start(); err != nil {
p.l.Error().Err(err).Msg("Failed to start TLSReloader for HTTP")
close(p.stopCh)
return p.stopCh
}
}
// Initialize gRPC client and gateway mux
if err := p.initGRPCClient(); err != nil {
p.l.Error().Err(err).Msg("Failed to initialize gRPC client")
close(p.stopCh)
return p.stopCh
}
go func() {
defer close(p.stopCh)
p.l.Info().Str("listenAddr", p.listenAddr).Msg("Start liaison http server")
var err error
if p.tls {
// Start the TLSReloader file monitoring already done above
// Use TLS with dynamic certificate loading
err = p.srv.ListenAndServeTLS("", "") // Empty strings because TLSConfig is set
} else {
err = p.srv.ListenAndServe()
}
if err != http.ErrServerClosed {
p.l.Error().Err(err).Msg("HTTP server failed")
}
}()
return p.stopCh
}
// initGRPCClient initializes or reinitializes the gRPC client with current credentials.
func (p *server) initGRPCClient() error {
// Clean up any existing client first
if p.grpcClient != nil {
p.l.Debug().Msg("Cleaning up existing gRPC client")
}
// Simplify the options slice initialization
var opts []grpc.DialOption
// Use switch statement instead of if-else chain
switch {
case p.grpcTLSReloader != nil:
// Extract hostname from grpcAddr
host, _, err := net.SplitHostPort(p.grpcAddr)
if err != nil {
p.l.Error().Err(err).Msg("Failed to split gRPC address")
return errors.Wrap(err, "failed to split gRPC address")
}
if host == "" || host == "0.0.0.0" || host == "[::]" {
host = "localhost"
}
// Get fresh TLS config from the reloader
tlsConfig, err := p.grpcTLSReloader.GetClientTLSConfig(host)
if err != nil {
p.l.Error().Err(err).Msg("Failed to get TLS config from reloader")
return errors.Wrap(err, "failed to get TLS config from reloader")
}
// Create new credentials from the TLS config
p.creds = credentials.NewTLS(tlsConfig)
p.l.Debug().Msg("Created fresh gRPC credentials from reloader")
opts = append(opts, grpc.WithTransportCredentials(p.creds))
case p.creds != nil:
opts = append(opts, grpc.WithTransportCredentials(p.creds))
default:
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
}
// Create health check client
var err error
p.grpcClient, err = healthcheck.NewClient(p.grpcCtx, p.l, p.grpcAddr, opts)
if err != nil {
return errors.Wrap(err, "failed to create health check client")
}
// Create gateway mux with health endpoint
p.gwMux = runtime.NewServeMux(runtime.WithHealthzEndpoint(p.grpcClient))
// Register all service handlers
err = multierr.Combine(
commonv1.RegisterServiceHandlerFromEndpoint(p.grpcCtx, p.gwMux, p.grpcAddr, opts),
databasev1.RegisterStreamRegistryServiceHandlerFromEndpoint(p.grpcCtx, p.gwMux, p.grpcAddr, opts),
databasev1.RegisterMeasureRegistryServiceHandlerFromEndpoint(p.grpcCtx, p.gwMux, p.grpcAddr, opts),
databasev1.RegisterIndexRuleRegistryServiceHandlerFromEndpoint(p.grpcCtx, p.gwMux, p.grpcAddr, opts),
databasev1.RegisterIndexRuleBindingRegistryServiceHandlerFromEndpoint(p.grpcCtx, p.gwMux, p.grpcAddr, opts),
databasev1.RegisterGroupRegistryServiceHandlerFromEndpoint(p.grpcCtx, p.gwMux, p.grpcAddr, opts),
databasev1.RegisterTopNAggregationRegistryServiceHandlerFromEndpoint(p.grpcCtx, p.gwMux, p.grpcAddr, opts),
databasev1.RegisterSnapshotServiceHandlerFromEndpoint(p.grpcCtx, p.gwMux, p.grpcAddr, opts),
databasev1.RegisterPropertyRegistryServiceHandlerFromEndpoint(p.grpcCtx, p.gwMux, p.grpcAddr, opts),
streamv1.RegisterStreamServiceHandlerFromEndpoint(p.grpcCtx, p.gwMux, p.grpcAddr, opts),
measurev1.RegisterMeasureServiceHandlerFromEndpoint(p.grpcCtx, p.gwMux, p.grpcAddr, opts),
propertyv1.RegisterPropertyServiceHandlerFromEndpoint(p.grpcCtx, p.gwMux, p.grpcAddr, opts),
)
if err != nil {
return errors.Wrap(err, "failed to register endpoints")
}
// Create a new router to replace the existing one
// This avoids the conflict when remounting to /api path
newMux := chi.NewRouter()
// Mount the gateway mux to the HTTP server
newMux.Mount("/api", http.StripPrefix("/api", p.gwMux))
// Replace the old mux with the new one
if err := p.setRootPath(newMux); err != nil {
return err
}
p.handlerWrapper.Store(newMux)
return nil
}
func (p *server) GracefulStop() {
if p.tlsReloader != nil {
p.tlsReloader.Stop()
}
if p.grpcTLSReloader != nil {
p.grpcTLSReloader.Stop()
}
if p.grpcCancel != nil {
p.grpcCancel()
}
if err := p.srv.Close(); err != nil {
p.l.Error().Err(err)
}
}
func intercept404(handler, on404 http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
hookedWriter := &hookedResponseWriter{ResponseWriter: w}
handler.ServeHTTP(hookedWriter, r)
if hookedWriter.got404 {
on404.ServeHTTP(w, r)
}
}
}
type hookedResponseWriter struct {
http.ResponseWriter
got404 bool
}
func (hrw *hookedResponseWriter) WriteHeader(status int) {
if status == http.StatusNotFound {
hrw.got404 = true
} else {
hrw.ResponseWriter.WriteHeader(status)
}
}
func (hrw *hookedResponseWriter) Write(p []byte) (int, error) {
if hrw.got404 {
return len(p), nil
}
return hrw.ResponseWriter.Write(p)
}
func serveFileContents(file string, files http.FileSystem) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !strings.Contains(r.Header.Get("Accept"), "text/html") {
w.WriteHeader(http.StatusNotFound)
fmt.Fprint(w, "404 not found")
return
}
index, err := files.Open(file)
if err != nil {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "%s not found", file)
return
}
fi, err := index.Stat()
if err != nil {
w.WriteHeader(http.StatusNotFound)
fmt.Fprintf(w, "%s not found", file)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
http.ServeContent(w, r, fi.Name(), fi.ModTime(), index)
}
}
type atomicHandler struct {
value atomic.Value // stores http.Handler
}
func (h *atomicHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
handler, ok := h.value.Load().(http.Handler)
if ok && handler != nil {
handler.ServeHTTP(w, r)
} else {
http.NotFound(w, r)
}
}
func (h *atomicHandler) Store(handler http.Handler) {
h.value.Store(handler)
}