cmd/server/main.go (159 lines of code) (raw):
// Copyright (c) Microsoft and contributors. All rights reserved.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
package main
import (
"context"
"flag"
"fmt"
"math"
"net"
"net/url"
"os"
"os/signal"
"strconv"
"syscall"
"time"
"github.com/Azure/kubernetes-kms/pkg/config"
"github.com/Azure/kubernetes-kms/pkg/metrics"
"github.com/Azure/kubernetes-kms/pkg/plugin"
"github.com/Azure/kubernetes-kms/pkg/utils"
"github.com/Azure/kubernetes-kms/pkg/version"
"google.golang.org/grpc"
"k8s.io/klog/v2"
kmsv1 "k8s.io/kms/apis/v1beta1"
kmsv2 "k8s.io/kms/apis/v2"
"monis.app/mlog"
)
var (
listenAddr = flag.String("listen-addr", "unix:///opt/azurekms.socket", "gRPC listen address")
keyvaultName = flag.String("keyvault-name", "", "Azure Key Vault name")
keyName = flag.String("key-name", "", "Azure Key Vault KMS key name")
keyVersion = flag.String("key-version", "", "Azure Key Vault KMS key version")
managedHSM = flag.Bool("managed-hsm", false, "Azure Key Vault Managed HSM. Refer to https://docs.microsoft.com/en-us/azure/key-vault/managed-hsm/overview for more details.")
logFormatJSON = flag.Bool("log-format-json", false, "set log formatter to json")
logLevel = flag.Uint("v", 0, "In order of increasing verbosity: 0=warning/error, 2=info, 4=debug, 6=trace, 10=all")
// TODO remove this flag in future release.
_ = flag.String("configFilePath", "/etc/kubernetes/azure.json", "[DEPRECATED] Path for Azure Cloud Provider config file")
configFilePath = flag.String("config-file-path", "/etc/kubernetes/azure.json", "Path for Azure Cloud Provider config file")
versionInfo = flag.Bool("version", false, "Prints the version information")
healthzPort = flag.Uint("healthz-port", 8787, "port for health check")
healthzPath = flag.String("healthz-path", "/healthz", "path for health check")
healthzTimeout = flag.Duration("healthz-timeout", 20*time.Second, "RPC timeout for health check")
metricsBackend = flag.String("metrics-backend", "prometheus", "Backend used for metrics")
metricsAddress = flag.String("metrics-addr", "8095", "The address the metric endpoint binds to")
proxyMode = flag.Bool("proxy-mode", false, "Proxy mode")
proxyAddress = flag.String("proxy-address", "", "proxy address")
proxyPort = flag.Int("proxy-port", 7788, "port for proxy")
)
func main() {
if err := setupKMSPlugin(); err != nil {
mlog.Fatal(err)
}
}
func setupKMSPlugin() error {
defer mlog.Setup()() // set up log flushing and attempt to flush on exit
flag.Parse()
ctx := withShutdownSignal(context.Background())
logFormat := mlog.FormatText
if *logFormatJSON {
logFormat = mlog.FormatJSON
}
if *logLevel > math.MaxUint8 {
return fmt.Errorf("invalid log level: %d", *logLevel)
}
if err := mlog.ValidateAndSetKlogLevelAndFormatGlobally(ctx, klog.Level(uint8(*logLevel)), logFormat); err != nil {
return fmt.Errorf("invalid --log-level set: %w", err)
}
if *versionInfo {
if err := version.PrintVersion(); err != nil {
return fmt.Errorf("failed to print version: %w", err)
}
return nil
}
// initialize metrics exporter
err := metrics.InitMetricsExporter(*metricsBackend, *metricsAddress)
if err != nil {
return fmt.Errorf("failed to initialize metrics exporter: %w", err)
}
mlog.Always("Starting KeyManagementServiceServer service", "version", version.BuildVersion, "buildDate", version.BuildDate)
pluginConfig := &plugin.Config{
KeyVaultName: *keyvaultName,
KeyName: *keyName,
KeyVersion: *keyVersion,
ManagedHSM: *managedHSM,
ProxyMode: *proxyMode,
ProxyAddress: *proxyAddress,
ProxyPort: *proxyPort,
ConfigFilePath: *configFilePath,
}
azureConfig, err := config.GetAzureConfig(pluginConfig.ConfigFilePath)
if err != nil {
return fmt.Errorf("failed to get azure config: %w", err)
}
kvClient, err := plugin.NewKeyVaultClient(
azureConfig,
pluginConfig.KeyVaultName,
pluginConfig.KeyName,
pluginConfig.KeyVersion,
pluginConfig.ProxyMode,
pluginConfig.ProxyAddress,
pluginConfig.ProxyPort,
pluginConfig.ManagedHSM,
)
if err != nil {
return fmt.Errorf("failed to create key vault client: %w", err)
}
// Initialize and run the GRPC server
proto, addr, err := utils.ParseEndpoint(*listenAddr)
if err != nil {
return fmt.Errorf("failed to parse endpoint: %w", err)
}
if err := os.Remove(addr); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to remove socket file %s: %w", addr, err)
}
listener, err := net.Listen(proto, addr)
if err != nil {
return fmt.Errorf("failed to listen addr: %s, proto: %s: %w", addr, proto, err)
}
opts := []grpc.ServerOption{
grpc.UnaryInterceptor(utils.UnaryServerInterceptor),
}
s := grpc.NewServer(opts...)
// register kms v1 server
kmsV1Server, err := plugin.NewKMSv1Server(kvClient)
if err != nil {
return fmt.Errorf("failed to create server: %w", err)
}
kmsv1.RegisterKeyManagementServiceServer(s, kmsV1Server)
// register kms v2 server
kmsV2Server, err := plugin.NewKMSv2Server(kvClient)
if err != nil {
return fmt.Errorf("failed to create kms V2 server: %w", err)
}
kmsv2.RegisterKeyManagementServiceServer(s, kmsV2Server)
mlog.Always("Listening for connections", "addr", listener.Addr().String())
go func() {
if err := s.Serve(listener); err != nil {
mlog.Fatal(fmt.Errorf("failed to serve kms server: %w", err))
}
}()
// Health check for kms v1 and v2
healthz := &plugin.HealthZ{
KMSv1Server: kmsV1Server,
KMSv2Server: kmsV2Server,
HealthCheckURL: &url.URL{
Host: net.JoinHostPort("", strconv.FormatUint(uint64(*healthzPort), 10)),
Path: *healthzPath,
},
UnixSocketPath: listener.Addr().String(),
RPCTimeout: *healthzTimeout,
}
go healthz.Serve()
<-ctx.Done()
// gracefully stop the grpc server
mlog.Always("terminating the server")
s.GracefulStop()
return nil
}
// withShutdownSignal returns a copy of the parent context that will close if
// the process receives termination signals.
func withShutdownSignal(ctx context.Context) context.Context {
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, syscall.SIGTERM, syscall.SIGINT, os.Interrupt)
nctx, cancel := context.WithCancel(ctx)
go func() {
<-signalChan
mlog.Always("received shutdown signal")
cancel()
}()
return nctx
}