pkg/plugin/healthz.go (136 lines of code) (raw):

// Copyright (c) Microsoft and contributors. All rights reserved. // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. package plugin import ( "context" "fmt" "net" "net/http" "net/url" "time" "github.com/Azure/kubernetes-kms/pkg/version" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "k8s.io/apimachinery/pkg/util/uuid" kmsv1 "k8s.io/kms/apis/v1beta1" kmsv2 "k8s.io/kms/apis/v2" "monis.app/mlog" ) const ( healthCheckPlainText = "healthcheck" ) // HealthZ is the health check server for the KMS plugin. type HealthZ struct { KMSv1Server *KeyManagementServiceServer KMSv2Server *KeyManagementServiceV2Server HealthCheckURL *url.URL UnixSocketPath string RPCTimeout time.Duration } // Serve creates the http handler for serving health requests. func (h *HealthZ) Serve() { serveMux := http.NewServeMux() serveMux.HandleFunc(h.HealthCheckURL.EscapedPath(), h.ServeHTTP) server := &http.Server{ Addr: h.HealthCheckURL.Host, ReadHeaderTimeout: 5 * time.Second, Handler: serveMux, } if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { mlog.Fatal(err, "failed to start health check server", "url", h.HealthCheckURL.String()) } } func (h *HealthZ) ServeHTTP(w http.ResponseWriter, _ *http.Request) { mlog.Trace("Started health check") ctx, cancel := context.WithTimeout(context.Background(), h.RPCTimeout) defer cancel() conn, err := h.dialUnixSocket() if err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) return } defer conn.Close() // create the kms client for v1 kmsClient := kmsv1.NewKeyManagementServiceClient(conn) // create the kms client for v2 kmsV2Client := kmsv2.NewKeyManagementServiceClient(conn) // check version response against KMS-Plugin's gRPC endpoint. err = h.checkRPC(ctx, kmsClient, kmsV2Client) if err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) return } // Both encryption and decryption calls are made for each version, // resulting in a total of 4 calls to the keyvault. // Additionally, a health check is performed every 10 seconds. // v1 checks // check the configured keyvault, key, key version and permissions are still // valid to encrypt and decrypt with test data. enc, err := h.KMSv1Server.Encrypt(ctx, &kmsv1.EncryptRequest{Plain: []byte(healthCheckPlainText)}) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } dec, err := h.KMSv1Server.Decrypt(ctx, &kmsv1.DecryptRequest{Cipher: enc.Cipher}) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } if string(dec.Plain) != healthCheckPlainText { http.Error(w, "plain text mismatch after decryption", http.StatusInternalServerError) return } // v2 checks. // appending a string to UUID allows us to differentiate the UUIDs generated by us from those generated by the API server. uid := "local-healthz-check-" + string(uuid.NewUUID()) v2EncryptResponse, err := h.KMSv2Server.Encrypt( ctx, &kmsv2.EncryptRequest{ Plaintext: []byte(healthCheckPlainText), Uid: uid, }, ) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } v2DecryptResponse, err := h.KMSv2Server.Decrypt(ctx, &kmsv2.DecryptRequest{ Ciphertext: v2EncryptResponse.Ciphertext, KeyId: v2EncryptResponse.KeyId, Uid: uid, // passing the same uid to track roundtrip encrypt/decrypt calls Annotations: v2EncryptResponse.Annotations, }) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } if string(v2DecryptResponse.Plaintext) != healthCheckPlainText { http.Error(w, "plain text mismatch after decryption with KMSv2", http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) if _, err = w.Write([]byte("ok")); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } mlog.Trace("Completed health check") } // checkRPC initiates a grpc request to validate the socket is responding // sends a KMS VersionRequest and checks if the VersionResponse is valid. func (h *HealthZ) checkRPC( ctx context.Context, kmsV1Client kmsv1.KeyManagementServiceClient, kmsV2Client kmsv2.KeyManagementServiceClient, ) error { v, err := kmsV1Client.Version(ctx, &kmsv1.VersionRequest{}) if err != nil { return err } if v.Version != version.KMSv1APIVersion || v.RuntimeName != version.Runtime || v.RuntimeVersion != version.BuildVersion { return fmt.Errorf("failed to get correct version response") } v2Status, err := kmsV2Client.Status(ctx, &kmsv2.StatusRequest{}) if err != nil { return err } if v2Status.Version != version.KMSv2APIVersion { return fmt.Errorf( "failed to get correct version response for v2 expected: %s, got: %s", version.KMSv2APIVersion, v2Status.Version, ) } return nil } func (h *HealthZ) dialUnixSocket() (*grpc.ClientConn, error) { return grpc.Dial( h.UnixSocketPath, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(func(ctx context.Context, target string) (net.Conn, error) { return (&net.Dialer{}).DialContext(ctx, "unix", target) }), ) }