internal/controller/acrpullbinding_v1beta2_controller.go (215 lines of code) (raw):
package controller
import (
"context"
"fmt"
"slices"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
azworkloadidentity "github.com/Azure/azure-workload-identity/pkg/webhook"
msiacrpullv1beta2 "github.com/Azure/msi-acrpull/api/v1beta2"
"github.com/Azure/msi-acrpull/pkg/authorizer"
"github.com/go-logr/logr"
authenticationv1 "k8s.io/api/authentication/v1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
corev1client "k8s.io/client-go/kubernetes/typed/core/v1"
ctrl "sigs.k8s.io/controller-runtime"
crclient "sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/handler"
"sigs.k8s.io/controller-runtime/pkg/reconcile"
)
type CoreOpts struct {
Client crclient.Client
Logger logr.Logger
Scheme *runtime.Scheme
now func() time.Time
}
type ServiceAccountTokenMinter func(ctx context.Context, serviceAccountNamespace, serviceAccountName string) (*authenticationv1.TokenRequest, error)
type armTokenFetcher func(ctx context.Context, spec msiacrpullv1beta2.AcrPullBindingSpec, tenantId, clientId, serviceAccountToken string) (azcore.AccessToken, error)
type armAcrTokenExchanger func(ctx context.Context, armToken azcore.AccessToken, spec msiacrpullv1beta2.AcrConfiguration) (azcore.AccessToken, error)
// V1beta2ReconcilerOpts configures the inputs for reconciling v1beta2 pull bindings
type V1beta2ReconcilerOpts struct {
CoreOpts
TTLRotationFraction float64
ServiceAccountClient corev1client.ServiceAccountsGetter
ServiceAccountTokenAudience string
// exposed here to allow unit tests to over-write them
mintToken ServiceAccountTokenMinter
fetchArmToken armTokenFetcher
exchangeArmTokenForAcrToken armAcrTokenExchanger
}
func NewV1beta2Reconciler(opts *V1beta2ReconcilerOpts) *PullBindingReconciler {
if opts.now == nil {
opts.now = time.Now
}
if opts.fetchArmToken == nil {
opts.fetchArmToken = authorizer.ARMTokenForBinding
}
if opts.exchangeArmTokenForAcrToken == nil {
opts.exchangeArmTokenForAcrToken = authorizer.ExchangeACRAccessTokenForSpec
}
if opts.mintToken == nil {
opts.mintToken = func(ctx context.Context, serviceAccountNamespace, serviceAccountName string) (*authenticationv1.TokenRequest, error) {
return opts.ServiceAccountClient.ServiceAccounts(serviceAccountNamespace).CreateToken(ctx, serviceAccountName, &authenticationv1.TokenRequest{
Spec: authenticationv1.TokenRequestSpec{
Audiences: []string{opts.ServiceAccountTokenAudience},
},
}, metav1.CreateOptions{})
}
}
return &PullBindingReconciler{
genericReconciler: &genericReconciler[*msiacrpullv1beta2.AcrPullBinding]{
Client: opts.Client,
Logger: opts.Logger,
Scheme: opts.Scheme,
NewBinding: func() *msiacrpullv1beta2.AcrPullBinding {
return &msiacrpullv1beta2.AcrPullBinding{}
},
AddFinalizer: func(binding *msiacrpullv1beta2.AcrPullBinding, finalizer string) *msiacrpullv1beta2.AcrPullBinding {
updated := binding.DeepCopy()
updated.ObjectMeta.Finalizers = append(updated.ObjectMeta.Finalizers, finalizer)
return updated
},
RemoveFinalizer: func(binding *msiacrpullv1beta2.AcrPullBinding, finalizer string) *msiacrpullv1beta2.AcrPullBinding {
updated := binding.DeepCopy()
updated.ObjectMeta.Finalizers = slices.DeleteFunc(updated.ObjectMeta.Finalizers, func(s string) bool {
return s == finalizer
})
return updated
},
GetServiceAccountName: func(binding *msiacrpullv1beta2.AcrPullBinding) string {
return binding.Spec.ServiceAccountName
},
GetPullSecretName: func(binding *msiacrpullv1beta2.AcrPullBinding) string {
return pullSecretName(binding.ObjectMeta.Name)
},
GetInputsHash: func(binding *msiacrpullv1beta2.AcrPullBinding) string {
return inputsHash(binding.Spec)
},
CreatePullCredential: func(ctx context.Context, binding *msiacrpullv1beta2.AcrPullBinding, serviceAccount *corev1.ServiceAccount) (string, time.Time, error) {
var tenantId, clientId, token string
if binding.Spec.Auth.WorkloadIdentity != nil {
if binding.Spec.Auth.WorkloadIdentity.TenantID != "" {
tenantId = binding.Spec.Auth.WorkloadIdentity.TenantID
clientId = binding.Spec.Auth.WorkloadIdentity.ClientID
} else {
for _, annotation := range []struct { // n.b. we need an array here to be able to test for the error output
value string
into *string
}{
{value: azworkloadidentity.ClientIDAnnotation, into: &clientId},
{value: azworkloadidentity.TenantIDAnnotation, into: &tenantId},
} {
value, set := serviceAccount.Annotations[annotation.value]
if !set {
return "", time.Time{}, fmt.Errorf("service account %s missing %s annotation", serviceAccount.Name, annotation.value)
}
*annotation.into = value
}
}
response, err := opts.mintToken(ctx, serviceAccount.Namespace, serviceAccount.Name)
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to mint service account token: %w", err)
}
token = response.Status.Token
}
armToken, err := opts.fetchArmToken(ctx, binding.Spec, tenantId, clientId, token)
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to retrieve ARM token: %v", err)
}
acrToken, err := opts.exchangeArmTokenForAcrToken(ctx, armToken, binding.Spec.ACR)
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to retrieve ACR token: %v", err)
}
dockerConfig, err := authorizer.CreateACRDockerCfg(binding.Spec.ACR.Server, acrToken)
if err != nil {
return "", time.Time{}, fmt.Errorf("failed to write ACR dockercfg: %v", err)
}
return dockerConfig, acrToken.ExpiresOn, nil
},
UpdateStatusError: func(binding *msiacrpullv1beta2.AcrPullBinding, s string) *msiacrpullv1beta2.AcrPullBinding {
updated := binding.DeepCopy()
updated.Status.Error = s
return updated
},
NeedsRefresh: func(logger logr.Logger, pullSecret *corev1.Secret, now func() time.Time) bool {
return needsRefresh(now, pullSecretRefresh(logger, pullSecret), pullSecretExpiry(logger, pullSecret), opts.TTLRotationFraction)
},
RequeueAfter: func(now func() time.Time) func(binding *msiacrpullv1beta2.AcrPullBinding) time.Duration {
return func(binding *msiacrpullv1beta2.AcrPullBinding) time.Duration {
var requeueAfter time.Duration
if binding.Status.TokenExpirationTime != nil && binding.Status.LastTokenRefreshTime != nil {
refresh, expiry := binding.Status.LastTokenRefreshTime.Time, binding.Status.TokenExpirationTime.Time
requeueAfter = refreshBoundary(refresh, expiry, opts.TTLRotationFraction).Sub(now())
}
return requeueAfter
}
},
NeedsStatusUpdate: func(refresh time.Time, expiry time.Time, binding *msiacrpullv1beta2.AcrPullBinding) bool {
return binding.Status.Error != "" || binding.Status.TokenExpirationTime == nil || !binding.Status.TokenExpirationTime.Equal(&metav1.Time{Time: expiry}) ||
binding.Status.LastTokenRefreshTime == nil || !binding.Status.LastTokenRefreshTime.Equal(&metav1.Time{Time: refresh})
},
UpdateStatus: func(refresh time.Time, expiry time.Time, binding *msiacrpullv1beta2.AcrPullBinding) *msiacrpullv1beta2.AcrPullBinding {
updated := binding.DeepCopy()
updated.Status.TokenExpirationTime = &metav1.Time{Time: expiry}
updated.Status.LastTokenRefreshTime = &metav1.Time{Time: refresh}
updated.Status.Error = ""
return updated
},
now: opts.now,
},
}
}
// PullBindingReconciler reconciles AcrPullBindings
type PullBindingReconciler struct {
*genericReconciler[*msiacrpullv1beta2.AcrPullBinding]
}
func (r *PullBindingReconciler) SetupWithManager(ctx context.Context, mgr ctrl.Manager) error {
if err := mgr.GetFieldIndexer().IndexField(ctx, &msiacrpullv1beta2.AcrPullBinding{}, serviceAccountField, indexV1beta2PullBindingByServiceAccount); err != nil {
return err
}
// n.b. we do not need to add the imagePullSecretsField indexer on service accounts since v1beta1 controller does it
// n.b. we do not need to add the pullBindingField indexer on service accounts since v1beta1 controller does it
return ctrl.NewControllerManagedBy(mgr).
For(&msiacrpullv1beta2.AcrPullBinding{}).
Named("acr-pull-binding-v1beta2").
Watches(&corev1.Secret{}, handler.EnqueueRequestsFromMapFunc(enqueuePullBindingsForPullSecret(mgr))).
Watches(&corev1.ServiceAccount{}, handler.EnqueueRequestsFromMapFunc(enqueueV1beta2PullBindingsForServiceAccount(mgr))).
Complete(r)
}
func indexV1beta2PullBindingByServiceAccount(object crclient.Object) []string {
acrPullBinding, ok := object.(*msiacrpullv1beta2.AcrPullBinding)
if !ok {
return nil
}
return []string{acrPullBinding.Spec.ServiceAccountName}
}
func enqueueV1beta2PullBindingsForServiceAccount(mgr ctrl.Manager) func(ctx context.Context, object crclient.Object) []reconcile.Request {
return func(ctx context.Context, object crclient.Object) []reconcile.Request {
var pullBindings msiacrpullv1beta2.AcrPullBindingList
if err := mgr.GetClient().List(ctx, &pullBindings, crclient.InNamespace(object.GetNamespace()), crclient.MatchingFields{serviceAccountField: object.GetName()}); err != nil {
return nil
}
var requests []reconcile.Request
for _, pullBinding := range pullBindings.Items {
requests = append(requests, reconcile.Request{
NamespacedName: crclient.ObjectKeyFromObject(&pullBinding),
})
}
return requests
}
}
//+kubebuilder:rbac:groups=acrpull.microsoft.com,resources=acrpullbindings,verbs=get;list;watch;create;update;patch;delete
//+kubebuilder:rbac:groups=acrpull.microsoft.com,resources=acrpullbindings/status,verbs=get;update;patch
//+kubebuilder:rbac:groups=acrpull.microsoft.com,resources=acrpullbindings/finalizers,verbs=update
//+kubebuilder:rbac:groups="",resources=secrets,verbs=*
//+kubebuilder:rbac:groups="",resources=serviceaccounts,verbs=get;list;watch;update;patch
//+kubebuilder:rbac:groups="",resources=serviceaccounts/token,verbs=create
// inputsHash captures all the inputs for the pull binding which, if changed, would require a token regeneration
func inputsHash(spec msiacrpullv1beta2.AcrPullBindingSpec) string {
inputs := []byte(spec.ServiceAccountName)
switch {
case spec.Auth.ManagedIdentity != nil:
inputs = append(inputs, []byte("managedIdentity"+spec.Auth.ManagedIdentity.ResourceID+spec.Auth.ManagedIdentity.ClientID)...)
case spec.Auth.WorkloadIdentity != nil:
inputs = append(inputs, []byte("workloadIdentity"+spec.Auth.WorkloadIdentity.ServiceAccountName)...)
}
inputs = append(inputs, []byte(string(spec.ACR.Environment)+spec.ACR.Server+spec.ACR.Scope)...)
return base36sha224(inputs)
}
// refreshBoundary determines when the TTL fraction required for rotation will have passed
func refreshBoundary(refresh, expiry time.Time, ttlRotationFraction float64) time.Time {
ttl := expiry.Sub(refresh)
return refresh.Add(time.Duration(float64(ttl) * ttlRotationFraction))
}
// needsRefresh determines if the TTL fraction required for rotation has passed since the last refresh
func needsRefresh(now func() time.Time, refresh, expiry time.Time, ttlRotationFraction float64) bool {
return now().After(refreshBoundary(refresh, expiry, ttlRotationFraction))
}