msihttpclient/msihttpclient.go (130 lines of code) (raw):
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
package msihttpclient
import (
"bytes"
"crypto/tls"
"fmt"
"github.com/Azure/azure-extension-foundation/errorhelper"
"github.com/Azure/azure-extension-foundation/httputil"
"github.com/Azure/azure-extension-foundation/metadata"
"github.com/Azure/azure-extension-foundation/msi"
"io/ioutil"
"net/http"
"net/url"
)
type msiHttpClient struct {
httpClient httpClientInterface
retryBehavior httputil.RetryBehavior
msi *msi.Msi
msiProvider msi.MsiProvider
metadata *metadata.Metadata
}
var getHttpClientFunc = func() httpClientInterface {
tlsConfig := &tls.Config{
Renegotiation: tls.RenegotiateFreelyAsClient,
}
transport := &http.Transport{TLSClientConfig: tlsConfig}
return &http.Client{Transport: transport}
}
type httpClientInterface interface {
Do(req *http.Request) (*http.Response, error)
}
func NewMsiHttpClient(msiProvider msi.MsiProvider, mdata *metadata.Metadata, retryBehavior httputil.RetryBehavior) httputil.HttpClient {
if retryBehavior == nil {
panic("Retry policy must be specified")
}
if msiProvider == nil {
panic("msiProvider must be specified")
}
httpClient := getHttpClientFunc()
mhc := msiHttpClient{httpClient, retryBehavior, nil, msiProvider, mdata}
mhc.refreshMsiAuthentication()
return &mhc
}
func (client *msiHttpClient) Get(url string, headers map[string]string) (responseCode int, body []byte, err error) {
return client.issueRequest(httputil.OperationGet, url, headers, nil)
}
// Post issues a post request
func (client *msiHttpClient) Post(url string, headers map[string]string, payload []byte) (responseCode int, body []byte, err error) {
return client.issueRequest(httputil.OperationPost, url, headers, bytes.NewBuffer(payload))
}
// Put issues a put request
func (client *msiHttpClient) Put(url string, headers map[string]string, payload []byte) (responseCode int, body []byte, err error) {
return client.issueRequest(httputil.OperationPut, url, headers, bytes.NewBuffer(payload))
}
// Delete issues a delete request
func (client *msiHttpClient) Delete(url string, headers map[string]string, payload []byte) (responseCode int, body []byte, err error) {
return client.issueRequest(httputil.OperationDelete, url, headers, bytes.NewBuffer(payload))
}
func (client *msiHttpClient) addVmIdQueryParameterToUrl(u string) (string, error) {
qParams, err := url.Parse(u)
if err != nil {
return "", err
}
qParams.RawQuery = fmt.Sprintf("%s&vmResourceId=%s", qParams.RawQuery, client.metadata.GetAzureResourceId())
return qParams.String(), nil
}
func (client *msiHttpClient) refreshMsiAuthentication() error {
if client.msi == nil {
myMsi, err := client.msiProvider.GetMsi()
if err != nil {
return err
}
client.msi = &myMsi
} else {
tokenExpired, err := client.msi.IsMsiTokenExpired()
if err != nil {
return err
}
if tokenExpired {
myMsi, err := client.msiProvider.GetMsi()
if err != nil {
return err
}
client.msi = &myMsi
}
}
return nil
}
func (client *msiHttpClient) setMsiAuthenticationHeader(request *http.Request) {
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", client.msi.AccessToken))
}
func (client *msiHttpClient) issueRequest(operation string, url string, headers map[string]string, payload *bytes.Buffer) (int, []byte, error) {
// add query parameter for vmId
modifiedUrl, err := client.addVmIdQueryParameterToUrl(url)
if err != nil {
return -1, nil, errorhelper.AddStackToError(err)
}
request, err := http.NewRequest(operation, modifiedUrl, nil)
if payload != nil && payload.Len() != 0 {
request, err = http.NewRequest(operation, modifiedUrl, payload)
}
// Initialize and refresh msi as required
err = client.refreshMsiAuthentication()
if err != nil {
return -1, nil, errorhelper.AddStackToError(err)
}
// Add authorization if required
client.setMsiAuthenticationHeader(request)
// add headers
for key, value := range headers {
request.Header.Set(key, value)
}
res, err := client.httpClient.Do(request)
if err == nil && httputil.IsSuccessStatusCode(res.StatusCode) {
// no need to retry
} else if err == nil && res != nil {
for i := 1; client.retryBehavior(res.StatusCode, i); i++ {
// Initialize as refresh msi as required
err = client.refreshMsiAuthentication()
if err != nil {
return -1, nil, errorhelper.AddStackToError(err)
}
// Add authorization if required
client.setMsiAuthenticationHeader(request)
res, err = client.httpClient.Do(request)
if err != nil {
break
}
}
}
if err != nil {
return -1, nil, errorhelper.AddStackToError(err)
}
body, err := ioutil.ReadAll(res.Body)
res.Body.Close()
code := res.StatusCode
if err != nil {
return -1, nil, errorhelper.AddStackToError(err)
}
return code, body, nil
}