pkg/authorizer/token_exchanger.go (71 lines of code) (raw):
package authorizer
import (
"context"
"encoding/json"
"fmt"
"net/url"
"time"
msiacrpullv1beta2 "github.com/Azure/msi-acrpull/api/v1beta2"
"github.com/golang-jwt/jwt/v5"
"github.com/pkg/errors"
"k8s.io/utils/ptr"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
)
// ExchangeACRAccessToken exchanges an ARM access token to an ACR access token
func ExchangeACRAccessToken(ctx context.Context, armToken azcore.AccessToken, acrFQDN, scope string) (azcore.AccessToken, error) {
endpoint, err := url.Parse(fmt.Sprintf("https://%s", acrFQDN))
if err != nil {
return azcore.AccessToken{}, fmt.Errorf("failed to parse ACR endpoint: %w", err)
}
client, err := azcontainerregistry.NewAuthenticationClient(endpoint.String(), nil)
if err != nil {
return azcore.AccessToken{}, fmt.Errorf("failed to create ACR authentication client: %w", err)
}
refreshResponse, err := client.ExchangeAADAccessTokenForACRRefreshToken(ctx, azcontainerregistry.PostContentSchemaGrantTypeAccessToken, endpoint.Hostname(), &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{
AccessToken: ptr.To(armToken.Token),
})
if err != nil {
return azcore.AccessToken{}, fmt.Errorf("failed to exchange AAD access token for ACR refresh token: %w", err)
}
if refreshResponse.RefreshToken == nil {
return azcore.AccessToken{}, errors.New("got an empty response when exchanging AAD access token for ACR refresh token")
}
// for legacy compatibility, we allow exposing the unscoped refresh token
accessToken := *refreshResponse.RefreshToken
if scope != "" {
accessResponse, err := client.ExchangeACRRefreshTokenForACRAccessToken(ctx, acrFQDN, scope, *refreshResponse.RefreshToken, &azcontainerregistry.AuthenticationClientExchangeACRRefreshTokenForACRAccessTokenOptions{
GrantType: ptr.To(azcontainerregistry.TokenGrantTypeRefreshToken),
})
if err != nil {
return azcore.AccessToken{}, fmt.Errorf("failed to exchange ACR refresh token for ACR access token: %w", err)
}
if accessResponse.AccessToken == nil {
return azcore.AccessToken{}, errors.New("got an empty response when exchanging ACR refresh token for ACR access token")
}
accessToken = *accessResponse.AccessToken
}
token, _, err := jwt.NewParser(jwt.WithoutClaimsValidation()).ParseUnverified(accessToken, jwt.MapClaims{})
if err != nil {
return azcore.AccessToken{}, fmt.Errorf("failed to parse ACR access token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return azcore.AccessToken{}, fmt.Errorf("unexpected claim type from ACR access token")
}
var expiry time.Time
switch exp := claims["exp"].(type) {
case float64:
expiry = time.Unix(int64(exp), 0)
case json.Number:
timestamp, _ := exp.Int64()
expiry = time.Unix(timestamp, 0)
default:
return azcore.AccessToken{}, fmt.Errorf("failed to parse ACR acess token expiration")
}
return azcore.AccessToken{
Token: accessToken,
ExpiresOn: expiry,
}, nil
}
func ExchangeACRAccessTokenForSpec(ctx context.Context, armToken azcore.AccessToken, spec msiacrpullv1beta2.AcrConfiguration) (azcore.AccessToken, error) {
return ExchangeACRAccessToken(ctx, armToken, spec.Server, spec.Scope)
}