pkg/download/blobwithmsitoken.go (144 lines of code) (raw):

package download import ( "fmt" "net/http" url2 "net/url" "strings" "github.com/Azure/azure-extension-foundation/httputil" "github.com/Azure/azure-extension-foundation/msi" "github.com/google/uuid" "github.com/pkg/errors" ) const ( xMsVersionHeaderName = "x-ms-version" xMsVersionValue = "2018-03-28" storageResourceName = "https://storage.azure.com/" ) var azureBlobDomains = map[string]interface{}{ // golang doesn't have builtin hash sets, so this is a workaround for that ".blob.core.": nil, ".blob.azurestack.": nil, } type blobWithMsiToken struct { url string msiProvider MsiProvider } type MsiProvider func() (msi.Msi, error) type MsiDownloader interface { GetMsiProvider(blobUri string) MsiProvider GetMsiProviderByClientId(blobUri, clientId string) MsiProvider GetMsiProviderByObjectId(blobUri, objectId string) MsiProvider } type ProdMsiDownloader struct{} type MockMsiDownloader struct{} // Used only for test var MockReturnErrorForMockMsiDownloader = false // Used only for test func (self *blobWithMsiToken) GetRequest() (*http.Request, error) { msi, err := self.msiProvider() if err != nil { return nil, err } if msi.AccessToken == "" { return nil, errors.New("MSI token is empty") } request, err := http.NewRequest(http.MethodGet, self.url, nil) if err != nil { return nil, err } request.Header.Set(xMsClientRequestIdHeaderName, uuid.New().String()) if IsAzureStorageBlobUri(self.url) { request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", msi.AccessToken)) request.Header.Set(xMsVersionHeaderName, xMsVersionValue) } return request, nil } func NewBlobWithMsiDownload(url string, msiProvider MsiProvider) Downloader { return &blobWithMsiToken{url, msiProvider} } // Uses system identity to get Msi token func (prodMsiDownloader ProdMsiDownloader) GetMsiProvider(blobUri string) MsiProvider { msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior)) return func() (msi.Msi, error) { msi, err := msiProvider.GetMsiForResource(GetResourceNameFromBlobUri(blobUri)) if err != nil { return msi, errors.Wrapf(err, "Unable to get managed identity. "+ "Please make sure that system assigned managed identity is enabled on the VM "+ "or user assigned identity is added to the system.") } return msi, nil } } // Mock implementation of GetMsiProvider func (mockMsiDownloader MockMsiDownloader) GetMsiProvider(blobUri string) MsiProvider { return func() (msi.Msi, error) { mockMsi := msi.Msi{ AccessToken: "uwsihdiuhiuasdfui*(*(&90790asofhdioas", Resource: "Msi by System Identity for blob " + blobUri, } if MockReturnErrorForMockMsiDownloader { return mockMsi, errors.New("Error getting msi") } else { return mockMsi, nil } } } // Get Msi token by clientId func (prodMsiDownloader ProdMsiDownloader) GetMsiProviderByClientId(blobUri, clientId string) MsiProvider { msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior)) return func() (msi.Msi, error) { msi, err := msiProvider.GetMsiUsingClientId(clientId, GetResourceNameFromBlobUri(blobUri)) if err != nil { return msi, errors.Wrapf(err, "Unable to get managed identity with client id %s. "+ "Please make sure that the user assigned managed identity is added to the VM ", clientId) } return msi, nil } } // Mock implementation of GetMsiProviderByClientId func (mockMsiDownloader MockMsiDownloader) GetMsiProviderByClientId(blobUri string, clientId string) MsiProvider { return func() (msi.Msi, error) { mockMsi := msi.Msi{ AccessToken: "uwsihdiuhiuasdfui*(*(&90790asofhdioas", Resource: "Msi by clientId for blob " + blobUri, } if MockReturnErrorForMockMsiDownloader { return mockMsi, errors.New("Error getting msi") } else { return mockMsi, nil } } } // Get Msi token by objectId func (prodMsiDownloader ProdMsiDownloader) GetMsiProviderByObjectId(blobUri, objectId string) MsiProvider { msiProvider := msi.NewMsiProvider(httputil.NewSecureHttpClient(httputil.DefaultRetryBehavior)) return func() (msi.Msi, error) { msi, err := msiProvider.GetMsiUsingObjectId(objectId, GetResourceNameFromBlobUri(blobUri)) if err != nil { return msi, errors.Wrapf(err, "Unable to get managed identity with object id %s. "+ "Please make sure that the user assigned managed identity is added to the VM ", objectId) } return msi, nil } } // Mock implementation of GetMsiProviderByObjectId func (mockMsiDownloader MockMsiDownloader) GetMsiProviderByObjectId(blobUri, objectId string) MsiProvider { return func() (msi.Msi, error) { mockMsi := msi.Msi{ AccessToken: "uwsihdiuhiuasdfui*(*(&90790asofhdioas", Resource: "Msi by objectId for blob " + blobUri, } if MockReturnErrorForMockMsiDownloader { return mockMsi, errors.New("Error getting msi") } else { return mockMsi, nil } } } func GetResourceNameFromBlobUri(uri string) string { // TODO: update this function as sovereign cloud blob resource strings become available // resource string for getting MSI for azure storage is still https://storage.azure.com/ for sovereign regions but it is expected to change return storageResourceName } func IsAzureStorageBlobUri(url string) bool { parsedUrl, err := url2.Parse(url) if err != nil { return false } host := parsedUrl.Host for validBlobDomain := range azureBlobDomains { if strings.Contains(host, validBlobDomain) { return true } } return false }