internal/service/server.go (184 lines of code) (raw):

package service import ( "context" "errors" "fmt" "github.com/aliyun/alibabacloud-kms-agent/internal/conf" "github.com/aliyun/alibabacloud-kms-agent/internal/logger" "github.com/aliyun/alibabacloud-kms-agent/internal/model" "io/ioutil" "net" "net/http" "os" "os/signal" "runtime/debug" "strings" "sync/atomic" "syscall" "time" "github.com/aliyun/alibabacloud-kms-agent/internal/cache" "github.com/aliyun/alibabacloud-kms-agent/internal/kms" ) const ( PingPath = "/ping" QueryStylePath = "/secretsmanager/get" defaultRequestTimeout = 30 * time.Second ) type Server struct { listener *net.TCPListener cacheStore cache.Cache kmsClient *kms.KeyManagementService loggerWrapper logger.Wrapper ssrfToken string ssrfHeaders []string pathPrefix string maxConn int responseType model.ResponseType IgnoreTransientErrors bool DisableSSRFToken bool connCount int32 } func NewServer(serverConfig conf.ServerConfig, cacheStore cache.Cache, kmsClient *kms.KeyManagementService, lw logger.Wrapper) (*Server, error) { addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", *serverConfig.HttpPort)) if err != nil { return nil, err } listener, err := net.ListenTCP("tcp", addr) if err != nil { return nil, err } var token string if !*serverConfig.DisableSSRFToken { token, err = getToken(*serverConfig.SSRFEnvVariables) if err != nil { return nil, err } } return &Server{ listener: listener, cacheStore: cacheStore, kmsClient: kmsClient, loggerWrapper: lw, ssrfToken: token, ssrfHeaders: *serverConfig.SSRFHeaders, pathPrefix: *serverConfig.PathPrefix, maxConn: *serverConfig.MaxConn, responseType: model.ResponseType(*serverConfig.ResponseType), IgnoreTransientErrors: *serverConfig.IgnoreTransientErrors, DisableSSRFToken: *serverConfig.DisableSSRFToken, }, nil } func (s *Server) Serve() { mux := http.NewServeMux() mux.HandleFunc("/", s.handleRequest) server := &http.Server{ Handler: mux, ReadTimeout: defaultRequestTimeout, WriteTimeout: defaultRequestTimeout, } go func() { s.loggerWrapper.Info("server listening on: %s", s.listener.Addr().String()) if err := server.Serve(s.listener); err != nil { s.loggerWrapper.Error("server listening err: %v", err) os.Exit(1) } }() // 监听中断信号以优雅地关闭服务器 signals := make(chan os.Signal, 1) signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) <-signals s.loggerWrapper.Info("shutting down the server") ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := server.Shutdown(ctx); err != nil { s.loggerWrapper.Error("server shutdown Failed :%v", err) } s.loggerWrapper.Info("server exiting") } func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) { //defer recover defer func(w http.ResponseWriter, r *http.Request) { if err := recover(); err != nil { w.WriteHeader(http.StatusInternalServerError) s.loggerWrapper.Error("HandlePanic:%v\tTraceStack:%s", err, string(debug.Stack())) } }(w, r) if err := s.validateMaxConn(r); err != nil { http.Error(w, err.Error(), http.StatusTooManyRequests) return } defer atomic.AddInt32(&s.connCount, -1) if err := s.validateToken(r); err != nil { http.Error(w, err.Error(), http.StatusForbidden) return } if err := s.validateMethod(r); err != nil { http.Error(w, err.Error(), http.StatusMethodNotAllowed) return } // Ping if r.URL.Path == PingPath { s.handlePing(w, r) return } // Get Secret if r.URL.Path == QueryStylePath || strings.HasPrefix(r.URL.Path, s.pathPrefix) { s.handleGetSecret(w, r) return } http.NotFound(w, r) } func (s *Server) validateMaxConn(r *http.Request) error { isPing := r.URL.Path == "/ping" limit := s.maxConn + 1 if isPing { limit += 3 } count := atomic.AddInt32(&s.connCount, 1) if count >= int32(limit) { return errors.New("connection limit exceeded") } return nil } func (s *Server) validateToken(r *http.Request) error { if r.URL.Path == "/ping" { return nil } if _, ok := r.Header["X-Forwarded-For"]; ok { errors.New("forwarded") } if s.DisableSSRFToken { return nil } for _, header := range s.ssrfHeaders { if token := r.Header.Get(header); token == s.ssrfToken { return nil } } return errors.New("bad token") } func (s *Server) validateMethod(r *http.Request) error { if r.Method != http.MethodGet { return errors.New("http method just allowed get") } return nil } func getToken(envs []string) (string, error) { var found string for _, envName := range envs { val, exists := os.LookupEnv(envName) if exists { found = val break } } if found == "" { return "", errors.New("environment variable not present, you must set one valid SSRFEnvVariable") } if !strings.HasPrefix(found, "file://") { return found, nil } file := strings.TrimPrefix(found, "file://") content, err := ioutil.ReadFile(file) if err != nil { return "", err } return strings.TrimSpace(string(content)), nil }