pkg/cmd/serviceaccount/auth/provider.go (219 lines of code) (raw):
package auth
import (
"crypto/tls"
"net/http"
"os"
"path/filepath"
"runtime"
"time"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/google/uuid"
nethttplibrary "github.com/microsoft/kiota-http-go"
msgrapsdkgo "github.com/microsoftgraph/msgraph-sdk-go"
msgraphgocore "github.com/microsoftgraph/msgraph-sdk-go-core"
"github.com/pkg/errors"
"github.com/spf13/pflag"
ini "gopkg.in/ini.v1"
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/client-go/rest"
"k8s.io/client-go/transport"
"monis.app/mlog"
"github.com/Azure/azure-workload-identity/pkg/cloud"
)
const (
clientSecretAuthMethod = "client_secret"
clientCertificateAuthMethod = "client_certificate"
cliAuthMethod = "cli"
)
// Provider is an interface for getting an Azure client
type Provider interface {
AddFlags(f *pflag.FlagSet)
GetAzureClient() cloud.Interface
GetAzureTenantID() string
Validate() error
}
// authArgs is an implementation of the Provider interface
type authArgs struct {
rawAzureEnvironment string
rawSubscriptionID string
subscriptionID uuid.UUID
authMethod string
rawClientID string
tenantID string
clientID uuid.UUID
clientSecret string
certificatePath string
privateKeyPath string
azureClient cloud.Interface
client *http.Client
}
// NewProvider returns a new authArgs
func NewProvider() Provider {
return &authArgs{client: defaultClient()}
}
func defaultClient() *http.Client {
return &http.Client{
Transport: defaultWrap(defaultTransport()),
Timeout: 3 * time.Hour, // make it impossible for requests to hang indefinitely
CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
return http.ErrUseLastResponse // copied from MS SDK
},
}
}
func defaultTransport() *http.Transport {
baseRT := http.DefaultTransport.(*http.Transport).Clone()
baseRT.MaxIdleConnsPerHost = 25 // copied from client-go
baseRT.TLSClientConfig = &tls.Config{
MinVersion: tls.VersionTLS12, // same as client-go and MS SDK
// enable HTTP2
// setting this explicitly is only required in very specific circumstances
// it is simpler to just set it here than to try and determine if we need to
NextProtos: []string{"h2", "http/1.1"},
}
utilnet.SetTransportDefaults(baseRT)
return baseRT
}
func defaultWrap(rt http.RoundTripper) http.RoundTripper {
opts := msgrapsdkgo.GetDefaultClientOptions()
rt = newMiddlewarePipeline(msgraphgocore.GetDefaultMiddlewaresWithOptions(&opts), rt)
rt = transport.NewUserAgentRoundTripper(rest.DefaultKubernetesUserAgent(), rt)
rt = newDelayDebugWrappers(rt)
return rt
}
type delayDebugWrappers struct {
transport http.RoundTripper
}
func newDelayDebugWrappers(rt http.RoundTripper) http.RoundTripper {
return &delayDebugWrappers{transport: rt}
}
func (d *delayDebugWrappers) RoundTrip(req *http.Request) (*http.Response, error) {
rt := d.transport
if mlog.Enabled(mlog.LevelTrace) {
rt = transport.DebugWrappers(rt) // delay wrapping because DebugWrappers makes static checks about log level
}
return rt.RoundTrip(req)
}
// copied from MS SDK so we can inject custom base round tripper
type middlewarePipeline struct {
transport http.RoundTripper
middlewares []nethttplibrary.Middleware
}
func newMiddlewarePipeline(middlewares []nethttplibrary.Middleware, rt http.RoundTripper) http.RoundTripper {
return &middlewarePipeline{
transport: rt,
middlewares: middlewares,
}
}
func (p *middlewarePipeline) Next(req *http.Request, middlewareIndex int) (*http.Response, error) {
if middlewareIndex < len(p.middlewares) {
middleware := p.middlewares[middlewareIndex]
return middleware.Intercept(p, middlewareIndex+1, req)
}
return p.transport.RoundTrip(req)
}
func (p *middlewarePipeline) RoundTrip(req *http.Request) (*http.Response, error) {
return p.Next(req, 0)
}
// AddFlags adds the flags for this package to the specified FlagSet
func (a *authArgs) AddFlags(f *pflag.FlagSet) {
f.StringVar(&a.rawAzureEnvironment, "azure-env", "AzurePublicCloud", "the target Azure cloud")
f.StringVarP(&a.rawSubscriptionID, "subscription-id", "s", "", "azure subscription id (required)")
f.StringVar(&a.authMethod, "auth-method", cliAuthMethod, "auth method to use. Supported values: cli, client_secret, client_certificate")
f.StringVar(&a.rawClientID, "client-id", "", "client id (used with --auth-method=[client_secret|client_certificate])")
f.StringVar(&a.clientSecret, "client-secret", "", "client secret (used with --auth-method=client_secret)")
f.StringVar(&a.certificatePath, "certificate-path", "", "path to client certificate (used with --auth-method=client_certificate)")
f.StringVar(&a.privateKeyPath, "private-key-path", "", "path to private key (used with --auth-method=client_certificate)")
}
// GetAzureClient returns an Azure client
func (a *authArgs) GetAzureClient() cloud.Interface {
return a.azureClient
}
// GetAzureTenantID returns the Azure tenant ID
func (a *authArgs) GetAzureTenantID() string {
return a.tenantID
}
// Validate validates the authArgs
func (a *authArgs) Validate() error {
var err error
if a.authMethod == "" {
return errors.New("--auth-method is a required parameter")
}
if a.authMethod == cliAuthMethod && a.rawClientID != "" && a.clientSecret != "" {
a.authMethod = clientSecretAuthMethod
}
if a.authMethod == clientSecretAuthMethod || a.authMethod == clientCertificateAuthMethod {
if a.clientID, err = uuid.Parse(a.rawClientID); err != nil {
return errors.Wrap(err, "parsing --client-id")
}
if a.authMethod == clientSecretAuthMethod {
if a.clientSecret == "" {
return errors.New(`--client-secret must be specified when --auth-method="client_secret"`)
}
} else if a.authMethod == clientCertificateAuthMethod {
if a.certificatePath == "" || a.privateKeyPath == "" {
return errors.New(`--certificate-path and --private-key-path must be specified when --auth-method="client_certificate"`)
}
}
}
a.subscriptionID, _ = uuid.Parse(a.rawSubscriptionID)
if a.subscriptionID.String() == "00000000-0000-0000-0000-000000000000" {
var subID uuid.UUID
subID, err = getSubFromAzDir(filepath.Join(getHomeDir(), ".azure"))
if err != nil || subID.String() == "00000000-0000-0000-0000-000000000000" {
return errors.New("--subscription-id is required (and must be a valid UUID)")
}
mlog.Info("No subscription provided, using selected subscription from Azure CLI", "subscriptionID", subID.String())
a.subscriptionID = subID
}
env, err := azure.EnvironmentFromName(a.rawAzureEnvironment)
if err != nil {
return errors.Wrap(err, "failed to parse --azure-env as a valid target Azure cloud environment")
}
if a.tenantID, err = cloud.GetTenantID(a.subscriptionID.String(), a.client); err != nil {
return err
}
switch a.authMethod {
case cliAuthMethod:
a.azureClient, err = cloud.NewAzureClientWithCLI(env, a.subscriptionID.String(), a.client)
case clientSecretAuthMethod:
a.azureClient, err = cloud.NewAzureClientWithClientSecret(env, a.subscriptionID.String(), a.clientID.String(), a.clientSecret, a.tenantID, a.client)
case clientCertificateAuthMethod:
a.azureClient, err = cloud.NewAzureClientWithClientCertificateFile(env, a.subscriptionID.String(), a.clientID.String(), a.tenantID, a.certificatePath, a.privateKeyPath, a.client)
default:
err = errors.Errorf("--auth-method: ERROR: method unsupported. method=%q", a.authMethod)
}
return err
}
// getSubFromAzDir returns the subscription ID from the Azure CLI directory
func getSubFromAzDir(root string) (uuid.UUID, error) {
subConfig, err := ini.Load(filepath.Join(root, "clouds.config"))
if err != nil {
return uuid.UUID{}, errors.Wrap(err, "error decoding cloud subscription config")
}
cloudConfig, err := ini.Load(filepath.Join(root, "config"))
if err != nil {
return uuid.UUID{}, errors.Wrap(err, "error decoding cloud config")
}
cloud := getSelectedCloudFromAzConfig(cloudConfig)
return getCloudSubFromAzConfig(cloud, subConfig)
}
// getSelectedCloudFromAzConfig returns the selected cloud from the Azure CLI config
func getSelectedCloudFromAzConfig(f *ini.File) string {
selectedCloud := "AzureCloud"
if cloud, err := f.GetSection("cloud"); err == nil {
if name, err := cloud.GetKey("name"); err == nil {
if s := name.String(); s != "" {
selectedCloud = s
}
}
}
return selectedCloud
}
// getCloudSubFromAzConfig returns the subscription ID from the Azure CLI config
func getCloudSubFromAzConfig(cloud string, f *ini.File) (uuid.UUID, error) {
cfg, err := f.GetSection(cloud)
if err != nil {
return uuid.UUID{}, errors.Wrap(err, "could not find user defined subscription id")
}
sub, err := cfg.GetKey("subscription")
if err != nil {
return uuid.UUID{}, errors.Wrap(err, "error reading subscription id from cloud config")
}
return uuid.Parse(sub.String())
}
// getHomeDir attempts to get the home dir from env
func getHomeDir() string {
if runtime.GOOS == "windows" {
home := os.Getenv("HOMEDRIVE") + os.Getenv("HOMEPATH")
if home == "" {
home = os.Getenv("USERPROFILE")
}
return home
}
return os.Getenv("HOME")
}