pkg/proxy/proxy.go (202 lines of code) (raw):

package proxy import ( "context" "encoding/json" "fmt" "io" "net/http" "net/url" "os" "strconv" "strings" "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" "github.com/gorilla/mux" "github.com/pkg/errors" "monis.app/mlog" "github.com/Azure/azure-workload-identity/pkg/version" "github.com/Azure/azure-workload-identity/pkg/webhook" ) const ( // "/metadata" portion is case-insensitive in IMDS tokenPathPrefix = "/{type:(?i:metadata)}/identity/oauth2/token" // #nosec // readyzPathPrefix is the path for readiness probe readyzPathPrefix = "/readyz" // metadataIPAddress is the IP address of the metadata service metadataIPAddress = "169.254.169.254" // metadataPort is the port of the metadata service metadataPort = 80 // localhost is the hostname of the localhost localhost = "localhost" ) var ( userAgent = version.GetUserAgent("proxy") ) type Proxy interface { Run(ctx context.Context) error } type proxy struct { port int tenantID string authorityHost string logger mlog.Logger } // using this from https://github.com/Azure/go-autorest/blob/b3899c1057425994796c92293e931f334af63b4e/autorest/adal/token.go#L1055-L1067 // this struct works with the adal sdks used in clients and azure-cli token requests type token struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` // AAD returns expires_in as a string, ADFS returns it as an int ExpiresIn json.Number `json:"expires_in"` // expires_on can be in two formats, a UTC time stamp or the number of seconds. ExpiresOn string `json:"expires_on"` NotBefore json.Number `json:"not_before"` Resource string `json:"resource"` Type string `json:"token_type"` } // NewProxy returns a proxy instance func NewProxy(port int, logger mlog.Logger) (Proxy, error) { // tenantID is required for fetching a token using client assertions // the mutating webhook will inject the tenantID for the cluster tenantID := os.Getenv(webhook.AzureTenantIDEnvVar) // authorityHost is required for fetching a token using client assertions authorityHost := os.Getenv(webhook.AzureAuthorityHostEnvVar) if tenantID == "" { return nil, errors.Errorf("%s not set", webhook.AzureTenantIDEnvVar) } if authorityHost == "" { return nil, errors.Errorf("%s not set", webhook.AzureAuthorityHostEnvVar) } return &proxy{ port: port, tenantID: tenantID, authorityHost: authorityHost, logger: logger, }, nil } // Run runs the proxy server func (p *proxy) Run(ctx context.Context) error { rtr := mux.NewRouter() rtr.PathPrefix(tokenPathPrefix).HandlerFunc(p.msiHandler) rtr.PathPrefix(readyzPathPrefix).HandlerFunc(p.readyzHandler) rtr.PathPrefix("/").HandlerFunc(p.defaultPathHandler) p.logger.Info("starting the proxy server", "port", p.port, "userAgent", userAgent) server := &http.Server{ Addr: fmt.Sprintf("%s:%d", localhost, p.port), ReadHeaderTimeout: 5 * time.Second, Handler: rtr, } go func() { if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { panic(err) } }() <-ctx.Done() p.logger.Info("shutting down the proxy server") // shutdown the server gracefully with a 5 second timeout shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() return server.Shutdown(shutdownCtx) } func (p *proxy) msiHandler(w http.ResponseWriter, r *http.Request) { p.logger.Info("received token request", "method", r.Method, "uri", r.RequestURI) w.Header().Set("Server", userAgent) clientID, resource := parseTokenRequest(r) // if clientID not found in request, then we default to the AZURE_CLIENT_ID if present. // This is to keep consistent with the current behavior in pod identity v1 where we // default the client id to the one in AzureIdentity. if clientID == "" { p.logger.Info("client_id not found in request, defaulting to AZURE_CLIENT_ID", "method", r.Method, "uri", r.RequestURI) clientID = os.Getenv(webhook.AzureClientIDEnvVar) } if clientID == "" { http.Error(w, "The client_id parameter or AZURE_CLIENT_ID environment variable must be set", http.StatusBadRequest) return } if resource == "" { http.Error(w, "The resource parameter is required.", http.StatusBadRequest) return } // get the token using the msal token, err := doTokenRequest(r.Context(), clientID, resource, p.tenantID, p.authorityHost) if err != nil { p.logger.Error("failed to get token", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } p.logger.Info("successfully acquired token", "method", r.Method, "uri", r.RequestURI) // write the token to the response w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(token); err != nil { p.logger.Error("failed to encode token", err) } } func (p *proxy) defaultPathHandler(w http.ResponseWriter, r *http.Request) { client := &http.Client{} req, err := http.NewRequest(r.Method, r.URL.String(), r.Body) if err != nil || req == nil { p.logger.Error("failed to create new request", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } host := fmt.Sprintf("%s:%d", metadataIPAddress, metadataPort) req.Host = host req.URL.Host = host req.URL.Scheme = "http" if r.Header != nil { copyHeader(req.Header, r.Header) } resp, err := client.Do(req) if err != nil { p.logger.Error("failed executing request", err, "url", req.URL.String()) http.Error(w, err.Error(), http.StatusInternalServerError) return } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { p.logger.Error("failed to read response body", err, "url", req.URL.String()) http.Error(w, err.Error(), http.StatusInternalServerError) } p.logger.Info("received response from IMDS", "method", r.Method, "uri", r.RequestURI, "status", resp.StatusCode) copyHeader(w.Header(), resp.Header) w.WriteHeader(resp.StatusCode) _, _ = w.Write(body) } func (p *proxy) readyzHandler(w http.ResponseWriter, r *http.Request) { p.logger.Info("received readyz request", "method", r.Method, "uri", r.RequestURI) fmt.Fprintf(w, "ok") } func doTokenRequest(ctx context.Context, clientID, resource, tenantID, authorityHost string) (*token, error) { tokenFilePath := os.Getenv(webhook.AzureFederatedTokenFileEnvVar) cred := confidential.NewCredFromAssertionCallback(func(context.Context, confidential.AssertionRequestOptions) (string, error) { return readJWTFromFS(tokenFilePath) }) authority, err := url.JoinPath(authorityHost, tenantID) if err != nil { return nil, errors.Wrap(err, "failed to construct authority URL") } confidentialClientApp, err := confidential.New(authority, clientID, cred) if err != nil { return nil, errors.Wrap(err, "failed to create confidential client app") } result, err := confidentialClientApp.AcquireTokenByCredential(ctx, []string{getScope(resource)}) if err != nil { return nil, errors.Wrap(err, "failed to acquire token") } return &token{ AccessToken: result.AccessToken, Resource: resource, Type: "Bearer", // -10s is to account for current time changes between the calls ExpiresIn: json.Number(strconv.FormatInt(int64(time.Until(result.ExpiresOn)/time.Second)-10, 10)), // There is a difference in parsing between the azure sdks and how azure-cli works // Using the unix time to be consistent with response from IMDS which works with // all the clients. ExpiresOn: strconv.FormatInt(result.ExpiresOn.UTC().Unix(), 10), }, nil } func parseTokenRequest(r *http.Request) (string, string) { var clientID, resource string if r.URL != nil { // Query always return a non-nil map clientID = r.URL.Query().Get("client_id") resource = r.URL.Query().Get("resource") } return clientID, resource } func copyHeader(dst, src http.Header) { for k, vv := range src { for _, v := range vv { dst.Add(k, v) } } } func readJWTFromFS(tokenFilePath string) (string, error) { token, err := os.ReadFile(tokenFilePath) if err != nil { return "", err } return string(token), nil } // ref: https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/issues/747 // For MSAL (v2.0 endpoint) asking an access token for a resource that accepts a v1.0 access token, // Azure AD parses the desired audience from the requested scope by taking everything before the // last slash and using it as the resource identifier. // For example, if the scope is "https://vault.azure.net/.default", the resource identifier is "https://vault.azure.net". // If the scope is "http://database.windows.net//.default", the resource identifier is "http://database.windows.net/". func getScope(resource string) string { if !strings.HasSuffix(resource, "/.default") { resource = resource + "/.default" } return resource }