pkg/webhook/webhook.go (394 lines of code) (raw):
package webhook
import (
"context"
"encoding/json"
"fmt"
"net/http"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/pkg/errors"
"monis.app/mlog"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
utilversion "k8s.io/apimachinery/pkg/util/version"
"k8s.io/client-go/discovery"
"k8s.io/client-go/rest"
"k8s.io/utils/ptr"
"github.com/Azure/azure-workload-identity/pkg/config"
)
var (
// ProxyImageRegistry is the image registry for the proxy init and sidecar.
// This is injected via LDFLAGS in the Makefile during the build.
ProxyImageRegistry string
// ProxyImageVersion is the image version of the proxy init and sidecar.
// This is injected via LDFLAGS in the Makefile during the build.
ProxyImageVersion string
)
// +kubebuilder:webhook:path=/mutate-v1-pod,mutating=true,failurePolicy=fail,groups="",resources=pods,verbs=create,versions=v1,name=mutation.azure-workload-identity.io,sideEffects=None,admissionReviewVersions=v1;v1beta1,matchPolicy=Equivalent,reinvocationPolicy=IfNeeded
// +kubebuilder:rbac:groups="",resources=serviceaccounts,verbs=get;list;watch
// this is required for the webhook server certs generated and rotated as part of cert-controller rotator
// +kubebuilder:rbac:groups="",namespace=azure-workload-identity-system,resources=secrets,verbs=get;list;watch;create;update;patch;delete
// +kubebuilder:rbac:groups=admissionregistration.k8s.io,resources=mutatingwebhookconfigurations,verbs=get;list;watch;update
// podMutator mutates pod objects to add project service account token volume
type podMutator struct {
client client.Client
// reader is an instance of mgr.GetAPIReader that is configured to use the API server.
// This should be used sparingly and only when the client does not fit the use case.
reader client.Reader
config *config.Config
decoder admission.Decoder
audience string
azureAuthorityHost string
proxyImage string
proxyInitImage string
useNativeSidecar bool
}
// NewPodMutator returns a pod mutation handler
func NewPodMutator(client client.Client, reader client.Reader, audience string, scheme *runtime.Scheme, restConfig *rest.Config) (admission.Handler, error) {
c, err := config.ParseConfig()
if err != nil {
return nil, err
}
if audience == "" {
audience = DefaultAudience
}
discoveryClient, err := discovery.NewDiscoveryClientForConfig(restConfig)
if err != nil {
return nil, errors.Wrap(err, "failed to create discovery client")
}
// "SidecarContainers" went beta in 1.29. With the 3 version skew policy,
// between API server and kubelet, 1.32 is the earliest version this can be
// safely used.
useNativeSidecar, err := serverVersionGTE(discoveryClient, utilversion.MajorMinor(1, 32))
if err != nil {
return nil, errors.Wrap(err, "failed to check kubernetes version")
}
// this is used to configure the AZURE_AUTHORITY_HOST env var that's
// used by the azure sdk
azureAuthorityHost, err := getAzureAuthorityHost(c)
if err != nil {
return nil, errors.Wrap(err, "failed to get AAD endpoint")
}
proxyImage := c.ProxyImage
if len(proxyImage) == 0 {
proxyImage = fmt.Sprintf("%s/%s:%s", ProxyImageRegistry, ProxySidecarImageName, ProxyImageVersion)
}
proxyInitImage := c.ProxyInitImage
if len(proxyInitImage) == 0 {
proxyInitImage = fmt.Sprintf("%s/%s:%s", ProxyImageRegistry, ProxyInitImageName, ProxyImageVersion)
}
if err := registerMetrics(); err != nil {
return nil, errors.Wrap(err, "failed to register metrics")
}
return &podMutator{
client: client,
reader: reader,
config: c,
decoder: admission.NewDecoder(scheme),
audience: audience,
azureAuthorityHost: azureAuthorityHost,
proxyImage: proxyImage,
proxyInitImage: proxyInitImage,
useNativeSidecar: useNativeSidecar,
}, nil
}
// PodMutator adds projected service account volume for incoming pods if service account is annotated
func (m *podMutator) Handle(ctx context.Context, req admission.Request) (response admission.Response) {
timeStart := time.Now()
defer func() {
ReportRequest(ctx, req.Namespace, time.Since(timeStart))
}()
pod := &corev1.Pod{}
err := m.decoder.Decode(req, pod)
if err != nil {
return admission.Errored(http.StatusBadRequest, err)
}
podName := pod.GetName()
if podName == "" {
podName = pod.GetGenerateName() + " (prefix)"
}
// for daemonset/deployment pods the namespace field is not set in objectMeta
// explicitly set the namespace to request namespace
pod.Namespace = req.Namespace
serviceAccountName := pod.Spec.ServiceAccountName
// When you create a pod, if you do not specify a service account, it is automatically
// assigned the default service account in the same namespace.
// xref: https://kubernetes.io/docs/tasks/configure-pod-container/configure-service-account/#use-the-default-service-account-to-access-the-api-server
if serviceAccountName == "" {
serviceAccountName = "default"
}
logger := mlog.New().WithName("handler").WithValues("pod", podName, "namespace", pod.Namespace, "service-account", serviceAccountName)
// get service account associated with the pod
serviceAccount := &corev1.ServiceAccount{}
if err = m.client.Get(ctx, types.NamespacedName{Name: serviceAccountName, Namespace: pod.Namespace}, serviceAccount); err != nil {
if !apierrors.IsNotFound(err) {
logger.Error("failed to get service account", err)
return admission.Errored(http.StatusBadRequest, err)
}
// bypass cache and get from the API server as it's not found in cache
err = m.reader.Get(ctx, types.NamespacedName{Name: serviceAccountName, Namespace: pod.Namespace}, serviceAccount)
if err != nil {
logger.Error("failed to get service account", err)
return admission.Errored(http.StatusBadRequest, err)
}
}
if shouldInjectProxySidecar(pod) {
// if the pod has hostNetwork set to true, we cannot inject the proxy sidecar
// as it'll end up modifying the network stack of the host and affecting other pods
if pod.Spec.HostNetwork {
err := errors.New("hostNetwork is set to true, cannot inject proxy sidecar")
logger.Error("failed to inject proxy sidecar", err)
return admission.Errored(http.StatusBadRequest, err)
}
proxyPort, err := getProxyPort(pod)
if err != nil {
logger.Error("failed to get proxy port", err)
return admission.Errored(http.StatusBadRequest, err)
}
pod.Spec.InitContainers = m.injectProxyInitContainer(pod.Spec.InitContainers, proxyPort)
if m.useNativeSidecar {
pod.Spec.InitContainers = m.injectProxySidecarContainer(pod.Spec.InitContainers, proxyPort, ptr.To(corev1.ContainerRestartPolicyAlways))
} else {
pod.Spec.Containers = m.injectProxySidecarContainer(pod.Spec.Containers, proxyPort, nil)
}
}
// get service account token expiration
serviceAccountTokenExpiration, err := getServiceAccountTokenExpiration(pod, serviceAccount)
if err != nil {
logger.Error("failed to get service account token expiration", err)
return admission.Errored(http.StatusBadRequest, err)
}
// get the clientID
clientID := getClientID(serviceAccount)
// get the tenantID
tenantID := getTenantID(serviceAccount, m.config)
// get containers to skip
skipContainers := getSkipContainers(pod)
pod.Spec.InitContainers = m.mutateContainers(pod.Spec.InitContainers, clientID, tenantID, skipContainers)
pod.Spec.Containers = m.mutateContainers(pod.Spec.Containers, clientID, tenantID, skipContainers)
// add the projected service account token volume to the pod if not exists
addProjectedServiceAccountTokenVolume(pod, serviceAccountTokenExpiration, m.audience)
marshaledPod, err := json.Marshal(pod)
if err != nil {
logger.Error("failed to marshal pod object", err)
return admission.Errored(http.StatusInternalServerError, err)
}
return admission.PatchResponseFromRaw(req.Object.Raw, marshaledPod)
}
// mutateContainers mutates the containers by injecting the projected
// service account token volume and environment variables
func (m *podMutator) mutateContainers(containers []corev1.Container, clientID string, tenantID string, skipContainers map[string]struct{}) []corev1.Container {
for i := range containers {
// container is in the skip list
if _, ok := skipContainers[containers[i].Name]; ok {
continue
}
// add environment variables to container if not exists
containers[i] = addEnvironmentVariables(containers[i], clientID, tenantID, m.azureAuthorityHost)
// add the volume mount if not exists
containers[i] = addProjectedTokenVolumeMount(containers[i])
}
return containers
}
func (m *podMutator) injectProxyInitContainer(containers []corev1.Container, proxyPort int32) []corev1.Container {
for _, container := range containers {
if container.Name == ProxyInitContainerName {
return containers
}
}
containers = append(containers, corev1.Container{
Name: ProxyInitContainerName,
Image: m.proxyInitImage,
ImagePullPolicy: corev1.PullIfNotPresent,
SecurityContext: &corev1.SecurityContext{
Capabilities: &corev1.Capabilities{
Add: []corev1.Capability{"NET_ADMIN"},
Drop: []corev1.Capability{"ALL"},
},
Privileged: ptr.To(true),
RunAsNonRoot: ptr.To(false),
RunAsUser: ptr.To[int64](0),
},
Env: []corev1.EnvVar{{
Name: ProxyPortEnvVar,
Value: strconv.FormatInt(int64(proxyPort), 10),
}},
})
return containers
}
func (m *podMutator) injectProxySidecarContainer(containers []corev1.Container, proxyPort int32, restartPolicy *corev1.ContainerRestartPolicy) []corev1.Container {
for _, container := range containers {
if container.Name == ProxySidecarContainerName {
return containers
}
}
logLevel := currentLogLevel() // run the proxy at the same log level as the webhook
containers = append([]corev1.Container{{
Name: ProxySidecarContainerName,
Image: m.proxyImage,
ImagePullPolicy: corev1.PullIfNotPresent,
Args: []string{
fmt.Sprintf("--proxy-port=%d", proxyPort),
fmt.Sprintf("--log-level=%s", logLevel),
},
Ports: []corev1.ContainerPort{{
ContainerPort: proxyPort,
}},
Lifecycle: &corev1.Lifecycle{
PostStart: &corev1.LifecycleHandler{
Exec: &corev1.ExecAction{
Command: []string{
"/proxy",
fmt.Sprintf("--proxy-port=%d", proxyPort),
"--probe",
fmt.Sprintf("--log-level=%s", logLevel),
},
},
},
},
SecurityContext: &corev1.SecurityContext{
AllowPrivilegeEscalation: ptr.To(false),
Capabilities: &corev1.Capabilities{
Drop: []corev1.Capability{"ALL"},
},
Privileged: ptr.To(false),
ReadOnlyRootFilesystem: ptr.To(true),
RunAsNonRoot: ptr.To(true),
},
RestartPolicy: restartPolicy,
}}, containers...)
return containers
}
func shouldInjectProxySidecar(pod *corev1.Pod) bool {
if len(pod.Annotations) == 0 {
return false
}
_, ok := pod.Annotations[InjectProxySidecarAnnotation]
return ok
}
// getSkipContainers gets the list of containers to skip based on the annotation
func getSkipContainers(pod *corev1.Pod) map[string]struct{} {
skipContainers := pod.Annotations[SkipContainersAnnotation]
if len(skipContainers) == 0 {
return nil
}
skipContainersList := strings.Split(skipContainers, ";")
m := make(map[string]struct{})
for _, skipContainer := range skipContainersList {
m[strings.TrimSpace(skipContainer)] = struct{}{}
}
return m
}
// getServiceAccountTokenExpiration returns the expiration seconds for the project service account token volume
// Order of preference:
// 1. annotation in the pod
// 2. annotation in the service account
// default expiration if no annotation specified
func getServiceAccountTokenExpiration(pod *corev1.Pod, sa *corev1.ServiceAccount) (int64, error) {
serviceAccountTokenExpiration := DefaultServiceAccountTokenExpiration
var err error
// check if expiry defined in the pod with annotation
if pod.Annotations != nil && pod.Annotations[ServiceAccountTokenExpiryAnnotation] != "" {
if serviceAccountTokenExpiration, err = strconv.ParseInt(pod.Annotations[ServiceAccountTokenExpiryAnnotation], 10, 64); err != nil {
return 0, err
}
} else if sa.Annotations != nil && sa.Annotations[ServiceAccountTokenExpiryAnnotation] != "" {
if serviceAccountTokenExpiration, err = strconv.ParseInt(sa.Annotations[ServiceAccountTokenExpiryAnnotation], 10, 64); err != nil {
return 0, err
}
}
// validate expiration time
if !validServiceAccountTokenExpiry(serviceAccountTokenExpiration) {
return 0, errors.Errorf("token expiration %d not valid. Expected value to be between 3600 and 86400", serviceAccountTokenExpiration)
}
return serviceAccountTokenExpiration, nil
}
// getProxyPort returns the port for the proxy init container and the proxy sidecar container
func getProxyPort(pod *corev1.Pod) (int32, error) {
if len(pod.Annotations) == 0 {
return DefaultProxySidecarPort, nil
}
proxyPort, ok := pod.Annotations[ProxySidecarPortAnnotation]
if !ok {
return DefaultProxySidecarPort, nil
}
parsed, err := strconv.ParseInt(proxyPort, 10, 32)
if err != nil {
return 0, errors.Wrap(err, "failed to parse proxy sidecar port")
}
return int32(parsed), nil //nolint:gosec // disable G115
}
func validServiceAccountTokenExpiry(tokenExpiry int64) bool {
return tokenExpiry <= MaxServiceAccountTokenExpiration && tokenExpiry >= MinServiceAccountTokenExpiration
}
// getClientID returns the clientID to be configured
func getClientID(sa *corev1.ServiceAccount) string {
return sa.Annotations[ClientIDAnnotation]
}
// getTenantID returns the tenantID to be configured
func getTenantID(sa *corev1.ServiceAccount, c *config.Config) string {
// use tenantID if provided in the annotation
if tenantID, ok := sa.Annotations[TenantIDAnnotation]; ok {
return tenantID
}
// use the cluster tenantID as default value
return c.TenantID
}
// addEnvironmentVariables adds the clientID, tenantID and token file path environment variables needed for SDK
func addEnvironmentVariables(container corev1.Container, clientID, tenantID, azureAuthorityHost string) corev1.Container {
m := make(map[string]string)
for _, env := range container.Env {
m[env.Name] = env.Value
}
// add the clientID env var
if _, ok := m[AzureClientIDEnvVar]; !ok {
container.Env = append(container.Env, corev1.EnvVar{Name: AzureClientIDEnvVar, Value: clientID})
}
// add the tenantID env var
if _, ok := m[AzureTenantIDEnvVar]; !ok {
container.Env = append(container.Env, corev1.EnvVar{Name: AzureTenantIDEnvVar, Value: tenantID})
}
// add the token file env var
if _, ok := m[AzureFederatedTokenFileEnvVar]; !ok {
container.Env = append(container.Env, corev1.EnvVar{Name: AzureFederatedTokenFileEnvVar, Value: filepath.Join(TokenFileMountPath, TokenFilePathName)})
}
// add the azure authority host env var
if _, ok := m[AzureAuthorityHostEnvVar]; !ok {
container.Env = append(container.Env, corev1.EnvVar{Name: AzureAuthorityHostEnvVar, Value: azureAuthorityHost})
}
return container
}
// addProjectedTokenVolumeMount adds the projected token volume mount for the container
func addProjectedTokenVolumeMount(container corev1.Container) corev1.Container {
for _, volume := range container.VolumeMounts {
if volume.Name == TokenFilePathName {
return container
}
}
container.VolumeMounts = append(container.VolumeMounts,
corev1.VolumeMount{
Name: TokenFilePathName,
MountPath: TokenFileMountPath,
ReadOnly: true,
})
return container
}
func addProjectedServiceAccountTokenVolume(pod *corev1.Pod, serviceAccountTokenExpiration int64, audience string) {
// add the projected service account token volume to the pod if not exists
for _, volume := range pod.Spec.Volumes {
if volume.Projected == nil {
continue
}
for _, pvs := range volume.Projected.Sources {
if pvs.ServiceAccountToken == nil {
continue
}
if pvs.ServiceAccountToken.Path == TokenFilePathName {
return
}
}
}
// add the projected service account token volume
// the path for this volume will always be set to "azure-identity-token"
pod.Spec.Volumes = append(
pod.Spec.Volumes,
corev1.Volume{
Name: TokenFilePathName,
VolumeSource: corev1.VolumeSource{
Projected: &corev1.ProjectedVolumeSource{
Sources: []corev1.VolumeProjection{
{
ServiceAccountToken: &corev1.ServiceAccountTokenProjection{
Path: TokenFilePathName,
ExpirationSeconds: &serviceAccountTokenExpiration,
Audience: audience,
},
},
},
},
},
},
)
}
// getAzureAuthorityHost returns the active directory endpoint to use for requesting
// tokens based on the azure environment the webhook is configured with.
func getAzureAuthorityHost(c *config.Config) (string, error) {
var env azure.Environment
var err error
if c.Cloud == "" {
env = azure.PublicCloud
} else {
env, err = azure.EnvironmentFromName(c.Cloud)
}
return env.ActiveDirectoryEndpoint, err
}
func currentLogLevel() string {
for _, level := range []mlog.LogLevel{
// iterate in reverse order
mlog.LevelAll,
mlog.LevelTrace,
mlog.LevelDebug,
mlog.LevelInfo,
mlog.LevelWarning,
} {
if mlog.Enabled(level) {
return string(level)
}
}
return "" // this is unreachable
}
// serverVersionGTE returns true if v is greater than or equal to the server version.
func serverVersionGTE(discoveryClient discovery.ServerVersionInterface, v *utilversion.Version) (bool, error) {
// check if the kubernetes version is supported
serverVersion, err := discoveryClient.ServerVersion()
if err != nil {
return false, err
}
sv, err := utilversion.ParseSemantic(serverVersion.GitVersion)
if err != nil {
return false, err
}
return sv.AtLeast(v), nil
}