msi/msi.go (112 lines of code) (raw):

// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. package msi import ( "encoding/json" "fmt" "net/url" "os" "strconv" "time" "github.com/Azure/azure-extension-foundation/errorhelper" "github.com/Azure/azure-extension-foundation/httputil" ) const ( metadataIdentityURL = "http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01" clientIdQueryParam = "client_id" objectIdQueryParam = "object_id" resourceQueryParam = "resource" armResourceUri = "https://management.core.windows.net/" identityEnvVar = "IDENTITY_ENDPOINT" ) type Msi struct { AccessToken string `json:"access_token"` ClientID string `json:"client_id"` ExpiresIn string `json:"expires_in"` ExpiresOn string `json:"expires_on"` // expressed in seconds from epoch ExtExpiresIn string `json:"ext_expires_in"` NotBefore string `json:"not_before"` Resource string `json:"resource"` TokenType string `json:"token_type"` } type MsiProvider interface { GetMsi() (Msi, error) GetMsiForResource(targetResource string) (Msi, error) GetMsiUsingClientId(clientId string, targetResource string) (Msi, error) GetMsiUsingObjectId(objectId string, targetResource string) (Msi, error) } type provider struct { httpClient httputil.HttpClient } func NewMsiProvider(client httputil.HttpClient) provider { return provider{httpClient: client} } func (p *provider) getMsiHelper(queryParams map[string]string) (*Msi, error) { var msi = Msi{} requestUrl, err := url.Parse(GetMetadataIdentityURL()) if err != nil { return &msi, err } urlQuery := requestUrl.Query() for key, value := range queryParams { urlQuery.Add(key, value) } requestUrl.RawQuery = urlQuery.Encode() code, body, err := p.httpClient.Get(requestUrl.String(), map[string]string{"Metadata": "true"}) if err != nil { return &msi, err } if code != 200 { return &msi, errorhelper.AddStackToError(fmt.Errorf("unable to get msi, metadata service response code %v", code)) } err = json.Unmarshal(body, &msi) if err != nil { return &msi, errorhelper.AddStackToError(fmt.Errorf("unable to deserialize metadata service response")) } return &msi, nil } func (p *provider) GetMsi() (Msi, error) { msi, err := p.getMsiHelper(map[string]string{resourceQueryParam: armResourceUri}) return *msi, err } func (p *provider) GetMsiForResource(targetResource string) (Msi, error) { msi, err := p.getMsiHelper(map[string]string{resourceQueryParam: targetResource}) return *msi, err } func (p *provider) GetMsiUsingClientId(clientId string, targetResource string) (Msi, error) { msi, err := p.getMsiHelper(map[string]string{clientIdQueryParam: clientId, resourceQueryParam: targetResource}) return *msi, err } func (p *provider) GetMsiUsingObjectId(objectId string, targetResource string) (Msi, error) { msi, err := p.getMsiHelper(map[string]string{objectIdQueryParam: objectId, resourceQueryParam: targetResource}) return *msi, err } // check expiry of MSI token based on time func (msi *Msi) IsMsiTokenExpired() (bool, error) { expiryTime, err := msi.GetExpiryTime() if err != nil { return false, err } // Consider token expired 2 minutes before expiry time expiryTime = expiryTime.Add(-2 * time.Minute) if time.Now().After(expiryTime) { return true, nil } else { return false, nil } } func (msi *Msi) GetExpiryTime() (time.Time, error) { expiryTimeInSeconds, err := strconv.ParseInt(msi.ExpiresOn, 10, 64) if err != nil { return time.Unix(0, 0), err } expiryTime := time.Unix(expiryTimeInSeconds, 0) return expiryTime, nil } func (msi *Msi) GetJson() (string, error) { jsonBytes, err := json.Marshal(msi) return string(jsonBytes[:]), err } func GetMetadataIdentityURL() string { envMetadataIdentityURL := os.Getenv(identityEnvVar) if envMetadataIdentityURL != "" { return envMetadataIdentityURL } return metadataIdentityURL }