pkg/plugin/healthz.go (136 lines of code) (raw):
// Copyright (c) Microsoft and contributors. All rights reserved.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
package plugin
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"time"
"github.com/Azure/kubernetes-kms/pkg/version"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"k8s.io/apimachinery/pkg/util/uuid"
kmsv1 "k8s.io/kms/apis/v1beta1"
kmsv2 "k8s.io/kms/apis/v2"
"monis.app/mlog"
)
const (
healthCheckPlainText = "healthcheck"
)
// HealthZ is the health check server for the KMS plugin.
type HealthZ struct {
KMSv1Server *KeyManagementServiceServer
KMSv2Server *KeyManagementServiceV2Server
HealthCheckURL *url.URL
UnixSocketPath string
RPCTimeout time.Duration
}
// Serve creates the http handler for serving health requests.
func (h *HealthZ) Serve() {
serveMux := http.NewServeMux()
serveMux.HandleFunc(h.HealthCheckURL.EscapedPath(), h.ServeHTTP)
server := &http.Server{
Addr: h.HealthCheckURL.Host,
ReadHeaderTimeout: 5 * time.Second,
Handler: serveMux,
}
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
mlog.Fatal(err, "failed to start health check server", "url", h.HealthCheckURL.String())
}
}
func (h *HealthZ) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
mlog.Trace("Started health check")
ctx, cancel := context.WithTimeout(context.Background(), h.RPCTimeout)
defer cancel()
conn, err := h.dialUnixSocket()
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
defer conn.Close()
// create the kms client for v1
kmsClient := kmsv1.NewKeyManagementServiceClient(conn)
// create the kms client for v2
kmsV2Client := kmsv2.NewKeyManagementServiceClient(conn)
// check version response against KMS-Plugin's gRPC endpoint.
err = h.checkRPC(ctx, kmsClient, kmsV2Client)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
// Both encryption and decryption calls are made for each version,
// resulting in a total of 4 calls to the keyvault.
// Additionally, a health check is performed every 10 seconds.
// v1 checks
// check the configured keyvault, key, key version and permissions are still
// valid to encrypt and decrypt with test data.
enc, err := h.KMSv1Server.Encrypt(ctx, &kmsv1.EncryptRequest{Plain: []byte(healthCheckPlainText)})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
dec, err := h.KMSv1Server.Decrypt(ctx, &kmsv1.DecryptRequest{Cipher: enc.Cipher})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if string(dec.Plain) != healthCheckPlainText {
http.Error(w, "plain text mismatch after decryption", http.StatusInternalServerError)
return
}
// v2 checks.
// appending a string to UUID allows us to differentiate the UUIDs generated by us from those generated by the API server.
uid := "local-healthz-check-" + string(uuid.NewUUID())
v2EncryptResponse, err := h.KMSv2Server.Encrypt(
ctx,
&kmsv2.EncryptRequest{
Plaintext: []byte(healthCheckPlainText),
Uid: uid,
},
)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
v2DecryptResponse, err := h.KMSv2Server.Decrypt(ctx, &kmsv2.DecryptRequest{
Ciphertext: v2EncryptResponse.Ciphertext,
KeyId: v2EncryptResponse.KeyId,
Uid: uid, // passing the same uid to track roundtrip encrypt/decrypt calls
Annotations: v2EncryptResponse.Annotations,
})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if string(v2DecryptResponse.Plaintext) != healthCheckPlainText {
http.Error(w, "plain text mismatch after decryption with KMSv2", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
if _, err = w.Write([]byte("ok")); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
mlog.Trace("Completed health check")
}
// checkRPC initiates a grpc request to validate the socket is responding
// sends a KMS VersionRequest and checks if the VersionResponse is valid.
func (h *HealthZ) checkRPC(
ctx context.Context,
kmsV1Client kmsv1.KeyManagementServiceClient,
kmsV2Client kmsv2.KeyManagementServiceClient,
) error {
v, err := kmsV1Client.Version(ctx, &kmsv1.VersionRequest{})
if err != nil {
return err
}
if v.Version != version.KMSv1APIVersion || v.RuntimeName != version.Runtime || v.RuntimeVersion != version.BuildVersion {
return fmt.Errorf("failed to get correct version response")
}
v2Status, err := kmsV2Client.Status(ctx, &kmsv2.StatusRequest{})
if err != nil {
return err
}
if v2Status.Version != version.KMSv2APIVersion {
return fmt.Errorf(
"failed to get correct version response for v2 expected: %s, got: %s",
version.KMSv2APIVersion,
v2Status.Version,
)
}
return nil
}
func (h *HealthZ) dialUnixSocket() (*grpc.ClientConn, error) {
return grpc.Dial(
h.UnixSocketPath,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithContextDialer(func(ctx context.Context, target string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, "unix", target)
}),
)
}