cmd/webhook/main.go (173 lines of code) (raw):
package main
import (
"crypto/tls"
"flag"
"fmt"
"net/http"
"github.com/open-policy-agent/cert-controller/pkg/rotator"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
clientgoscheme "k8s.io/client-go/kubernetes/scheme"
"monis.app/mlog"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client/apiutil"
"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/controller-runtime/pkg/manager"
"sigs.k8s.io/controller-runtime/pkg/manager/signals"
metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server"
"sigs.k8s.io/controller-runtime/pkg/webhook"
"github.com/Azure/azure-workload-identity/pkg/metrics"
"github.com/Azure/azure-workload-identity/pkg/util"
"github.com/Azure/azure-workload-identity/pkg/version"
wh "github.com/Azure/azure-workload-identity/pkg/webhook"
)
var webhooks = []rotator.WebhookInfo{
{
Name: "azure-wi-webhook-mutating-webhook-configuration",
Type: rotator.Mutating,
},
}
const (
secretName = "azure-wi-webhook-server-cert" // #nosec
serviceName = "azure-wi-webhook-webhook-service"
caName = "azure-workload-identity-ca"
caOrganization = "azure-workload-identity"
)
var (
audience string
webhookCertDir string
tlsMinVersion string
healthAddr string
metricsAddr string
disableCertRotation bool
metricsBackend string
logLevel string
// DNSName is <service name>.<namespace>.svc
dnsName = fmt.Sprintf("%s.%s.svc", serviceName, util.GetNamespace())
scheme = runtime.NewScheme()
entryLog = mlog.New().WithName("entrypoint")
)
func init() {
_ = clientgoscheme.AddToScheme(scheme)
}
func main() {
if err := mainErr(); err != nil {
mlog.Fatal(err)
}
}
func mainErr() error {
defer mlog.Setup()()
flag.StringVar(&audience, "audience", "", "Audience for service account token")
flag.StringVar(&webhookCertDir, "webhook-cert-dir", "/certs", "Webhook certificates dir to use. Defaults to /certs")
flag.BoolVar(&disableCertRotation, "disable-cert-rotation", false, "disable automatic generation and rotation of webhook TLS certificates/keys")
flag.StringVar(&tlsMinVersion, "tls-min-version", "1.3", "Minimum TLS version")
flag.StringVar(&healthAddr, "health-addr", ":9440", "The address the health endpoint binds to")
flag.StringVar(&metricsAddr, "metrics-addr", ":8095", "The address the metrics endpoint binds to")
flag.StringVar(&metricsBackend, "metrics-backend", "prometheus", "Backend used for metrics")
flag.StringVar(&logLevel, "log-level", "",
"In order of increasing verbosity: unset (empty string), info, debug, trace and all.")
flag.Parse()
ctx := signals.SetupSignalHandler()
if err := mlog.ValidateAndSetLogLevelAndFormatGlobally(ctx, mlog.LogSpec{
Level: mlog.LogLevel(logLevel),
Format: mlog.FormatJSON,
}); err != nil {
return fmt.Errorf("invalid --log-level set: %w", err)
}
// nolint:staticcheck
// controller-runtime forces use to use the deprecated logr.Logger returned by mlog.Logr here
log.SetLogger(mlog.Logr())
config := ctrl.GetConfigOrDie()
config.UserAgent = version.GetUserAgent("webhook")
// initialize metrics exporter before creating measurements
entryLog.Info("initializing metrics backend", "backend", metricsBackend)
if err := metrics.InitMetricsExporter(metricsBackend); err != nil {
return fmt.Errorf("entrypoint: failed to initialize metrics exporter: %w", err)
}
// log the user agent as it makes it easier to debug issues
entryLog.Info("setting up manager", "userAgent", config.UserAgent)
tlsVersion, err := parseTLSVersion(tlsMinVersion)
if err != nil {
return fmt.Errorf("entrypoint: unable to parse TLS version: %w", err)
}
serverOpts := webhook.Options{
CertDir: webhookCertDir,
TLSOpts: []func(c *tls.Config){func(c *tls.Config) { c.MinVersion = tlsVersion }},
}
mgr, err := ctrl.NewManager(config, ctrl.Options{
Scheme: scheme,
LeaderElection: false,
HealthProbeBindAddress: healthAddr,
Metrics: metricsserver.Options{
BindAddress: metricsAddr,
},
WebhookServer: webhook.NewServer(serverOpts),
MapperProvider: apiutil.NewDynamicRESTMapper,
})
if err != nil {
return fmt.Errorf("entrypoint: unable to set up controller manager: %w", err)
}
// Make sure certs are generated and valid if cert rotation is enabled.
setupFinished := make(chan struct{})
if !disableCertRotation {
entryLog.Info("setting up cert rotation")
if err := rotator.AddRotator(mgr, &rotator.CertRotator{
SecretKey: types.NamespacedName{
Namespace: util.GetNamespace(),
Name: secretName,
},
CertDir: webhookCertDir,
CAName: caName,
CAOrganization: caOrganization,
DNSName: dnsName,
IsReady: setupFinished,
Webhooks: webhooks,
}); err != nil {
return fmt.Errorf("entrypoint: unable to set up cert rotation: %w", err)
}
} else {
close(setupFinished)
}
setupProbeEndpoints(mgr, setupFinished)
go setupWebhook(mgr, setupFinished)
entryLog.Info("starting manager")
if err := mgr.Start(ctx); err != nil {
return fmt.Errorf("entrypoint: unable to run manager: %w", err)
}
return nil
}
func setupWebhook(mgr manager.Manager, setupFinished chan struct{}) {
// Block until the setup (certificate generation) finishes.
<-setupFinished
hookServer := mgr.GetWebhookServer()
// setup webhooks
entryLog.Info("registering webhook to the webhook server")
podMutator, err := wh.NewPodMutator(mgr.GetClient(), mgr.GetAPIReader(), audience, mgr.GetScheme(), mgr.GetConfig())
if err != nil {
panic(fmt.Errorf("unable to set up pod mutator: %w", err))
}
hookServer.Register("/mutate-v1-pod", &webhook.Admission{Handler: podMutator})
}
func setupProbeEndpoints(mgr ctrl.Manager, setupFinished chan struct{}) {
// Block readiness on the mutating webhook being registered.
// We can't use mgr.GetWebhookServer().StartedChecker() yet,
// because that starts the webhook. But we also can't call AddReadyzCheck
// after Manager.Start. So we need a custom ready check that delegates to
// the real ready check after the cert has been injected and validator started.
checker := func(req *http.Request) error {
select {
case <-setupFinished:
return mgr.GetWebhookServer().StartedChecker()(req)
default:
return fmt.Errorf("certs are not ready yet")
}
}
if err := mgr.AddHealthzCheck("healthz", checker); err != nil {
panic(fmt.Errorf("unable to add healthz check: %w", err))
}
if err := mgr.AddReadyzCheck("readyz", checker); err != nil {
panic(fmt.Errorf("unable to add readyz check: %w", err))
}
entryLog.Info("added healthz and readyz check")
}
func parseTLSVersion(tlsVersion string) (uint16, error) {
switch tlsVersion {
case "1.0":
return tls.VersionTLS10, nil
case "1.1":
return tls.VersionTLS11, nil
case "1.2":
return tls.VersionTLS12, nil
case "1.3":
return tls.VersionTLS13, nil
default:
return 0, fmt.Errorf("invalid TLS version. Must be one of: 1.0, 1.1, 1.2, 1.3")
}
}