internal/service/server.go (184 lines of code) (raw):
package service
import (
"context"
"errors"
"fmt"
"github.com/aliyun/alibabacloud-kms-agent/internal/conf"
"github.com/aliyun/alibabacloud-kms-agent/internal/logger"
"github.com/aliyun/alibabacloud-kms-agent/internal/model"
"io/ioutil"
"net"
"net/http"
"os"
"os/signal"
"runtime/debug"
"strings"
"sync/atomic"
"syscall"
"time"
"github.com/aliyun/alibabacloud-kms-agent/internal/cache"
"github.com/aliyun/alibabacloud-kms-agent/internal/kms"
)
const (
PingPath = "/ping"
QueryStylePath = "/secretsmanager/get"
defaultRequestTimeout = 30 * time.Second
)
type Server struct {
listener *net.TCPListener
cacheStore cache.Cache
kmsClient *kms.KeyManagementService
loggerWrapper logger.Wrapper
ssrfToken string
ssrfHeaders []string
pathPrefix string
maxConn int
responseType model.ResponseType
IgnoreTransientErrors bool
DisableSSRFToken bool
connCount int32
}
func NewServer(serverConfig conf.ServerConfig, cacheStore cache.Cache,
kmsClient *kms.KeyManagementService, lw logger.Wrapper) (*Server, error) {
addr, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("127.0.0.1:%d", *serverConfig.HttpPort))
if err != nil {
return nil, err
}
listener, err := net.ListenTCP("tcp", addr)
if err != nil {
return nil, err
}
var token string
if !*serverConfig.DisableSSRFToken {
token, err = getToken(*serverConfig.SSRFEnvVariables)
if err != nil {
return nil, err
}
}
return &Server{
listener: listener,
cacheStore: cacheStore,
kmsClient: kmsClient,
loggerWrapper: lw,
ssrfToken: token,
ssrfHeaders: *serverConfig.SSRFHeaders,
pathPrefix: *serverConfig.PathPrefix,
maxConn: *serverConfig.MaxConn,
responseType: model.ResponseType(*serverConfig.ResponseType),
IgnoreTransientErrors: *serverConfig.IgnoreTransientErrors,
DisableSSRFToken: *serverConfig.DisableSSRFToken,
}, nil
}
func (s *Server) Serve() {
mux := http.NewServeMux()
mux.HandleFunc("/", s.handleRequest)
server := &http.Server{
Handler: mux,
ReadTimeout: defaultRequestTimeout,
WriteTimeout: defaultRequestTimeout,
}
go func() {
s.loggerWrapper.Info("server listening on: %s", s.listener.Addr().String())
if err := server.Serve(s.listener); err != nil {
s.loggerWrapper.Error("server listening err: %v", err)
os.Exit(1)
}
}()
// 监听中断信号以优雅地关闭服务器
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
<-signals
s.loggerWrapper.Info("shutting down the server")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
s.loggerWrapper.Error("server shutdown Failed :%v", err)
}
s.loggerWrapper.Info("server exiting")
}
func (s *Server) handleRequest(w http.ResponseWriter, r *http.Request) {
//defer recover
defer func(w http.ResponseWriter, r *http.Request) {
if err := recover(); err != nil {
w.WriteHeader(http.StatusInternalServerError)
s.loggerWrapper.Error("HandlePanic:%v\tTraceStack:%s", err, string(debug.Stack()))
}
}(w, r)
if err := s.validateMaxConn(r); err != nil {
http.Error(w, err.Error(), http.StatusTooManyRequests)
return
}
defer atomic.AddInt32(&s.connCount, -1)
if err := s.validateToken(r); err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
if err := s.validateMethod(r); err != nil {
http.Error(w, err.Error(), http.StatusMethodNotAllowed)
return
}
// Ping
if r.URL.Path == PingPath {
s.handlePing(w, r)
return
}
// Get Secret
if r.URL.Path == QueryStylePath || strings.HasPrefix(r.URL.Path, s.pathPrefix) {
s.handleGetSecret(w, r)
return
}
http.NotFound(w, r)
}
func (s *Server) validateMaxConn(r *http.Request) error {
isPing := r.URL.Path == "/ping"
limit := s.maxConn + 1
if isPing {
limit += 3
}
count := atomic.AddInt32(&s.connCount, 1)
if count >= int32(limit) {
return errors.New("connection limit exceeded")
}
return nil
}
func (s *Server) validateToken(r *http.Request) error {
if r.URL.Path == "/ping" {
return nil
}
if _, ok := r.Header["X-Forwarded-For"]; ok {
errors.New("forwarded")
}
if s.DisableSSRFToken {
return nil
}
for _, header := range s.ssrfHeaders {
if token := r.Header.Get(header); token == s.ssrfToken {
return nil
}
}
return errors.New("bad token")
}
func (s *Server) validateMethod(r *http.Request) error {
if r.Method != http.MethodGet {
return errors.New("http method just allowed get")
}
return nil
}
func getToken(envs []string) (string, error) {
var found string
for _, envName := range envs {
val, exists := os.LookupEnv(envName)
if exists {
found = val
break
}
}
if found == "" {
return "", errors.New("environment variable not present, you must set one valid SSRFEnvVariable")
}
if !strings.HasPrefix(found, "file://") {
return found, nil
}
file := strings.TrimPrefix(found, "file://")
content, err := ioutil.ReadFile(file)
if err != nil {
return "", err
}
return strings.TrimSpace(string(content)), nil
}