pkg/healthprobe/gw_health.go (90 lines of code) (raw):
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package healthprobe
import (
"context"
"errors"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
"sigs.k8s.io/controller-runtime/pkg/log"
"github.com/Azure/kube-egress-gateway/pkg/consts"
)
type LBProbeServer struct {
lock sync.RWMutex
activeGateways map[string]bool
listenPort int
}
func NewLBProbeServer(listenPort int) *LBProbeServer {
return &LBProbeServer{
activeGateways: make(map[string]bool),
listenPort: listenPort,
}
}
func (svr *LBProbeServer) Start(ctx context.Context) error {
log := log.FromContext(ctx)
mux := http.NewServeMux()
mux.HandleFunc(consts.GatewayHealthProbeEndpoint, svr.serveHTTP)
httpServer := &http.Server{
Addr: net.JoinHostPort("", strconv.Itoa(svr.listenPort)),
Handler: mux,
MaxHeaderBytes: 1 << 20,
IdleTimeout: 90 * time.Second, // matches http.DefaultTransport keep-alive timeout
ReadHeaderTimeout: 32 * time.Second,
}
go func() {
log.Info("Starting gateway lb health probe server")
if err := httpServer.ListenAndServe(); err != nil {
if errors.Is(err, http.ErrServerClosed) {
return
}
log.Error(err, "failed to start gateway lb health probe server")
}
}()
// Shutdown the server when stop is closed.
<-ctx.Done()
log.Info("Stopping gateway lb health probe server")
if err := httpServer.Close(); err != nil {
log.Error(err, "failed to close gateway lb health probe server")
return err
}
return nil
}
func (svr *LBProbeServer) AddGateway(gatewayUID string) error {
svr.lock.Lock()
defer svr.lock.Unlock()
svr.activeGateways[gatewayUID] = true
return nil
}
func (svr *LBProbeServer) RemoveGateway(gatewayUID string) error {
svr.lock.Lock()
defer svr.lock.Unlock()
delete(svr.activeGateways, gatewayUID)
return nil
}
func (svr *LBProbeServer) GetGateways() []string {
var res []string
svr.lock.RLock()
defer svr.lock.RUnlock()
for gatewayUID := range svr.activeGateways {
res = append(res, gatewayUID)
}
return res
}
func (svr *LBProbeServer) serveHTTP(resp http.ResponseWriter, req *http.Request) {
reqPath := req.URL.Path
subPaths := strings.Split(reqPath, "/")
if len(subPaths) != 3 {
resp.WriteHeader(http.StatusBadRequest)
return
}
gatewayUID := subPaths[2]
svr.lock.RLock()
_, ok := svr.activeGateways[gatewayUID]
svr.lock.RUnlock()
if !ok {
resp.WriteHeader(http.StatusServiceUnavailable)
} else {
resp.WriteHeader(http.StatusOK)
}
}