pkg/cmd/podidentity/detect.go (366 lines of code) (raw):
package podidentity
import (
"context"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
aadpodv1 "github.com/Azure/aad-pod-identity/pkg/apis/aadpodidentity/v1"
"github.com/pkg/errors"
"github.com/spf13/cobra"
appsv1 "k8s.io/api/apps/v1"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/runtime/serializer/json"
clientgoscheme "k8s.io/client-go/kubernetes/scheme"
"monis.app/mlog"
"sigs.k8s.io/controller-runtime/pkg/client"
"github.com/Azure/azure-workload-identity/pkg/cmd/podidentity/k8s"
"github.com/Azure/azure-workload-identity/pkg/cmd/serviceaccount/options"
"github.com/Azure/azure-workload-identity/pkg/kuberneteshelper"
"github.com/Azure/azure-workload-identity/pkg/webhook"
)
var (
scheme = runtime.NewScheme()
)
const (
imageRepository = "mcr.microsoft.com/oss/azure/workload-identity"
imageTag = "v1.5.0"
proxyInitImageName = "proxy-init"
proxyImageName = "proxy"
proxyInitContainerName = "azwi-proxy-init"
proxyContainerName = "azwi-proxy"
nextStepsLogMessage = `Next steps:
1. Install the Azure Workload Identity Webhook. Refer to https://azure.github.io/azure-workload-identity/docs/installation.html.
2. Create federated identity credential for all identities used in this namespace. Refer to https://azure.github.io/azure-workload-identity/docs/topics/federated-identity-credential.html.
3. Review the generated config files and apply them with 'kubectl apply -f <generated file>'.`
)
var (
proxyInitImage = fmt.Sprintf("%s/%s:%s", imageRepository, proxyInitImageName, imageTag)
proxyImage = fmt.Sprintf("%s/%s:%s", imageRepository, proxyImageName, imageTag)
)
func init() {
_ = clientgoscheme.AddToScheme(scheme)
}
type detectCmd struct {
namespace string
outputDir string
proxyPort int32
serviceAccountTokenExpiration time.Duration
tenantID string
kubeClient client.Client
serializer *json.Serializer
}
func newDetectCmd() *cobra.Command {
detectCmd := &detectCmd{}
cmd := &cobra.Command{
Use: "detect",
Short: "Detect the existing aad-pod-identity configuration",
Long: "This command will detect the existing aad-pod-identity configuration and generate a sample configuration file for migration to workload identity",
PreRunE: func(cmd *cobra.Command, args []string) error {
return detectCmd.prerun()
},
RunE: func(cmd *cobra.Command, args []string) error {
return detectCmd.run()
},
}
f := cmd.Flags()
f.StringVar(&detectCmd.namespace, "namespace", "default", "Namespace to detect the configuration")
f.StringVarP(&detectCmd.outputDir, "output-dir", "o", "", "Output directory to write the configuration files")
f.Int32VarP(&detectCmd.proxyPort, "proxy-port", "p", 8000, "Proxy port to use for the proxy container")
f.DurationVar(&detectCmd.serviceAccountTokenExpiration, options.ServiceAccountTokenExpiration.Flag, time.Duration(webhook.DefaultServiceAccountTokenExpiration)*time.Second, options.ServiceAccountTokenExpiration.Description)
f.StringVar(&detectCmd.tenantID, "tenant-id", "", "Managed identity tenant id. If specified, the tenant id will be set as an annotation on the service account.")
_ = cmd.MarkFlagRequired("output-dir")
return cmd
}
func (dc *detectCmd) prerun() error {
dc.serializer = json.NewSerializerWithOptions(
json.DefaultMetaFactory, scheme, scheme,
json.SerializerOptions{
Yaml: true,
Pretty: true,
Strict: true,
},
)
// TODO(aramase): this validation can be refactored to a common function as it's used in multiple places
minTokenExpirationDuration := time.Duration(webhook.MinServiceAccountTokenExpiration) * time.Second
maxTokenExpirationDuration := time.Duration(webhook.MaxServiceAccountTokenExpiration) * time.Second
if dc.serviceAccountTokenExpiration < minTokenExpirationDuration {
return errors.Errorf("--service-account-token-expiration must be greater than or equal to %s", minTokenExpirationDuration.String())
}
if dc.serviceAccountTokenExpiration > maxTokenExpirationDuration {
return errors.Errorf("--service-account-token-expiration must be less than or equal to %s", maxTokenExpirationDuration.String())
}
var err error
dc.kubeClient, err = kuberneteshelper.GetKubeClient()
if err != nil {
return errors.Wrap(err, "failed to get Kubernetes client")
}
// create output directory if it doesn't exist
if _, err := os.Stat(dc.outputDir); os.IsNotExist(err) {
return os.MkdirAll(dc.outputDir, 0755)
}
return nil
}
func (dc *detectCmd) run() error {
mlog.Debug("detecting aad-pod-identity configuration", "namespace", dc.namespace)
// Implementing force namespaced mode
// 1. Get AzureIdentityBinding in the namespace
// 2. Get AzureIdentity referenced by AzureIdentityBinding and store in map with aadpodidbinding label value as key and AzureIdentity as value
// 3. Get all pods in the namespace that have aadpodidbinding label
// 4. For each pod, check if there is an owner reference (deployment, statefulset, cronjob, job, daemonset, replicaset, replicationcontroller)
// 5. If there is an owner reference, get the owner reference object and add to map with aadpodidbinding label value as key and owner reference as value
// 6. If no owner reference, then assume it's a static pod and add to map with aadpodidbinding label value as key and pod as value
// 7. Loop through the first map and generate new config file for each owner reference and service account
// 1. If owner is using a service account, get the service account and generate a config file with it
// 2. If owner doesn't use service account, generate a new service account yaml file with owner name as service account name
azureIdentityBindings, err := kuberneteshelper.ListAzureIdentityBinding(context.TODO(), dc.kubeClient, dc.namespace)
if err != nil {
return err
}
azureIdentities, err := kuberneteshelper.ListAzureIdentity(context.TODO(), dc.kubeClient, dc.namespace)
if err != nil {
return err
}
azureIdentityMap := make(map[string]aadpodv1.AzureIdentity)
for _, azureIdentity := range azureIdentities {
if azureIdentity.Spec.Type == aadpodv1.UserAssignedMSI {
azureIdentityMap[azureIdentity.Name] = azureIdentity
}
}
labelsToAzureIdentityMap := filterAzureIdentities(azureIdentityBindings, azureIdentityMap)
if count := len(labelsToAzureIdentityMap); count > 0 {
mlog.Debug("found valid aad-pod-identity bindings", "count", count)
} else {
mlog.Debug("did not find any valid aad-pod-identity bindings")
}
ownerReferences := make(map[metav1.OwnerReference]string)
results := make(map[client.Object]string)
for selector, azureIdentity := range labelsToAzureIdentityMap {
mlog.Debug("getting pods", "selector", selector)
pods, err := kuberneteshelper.ListPods(context.TODO(), dc.kubeClient, dc.namespace, map[string]string{aadpodv1.CRDLabelKey: selector})
if err != nil {
return err
}
for i := range pods {
// for pods created by higher level constructors like deployment, statefulset, cronjob, job, daemonset, replicaset, replicationcontroller
// we can get the owner reference with pod.OwnerReferences
ownerFound := false
if len(pods[i].OwnerReferences) > 0 {
for _, ownerReference := range pods[i].OwnerReferences {
// only get the owner reference that was set by the parent controller
if ownerReference.Controller != nil && *ownerReference.Controller {
ownerReferences[ownerReference] = azureIdentity.Spec.ClientID
ownerFound = true
break
}
}
}
// this is a standalone pod, so add it to the results
if !ownerFound {
p := pods[i]
results[&p] = azureIdentity.Spec.ClientID
}
}
}
for ownerReference, clientID := range ownerReferences {
owner, err := dc.getOwner(ownerReference)
if err != nil {
return err
}
results[owner] = clientID
}
// results contains all the resources that we need to generate a config file.
// for each entry in the results map, we will generate a service account yaml file
// and a resource file
for o, clientID := range results {
localObject := k8s.NewLocalObject(o)
sa, err := dc.createServiceAccountFile(localObject.GetServiceAccountName(), localObject.GetName(), clientID)
if err != nil {
return err
}
if err = dc.createResourceFile(localObject, sa); err != nil {
return err
}
mlog.Debug("generated config",
"kind", strings.ToLower(localObject.GetObjectKind().GroupVersionKind().Kind),
"name", localObject.GetName(),
"clientID", clientID,
)
}
if len(results) == 0 {
mlog.Debug("no aad-pod-identity configuration found", "namespace", dc.namespace)
return nil
}
mlog.Info("generated resource and service account files", "directory", dc.outputDir)
mlog.Info(nextStepsLogMessage)
return nil
}
// createServiceAccountFile will create a service account yaml file
// 1. If the resource is using default service account, then a new service account yaml is generated
// with the resource name as service account name
// 2. If the resource is already using a non-default service account, then we modify that service account
// to generate the desired yaml file
//
// The service account yaml will contain the workload identity use label ("azure.workload.identity/use: true")
// and the client-id annotation ("azure.workload.identity/client-id: <client-id from AzureIdentity>")
func (dc *detectCmd) createServiceAccountFile(name, ownerName, clientID string) (*corev1.ServiceAccount, error) {
sa := &corev1.ServiceAccount{}
var err error
if name == "" || name == "default" {
mlog.Debug("generating a new service account instead of using default service account", "owner", ownerName)
// generate a new service account yaml file with owner name as service account name
sa.SetName(ownerName)
sa.SetNamespace(dc.namespace)
} else {
// get service account referenced by the owner
if sa, err = kuberneteshelper.GetServiceAccount(context.TODO(), dc.kubeClient, dc.namespace, name); err != nil {
return nil, err
}
}
// set the annotations for the service account
saAnnotations := make(map[string]string)
if sa.GetAnnotations() != nil {
saAnnotations = sa.GetAnnotations()
}
saAnnotations[webhook.ClientIDAnnotation] = clientID
// Round to the nearest second before converting to a string
saAnnotations[webhook.ServiceAccountTokenExpiryAnnotation] = fmt.Sprintf("%.0f", dc.serviceAccountTokenExpiration.Round(time.Second).Seconds())
if dc.tenantID != "" {
saAnnotations[webhook.TenantIDAnnotation] = dc.tenantID
}
sa.SetAnnotations(saAnnotations)
sa.SetGroupVersionKind(schema.GroupVersionKind{Group: "", Version: "v1", Kind: "ServiceAccount"})
sa.SetResourceVersion("")
fileName := filepath.Join(dc.getServiceAccountFileName(ownerName))
// write the service account yaml file
file, err := os.Create(fileName)
if err != nil {
return nil, err
}
defer file.Close()
return sa, dc.serializer.Encode(sa, file)
}
// createResourceFile will create a resource yaml file
//
// If the resource is using default service account, then the service account name is updated to the resource name
// to match the service account yaml we generated in createServiceAccountFile()
//
// The resource yaml will contain:
// 1. proxy container that is required for migration
// 2. proxy-init init container that sets up iptables rules to redirect IMDS traffic to proxy
func (dc *detectCmd) createResourceFile(localObject k8s.LocalObject, sa *corev1.ServiceAccount) error {
// add the init container to the container list
localObject.SetInitContainers(dc.addProxyInitContainer(localObject.GetInitContainers()))
// add the proxy container to the container list
localObject.SetContainers(dc.addProxyContainer(localObject.GetContainers()))
// set the service account name for the object
localObject.SetServiceAccountName(sa.GetName())
// reset the managed fields to reduce clutter in the output yaml
localObject.SetManagedFields(nil)
// reset the resource version, uid and other metadata to make the yaml file applyable
localObject.SetResourceVersion("")
localObject.SetUID("")
localObject.SetCreationTimestamp(metav1.Time{})
localObject.SetSelfLink("")
localObject.SetGeneration(0)
localObject.ResetStatus()
// set the group version kind explicitly before serializing the object
localObject.SetGVK()
// write the modified object to the output dir
file, err := os.Create(dc.getResourceFileName(localObject))
if err != nil {
return err
}
defer file.Close()
return dc.serializer.Encode(localObject.GetObject(), file)
}
// addProxyInitContainer adds the proxy-init container to the list of init containers
func (dc *detectCmd) addProxyInitContainer(initContainers []corev1.Container) []corev1.Container {
if initContainers == nil {
initContainers = make([]corev1.Container, 0)
}
for _, container := range initContainers {
if strings.HasPrefix(container.Image, fmt.Sprintf("%s/%s", imageRepository, proxyInitImageName)) {
return initContainers
}
}
trueVal := true
// proxy-init needs to be run as root
runAsRoot := int64(0)
// add the init container to the container list
proxyInitContainer := corev1.Container{
Name: proxyInitContainerName,
Image: proxyInitImage,
ImagePullPolicy: corev1.PullIfNotPresent,
SecurityContext: &corev1.SecurityContext{
Privileged: &trueVal,
RunAsUser: &runAsRoot,
Capabilities: &corev1.Capabilities{
Add: []corev1.Capability{"NET_ADMIN"},
Drop: []corev1.Capability{"ALL"},
},
},
Env: []corev1.EnvVar{
{
Name: "PROXY_PORT",
Value: strconv.FormatInt(int64(dc.proxyPort), 10),
},
},
}
initContainers = append(initContainers, proxyInitContainer)
return initContainers
}
// addProxyContainer adds the proxy container to the list of containers
func (dc *detectCmd) addProxyContainer(containers []corev1.Container) []corev1.Container {
if containers == nil {
containers = make([]corev1.Container, 0)
}
for _, container := range containers {
if strings.HasPrefix(container.Image, fmt.Sprintf("%s/%s", imageRepository, proxyImageName)) {
return containers
}
}
logLevel := mlog.LevelInfo // somewhat arbitrary decision
proxyContainer := corev1.Container{
Name: proxyContainerName,
Image: proxyImage,
ImagePullPolicy: corev1.PullIfNotPresent,
Args: []string{
fmt.Sprintf("--proxy-port=%d", dc.proxyPort),
fmt.Sprintf("--log-level=%s", logLevel),
},
Ports: []corev1.ContainerPort{
{
ContainerPort: dc.proxyPort,
},
},
Lifecycle: &corev1.Lifecycle{
PostStart: &corev1.LifecycleHandler{
Exec: &corev1.ExecAction{
Command: []string{
"/proxy",
fmt.Sprintf("--proxy-port=%d", dc.proxyPort),
"--probe",
fmt.Sprintf("--log-level=%s", logLevel),
},
},
},
},
}
containers = append(containers, proxyContainer)
return containers
}
// getOwner returns the owner of the resource
// It makes a recursive call to get the top level owner of the resource
func (dc *detectCmd) getOwner(ownerRef metav1.OwnerReference) (owner client.Object, err error) {
mlog.Debug("getting owner reference", "name", ownerRef.Name)
or, err := dc.getOwnerObject(ownerRef)
if err != nil {
return nil, err
}
owners := or.GetOwnerReferences()
for _, o := range owners {
if o.Controller != nil && *o.Controller {
return dc.getOwner(o)
}
}
return or, nil
}
// getOwnerObject gets the owner object based on the owner reference kind
func (dc *detectCmd) getOwnerObject(ownerRef metav1.OwnerReference) (client.Object, error) {
switch ownerRef.Kind {
case "Deployment":
return kuberneteshelper.GetObject(context.TODO(), dc.kubeClient, dc.namespace, ownerRef.Name, &appsv1.Deployment{})
case "StatefulSet":
return kuberneteshelper.GetObject(context.TODO(), dc.kubeClient, dc.namespace, ownerRef.Name, &appsv1.StatefulSet{})
case "CronJob":
return kuberneteshelper.GetObject(context.TODO(), dc.kubeClient, dc.namespace, ownerRef.Name, &batchv1.CronJob{})
case "Job":
return kuberneteshelper.GetObject(context.TODO(), dc.kubeClient, dc.namespace, ownerRef.Name, &batchv1.Job{})
case "DaemonSet":
return kuberneteshelper.GetObject(context.TODO(), dc.kubeClient, dc.namespace, ownerRef.Name, &appsv1.DaemonSet{})
case "ReplicaSet":
return kuberneteshelper.GetObject(context.TODO(), dc.kubeClient, dc.namespace, ownerRef.Name, &appsv1.ReplicaSet{})
case "ReplicationController":
return kuberneteshelper.GetObject(context.TODO(), dc.kubeClient, dc.namespace, ownerRef.Name, &corev1.ReplicationController{})
default:
return nil, fmt.Errorf("unsupported owner kind: %s", ownerRef.Kind)
}
}
func (dc *detectCmd) getResourceFileName(obj k8s.LocalObject) string {
return filepath.Join(dc.outputDir, obj.GetName()+".yaml")
}
func (dc *detectCmd) getServiceAccountFileName(prefix string) string {
return filepath.Join(dc.outputDir, fmt.Sprintf("%s-serviceaccount.yaml", prefix))
}
// filterAzureIdentities will filter out the Azure identities referenced in AzureIdentityBinding
// the return value is a map of selector used in AzureIdentityBinding to the AzureIdentity
func filterAzureIdentities(bindings []aadpodv1.AzureIdentityBinding, identities map[string]aadpodv1.AzureIdentity) map[string]aadpodv1.AzureIdentity {
labelsToAzureIdentityMap := make(map[string]aadpodv1.AzureIdentity)
for _, binding := range bindings {
if binding.Spec.Selector == "" || binding.Spec.AzureIdentity == "" {
continue
}
// this can happen when multiple AzureIdentityBinding exist in the namespace with same selector
// Multiple AzureIdentityBinding with same selector are configured in AAD Pod Identity to enable a
// a single pod to have access to multiple identities.
// In case of workload identity, we can only annotate with a single client id and there can only
// be one AZURE_CLIENT_ID environment variable. The client id annotation will be configured to the first
// AzureIdentityBinding with the selector. The workload will use the client id of the specific identity
// to get a token and will not really use the AZURE_CLIENT_ID environment variable.
if b, ok := labelsToAzureIdentityMap[binding.Spec.Selector]; ok {
mlog.Debug("multiple AzureIdentityBinding found, using the first one",
"selector", binding.Spec.Selector,
"binding", b.Name,
)
continue
}
if azureIdentity, ok := identities[binding.Spec.AzureIdentity]; ok {
labelsToAzureIdentityMap[binding.Spec.Selector] = azureIdentity
}
}
return labelsToAzureIdentityMap
}