extension/azureauthextension/extension.go (189 lines of code) (raw):
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0
package azureauthextension // import "github.com/open-telemetry/opentelemetry-collector-contrib/extension/azureauthextension"
import (
"context"
"crypto"
"crypto/x509"
"errors"
"fmt"
"net/http"
"os"
"strings"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"go.opentelemetry.io/collector/component"
"go.opentelemetry.io/collector/extension"
"go.opentelemetry.io/collector/extension/extensionauth"
"go.uber.org/zap"
)
type authenticator struct {
credential azcore.TokenCredential
logger *zap.Logger
}
var (
_ extension.Extension = (*authenticator)(nil)
_ extensionauth.HTTPClient = (*authenticator)(nil)
_ extensionauth.Server = (*authenticator)(nil)
_ azcore.TokenCredential = (*authenticator)(nil)
)
func newAzureAuthenticator(cfg *Config, logger *zap.Logger) (*authenticator, error) {
var credential azcore.TokenCredential
var err error
failMsg := "failed to create authenticator using"
if cfg.UseDefault {
if credential, err = azidentity.NewDefaultAzureCredential(nil); err != nil {
return nil, fmt.Errorf("%s default identity: %w", failMsg, err)
}
}
if cfg.Workload != nil {
if credential, err = azidentity.NewWorkloadIdentityCredential(&azidentity.WorkloadIdentityCredentialOptions{
ClientID: cfg.Workload.ClientID,
TenantID: cfg.Workload.TenantID,
TokenFilePath: cfg.Workload.FederatedTokenFile,
}); err != nil {
return nil, fmt.Errorf("%s workload identity: %w", failMsg, err)
}
}
if cfg.Managed != nil {
clientID := cfg.Managed.ClientID
var options *azidentity.ManagedIdentityCredentialOptions
if clientID != "" {
options = &azidentity.ManagedIdentityCredentialOptions{
ID: azidentity.ClientID(clientID),
}
}
if credential, err = azidentity.NewManagedIdentityCredential(options); err != nil {
return nil, fmt.Errorf("%s managed identity: %w", failMsg, err)
}
}
if cfg.ServicePrincipal != nil {
if cfg.ServicePrincipal.ClientCertificatePath != "" {
cert, privateKey, errParse := getCertificateAndKey(cfg.ServicePrincipal.ClientCertificatePath)
if errParse != nil {
return nil, fmt.Errorf("%s service principal with certificate: %w", failMsg, errParse)
}
if credential, err = azidentity.NewClientCertificateCredential(
cfg.ServicePrincipal.TenantID,
cfg.ServicePrincipal.ClientID,
[]*x509.Certificate{cert},
privateKey,
nil,
); err != nil {
return nil, fmt.Errorf("%s service principal with certificate: %w", failMsg, err)
}
}
if cfg.ServicePrincipal.ClientSecret != "" {
if credential, err = azidentity.NewClientSecretCredential(
cfg.ServicePrincipal.TenantID,
cfg.ServicePrincipal.ClientID,
cfg.ServicePrincipal.ClientSecret,
nil,
); err != nil {
return nil, fmt.Errorf("%s service principal with secret: %w", failMsg, err)
}
}
}
return &authenticator{
credential: credential,
logger: logger,
}, nil
}
// getCertificateAndKey from the file
func getCertificateAndKey(filename string) (*x509.Certificate, crypto.PrivateKey, error) {
data, err := os.ReadFile(filename)
if err != nil {
return nil, nil, fmt.Errorf("could not read the certificate file: %w", err)
}
certs, privateKey, err := azidentity.ParseCertificates(data, nil)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse certificates: %w", err)
}
return certs[0], privateKey, nil
}
func (a *authenticator) Start(_ context.Context, _ component.Host) error {
return nil
}
func (a *authenticator) Shutdown(_ context.Context) error {
return nil
}
// GetToken returns an access token with a
// valid token for authorization
func (a *authenticator) GetToken(ctx context.Context, options policy.TokenRequestOptions) (azcore.AccessToken, error) {
if a.credential == nil {
// This is not expected, since creating a new authenticator
// instance returns error if the supported credentials fail
// to initialize, and any unexpected ones should be prevented
// from validating the config.
return azcore.AccessToken{}, errors.New("unexpected: credentials were not initialized")
}
return a.credential.GetToken(ctx, options)
}
func getHeaderValue(header string, headers map[string][]string) (string, error) {
value, ok := headers[header]
if !ok {
value, ok = headers[strings.ToLower(header)]
}
if !ok {
return "", fmt.Errorf("missing %q header", header)
}
if len(value) == 0 {
return "", fmt.Errorf("empty %q header", header)
}
return value[0], nil
}
// getTokenForHost will request an access token based on a scope
// computed from the host value. It will return the token value
// or an error if request failed.
func (a *authenticator) getTokenForHost(ctx context.Context, host string) (string, error) {
token, err := a.credential.GetToken(ctx, policy.TokenRequestOptions{
// TODO Cache the tokens
Scopes: []string{
// Example: if host is "management.azure.com", then the scope to get the
// token will be "https://management.azure.com/.default".
// See default scope: https://learn.microsoft.com/en-us/entra/identity-platform/scopes-oidc#the-default-scope.
fmt.Sprintf("https://%s/.default", host),
},
})
if err != nil {
return "", err
}
return token.Token, nil
}
func (a *authenticator) Authenticate(ctx context.Context, headers map[string][]string) (context.Context, error) {
// See request header: https://learn.microsoft.com/en-us/rest/api/azure/#request-header
auth, err := getHeaderValue("Authorization", headers)
if err != nil {
return ctx, err
}
host, err := getHeaderValue("Host", headers)
if err != nil {
return ctx, err
}
authFormat := strings.Split(auth, " ")
if len(authFormat) != 2 {
return ctx, errors.New(`authorization header does not follow expected format "Bearer <Token>"`)
}
if authFormat[0] != "Bearer" {
return ctx, fmt.Errorf(`expected "Bearer" as schema, got %q`, authFormat[0])
}
token, err := a.getTokenForHost(ctx, host)
if err != nil {
return ctx, err
}
if authFormat[1] != token {
return ctx, errors.New("unauthorized: invalid token")
}
return ctx, nil
}
func (a *authenticator) RoundTripper(base http.RoundTripper) (http.RoundTripper, error) {
return &roundTripper{
base: base,
auth: a,
}, nil
}
type roundTripper struct {
base http.RoundTripper
auth *authenticator
}
var _ http.RoundTripper = (*roundTripper)(nil)
func (r *roundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
req := request.Clone(request.Context())
if req.Header == nil {
return nil, errors.New(`request headers are empty, expected to find "Host" header`)
}
host := req.Header.Get("Host")
if host == "" {
return nil, errors.New(`missing "host" header`)
}
token, err := r.auth.getTokenForHost(req.Context(), host)
if err != nil {
return nil, fmt.Errorf("azureauth: failed to get token: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
return r.base.RoundTrip(req)
}